Large language models are costly to setup, train, and deploy. Weight sparsity, when coupled with hardware which accelerates unstructured sparsity, is a promising way to cut inference time, speed up training, and reduce memory requirements. We trained extremely sparse GPT-3 1.3B parameter models via iterative pruning with unstructured weight sparsity on the Cerebras CS-2 system using the Pile dataset, including an 83.8% sparse model with 3x reduction in inference FLOPs1, 4.3x reduction in parameters, and no degradation in loss.

In an industry in which inference costs are prohibitively high, we show that large-scale language models can be pruned to high levels of sparsity while maintaining accuracy competitive with their dense counterparts. This is the first of many sparse training experiments possible on Cerebras hardware. We will discuss future directions and the promise of directly training extremely sparse models at a fraction of the FLOPs.

Introduction & Motivation

Large scale Transformer based language models have led major advances in a variety of tasks, including question & answer, summarization, code generation, and protein structure prediction. These models – the largest among which have hundreds of billions of parameters – require massive amounts of compute and memory to train and deploy. OpenAI’s GPT-3 175B parameter model, for example, was trained with 3.14*1023 FLOPs (floating point operations), requiring 10,000 NVIDIA V100 GPUs for 15 days, consuming an estimated 1,287 MWh energy and accounting for 552 tons of CO2e emissions [Patterson et al.]. Cluster setup aside, the compute required for a single training run of this model would cost more than $3M on AWS2 today. In addition, the trained model would require hundreds of gigabytes of memory [Lambda], making the model challenging to directly use in resource-constrained edge applications.

These compute and memory requirements, along with the difficulties of setting up and training on highly distributed GPU clusters, make large language models out of reach for AI groups looking to make breakthroughs in their respective domains.

Weight sparse training methods set subsets of weights to zero during training – often the ones which are already close to zero in magnitude. The resulting sparse model (see Figure 1) requires far fewer FLOPs to train and fewer parameters to store, as multiplies with zeros can be skipped on both forward and backward passes through the network. Coupled with the Cerebras CS-2, which automatically takes advantage of sparsity, this reduction in FLOPs could significantly accelerate training and reduce inference time.

This is an important result for those looking to deploy large language models, which could otherwise have higher than desired latencies. With inference comprising a large portion of industry machine learning workloads, even methods which spend more compute during training are of interest if they lead to energy savings during inference [Patterson et al.].

Figure 1. Applying weight sparsity to a dense neural network by zeroing weights (“pruning” connections within a network).

Finding and training sparse models to match the accuracy of their original “dense” (i.e. non-sparse) configurations is an exciting and open area of research. In “Comparing Rewinding and Fine-Tuning in Neural Network Pruning,” Renda, et al. extend previous work on pruning and weight rewinding to find a sparse model over multiple iterations of training. In the first iteration, the dense model is fully trained to target accuracy. Then, 20% of its dense weights closest to zero are “pruned” (set to zero). The model is retrained for a full training run, rewinding the learning rate schedule back to step 0. This is followed by pruning 20% of the remaining weights, then rewinding and retraining, and so on. Each iteration results in a model sparser than the previous one. This simple method leads to sparse models matching accuracies of dense counterparts for models including ResNet-56 (852K parameters) on CIFAR-10, ResNet-50 (25.5M parameters) on ImageNet, and GNMT (165M parameters) on WMT16.

In this work, we applied pruning and learning rate rewinding on a GPT-3 1.3 billion parameter model. We first trained a dense model from scratch on the Pile, an 800 GB language modeling dataset, and then iteratively pruned and trained to find sparse models. Our 83.8% sparse model represents a 3x reduction in inference FLOPs without any degradation in validation loss. All models are trained on a single CS-2 device, without requiring user level distributed code or setup and are competitive with the original dense model on the Pile, LAMBADA and WikiText103.

Model, Data, and Method

We trained a GPT-3 1.3B style model with 24 hidden layers, 16 heads, 2048 hidden size, 8192 filter size, standard attention, and a vocabulary of 50,267. We trained this on the Pile, a large language dataset compiled from various datasets including PubMed, OpenWebText2, Wikipedia, Github, and BookCorpus2. We trained for ~26B tokens based on a ~20x parameter multiplier rule-of-thumb from DeepMind’s “Training Compute-Optimal Large Language Models.”  Figure 2 shows our training loss curve from Tensorboard.

Figure 2. Tensorboard Training Cross Entropy Loss Baseline Dense GPT-3 1.3B on the Pile.Trained on a single CS-2 system.

In each iteration of training, we prune a model by removing low magnitude weights and then train the model to recover performance. Unlike Renda et al where all model weights are sparsified together, we sparsify each layer separately. We only sparsify weights in projection layers (QKV projections, output attention projections) and feed forward network layers and do not sparsify other variables such as embeddings or biases. For a model with roughly 83.8% sparsity per projection and feedforward network, this results in a 3x FLOPs reduction and 4.3x reduction in total parameters. At every iteration, we sparsify 20 – 33% of the remaining dense weights. Following sparsification, we shuffle the dataset and continue to train the model.  As opposed to always rewinding learning rate to step 0 and training for the dense training budget, we often rewind to a learning rate later in the learning rate schedule, and then train the model for as little as ~20K steps (i.e. for the model with the least sparsity) to ~120K steps (i.e. for models with the most sparsity).

Results

We measured the effectiveness of iterative pruning by evaluating the resulting sparse models using several metrics: Pile loss, zero-shot LAMBADA accuracy, and zero-shot WikiText103 perplexity. We show that the GPT-3 1.3B dense model can be iteratively pruned to high degrees of sparsity without significant reduction in quality across these metrics.

