The Cerebras Stack | PyTorch Integration | Cerebras Backend | Cerebras Runtime | User Experience | Learn More

PyTorch has become the leading machine learning (ML) framework because it is easy to use, easy to debug and because it can express a wide range of ideas. It also has a comprehensive and rapidly growing ecosystem.

In the 1.2 version of the Cerebras Software Platform (CSoft), we vastly expand our support for PyTorch. With this in mind, I thought it would be interesting to share the Cerebras approach to supporting PyTorch models. For an accelerator to support an ML framework, it must

  1. Adhere to the core design principles of the framework wherever possible,
  2. It must seek to minimize surprises and complexity for the user and
  3. It must integrate seamlessly to the framework’s existing ecosystem.

There are many ways to support PyTorch, each has its pros and cons. In this blog, we will discuss the Cerebras implementation, our future directions and provide a high-level overview of the Cerebras ML Backend.

The Cerebras Stack

The key difference between our CS-2 system and conventional processors is the sheer scale of our solution. At the heart of our system lies the Wafer-Scale Engine (WSE-2), which is the world largest and fastest AI processor. The WSE-2 contains an astonishing 850,000 AI optimized compute cores and more than 40 Gigabytes of high performance on chip memory. The sheer scale of the computational resources on the WSE-2 drove our PyTorch implementation.

WSE-2 A100 Cerebras Advantage
Chip Size 46,225 mm2 826 mm2 56 X
Cores 850,000 6912 + 432 123X
On-chip memory 40 Gigabytes 40 Megabytes 1,000 X
Memory bandwidth 20 Petabytes/sec 1.6 Terabytes/sec 12,733 X
Fabric bandwidth 220 Petabits/sec 4.8 Terabits/sec 45,833 X

Pytorch for conventional processors must work around the weaknesses of those processors, including limited on-chip memory and limited memory bandwidth. This problem is exacerbated exponentially when models don’t fit on a single processor. The efficiency of algorithms tends to decrease as they are split, or “sharded” across many chips, because moving data across those chips is much slower than moving data on a single chip. Writing code for massively parallel systems is difficult and time consuming.

The Cerebras hardware avoids these issues. We wrote our PyTorch implementation to take full advantage of the wafer scale engine’s size and enormous compute resources.

The Cerebras CS-2 is easy to program and uses a straightforward model. A host system loads programs on to the CS-2 system for execution. The host is capable of targeting either an actual CS-2 or a software simulator for testing and debugging. The Cerebras compiler running on a host takes the PyTorch representation of a model and converts it to the optimum low-level instructions to run on the CS-2 system. This is then sent to the CS-2 as a config file and run on the system. The execution model supports different size models in different ways.

  • Small models such as BERTBASE can be replicated several times across the WSE and accelerated by leveraging data parallelism (Multi-Replica)
  • Medium-sized models such as BERTLARGE, Transformers and GPT-2 that can fit entirely in the WSE and be trained using with input data streamed to and from the wafer (Pipelined execution mode)
  • Extremely large scale models such as GPT3 and GPT-J, that do not fit on the WSE, are trained by storing activations on the wafer instead, and streaming the weights layer-by-layer to and from the WSE (Weight streaming execution mode [7])

In the sections below we will take a closer at the process by which we convert, or “lower” the full PyTorch model to our intermediate representation (IR), which we rather creatively call the Cerebras Intermediate Representation – High (CIRH).

PyTorch Integration

Depending on the user’s goals—inference versus training, support for part of the graph or the entire graph– there are many ways to integrate PyTorch. In this blog we focus on integration for training. For inference, we can use TorchScript or export to the open format ONNX but that will be the topic of another blog.

The first step we take is to extract the full graph. Once captured it should be compiled, and then run it on the Cerebras hardware.

The “lazy tensor” approach [1] provides us with a way of capturing the full graph of most models and supporting them at high performance. There are of course some limitations to this approach such as flattening control flows and loops as well as dealing with dynamic shapes. These and others will be addressed in the future. But for now, most models can be supported at very high performance via the lazy tensor approach.

Despite a collection of known weaknesses in PyTorch XLA, we started there because it is a mature solution and is being used in production on TPU-based systems [6]. We used a modified version of the XLA backend rather than implementing our own lazy tensor backend.

And then, in order to minimize our dependency on XLA, we wrote an XLA custom call [8] to wrap each ATen operator, then import it to MLIR in the XLA input language, HLO [9]. We then lower each wrapped operator to the ATen dialect (by simply unwrapping it), then, in turn, lower the ATen dialect to CIRH.

