Introduction

The growing commercialization of large language models (LLMs) for myriad tasks, such as text generation, retrieval-augmented search, etc., leads to exponential growth in training new language models. As models and dataset sizes scale, the ability to reduce the prohibitive costs of training is a fundamental enabler. At Cerebras, we believe unstructured sparsity is the answer for lowering the compute for training foundation models.

Training with sparsity involves masking certain learnable weights in the layer’s weight matrix, as shown in Figure 1. As an earlier blog post explained, training with sparse weights allows us to skip floating point operations (FLOPs), i.e., compute during the forward and backward pass, giving speedup on hardware that supports accelerating sparsity, such as the Cerebras CS-2 system.

Figure 1: Applying weight sparsity to a dense neural network by masking weights effectively prunes neuron connections within the network. The light blue connections in the sparse network indicate the masked weights.

The Cerebras CS-2, enabled by our on-chip memory bandwidth and fine-grained dataflow scheduling, is the only production hardware capable of accelerating training with unstructured weight sparsity. This necessitates software with an easy interface to access the power of sparsity. Most deep learning libraries, such as PyTorch, are optimized for dense computations and provide minimal support for sparsity. Also, the sparsity support it does have is not a first-class citizen and is optimized for GPU computing. A good user interface allows ML users to train dense models and takes advantage of sparsity if the underlying hardware supports it without complex rewrites. With this in mind, we release our fully integrated PyTorch-based library for sparse training. The library is built on the principle of modular APIs shared among different sparsity algorithm implementations, allowing for fast research ideation and extensibility to new algorithms. It is co-designed from the ground up with our software solution, does not require changing complex code deep in the framework, and introduces minimal overhead for ML users.

In the rest of the blog, we introduce the core API that makes the foundations of our sparsity library. We release implementations of several popular weight sparsity algorithms (for static and dynamic sparse training). We follow this up by showing how easy it is to enable sparse training with our ModelZoo and a few benchmark results. Finally, we demonstrate how extensible the API is by enabling a new dynamic sparsity algorithm and training a model using that implementation.

Cerebras PyTorch Sparsity Library

Our library is designed to support unstructured sparsity. It is hardware agnostic; however, when enabled with the Cerebras CS-2, it can harness its unique ability to accelerate unstructured sparsity. The library abstracts all low-level complexity by providing reusable APIs in PyTorch, enabling the ML user to focus on research ideation. It introduces minimal overhead and is easy to integrate into most training workflows.

API Design

In this section, we will give an overview of the design of our sparsity API. We have four key abstractions that enable us to flexibly design most weight-sparsity algorithms for training.

Optimizer Integration: Most state-of-the-art sparsity algorithms require stateful variables such as masks, counters, initialization states, etc., to enable iterative adjustments during training for better model quality. This is similar to PyTorch optimizers such as AdamW [3], which track momentum buffers for adaptive optimization. Relying on this framework of optimizers for handling states and enabling per-parameter options, we design our sparsity algorithms as special optimizers, which operate on a sparse view of the model parameters.

We wrap our sparse optimizers over the existing PyTorch optimizer for training (e.g., SGD or AdamW). This enables efficient parallelization, saving the sparsity states in checkpoints and conditionally handling dynamic gradient loss scaling. Figure 2 shows how our wrapper introduces minimal changes to facilitate training with sparsity.

# Construct model and optimizer
model = torch.nn.Module(...)
optimizer = cstorch.optim.SGD(...)

# Model forward and backward
loss = model(...)
loss.backward()

# Update weights.
optimizer.step()
# Construct model and optimizer as usual
model = torch.nn.Module(...)
optimizer = cstorch.optim.SGD(...)

# Construct a sparsity optimizer, and use the returned wrapper as a drop-in replacement for the original optimizer
optimizer = cstorch.sparse.configure_sparsity_wrapper(
   model,
   optimizer,
   sparsity_type=...,
   sparsity=...,
   init_method=...,
)

# Model forward and backward as usual. Sparsity is automatically applied.
loss = model(...)
loss.backward()

# Update weights. If using dynamic sparsity, it is also updated according to its schedule.
optimizer.step()

Figure 2: We compare CS2 workflows for dense (left) and sparse (right) training. Our sparsity wrapper handles all the changes needed for sparse training internally. For the ML developer, it is as simple as calling the wrapper configuration function with the arguments for sparsity.

Base Optimizer: Most state-of-the-art sparsity and pruning algorithms share similar routines for applying and updating masks for models. We consolidate these under a BaseSparsityOptimizer implementation, which handles other standard functions such as defining the checkpoint states for sparsity, initializing masks with custom distributions, and handling the sparse views of the model parameters and optimizer states. This allows users to define new optimizers relatively easily without worrying about the control flow of sparse training and checkpointing.

