William Marshall, Machine Learning Engineer | June 1, 2022

An exciting new feature released in version 1.3 of the Cerebras Software Platform (CSoft) is Variable Tensor Shape (VTS) computations. The unique technology accelerates training of transformer natural language processing models by performing computation on data samples with heterogeneous shapes efficiently, rather than wasting time on padded elements needed by traditional compute architectures. In this article, I’d like to explain how VTS works and how to adjust your code to unlock this speedup.

Over the past few years, large Transformer style models have become the mainstay of natural language processing (NLP). Since the publication of BERT [1] in 2018, the amount of compute and memory needed to train state-of-the-art language models have increased by several orders of magnitude (Figure 1). As models get more compute intensive, it is increasingly important to not waste any computation.

Figure 1. Exponential growth of NLP model size and computational intensity for training (Cerebras)

Datasets used for training NLP models often contain samples of varied lengths, or “shapes”. In such situations, it is common practice to pad out shorter sequences, or tensors, to the same length as the longest allowed sequence, in order to be able to batch samples together (Figure 2).

Figure 2. An example of padding input samples to be the same length. Text underlined in orange is used to compute the loss function. Padding text (underlined in gray) is ignored when computing the loss, but in many traditional settings is still passed through the entire model.

In practice, this means that a large fraction of many datasets is padding. For example, on average a sample in the WMT 2014 English to German dataset [4] as set up in [3] is 87% padding and 13% valid tokens. Figure 3 gives examples of several other common datasets and their breakdown between padding and real tokens. Across all these applications, a naïve treatment of data through the model would result in significant amounts of wasted computation, and conversely there is a huge opportunity for faster model training by addressing this problem more thoughtfully.

Figure 3. Breakdown between padding and real tokens in common public datasets.

At Cerebras Systems, we have developed the capability to perform computation on data sequences that have heterogeneous shapes. This variable tensor shape (VTS) technique, sometimes referred to as “variable sequence length”, allows users to train models on Cerebras systems without wasting valuable FLOPs on padding tokens, resulting in a significant performance boost. To enable this feature, released in version R1.3 of the Cerebras Software Platform (CSoft) the user needs to simply add a few lines to the model code as we show below.

Understanding Variable Tensor Shape (VTS) on Cerebras

In order to understand the VTS feature and its associated speedups, it is useful to first understand the pipeline execution mode used for training small to medium sized ML models on Wafer Scale Engine, which is the heart of our CS-2 System. The reader is likely familiar with the batch-parallel computation paradigm used to run ML models on traditional hardware like GPUs or TPUs. In this paradigm, samples are batched together, and a single GPU will perform one operation simultaneously for multiple samples (Figure 4).

Figure 4. Illustration of batch-parallel execution. A batch of tensors is loaded together onto the GPU) and each operation is performed on the entire batch at the same time.

In contrast, the Cerebras Software Platform (CSoft) in pipeline execution mode uses pipeline parallelism. Different areas of the Wafer-Scale Engine (WSE) are responsible for different sub-computations of the model. Samples are then fed one-by-one through this pipeline of computations such that at any given point in time multiple different sub-components of the model are being performed, each acting on a different sample in a different physical area of the fabric (Figure 5).

Figure 5. Illustration of pipeline-parallel execution. Input tensors are fed one by one onto the WSE, and each area of the WSE is responsible for performing a single computation on a tensor before passing it on to the next area.

When variable tensor shape computations are enabled, the padding on the end of a tensor simply gets stripped away soon after it enters the fabric (Figure 6). All subsequent computations with such a tensor are performed on the shorter version of the tensor unless the user explicitly specifies that the padding should be introduced back onto the end of the tensor. This is possible because pipeline parallelism allows samples to get processed individually on the fabric. There is no need to represent groups of samples together, and the unique ability of the Cerebras dataflow architecture to achieve full utilization on matrix-vector multiplications allows us to process samples individually. (There’s great explanation of this in this blog by Michael James, our Advanced Technologies Chief Architect.)

Figure 6. Illustration of pipeline-parallel execution with VTS enabled. The padding on the end of a tensor simply gets stripped away soon after it enters the fabric.

Typically, if an application involves padding out data, the associated model will have explicit code to ensure that these padding tokens are not used by any of the computations performed by the model. For example, padding tokens are usually ignored in attention computations by adding a mask to the attention scores before computing the attention softmax function, and loss computations are ignored if they correspond to padding tokens. This means that there is no difference in the training of a model using VTS other than increased throughput. All the same operations are still performed, the only difference is that they are performed on shorter tensors. If a user desires that the computation of a model depends on the padding tokens, they can either decide not to use VTS, or selectively turn off VTS for subsets of the model. In either case, the user can still rely on faithful execution of whatever model they have defined, and there is still no difference in model behavior other than increased throughput.

Performance Benefits

