GigaGPT is Cerebras’ implementation of Andrei Karpathy’s nanoGPT – the simplest and most compact code base to train and fine-tune GPT models. Whereas nanoGPT can train models in the 100M parameter range, gigaGPT trains models well over 100B parameters. We do this without introducing additional code or relying on third party frameworks – the entire repo is just 565 lines of code. Instead gigaGPT utilizes the large memory and compute capacity of Cerebras hardware to enable large scale training on vanilla torch.nn code. With no modifications, gigaGPT supports long context lengths and works with a variety of optimizers.

Why gigaGPT

While the transformer architecture is simple, training a large transformer on a large number of GPUs is hard. Beyond a few billion parameters, vanilla GPT models run out of memory on even the latest GPUs. Training larger models requires breaking up models into smaller pieces, distributing them to multiple GPUs, coordinating the workload among the workers, and assembling the results. This is typically done via LLM scaling frameworks such as Megatron, DeepSpeed, NeoX, Fairscale, and Mosaic Foundry. Though powerful, these frameworks introduce significant complexity.

A small model such as nanoGPT requires just 639 lines of code to implement. To implement a 20B parameter model using Nvidia Megatron model requires 20,507 lines of code – a 32x increase in complexity. Even though this code doesn’t need to be written from scratch, implementing, debugging, and maintaining such a project is a major undertaking. Many ML teams struggle to get these frameworks to work and few manage to converge models with decent utilization. gigaGPT shows that on Cerebras hardware you can have the best of both worlds – a compact, hackable codebase and the ability to train GPT-3 sized models with long context.

The Models

gigaGPT implements the basic GPT-2 architecture in a way that matches nanoGPT. In particular, we use learned position embeddings, standard attention, and biases throughout the model. These choices were made primarily to stick closely to nanoGPT and can easily be changed. We validate gigaGPT by training four models with 111M, 13B, 70B, and 175B parameters. All the models tested use the OpenWebText dataset using the GPT-2 tokenizer with preprocessing code taken from nanoGPT. As the goal of this project was to create a clean, performant, and usable code base for others to use rather than to train state of the art models ourselves, our validation was geared towards functional correctness rather than convergence, downstream performance, or other similar metrics. To the best of our knowledge, this is the only GPT model that scales from millions to hundreds of billions of parameters without specialized parallelization techniques.

gigaGPT-111M

The 111M config was inspired by Cerebras-GPT. It uses the same model dimensions, learning rate, batch size, and training schedule and differ from Cerebras-GPT primarily in their dataset. The loss trend looks good and roughly matches what we observed for the corresponding Cerebras-GPT configuration despite the different choice of dataset. While we were writing gigaGPT, we performed thorough side-by-side numerical checks against a trusted reference code base, so seeing this loss trend isn’t surprising, but it’s always nice to get the confirmation of a converged model.

As with the 111M configuration, the 13B configuration also closely matches the model of that size from Cerebras-GPT in model dimensions, learning rate, batch size, and training schedule. Over the first hundred steps we see a few minor loss spikes, but nothing the model can’t recover from. A 13B model training run takes a substantial amount of compute so we stop after about 100 training steps. At this point we’re well past the scale that nanoGPT can accommodate, but gigaGPT handles it without a problem.

The 70B configuration is loosely inspired by Llama-2 70B. It takes its model dimensions and from that work and trains for the same 2T parameters as Llama-2 70B with a similar batch size. The learning rate used is slightly more conservative, as previous groups have noticed more instabilities when using learned position embeddings compared to RoPE . Since our goal wasn’t to converge the model, we didn’t perform rigorous hyperparameter selection. Even so, loss appears to be decreasingly steadily and training is fairly stable. At 70B gigaGPT continues to show great performance and scaling. Even though we grew the model size by two orders of magnitude, and made no effort to optimize for throughput, utilization remains equal or better than previous runs; we just wrote it, ran it, and immediately saw fast results. This model code is also trivial to scale out despite having been written as a single monolithic model – gigaGPT-70B works from 1 to 64 systems such as the Condor Galaxy 1 by just editing a single flag in the configuration file.

gigaGPT-175B

After validating the 70B model, we were curious to further probe the limits of model scale that gigaGPT could accommodate. We changed the model dimensions of the 70B config to match what was reported in the original GPT-3 paper, scaled the learning rate and initialization using common heuristics, and launched a run. The few steps we trained weren’t of much interest from a convergence perspective, but the model ran without any issues at the same utilization as the 70B model. What’s most notable here for ML practitioners is that going large does not cause Cerebras hardware to run out of memory. Based on the results we believe gigaGPT can scale to models in excess of 1 trillion parameters.

How gigaGPT Works

The gigaGPT model does not use any sharding or pipelining techniques because it fits entirely into the system memory of Cerebras hardware. To briefly recap: Cerebras Wafer Scale Clusters are comprised of 1 to 192 Cerebras CS-2 systems supported by accompanying CPU server nodes that store parameters (MemoryX), data and an interconnect (SwarmX). Unlike GPU based clusters, compute and memory are de-coupled. The entire model – upward of trillions of parameters – is stored in a dedicated MemoryX appliance. The model weights are streamed to the wafer one layer at a time during training. By storing models in large, unified memory systems, we obviate the need to break models apart using complex frameworks. All model training from 1 to 192 systems is done using standard mini batches aka data parallelism.

gigaGPT is comprised mainly of model.py and train.py. Looking more closely at the model code, we see that it looks quite similar to concise GPT implementations written for GPUs. It is built from `torch.nn` components without use of any fancy external libraries like xFormers or DeepSpeed. In fact, the model code is quite boring. Compared to nanoGPT, we rewrote the attention layer to use primitive torch ops instead of fused attention algorithms and to expose the attention mask as an argument to the attention layer to increase flexibility. Other than that, the differences compared to nanoGPT are mainly cosmetic.