Update Schedules: Sparse training algorithms often change the mask patterns and sparsity level at different frequencies and rely on some scheduling functions to enable them. We provide a BaseHyperParameter class, which can be used to define custom schedules easily and pass them to existing algorithms. We also implement standard schedules such as cosine, polynomial, and periodic.

Tensor utilities: We provide a few base utilities for handling mask updates for sparsity algorithms, enabled via an efficient TopK implementation for the CS2. To enable fine-grained tensor handling, we also develop two utilities for developers:

    1. ScoreShaper: enables parameter reshaping, allowing for grouped structures during training.
    2. ScoreTieBreaker: enables breaking ties between individual elements within a tensor when writing custom sparsification logic to ensure determinism.

We rely on the above-defined abstractions of our API to implement the following sparse training algorithms as a representative set of baselines:

    1. Static Sparse Training
    2. Gradual Magnitude Pruning (GMP) [4]
    3. Sparse Evolutionary Training (SET) [1]
    4. Rigging the Lottery (RigL) [2]

Our developer documentation contains more details for functional arguments and support for update schedules, initializations, etc. We also provide a detailed guide on how to set up a sparse training workflow from scratch for training on the Cerebras CS2.

Why support dynamic sparsity?

The Lottery Ticket Hypothesis [5] demonstrated that we can find a sparse network with iterative pruning and successfully train it from scratch to achieve comparable accuracy by starting from the original initial conditions. In practice, this work relies on finding the “winning ticket,” which is compute-intensive and often challenging to discover. Previous works, such as SNIP [6], GRASP [7], etc., have tried finding this winning ticket at initialization to reduce compute costs but lose accuracy compared to training a dense model. As an orthogonal approach, some works, such as SET [1], RigL [2], etc., have focused on employing dynamic updates to efficiently identify optimal sparse networks within a single training cycle, bypassing the need for finding the winning ticket. Figure 3 illustrates the general workflow of dynamic updates during training. The recent state-of-the-art research on sparsity for training neural networks relies on dynamic sparse methods by default. Also, in our recent work, Sparse-IFT [8], we benchmark the advantages of dynamic sparse training over static sparse training and show consistent wins at all sparsity levels.

Figure 3: Dynamic sparsity algorithms improve the optimization of sparse neural networks by leveraging updates during training. For example, RigL [2] utilizes weight and gradient magnitudes to jointly optimize model parameters and connectivity. Figure sourced from the RigL paper.

Push Button Software for Sparse Training

The Cerebras Software Platform makes it extremely simple to pre-train models using unstructured sparsity. Any existing PyTorch model in the Cerebras Model Zoo can be made sparse with just a few lines of change to the configuration file, as shown in Figure 4.

Figure 4: Example configuration changes to enable 80% sparsity for training a 1.3B Llama2 model using RigL. In this example, we start with a random mask on all linear layers in the network. For the drop fraction, we follow a cosine decay schedule to 0.

Benchmarks

We benchmark our sparse training algorithms for training LLMs to demonstrate the new API and the effectiveness of dynamic sparsity beyond static sparsity. We use the same architecture as the Llama2 [9] family of models but do not adopt their findings on Generalized Query Attention (GQA) and long context lengths. We train a 1.3 billion (B) parameter model on 112 billion tokens of SlimPajama [10] data. Table 1 shows the architectural details of the model. We use the AdamW [3] optimizer with betas of (0.9, 0.95) and epsilon of 10-8. The global norm is clipped at 1.0, and a weight decay of 0.1 is used. There is a learning rate warm-up over the first 2000 steps to a peak value of 2 ∗ 10-4, followed by a cosine decay to 4.5 ∗ 10-6. We train on packed sequences of 2048 context length for computational efficiency.

Table 1: Size and architecture of the trained Llama2 model.

To train the sparse models, we uniformly prune 80% of all the linear layer weights in the decoder blocks (5x compression). The normalization, embeddings, and output linear layers are not pruned to promote training stability for the sparse models. Following the findings in the In-Time Overparameterization [11] paper, we reduce the batch size by half compared to the dense model and train for 2x longer. We also increase the drop fraction to 0.5, follow a cosine decay pruning schedule, and decrease the frequency of updates to allow algorithms such as SET and RigL to find better masks. The hyper-parameters for all models are shown in Table 2.

Table 2: Learning hyper-parameters (batch size and sparsity) of the models we trained. All models are trained on 112B tokens, and for both dynamic sparsity runs, the drop fraction is decayed to 0 for 75% of the training run, following the recommendations of the RigL paper.

