Antonio Kim and Ryan Reece, Machine Learning Engineering | April 14, 2022

Natural language processing has taken the machine learning community by storm over the past few years. Transformer architectures such as Pytorch BERT and its many variants have gained great notoriety and adoption as over the past few years as a preferred model for sequence data in a variety of domains.

In addition, the PyTorch machine learning framework has also gained immense popularity within the machine learning community due to its intuitive nature and ease-of-use.

Naturally, this means that there is a huge demand for running transformer models such as BERT using PyTorch. We at Cerebras are constantly expanding our support for PyTorch models to provide a simple and easy way to port existing PyTorch models to run at high performance on Cerebras systems with just a few extra lines of code.

For a more detailed look at how the Cerebras Software Platform (CSoft) works, please see our Software Documentation.

In this blog, we cover the practical steps to get a PyTorch transformer model like BERT running on the CS-2. For details on how we support PyTorch framework on our architecture, please read the blog “Supporting PyTorch on the Wafer-Scale Engine.”

Running PyTorch on a Cerebras System

Running any PyTorch model on a Cerebras system is straightforward. Many convenient wrappers are exposed in our API to adapt existing PyTorch training scripts and run models on a Cerebras system.

To start, import the PyTorch module from the Cerebras python environment where the convenience wrappers are housed.

import torch
import cerebras.framework.torch as cbtorch

Next, to establish a connection to the system, we need to call initialize. All that is required to pass into this function is the IP address of the Cerebras system you want to connect to:


Once the connection to the system is established, we need to prepare the model, dataloader and optimizer to be loaded onto the system. Let’s assume that we have these objects predefined as follows:

model = torch.nn.Module ...
dataloader = ...
optimizer = torch.optim.Optimizer ...

Loading the model’s weights, the dataloader’s data and the optimizer’s state onto a Cerebras system can be done with the following calls:

model = cbtorch.module(model)
dataloader = cbtorch.dataloader(dataloader)
optimizer = cbtorch.optimizer(optimizer)

This is virtually all that is needed to enable training any PyTorch model on a Cerebras system. From here, the model can be trained using a typical training script/workflow.

Running PyTorch Models Performantly on a Cerebras System

While the above can be used to run a PyTorch model on a Cerebras system, there are some performance implications that can arise with poorly designed training scripts.

Multiple Workers

A Cerebras system is a network-attached accelerator coexisting with a CPU cluster that consists of one chief node and one or more worker nodes. In order to maximize the full computing power of the wafer scale engine, we highly recommend using multiple workers to send in data so that there is enough data to keep the wafer at full utilization. This means PyTorch training code would need to be adapted to enable multi-worker coordination, which is easily done through the context manager – cbtorch.Session that wraps around the “training” portion of the code as shown below:

# initialization and loading onto system happens above

def train(...):
    # train the model for N epochs

with cbtorch.Session(dataloader, mode="train"):

This allows multiple workers to call the exact same training script, but only the chief worker is given control of the training loop while the other workers are solely designated to sending data from the dataloader to the system.

Step Closures

A step is everything that happens inside a single iteration of the training loop. Retrieving the data must happen in between iterations to be done in a way that doesn’t significantly impact the training performance. In order to make this easier and less intrusive, we expose a step closure API that will queue up a number of functions to be run in between iterations.

We first must import the cb_model library which contains a number of useful utility functions:

import cerebras.framework.torch.core.cb_model as cm

An arbitrary function and its arguments can then be added to the queue for execution at a step boundary as follows:

cm.add_step_closure(closure, args=(...))

For added convenience, we also introduce a function decorator that can be used as follows:

def closure(...):

This automatically wraps any calls to closure in a step closure and queues it up for execution between training iterations.

Running BERT models on a Cerebras System

Now we’ve looked at the general steps for running PyTorch models on a Cerebras system, we can easily apply those to a model such as BERT.

Cerebras Reference Implementations

We share implementations of BERT and a few other models for our users in this public github repository – Cerebras Reference Implementations. The reference implementations make use of the API’s mentioned above and the best practices for maximizing performance on CS-2. All the code needed to train and run evaluation of the model is provided including the model, the data loaders and data preparation scripts. We include model references in both TensorFlow and PyTorch.

Here we will describe the recently added BERT reference model implemented in PyTorch. Each model has a directory structure like the one for BERT shown below:

$ cd cerebras_reference_implementations/bert/pytorch
$ ls -1

The top level model implementation and dataloader is in the and modules, respectively. The script is mainly a high level wrapper that imports the data loader specified in the configuration file and generates an instance of the data loader with the appropriate parameters. A set of standard model configurations, or variants, are included in the configs directory. We provide the script for running training and eval modes on the Cerebras Wafer Scale Engine. The same scripts can also be used to run on other platforms – CPU and GPU.

Model Implementation

BERT is a multi-purpose sequence model based on the encoder of the Transformer architecture. BERT is pre-trained with two final head layers that calculate terms in the loss, one that does Masked Language Modeling (MLM), and one that does Next Sentence Prediction (NSP).