Put simply, our approach is to piggyback on the XLA lazy tensor backend for PyTorch, and modify it to pass through the actual ATen graph without breaking it into more primitive operators. This enables us to switch easily to PyTorch lazy tensor [3] and torch-MLIR [2] when they are ready.

Supporting ATen directly also helping us in supporting other entry points such as TorchScript. We are also exploring other approaches such as AOTAutograd which was added as part of functorch, which is a set of JAX-like composable transforms for PyTorch.[10]

Our long term goal is to work with the community on the torch-MLIR project [2], which aims to bridge the PyTorch and MLIR ecosystems using Aten-based lazy tensors [3]. Our IR, called CIRH, is MLIR based and very close to ATen from an abstraction point of view, so we prefer mapping ATen directly to CIRH.

Cerebras Backend

Now that we have shown how we integrate with PyTorch, we can examine the Cerebras ML backend, which is common between PyTorch and Tensorflow. It takes our CIRH dialect as input and performs the following optimization and compilation stages to produce the final kernel graph, ready to be executed on the CS-2 system.


  1. Graph rewrite: In this phase we do a lot of optimization passes at CIRH level such as operator fusion and constant folding. The fused operators are replaced with our high-performance custom kernels. Rewriting the graph is complex, especially for training, so we use PDL [5] which is a MLIR dialect for graph rewrite, by actively contributing larger features to PDL
  2. If there are “left-over” operators, i.e. operators that are not matched to kernels, we lower them to our low-level IR, which is called LAIR (Linear Algebra IR) and then proceed to the compiler stage
  3. Automatic kernel generation: The compiler takes any left-over operators as input, runs its optimization passes such as fusion, constant folding, common subexpression elimination, and then uses polyhedral techniques to auto-generate kernels that implement the semantics of the leftover ops
  4. Now we have a complete graph that can be placed in the wafer


Cerebras Runtime

Once the graph is compiled and sent to the wafer, we are ready to run the training workload. As shown below, the CS-2 only runs the neural network portion of the workload. Model compilation and training loop management are done by an external server called the “chief worker”. Input data preprocessing and streaming is done by multiple “input workers”. In order to keep the WSE busy, we multiplex the streams of data from each worker.

Metrics like loss are streamed back from the CS-2 to the chief worker. Summary operators can be used to extract any value from the fabric. Model states are saved as standard checkpoints and outputs are saved in standard formats that support TensorBoard for visualization.

User Experience

A good API design should not be in the way. PyTorch has a well-designed, well-liked API , so our goal is to add the minimum in order to extract all the information needed without additional constraints on changes to code structure. Cerebras PyTorch APIs are under the cerebras_torch namespace and the actual change to the user code is minimum.

Below are the needed additions:

1. Specify the IP address of the Cerebras System

import cerebras_torch as cbtorch

2. Wrap your PyTorch custom module

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

model = cbtorch.module(Model())

3. Wrap your PyTorch dataloader

train_dataloader = cbtorch.dataloader(get_train_dataloader())

Note that the input to cbtorch.dataloader(...) is the standard PyTorch data loader.

4. Wrap your training loop with cbtorch.Session

with cbtorch.Session(model, optimizer, criterion, mode="train",) as session:
    for epoch in range(num_epochs):
        for step, batch in enumerate(train_dataloader):
            global_step += 1

The code within the Session is the standard PyTorch train and eval loop.

Emad Barsoum, Senior Director, AI Frameworks | April 13, 2022

Learn More
  • For a detailed walkthrough of a PyTorch BERT code example, please see this article
  • You can learn more about what’s new in R1.2 here
  • To schedule a demonstration, please reach out here
  1. Alex Suhan, Davide Libenzi, Ailing Zhang, Parker Schuh, Brennan Saeta, Jie Young Sohn, Denys Shabalin, “LazyTensor: combining eager execution with domain-specific compilers”, 2021, arXiv,
  2. Torch-MLIR project code repository,
  3. Lazy tensor staging code repository,
  4. Tutorial, “Extending dispatcher for a new backend in C++”, PyTorch, 2021,
  5. PDL Dialect, MLIR documentation,
  6. “Training PyTorch models on Cloud TPU Pods”, Google Cloud, 2022,
  7. Stewart Hall, Rob Schreiber, Sean Lie, “Training Giant Neural Networks Using Weight Streaming on Cerebras Wafer-Scale Systems”, Cerebras Systems, 2021, Booth Docs/CS Weight Streaming White Paper 111521.pdf
  8. “XLA Custom Calls”, TensorFlow, 2021,
  9. “XLA Architecture”, TensorFllow, 2021,
  10. “Functorch”, PyTorch,

(PyTorch, the PyTorch logo and any related marks are trademarks of Facebook, Inc.)