While a single CS-2 system can seamlessly pre-train GPT models up to 175 billion parameters, we leverage a Wafer-Scale Cluster equipped with 4 x CS-2 systems to scale pre-training to speed up our experiments. The remarkable ease of scaling is shown in Figure 5. A more detailed discussion of the CS-2’s scaling properties can be found in this blog post.

Figure 5: Distributing training across multiple CS-2 systems in a Wafer-Scale Cluster is as easy as specifying the number of systems in the run command. No programming for distributed training is necessary.

Table 3 shows the model’s results on the SlimPajama validation subset and downstream tasks from the eval harness following the Open LLM Leaderboard. Using dynamic sparsity algorithms during training leads to better model quality over static sparse training on upstream validation perplexity and downstream few-shot evaluation tasks.

Table 3: Evaluation of the dense and sparse trained models. We report the validation perplexity (↓ - lower is better) and the average downstream few-shot accuracy (­↑ - higher is better) on the public Open LLM Leaderboard. We do not report the scores for the GSM8K task, as none of the models have strong scores (less than 0.5) on this task.

While we do not run baselines here for GMP, the recent paper on scaling laws for sparse neural networks [13] shows some examples of training transformer models using this algorithm.

Ease of Integrating New Algorithms

Our library’s modular and extensible design enables the building of new algorithms seamlessly. We showcase this flexibility by building support for new state-of-the-art dynamic sparse algorithms such as GraNet [12]. GraNet builds on top of the pruning-and-regeneration design of RigL [2]. The critical difference is that RigL is a constant sparse-to-sparse algorithm (i.e., the sparsity level does not change throughout training), whereas GraNet follows a gradual nondecreasing sparsity schedule. This unlocks both dense-to-sparse (i.e., start dense and end sparse) and sparse-to-sparse (i.e., start at lower sparsity level and end at higher sparsity level) training.

Figure 6 shows the changes to enable GraNet, given a RigL configuration file.

Figure 6: Example configuration changes to enable GraNet training. In this example, we start with the RigL configuration defined in Figure 3 and add changes to the sparsity schedule to allow dynamic changes in both the mask update and the sparsity level through training. Note that beyond adding the sparsity schedule, all other hyper-parameters are the same between RigL and GraNet.

To enable the gradual sparsity schedule of GraNet, we implement a simple cubic schedule using our BaseHyperParameter abstraction for schedules (described in API design). We compare this with the implementation of constant sparsity level for RigL in Figure 7.

class Constant(BaseHyperParameter):

    """
    Constant at every step.
    """
  
  TYPE = "constant"

    def __init__(self, value):
        self.value = torch.tensor(value)

    def __call__(self, step: torch.Tensor, is_update_step: torch.Tensor):
        return self.value
class Cubic(BaseHyperParameter):

    """
    Cubic sparsity function.

    :math:`s_t = s_f + (s_i - s_f) * (1 - (t - t0) / (n * t_delta))**3`
    """

    TYPE = "cubic"

    def __init__(
        self,
        init_sparsity,
        end_sparsity,
        sparsity_start_step,
        sparsity_end_step,
        prune_every_k_steps,
    ):
        self.s_init = init_sparsity
        self.s_end = end_sparsity
        self.update_iter = prune_every_k_steps
        self.init_iter = int(sparsity_start_step / prune_every_k_steps)
        self.final_iter = int(sparsity_end_step / prune_every_k_steps)
        self.total_iters = self.final_prune_iter - self.initial_prune_iter

       def __call__(self, step: torch.Tensor, is_update_step: torch.Tensor):

        curr_iter = (step / self.update_iter).int()
        prune_decay = (1 - ((curr_iter - self.init_iter) / self.total_iters)) ** 3
        current_prune_rate = self.s_end + (self.s_init - self.s_end) * prune_decay
        return torch.clamp(current_prune_rate, min=self.s_init, max=self.s_end)

    def get_min_max_end(
       self, begin: int, end: int
    ) -> Tuple[float, float, float]:
        return (self.init_sparsity, self.end_sparsity, self.end_sparsity)

Figure 7: Defining the schedulers used by RigL (left) and GraNet (right) for sparsity. RigL keeps the sparsity level constant throughout training, whereas GraNet uses a cubic, non-decreasing schedule to enable dense-to-sparse or sparse-to-sparse training. No other changes are required to the base dynamic update logic for the masks.

We train a Llama2 1.3B model (from our benchmarks) following the same training configurations for 2.6B tokens using our GraNet implementation and show the training curves below in Figure 8.