We base our implementation of BERT on the popular open-source Transformers library by Hugging Face. For the most part the implementation we include is very close to what is distributed by Hugging Face. We made some small changes to the implementation to make it more performant on Cerebras’ unique architecture.

Our implementation of BERT can be found at bert/pytorch/

Dataset and Dataloaders

The dataset used in our reference implementation for BERT is a pre-processed version of Open Web Text dataset, where we have tokenized the data into word parts and saved them to CSV files. Both the dataset and the dataloaders are configured using the configuration YAML files found at bert/pytorch/configs/.

There are number of configurations under that directory for BERTBASE and BERTLARGE variants with maximum sequence lengths of 128 and 512 tokens. To minimize training time, we perform the initial 90% of training at the smaller maximum sequence length of 128 and the final 10% of the training at a maximum sequence length of 512.

Our implementation of the BERT dataloader can be found at bert/pytorch/input/

Configuration YAMLs

The configuration files mentioned above also contain the parameters that can be used to configure both the model as well as the run itself. Please see the configuration for BERTBASE MSL128 for an example of the configuration parameters that are available. Below are just some examples of parameters that can be configured.

Train/Eval Input

Parameters that can be configured include the maximum sequence length, the maximum predictions per sequence, the batch size, etc.

max_sequence_length: 128
max_predictions_per_seq: 20
batch_size: 256


Parameters that can be configured include the number of hidden layers, the dropout rate, the maximum position embeddings, etc.

num_hidden_layers: 12
dropout_rate: 0.1
max_position_embeddings: 512


Parameters that can be configured include the learning rate (scheduler), the loss scaling factor, the maximum gradient norm, etc.

    - scheduler: "Linear"
      initial_learning_rate: 0.0
      end_learning_rate: 0.0001
      steps: 10000
    - scheduler: "Linear"
      initial_learning_rate: 0.0001
      end_learning_rate: 0.0
      steps: 1000000
loss_scaling_factor: "dynamic"
max_gradient_norm: 1.0

In the above example, we are “warming up” the learning rate for the first 10000 steps and then “cooling it down” for the next million steps.

Run Configuration

Parameters that can be configured include the maximum number of steps, the step frequency at which to take checkpoints, etc.

max_steps: 900000
checkpoint_steps: 5000

Training execution

We provide a convenient run script that can be used to configure and start a training run on a Cerebras system. The common implementation can be found at common/pytorch/

This run script implements all the steps that are required to train the BERT model on a Cerebras system:

  1. The initialization can be found at common/pytorch/
  2. The model is initialized at common/pytorch/
  3. The model is loaded onto the system at common/pytorch/
  4. The dataloader is initialized at common/pytorch/
  5. The data is loaded onto the system at common/pytorch/
How to run
Test your model implementation first on a CPU

For debugging purposes, it is helpful to try to run training of the model on a CPU before compiling for a CS system. We provide the slurm scheduler script csrun_cpu to submit a CPU training job like this within our CS Support cluster:

csrun_cpu python-pt --mode train \
    --params configs/bert_base_MSL128.yaml

Compiling the model for CS

When you have validated that your model works on CPU, you are ready to compile the model to run on a Cerebras WSE. Compilation for WSE also happens on the CPU chief node of our CS Support cluster. It uses the same command as running on CPU, but with a --compile_only option as shown below:

csrun_cpu python-pt --mode train \
    --params configs/bert_base_MSL128.yaml \

Training/Evaluating the model on a CS system

Training is the process of learning the weights via a forwards and backwards pass of the model with the given input data. Evaluation is typically run after training using the learned weights and a test dataset to assess whether the model has learned well. The evaluation process returns metrics like loss convergence, accuracy or perplexity or f1 score (a statistical measure of the accuracy of a test or model).

When you are ready to train or evaluate on a CS system, we provide a slurm scheduler script csrun_wse that encapsulates the call to the python script. To run the script you must specify the path to your configuration YAML, the IP and port of your system, as well as the mode to be run (train/eval). For example:

csrun_wse python-pt --mode \
    --params configs/bert_base_MSL128.yaml \
    --cs_ip &ltIP:port&gt \
    --mode train


Now we have covered the steps needed to get a PyTorch transformer model like BERT running on the CS-2. As you’ve seen, it’s a straightforward process, using the convenient wrappers in our API to adapt existing PyTorch training scripts for our systems.
We hope you found this walkthrough useful. If you have comments or questions, please reach out to

Release 1.2 of the Cerebras Software Platform (CSoft) includes expands PyTorch BERT support and unlocks Cerebras Weight-Streaming for extreme-scale models. Learn more in this blog.

And if you’d like to schedule a demo: Schedule a Demo


  1. Devlin, J. et al. (2018). “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding“(arXiv:1810.04805)
  2. Cerebras Software Documentation
  3. Cerebras Reference Implementations code repository (GitHub)
  4. Vaswani, A. et al. (2017). “Attention Is All You Need” (arXiv:1706.03762)
  5. Wolf, T. et al. (2019). “Hugging Face’s Transformers: State-of-the-art Natural Language Processing” (arXiv:1910.03771)
  6. Hugging Face transformers code repository (GitHub)
  7. OpenWebTextCorpus