We first evaluated these models on the validation set of the Pile dataset. The horizontal line in Figure 3 corresponds to the validation loss reached by the baseline dense model. Each color represents a different level of sparsity (in the projection matrices) and has 3 validation points – one at the start of training, one in the middle, and one at the end. As we trained the sparse model, the loss reduces, recovering and sometimes surpassing the baseline loss. This resulted in a sawtooth pattern, where each valley corresponds to a trained sparse model that can be used for inference! Remarkably, even the sparse model with 83.8% sparsity can match its dense counterpart (which has ~6.2x more trainable parameters in projections and feed forward layers).

Figure 3. Cross entropy on the Pile validation set vs. training steps (lower is better). Each vertical line represents the increase of loss when a model is sparsified (note that this is recovered as we train the model). Sparsity levels refer to average sparsity percentages on projections and feed forward layers in GPT3. Evaluation is done for 380.1M tokens, roughly the full validation set of PILE. Training and evaluation done on a Cerebras CS-2 system.

Table 1 below shows inference FLOPs reductions for a variety of the sparse models trained and the evaluation loss they achieve on the PILE dataset. We also demonstrated zero-shot downstream results of sparse models on LAMBADA and WikiText103. This was done without any additional fine-tuning of the dense or sparse models. LAMBADA and WikiText103 evaluation results were gathered via Megatron (in the process of being validated & compared against other evaluation harnesses).

When coupled with the Cerebras CS-2 system, which can take advantage of such sparsity, these models could significantly reduce inference latencies. Note that while sparse model training was stopped after crossing baseline loss on the PILE, these results could improve with more training.

Table 1. PILE cross entropy on validation set, accuracy on LAMBADA, and adjusted perplexity on WikiText103 across trained sparse models, along with inference FLOPs reductions compared to baseline dense model. The direction of the arrow indicates better result (e.g., up indicates higher is better). Training done on a Cerebras CS-2.

It is worth noting that we trained our models on a shorter training schedule (26B tokens) than other existing published models, because we based our training on compute-optimal training recommendations from Hoffman, et al (see Table 3). While neither of the following published models are compute-optimal, Megatron’s GPT-2 355M model trained for 157B tokens achieves 45.18% accuracy on LAMBADA and 19.31 adjusted ppl on WikiText103, and GPT-Neo 1.3B trained for 380B tokens achieves 1.818 cross entropy3 on the Pile and 57.23% accuracy on LAMBADA. We believe this to largely be due to the models being trained on an order of magnitude more data. Furthermore, while sparse models in iterative pruning see more training data and part of their gains could be a result of additional training over baseline, under a constrained compute budget where one must both train and prepare models for deployment, iterative pruning is one way of discovering models at extreme sparsity and high accuracy.

Where to Go from Here

In their work on the Lottery Ticket Hypothesis, Frankle and Carbin demonstrate that for certain models, sparse subnetworks exist that can be trained, from initialization, to similar accuracy as their original dense networks. Demonstrating this for non-trivial sparsity levels for GPT-3 style models could be a promising step towards showing the existence of sparse subnetworks within large language models.

Along with exploring the Lottery Ticket Hypothesis, finding better ways of choosing which weights to prune could improve accuracies of sparse models or help find sparse models sooner. Designing sparse training methods to find optimal sparse weight sets without iterative pruning could produce competitive models while also reducing the time to train large language models.

Conclusion

As language models scale in size, cost, and energy and become harder to setup and use, they also become out of reach of AI groups looking to train and deploy them. With the Cerebras CS-2 enabling unstructured sparse research and language model training at scale, we have taken our first steps towards exploring sparse models and the impact they could have on industry applications. We successfully trained unstructured sparse 1.3 billion parameter GPT-3 models on Cerebras CS-2 systems and demonstrated how these models achieve competitive results at a fraction of the inference FLOPs, with our 83.8% sparse model achieving a 3x reduction in FLOPs at matching performance on the Pile, setting the stage for future work in sparse training, pruning, and inference.

Footnotes

1 FLOPs calculated following methodology in Appendix F of Training Compute-Optimal Large Language Models excluding embedding FLOPs and adjusting for the amount of sparsity in each layer.

2 NVIDIA’s 2021 post on Megatron estimates that one “the GPT-3 model with 175 billion parameters can be trained in just over a month” with 1024 A100s. Using AWS pricing for 8 A100 p4d.24xlarge instances, this would be 1024 GPUs / 8 GPUS per instance * 30 days * 24 hours/day * 32.77 $ / hour = $3.02M. (as per AWS, on demand pricing and spot pricing are reported to be same at time of writing this post). According to NVIDIA, for customers who only have access to a single DGX A100, training GPT-3 would take almost 15 years.

3 GPT-Neo’s repo reports 6.159 ppl on the Pile. ln(6.159) = 1.812 cross entropy. Note perplexity can be impacted by tokenization, context window length, padding, and validation subset chosen.

Contributors

Anshul Samar wrote this post, analysis, and conducted model evaluations.
Eugene Vecharynski designed the study and conducted iterative pruning training runs.
Jay Jagtap  and Kevin Leong supported training infrastructure, testing, and running experiments.
Joel Hestness managed the project and trained the initial dense model baseline on the CS-2.
Dennis DeCoste advised the project.
Sean Lie advised and led engineering efforts to make sparse training possible.

The contributors would like to thank the Cerebras software and hardware teams who have made large language model training and sparse research a reality on the CS-2.

Dive deeper