Likewise, the main training loop is also very simple. It utilizes `cerebras_pytorch` (a custom PyTorch wrapper specialized for CS System execution) as a drop-in replacement for calls to standard torch APIs. There are only a couple of parts of this file that will look new to someone who is already familiar with PyTorch, in particular the use of a few decorators to section off different sections of functionality as well as the use of a `cerebras_pytorch.backend` scope for model creation. Overall the code is easy to understand, familiar looking to PyTorch users, and easy to modify and customize.

Training a large model across a huge cluster requires careful orchestration of multiple independent jobs running across heterogeneous hardware, on its face a very daunting challenge. The `cerebras_pytorch` package is the solution to this problem and is the crux of what allows the gigaGPT code to be so simple. `cerebras_pytorch` wraps some PyTorch functionality that users will already be familiar with and adds a small number of new APIs that help simplify the distributed computing needs of the problem. In this section we will walk through the code of gigaGPT’s `train.py` to better illuminate these APIs.

We’ll start with the end of the `main` function which contains the high-level logic for the training loop and work backwards through the definitions of each of the components it uses.

for step, batch in enumerate(executor, start=global_step + 1):
        if step > config.num_steps:
            break
        loss = training_step(batch)
        log_loss(loss, step)
        save_checkpoint(step)

This is a simple starting point, so we don’t need to spend too much time here. Let’s first take a closer look at the `executor` used above.

    dataloader = cstorch.utils.data.DataLoader(
        get_dataloader,
        data_path,
        config.sequence_length,
        config.batch_size,
        config.seed,
    )
    executor = cstorch.utils.data.DataExecutor(
        dataloader,
        num_steps=config.num_steps - global_step,
        checkpoint_steps=config.checkpoint_steps,
        cs_config=cs_config,
        writer=writer,
    )

For Cerebras system runs, there are dedicated CPU nodes that handle loading data and feeding it into the model. `cstorch.utils.data.DataLoader` handles the job of defining the dataloader instance that will run on each of these worker nodes. It takes a function that returns a dataloader instance for ease of setting up independent properly sharded dataloaders on each worker node. This `cstorch.utils.data.DataLoader` object is then fed into a `DataExecutor`, which is responsible for the top-level coordination of all the different independent tasks required for the run.

Next, let’s look more closely at the components involved in defining a single training step. Internally, `cerebras_pytorch` uses PyTorch LTC to trace the compute graph associated with the training job and maps this compute graph down to operations that can be run on the Cerebras Wafer Scale Engine (WSE). Accordingly, the first step towards defining the training logic is to create the model instance in a way that enables it to be traced later. This is accomplished by the following code at the top of `train.py::main`:

    backend = cstorch.backend(config.backend, use_cs_grad_accum=True)
    …

    with backend.device:
        model = GPTModel(model_config)

    compiled_model = cstorch.compile(model, backend)

With this model definition along with an optimizer and learning rate scheduler that are created with APIs that directly mirror PyTorch APIs, we are now ready to define the logic of a basic training step.

    @cstorch.trace
    def training_step(batch):
        input_ids, labels = batch
        loss = compiled_model(input_ids, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(all_params), config.max_gradient_norm)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        return loss

The body of this function is fairly standard training code. The only interesting part is the `@cstorch.trace` decorator. This signals to the framework that the code in the function is intended to be traced and run on CS system. No tensors can be eagerly executed within this scope, which means the code in here can’t include any logging functionality or python conditionals. For that, we need a different decorator:

    @cstorch.step_closure
    def log_loss(loss, step):
        rate = executor.profiler.rate()
        global_rate = executor.profiler.global_rate()

        logger.info(
            f"| Step={step}, "
            f"Loss={loss.item():.5f}, "
            f"Rate={rate:.2f} samples/sec, "
            f"GlobalRate={global_rate:.2f} samples/sec"
        )
        writer.add_scalar("loss", loss.item(), step)
        writer.add_scalar("samples_per_second", global_rate, step)

This logging code requires eagerly executing tensor values and doesn’t need to run on the WSE, so we wrap it in a `@cstorch.step_closure` decorator. Checkpointing code works similarly, except that we want to make sure that it only runs every `checkpoint_steps` steps for whatever value of `checkpoint_steps` we pass into the `DataExecutor` above. For this we have the `@cstorch.checkpoint_closure` decorator. Functions wrapped in this decorator can be called at any time but will only execute if the current step is a checkpoint step.

@cstorch.checkpoint_closure
    def save_checkpoint(step):
        checkpoint_path = out_dir.joinpath(f"checkpoint_{step}.mdl")
        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "global_step": step,
            "model_config": asdict(model_config),
        }
        cstorch.save(state_dict, checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path}")

With that, we’re done defining the functions we used in the main training loop we started with. After a couple of lines of code to fill in the gaps of checkpoint loading, config handling, etc, we end up with a `train.py` which in just 156 lines of code is able to seamlessly coordinate training jobs across huge distributed clusters.

Conclusion

We’d like to thank Andrei Karpathy for creating nanoGPT and inspiring this work. We believe simple, hackable, and performant code is essential for advancing machine learning research. By combining the benefits of a compact code base and the ability to train GPT-3 scale models, gigaGPT on Cerebras hardware represents a significant leap towards more accessible, scalable, and efficient AI model training. If you’re working with the Cerebras platform, we encourage you to experiment with gigaGPT and share your feedback.

Find at: https://github.com/Cerebras/gigaGPT
Author: William Marshall
Contributors: James Wang, Gavia Gray