Figure 8: Loss curves for a Llama2 1.3B model trained with RigL (in blue) and GraNet (in orange). We compare a model trained at 80% sparsity with RigL to one trained with GraNet (start at 50% and end at 80% sparsity). We observe that the gradual increase in sparsity leads to a lower loss (i.e., better) than RigL.

Conclusion

In this blog, we introduce our PyTorch-based library for training models with weight sparsity and show results for training some large models with it. We also show how easy integrating new algorithms and enabling sparsity for training models with the Cerebras Model Zoo is. Parallel to our work, libraries like JaxPruner [16] and STen [17] have also been released to enable sparsity research.

The Cerebras CS-2’s specialized architecture enables unprecedented efficiency and performance for sparse neural network models. Our co-designed ML/software solution allows users to access this performance through a research-friendly API. This library is already pivotal in supporting our in-house research on sparsity, demonstrated through the works in Sparse Pre-training and Dense Fine-tuning [14, 15] and Sparse Iso-FLOP Transformations [8].

We are actively exploring new methods and directions to optimize performance and the quality of sparse models. Contact us to learn more about this study or how the Cerebras CS-2 and our software platform can empower your sparsity research.

References

      1. Mocanu, Decebal Constantin, et al. “Scalable training of artificial neural networks with adaptive sparse connectivity inspired by network science.” Nature communications1 (2018): 2383.
      2. Evci, Utku, et al. “Rigging the lottery: Making all tickets winners.” International Conference on Machine Learning. PMLR, 2020.
      3. Loshchilov, Ilya, and Frank Hutter. “Decoupled weight decay regularization.” arXiv preprint arXiv:1711.05101 (2017).
      4. Zhu, Michael, and Suyog Gupta. “To prune, or not to prune: exploring the efficacy of pruning for model compression.” arXiv preprint arXiv:1710.01878 (2017).
      5. Frankle, Jonathan, and Michael Carbin. “The lottery ticket hypothesis: Finding sparse, trainable neural networks.” arXiv preprint arXiv:1803.03635 (2018).
      6. Lee, Namhoon, Thalaiyasingam Ajanthan, and Philip HS Torr. “Snip: Single-shot network pruning based on connection sensitivity.” arXiv preprint arXiv:1810.02340 (2018).
      7. Wang, Chaoqi, Guodong Zhang, and Roger Grosse. “Picking winning tickets before training by preserving gradient flow.” arXiv preprint arXiv:2002.07376 (2020).
      8. Saxena, Shreyas, et al. “Sparse Iso-FLOP Transformations for Maximizing Training Efficiency.” arXiv e-prints (2023): arXiv-2303.
      9. Touvron, Hugo, et al. “Llama 2: Open foundation and fine-tuned chat models, 2023.” URL https://arxiv. org/abs/2307.09288 (2023).
      10. Soboleva, Daria, et al. “SlimPajama: A 627B token cleaned and deduplicated version of RedPajama.” (2023).
      11. Liu, Shiwei, et al. “Do we actually need dense over-parameterization? in-time over-parameterization in sparse training.” International Conference on Machine Learning. PMLR, 2021.
      12. Liu, Shiwei, et al. “Sparse training via boosting pruning plasticity with neuroregeneration.” Advances in Neural Information Processing Systems 34 (2021): 9908-9922.
      13. Frantar, Elias, et al. “Scaling laws for sparsely-connected foundation models.” arXiv preprint arXiv:2309.08520 (2023).
      14. Thangarasa, Vithursan, et al. “SPDF: Sparse Pre-training and Dense Fine-tuning for Large Language Models.” arXiv preprint arXiv:2303.10464 (2023).
      15. Gupta, Abhay, et al. “Accelerating Large Language Model Training with Variable Sparse Pre-training and Dense Fine-tuning.” Cerebras Blog (2023)
      16. Lee, Joo Hyung, et al. “JaxPruner: A concise library for sparsity research.” Conference on Parsimony and Learning. PMLR, 2024.
      17. Ivanov, Andrei, et al. “STen: Productive and Efficient Sparsity in PyTorch.” arXiv preprint arXiv:2304.07613 (2023).

Contributions

Abhay Gupta led the design of the sparsity API along with Mark Browning, tested the algorithms during their internal bring-up, ran all benchmarks, and contributed to the writing of this blog. Mark Browning is the primary developer of the sparsity library, enabling framework support for the CS-2. Claire Zhang helped implement the GraNet algorithm, ran the associated experiments, and contributed to writing the blog. Sean Lie is the architect of sparsity on the Cerebras CS-2 and has guided the bring-up of the hardware, software, and training infrastructure used for the training runs in this blog. We also acknowledge the software, machine learning, and performance teams, who have played an instrumental role in developing sparsity support on the CS-2 hardware.