The performance benefits of VTS are highly dependent on the dataset being used, so looking at its impacts in a controlled setup with synthetic data is a good way to get a sense of its overall usefulness. Figure 7 shows relative throughput increase for the BERTLARGE [1] model with synthetic data generated such that every sample has the same sequence length. As we reduce the fraction of the data that is made up of non-padding tokens, we see an increase in throughput that is inversely proportional to the fraction of each sample that contains non-padding tokens. That is, the time spent on a sample in the slowest pipeline stage of the model is directly proportional to the length of that sample.

Figure 7. Relative throughput increase for the BERTLARGE [1] model with synthetic data.

This performance transfers to real applications. For BERTLARGE with a maximum sequence length of 512 tokens, training on a CS-2 system with the Open Web Text dataset is 20% faster when using variable tensor shape computations. The performance benefits are greater for more heavily padded datasets. For example, Transformer-large [3] trained on the WMT 2014 English to German translation dataset gets an impressive 5x speedup from VTS. VTS also gives benefits beyond raw throughput increase. In [2], the authors of the T5 model use several tricks in dataset preparation that arrange for each sample to be close to fully packed. In this case using VTS allows us to avoid packing sequences to the maximum length without a meaningful increase in total time to train. This simplifies the process of preparing the input pipeline and obviates the machine learning subtleties surrounding data packing.

Model VTS Throughput Multiplier
Transformer-large [3] trained on WMT 2014 English-German 5x
BERTLARGE [1] trained on Open Web Text 1.2x
BERTLARGE [1] trained on Open Web Text with bucketing 1.4x
T5-base [2] trained on C4 Similar throughput, easier implementation

Table 1. VTS performance results

One detail that is notable about these VTS related throughput gains is that unlike the results from the synthetic data setup in Figure 8, throughput increases on real-world datasets are not quite linear in one over the fraction of the dataset that is not padding. The reason and the remedy for this phenomenon can both be understood by diving deeper into the performance implications of the pipelined model of parallelism. For the sake of concreteness, suppose that in the batch currently being processed by the model the first sample has a very long sequence length while the second sample has a very short sequence length. Even though the short sample might take a much shorter amount of time to complete its current stage of the pipeline, it can’t move on to the next stage until the long sample ahead of is has been processed and moved on to make room for the short sample.

This example shows us that even a few long samples in a batch can end up slowing down the overall throughput, and more generally if nearby samples have significantly different lengths this might reduce throughput. One corollary of this is that if we can arrange for all samples in a batch to have similar length then we can further boost performance for real-world datasets. We implement this for BERT by first bucketing samples into groups of similar sequence length and then generating batches from within individual buckets. This technique yields an additional 13% throughput increase for BERTLARGE trained on the Open Web Text dataset with max sequence length of 512 tokens, giving a total of 1.4x throughput compared to the same model trained without VTS.

Adapting Your Code for VTS

VTS feature is very simple to add to an existing PyTorch model ported to Cerebras. To learn how to port PyTorch models, refer to the blog “Getting Started with PyTorch BERT Models on the Cerebras CS-2 System” published by my colleagues Antonio Kim and Ryan Reece.

To enable VTS, add a call to cerebras.framework.torch.nn.StripPadding for each padded tensor. This custom PyTorch operation maps to a Cerebras kernel that removes the padding from the end of the tensor provided to it as input. It has the following signature:

    input, # a torch.nn.Tensor to strip the padding from
    mask, # a binary mask defining which locations of `input` contain
               # valid tokens and which contain padding

Using this function, we can easily adapt models to use VTS. For example, a BERT model similar to the one implemented in Reference Implementations can be adapted for VTS as follows.

class BertModel(torch.nn.Module):
    # model initialization code

    def forward(
        input_ids = StripPadding(input_ids, attention_mask)
        masked_lm_positions = StripPadding(masked_lm_positions, masked_lm_weights)
        labels = StripPadding(mlm_labels, masked_lm_weights)

        # remaining model code


Real world datasets are messy. They often have samples of heterogeneous length. While this provides a challenge to traditional hardware, at Cerebras we see it as an opportunity. Leveraging our unique pipeline parallelism, we can handle these heterogeneities directly through our variable tensor shape feature. By changing just a few lines of code, users can enable VTS and cut their total training time substantially, unlocking easier model experimentation and a quicker path to their ultimate solution.





[1] Devlin et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, 2018. https://arxiv.org/abs/1810.04805
[2] Raffel et al. Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. Journal of Machine Learning Research, 2020. https://arxiv.org/abs/1910.10683
[3] Vaswani et al. Attention is All You Need. NeurIPS, 2017. https://arxiv.org/abs/1706.03762
[4] WMT 2014 English-German dataset. https://nlp.stanford.edu/projects/nmt/
[5] OpenWebText. (n.d.). Retrieved from https://github.com/jcpeterson/openwebtext
[6] Wikipedia text retrieved from https://dumps.wikimedia.org/enwiki/