Sparse expert neural networks showcase the advantage of sheer scale and offer an efficient alternative to the static neural network architectures commonly used today(raffel2019exploring; brown2020language; rae2021scaling). Rather than applying the same parameters to all inputs, sparse expert networks dynamically select which parameters to use for each input (shazeer2017outrageously). This allows for networks to vastly expand their number of parameters, while keeping the FLOPs per token roughly constant. These approaches have yielded state-of-the-art translation models (lepikhin2020gshard), 4-7x pre-training speed-ups (fedus2021switch; artetxe2021efficient), and GPT-3 level one-shot performance using 1/3 the energy training cost (du2021glam). And despite a shocking number of parameters, sparse models reduce the carbon footprint for training large neural networks by an order of magnitude (patterson2021carbon). However, difficulties remain.
fedus2021switch observed that a sparse 1.6T parameter model achieved a 4x pre-training speed-up over the prior state-of-the-art (raffel2019exploring), but lagged smaller models when fine-tuned on common benchmarks like SuperGLUE. Similar gaps were observed in artetxe2021efficient when MoE language models were fine-tuned on out-of-domain data. In response, Switch-XXL, a model with fewer parameters, but a 8x-larger computational footprint (FLOPs approximately equal to the largest T5 model), was proposed and improved quality on natural language understanding tasks. However, necessary pre-training was hampered by training instabilities previously undetected during smaller scale studies. These instabilities were later identified in other sparse models (du2021glam). These results revealed a necessary balance of parameters and computation, but left an open question on how to reliably train these types of models.
Our aim in this paper is to increase the practicality and reliability
of sparse models. We study these two issues and pre-train a 269B sparse model that achieves state-of-the-art results when fine-tuned across many competitive NLP benchmarks, including SuperGLUE. We also put forth additional analysis and a design guide (or at least, our heuristics) for sparse expert models. Furthermore, this work emphasizesjointly optimizing both the upstream pre-training and the downstream fine-tuning metrics to avoid discrepancies (tay2021scale).
[enhanced,attach boxed title to top center=yshift=-3mm,yshifttext=-1mm, title=Contributions, colback=white, colframe=white!75!blue, coltitle=black, colbacktitle=white]
A large-scale study of the quality-stability trade-offs of stability techniques.
An introduction of the router z-loss that resolves instability issues, while slightly improving model quality.
A fine-tuning analysis of sparse and dense models highlighting different hyperparameter sensitivity to the batch size and learning rate. We show bad hyperparameters result in virtually no fine-tuning gain over dense models, despite large pre-training speed-ups.
Architectural, routing and model design principles for designing Pareto efficient sparse models in a distributed setting.
A qualitative analysis tracing token routing decisions across expert layers.
A 269B sparse model (the Stable Transferable Mixture-of-Experts or ST-MoE-32B) which achieves state-of-the-art performance across a diverse set of natural language benchmarks.
Sparse expert models typically substitute a neural network layer with a set of experts, each having unique weights (jacobs1991adaptive; jordan1994hierarchical). Typically all the experts within a layer are of the same type and shape (homogeneous), however, varied (heterogeneous) expert-types are possible. Inputs are only processed by a subset of the experts to save computation, so a mechanism must be added to determine where to send each input. Usually a router or gating network determines where to send inputs (i.e. words, sentences, image patches, etc.), but alternative schemes have been proposed (lewis2021base; roller2021hash; zuo2021taming; clark2022unified).
Specifically, in natural language processing, shazeer2017outrageously proposed a Mixture-of-Experts (MoE) layer which takes a token representation as input and routes it to the best matched top- experts selected out of a set of experts. The router variable
produces logitswhich are normalized via a softmax distribution over the available experts at that layer. The gate-value for expert is given by
and the token is routed to the experts with the highest top- gate values (set of indices ). The output of the layer is the weighted sum of each expert’s computation by the gate value
Originally proposed in LSTMs (hochreiter1997long), expert layers were later used in the Transformer (vaswani2017attention) by shazeer2018mesh and lepikhin2020gshard. Follow-on work by fedus2021switch simplified the MoE further to route tokens to a single expert (top-1) and reduced other costs to improve training efficiency.
To improve hardware utilization, most implementations of sparse models have static batch sizes for each expert (shazeer2017outrageously; shazeer2018mesh; lepikhin2020gshard; fedus2021switch). The expert capacity
refers to the number of tokens that can be routed to each expert. If this capacity is exceeded (the router sends too many inputs to that expert) then the overflowed tokens have no computation applied to them and are passed to the next layer through a residual connection.
|Expert||An independently-learned neural network with unique weights.|
A network that computes the probability of each token getting sent to each expert.
|Top- Routing||Routing algorithm where each token is routed to experts.|
|Load Balancing Loss||An auxiliary (aux) loss to encourage each group of tokens to evenly distribute across experts.|
|Group Size||The global batch size is split into smaller groups, each of size Group Size. Each group is considered separately for load balancing across experts. Increasing it increases memory, computation, and communication.|
|Capacity Factor (CF)||Each expert can only process up to a fixed number of tokens, which is often set by evenly dividing across experts, . The capacity factor can expand or contract this amount to .|
|FFN||Acronym of Feed Forward Network (FFN) layer of Transformer consisting of linear, activation, linear.|
|Encoder-Decoder||A Transformer architectural variant that all of our models are based on. Consists of an encoder that does all-to-all attention on the inputs and a decoder that attends to the encoder and to its own inputs in an autoregressive manner.|
|allreduce||Communication primitive which sums a subset of tensors on different devices, then broadcasts the summed value to all devices. This is used in distributed training for gradient accumulation and model parallelism.|
|all2all||Communication primitive where each device sends to every other device a part of its tensor. Used in sparse Transformer models for token routing.|
|(/)||Indicates whether higher/lower values are better (e.g. accuracy/train loss).|
The batch of input tokens is broken into unique groups across the data-parallelism dimension222Our implementation relies on einsums with one-hot tensors for dispatching and combining tensors to/from experts. The size of this one-hot tensor grows quadratically with the number of tokens being routed as a group which motivates breaking the batch into smaller groups. This may be avoided with sparse lookup operations., each with size . The expert capacity is equal to where CF represents the capacity factor hyperparameter, experts is the number of experts and tokens is the group size. If the capacity factor is increased, it creates extra buffer so that fewer tokens will be dropped in case of load imbalance. However, increasing the capacity factor also increases the memory and computational costs, so there exists a trade off333See fedus2021switch for a graphical illustration of how the capacity factor works..
Finally, an auxiliary load balancing loss encourages tokens to be roughly evenly distributed across the experts (shazeer2017outrageously). This improves the hardware efficiency by ensuring that all accelerators are processing significant chunks of data in parallel as mentioned above. The details of the loss are presented in Appendix A. However, alternatives exist: lewis2021base and clark2022unified treats balanced token allocation as an assignment problem and removes the auxiliary loss entirely.
3 Stabilizing Training of Sparse Models
Sparse models often suffer from training instabilities (Figure 1) worse than those observed in standard densely-activated Transformers.
It’s straightforward to find changes that improve the stability, however, these often come at an untenable expense to model quality
(for instance, using an arbitrarily small learning rate or using tight gradient clipping). We categorize and examine several approaches to improve stability. The stability techniques span generic fixes to Transformers as well as those specific to sparse models:(1) Remove multiplicative interactions (2) Inject model noise (3) Constrain activations, and gradients. We conclude with our recommendation: a new auxiliary loss, the router z-loss
, which significantly improves training stability with no quality degradation. This is an adaptation of the z-loss used for final softmax logits in the Mesh Tensorflow codebase(shazeer2018mesh).
[enhanced,attach boxed title to top center=yshift=-3mm,yshifttext=-1mm, title=Stabilizing Sparse Models, colback=white, colframe=white!75!blue, coltitle=black, colbacktitle=white]
Many methods stabilize sparse models, but at the expense of worse quality.
The router z-loss stabilizes models without quality degradation.
Transformer modifications with more multiplicative components (GEGLU, RMS normalization) worsen stability, but boost quality.
Designing a large-scale stability study.
We design a large-scale stability study of sparse models FLOP-matched to the T5-XL version (raffel2019exploring) pre-trained on the multilingual corpus mC4 (xue2020mt5). Each sparse model has 32 experts and we introduce a sparse MoE layer for every fourth FFN. The train capacity factor is 1.25 and the eval capacity factor is 2.0. See Table 11
for a more detailed description of models used throughout this paper. For each stability technique, we record the fraction that are stable, the mean quality (negative log perplexity on English), and the standard deviation over seeds.
The primary issue in constructing this study is that small models are rarely unstable but large unstable models are too costly to run for sufficient steps and seeds. We found a sparse model FLOP-matched to T5-XL to be good object of study because it was unstable roughly of the runs, but was still relatively cheap to train. Furthermore, we run our instability experiments on multilingual data since we find this exacerbates model instabilities, allowing us to experiment on slightly smaller models. See Section 9 for more details. Our baseline configuration is trained using six random seeds and each configuration with a stability technique uses three random seeds. We use six seeds for the baseline to better characterize the instability rate and three seeds for the variants to save compute. Each model is pre-trained for 20k steps on mC4 using a masked language modeling objective (fedus2018maskgan; devlin2018bert).
3.1 Stability and Quality Tradeoffs when Removing Multiplicative Interactions
Some architectural improvements involve more multiplications than additions or do not sum many items at once. For example, a matrix multiplication has one multiplication for each addition and hence we do not refer to it as a “multiplicative” operation. We present and analyze the impact of two instances of multiplicative interactions in Transformers here.
GELU Gated Linear Units (GEGLU).
Our first example is the Gated Linear Unit (dauphin2017language)
which is a component-wise product of two linear projections, one of which is first passed through a sigmoid function.shazeer2020glu extends this to other variants and presents a GELU-Linear (hendrycks2016gaussian)
FFN layer as a replacement the usual ReLU(nair2010rectified) FFN in Transformer.
This quality gain was corroborated in later work (narang2021transformer).
Root Mean Square Scale Parameters.
Our second example is the scale parameter in root mean square (RMS) normalization (zhang2019root). Within the Transformer, rather than calling layers back-to-back, there is an internal structure (referred to as sublayer calls) which improve gradient propagation and training dynamics. Our sublayer calls match that of raffel2019exploring and consist of: (1) RMS normalization, (2) layer call (e.g. Self Attention), (3) dropout (srivastava2014dropout), (4) add residual (he2015deep)
. RMS normalization scales the input vectorelement-wise per the root-mean-square. It then rescales the output element-wise by multiplying with a learned scale parameter .
Table 2 shows that both removing GEGLU layers or the RMS scale parameter improves stability, but at a significant loss to model quality. We note that these scale parameters () have a disproportionate gain to model quality versus parameters elsewhere (e.g. FFN). In line with our findings, shleifer2021normformer found adding a learned multiplicative scalar to the residual connection in Transformers made them much more unstable.
|Method||Fraction Stable||Quality ()|
|Remove RMS Norm. Scale Param||-2.020|
In Appendix C, we further study the quality impact of adding new multiplicative interactions in expert layers. We find that this operation yields quality improvements with virtually no slow-down in model step time.
3.2 Stability and Quality Tradeoffs when Adding Noise
We next explore a hypothesis that adding noise into the model can improve training stability (neelakantan2015adding). taleb2012antifragile argues that certain systems exhibit the property of anti-fragility, where they improve through noise. Inspired by the concept and by our observation that fine-tuning (which injects noise via dropout) was rarely unstable, we examined whether training noise might improve the stability of sparse models. Table 3 shows a stability improvement versus the baseline, but at the expense of lower quality. We also find that input-jitter, introduced by fedus2021switch
, diminishes quality at XL-scale, hence we ablate it in our models. Input-jitter multiplies the input logits to the router by a uniform random variable between. Dropout in our ablation is applied throughout the Transformer. As seen previously, improvements in small-scale settings may fail to generalize when scaled up and therefore trends should always be monitored and re-assessed at increasing scale (kaplan2020scaling).
|Method||Fraction Stable||Quality ()|
|Input jitter ()||-1.777|
3.3 Stability and Quality Tradeoffs when Constraining Activations and Gradients
One of the most successful approaches to stabilizing neural networks are constraints on activations, and gradients (pascanu2013difficulty; ioffe2015batch; salimans2016weight; ba2016layer)
. A popular approach consists in the clipping of gradient norms to remedy exploding gradients while backpropagating through deep networks(pascanu2013difficulty).
In this work, we use the Adafactor optimizer due to its memory efficiency (though recently introduced 8-bit optimizers (dettmers20218bit) may offer better trade-offs). Instead of gradient clipping, Adafactor uses update clipping, where the changes to the weights are constrained to be below a certain norm. We experiment with tightening the update clipping to a smaller value.
Next, we study constraints on the logits going into the router. The router computes the probability distribution over the experts infloat32 precision (i.e. selective precision) (fedus2021switch). However, at the largest scales, we find this is insufficient to yield reliable training. To fix this, we introduce the router z-loss,
where is the number of tokens, is the number of experts, and are the logits going into the router. This penalizes large logits into the gating network and Section 3.4 contains a more detailed explanation of why the z-loss before the router is useful.
Table 4 shows that both update clipping and the router z-loss stabilize the model in all 3 runs, but the update clipping significantly hurts the model quality. Therefore we use the z-loss method for fixing our model stability due to improved quality and stability444We also experimented with adding z-losses onto the attention logits which also improves model instability without hurting model quality..
|Method||Fraction Stable||Quality ()|
|Update clipping ()||-4.206|
The router z-loss introduces another hyperparameter , which is the coefficient to weight this as part of the total loss optimized. The total loss is a linearly weighted combination of the cross entropy loss (), the auxiliary load balance loss (), and the router z-loss (), yielding a total loss
We choose a value of based on the best model quality after pre-training with a hyperparameter sweep. Appendix B logs the resulting losses over the course of pre-training.
3.4 Selecting a Precision Format: Trading Efficiency and Stability
As in most modern distributed Transformers we train with mixed precision (micikevicius2017mixed) 555See Mesh Tensorflow for implementation details: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/. Weights are stored in float32 for gradient updates and then converted to bfloat16 when doing matrix multiplications in the forward and backward pass666Matrix multiplications on TPUs perform multiplications in bfloat16 and accumulations in float32.. Furthermore, all activations are stored and operated on in bfloat16 and allreduce communications can be done in either bfloat16 or float32 numerical precision. For the largest model explored in this work (ST-MoE-32B presented later) we find speed-ups halving the numerical precision of the allreduce, however this also can destabilize the training so we keep this as float32 throughout this work.
A lower precision format enables more efficient models by reducing (a) communication costs between processors and memory, (b) computation costs, (c) memory for storing tensors (e.g. activations). However, lower precision formats come at the expense of larger roundoff errors which can lead to irrecoverable training instabilities.
Understanding precision format and roundoff errors.
Figure 2 reviews the properties of different precision formats and their corresponding roundoff errors for different number ranges. Numbers in any range of two consecutive powers of 2 (e.g. [2,4) and [1024, 2048)) are represented by a fixed number of mantissa bits (7 for bfloat16, 23 for float32). As a result, (1) bfloat16 will have about 65,536x (i.e. additional bits and ) as large roundoff errors as float32 and (2) larger numbers have larger roundoff errors. Due to the 8 exponent bits, number can get as large as , which leads to even float32 having some issues with roundoff errors.
Sparse expert models are sensitive to roundoff errors because they have more exponential functions due to the routers.
Sparse expert models introduce additional exponential functions – through the router – which can exacerbate roundoff errors777Exponential functions have the property that a small input perturbation can lead to a large difference in the output. As an example, consider inputting 10 logits to a softmax function with values of 128 and one logit with a value 128.5. A roundoff error of 0.5 in bfloat16 will alter the softmax output by 36% and incorrectly make all logits equal. The calculation goes from to . This occurs because the max is subtracted from all logits (for numerical stability) in softmax operations and the roundoff error changes the number from 128.5 to 128. This example was in bfloat16, but analogous situations occur in float32 with larger logit values. and lead to training instabilities. While a roundoff error does not change the ordering of probabilities within a softmax operation, it does impact the routing of the second token in MoE due to relative thresholding (e.g. a token is only routed to its second place expert if the gating probability for the second expert is as large as that of the first expert). Additionally, roundoff errors can drastically change the probability that scales the expert output – which we have found to be important. Finally, we conjecture that the higher stability we observed for decoder-only models (not shown here) was because they had fewer exponential functions. Section 9 contains a more detailed discussion.
An aside on the router z-loss.
One might think that the router z-loss is a convoluted method replaceable by clipping logits (wu2016google). We explain why this is not the case. The goal is to minimize large roundoff errors going into exponential functions. Clipping the logits occurs after any roundoff errors – resulting in even larger discontinuities. In one view, clipping in itself is a roundoff error; conversely, the z-loss naturally encourages the model to produce logits that are small in value and thus more accurately modeled. Due to these dynamics, we ensure all exponentiated tensors are cast to float32. This hints at the possibility of better number formats for neural networks because of the unused exponent bits when z-losses are added throughout the network (see Section 9).
4 Fine-Tuning Performance of Sparse Models
The best performing language models are usually obtained by (1) pre-training on large amounts of data (e.g. the internet) followed by (2) fine-tuning on a task of interest (e.g. SuperGLUE). Promising new techniques have emerged as an alternative, including few-shot inference (brown2020language), prefix tuning (li2021prefix), prompt tuning (lester2021power), and adapter modules (houlsby2019parameter) – however, a quality gap still persists compared to fine-tuning. Because of this, we focus on fine-tuning in this work, but highlight recent successes of sparse models in few-shot settings from du2021glam; artetxe2021efficient
. Further, we leave as future work techniques that adapt large language models through reinforcement learning(ouyang2022training)
4.1 Hypothesis: A Generalization Problem
Sparse models have performed remarkably well in the regime of large datasets, but have sometimes performed poorly when fine-tuning (fedus2021switch; artetxe2021efficient). We present evidence for a (not so surprising) hypothesis that sparse models are prone to overfitting. We illustrate this problem through two tasks in SuperGLUE (wang2019superglue) – Commitment Bank (de2019commitmentbank) and ReCORD (zhang2018record). Commitment Bank (CB) has 250 training examples while ReCORD has over 100,000. This significant size discrepancy facilitates a natural study for overfitting on two tasks selected as part of the same benchmark.
In Figure 3, we compare the fine-tuning characteristics of the Dense L and the ST-MoE-L model. Each model was pre-trained on 500B tokens from the C4 corpus (raffel2019exploring). The models are designed to be roughly FLOP matched variants of the T5-Large encoder-decoder models from raffel2019exploring with 770M parameters. The ST-MoE models have 32 experts with an expert layer frequency of (every fourth FFN layer is replaced by an MoE layer). The pre-training and fine-tuning train capacity factor is 1.25 and the eval is 2.0. We evaluate performance on the held-out validation and train dataset partitions.
Across both tasks, the sparse model converges faster to 100% train set accuracy supporting that sparse models optimize effectively under a data distribution shift. On the larger task, ReCORD, the validation quality of the sparse model follows the boost in training and significantly exceeds the dense model. However, on the smaller task, CB, the sparse model lags its dense counterpart on held-out data. As per the recommendation of fedus2021switch, we consider increasing the dropout within the expert hidden state (i.e. expert dropout), but find that at this scale, higher values only moderately improve quality (Figure 4). We study further improvements to fine-tuning in Section 4.2 and hyperparameter sensitivity in Section 4.3.
4.2 Fine-Tuning a Subset of Model Parameters to Improve Generalization
To combat overfitting we experiment updating only a subset of models parameters during fine-tuning. Figure 5 measures quality for updating 5 different subsets of parameters: all parameters (All), only non MoE parameters (Non MoE), only MoE parameters (MoE), only the self-attention and enc-dec attention parameters (Attention) and only the non MoE FFN parameters (FFN).
We observe that updating the non MoE parameters works about as well as updating all the parameters and updating only the FFN parameters works a bit better. Updating only the MoE parameters significantly degrades fine-tuning performance, which is where 80% of model parameters are. Only updating the non MoE parameters can be an effective way to speedup and reduce memory for fine-tuning.
We hypothesize that fine-tuning only the MoE parameters leads to bad performance since expert layers only occur every layers and a token will see at most two experts per layer. Therefore, updating the MoE parameters will affect much fewer layers and FLOPs than updating any other subset of the parameters we tried. Updating only the MoE parameters resulted in a much larger training loss than updating the non MoE parameters, even though there are significantly more parameters. We further observe that updating all the non-MoE parameters results in a higher training loss than updating all the parameters, but unfortunately this regularization effect didn’t translate to better validation performance.
Further, one regularizer we tried was a dropout variant where entire experts were masked out stochastically during training. However, this failed to improve generalization in our preliminary studies. Appendix J expands on this experiment and contains other negative results.
4.3 Sparse and Dense Models Require Different Fine-Tuning Protocols
How sensitive are sparse and dense models to the fine-tuning protocol? We study two hyperparameters: the batch size and the learning rate. We pretrain a Dense-L and ST-MoE-L on 500B tokens of C4 and then fine-tune on SuperGLUE. Figure 6 summarizes our experiments with the full data presented in Table 20 (Appendix F). Across all hyperparameter settings, the sparse models (orange) outperform the dense (blue) counterparts – however, the best setting for each can materially change results. Sparse and dense models have vastly different performance across different batch sizes and learning rates. Sparse models benefit from smaller batch sizes and a higher learning rate. Consistent with the overfitting hypothesis (Section 4.1), both these changes might improve generalization through higher noise in the fine-tuning process. Finally, we point out the importance of correctly tuning the batch size and learning rate during fine-tuning. Simply using the same fine-tuning hyperparameters that worked well for the dense model can mask any pre-training improvements obtained by the sparse model.
4.4 Sparse Models Are Robust to Dropped Tokens During Fine-Tuning
Sparse models route tokens to one or more experts at each layer. To make these models efficient in the SPMD paradigm with modern hardware, the expert capacity (the number of tokens each expert processes) needs to be fixed ahead of time (see Section 2 for more details). When an expert receives more tokens than its capacity, the extra tokens are dropped — no computation is applied to those tokens. We again try to prevent this by (1) pre-training with an auxiliary loss that promotes equal amounts of tokens getting sent to each expert and (2) a capacity factor (a hyperparameter) that adds room for extra tokens at each expert. We experiment with turning off the auxiliary loss during fine-tuning and using different capacity factors. Tables 5 reveals a surprising result that fine-tuning quality is not materially impacted by dropping up to 10-15% of tokens888Token dropping may be a form of regularization and a more extensive study may be an interesting direction for future work.. Studies on ST-MoE-32B corroborate that high capacity factors do not improve fine-tuning quality. This is in-line with findings of yang2021m6t that unequal load balance may not significantly impact model quality.
|Model||Train CF||Eval CF||Aux Loss||Percent Tokens Dropped||SuperGLUE ()|
4.5 Inserting Sentinels Tokens During Fine-Tuning
Sentinel tokens denote masked sequences in the span-corruption objective (fedus2018maskgan; devlin2018bert). This differs from any fine-tuning task we would likely encounter, leading to a domain mismatch between pre-training and fine-tuning. Table 6 illustrates the difference. We examine whether modifying the fine-tuning task to look more like the pre-training task effects results.
|Span Corruption||I like <X> the pool <Y> day .||<X> going to <Y> on a sunny|
|Fine-Tuning||What is the capital of Illinois ?||Springfield|
|Fine-Tuning + Sentinels||What is the capital of Illinois ? <X>||<X> Springfield|
In Table 7 we find that adding sentinel tokens while fine-tuning only improves Grammar Error Correction (GEC) (rothe2021simple), but not SuperGLUE. We tried to further reduce the data distribution shift by inserting multiple sentinel tokens (as would be encountered by the model while pre-training), but again found no universal benefit. However, despite no consistent benefit on held-out data, we find that training convergence is accelerated for both dense and sparse models.
|Model||Insert Sentinel Tokens||SuperGLUE ()||GEC ()|
|Dense||84.9 0.33||22.3 0.25|
|Dense||✓||85.1 0.25||22.1 0.42|
|Sparse||86.6 0.18||22.2 0.04|
|Sparse||✓||86.6 0.24||22.9 0.09|
5 Designing Sparse Models
The design of dense models has been guided by the foundational work of kaplan2020scaling. But sparse models pose a myriad of additional questions: (1) How many experts to use? (2) Which routing algorithm? (3) What value for the capacity factor? (4) How does hardware change these decisions? In this section, we comment on these and offer recommendations for building Pareto efficient sparse models. Concurrently, clark2022unified provides additional design recommendations including higher layer frequency and top-1 routing as per fedus2021switch.
[enhanced,attach boxed title to top center=yshift=-3mm,yshifttext=-1mm, title=Designing Sparse Models, colback=white, colframe=white!75!blue, coltitle=black, colbacktitle=white]
In our setup, we recommend top-2 routing with 1.25 capacity factor and at most one expert per core.
The capacity factor can be changed during evaluation to adjust to new memory/compute requirements.
Dense layer stacking and a multiplicative bias can boost quality (Appendix C).
5.1 Setting the Number of Experts
One of the first questions is the number of experts to use. fedus2021switch presented the scaling-properties of Switch Transformer which yielded monotonic pre-training benefits (on a step basis) on C4 up to 512-experts, kim2021scalable up to 64-experts and clark2022unified up to 512-experts. But the incremental benefit quickly diminishes with many experts (256) or equivalently, with very sparse models (1% of experts activated).
However, reflecting on the specific hardware system can further guide this choice. The compute-to-memory ratio (operational intensity) can serve as an estimate of the efficiency of different operations(williams2009roofline; shazeer2019fast). A model is memory bound if the time to load tensors to the computing core (e.g. ALU/MMU) greatly exceeds the time required to do the computation on the tensors. On modern GPUs and TPUs, increasing this compute to memory ratio improves the efficiency.
Returning to sparse expert models, using more than one expert per core increases memory transfer, potentially hurting efficiency. Increasing the number of experts does not change the computation done (sparse models apply a fixed amount of computation to each input), but increases the memory transfer requirement (additional expert variables must be loaded from device memory). This decreases the compute-to-memory ratio999As an exercise to the reader, verify the operational intensity of the first expert computation is with batch size, hidden dimension, number of experts..
On our TPU system, we recommend to one expert (or less) per core. Our largest models use both data and model parallelism where data parallelism is over “rows” and model-parallelism over “columns” of the logical mesh. We use 1 expert per data parallelism row to ensure the compute-to-memory ratio is high and to reduce the cores needed for evaluation and inference. Furthermore, using less experts lets us allocate more cores to the model parallelism “column” to have more FLOPs in our model. Appendix H explains our mesh layouts for when we have fewer experts than data parallelism rows.
5.2 Choosing the Capacity Factor and Routing Algorithm
We generalize top-1 routing (fedus2021switch; roller2021hash) and top-2 (shazeer2017outrageously; lepikhin2020gshard) to study top- routing where each token is processed by at most experts. In this study, all models are pre-trained for 100k steps with 1M tokens per batch and sparse models have 32 experts and are FLOP matched to T5-Large raffel2019exploring. We draw two key conclusions.
First, increasing both the train and eval capacity factors (CF) improves quality as seen by comparing across the segmented blocks of Table 8. For instance, top-1 routing improves by +0.011 neg. log perp. when increasing from 1.0 1.25 train CF and top-2 routing improves +0.009 increasing from 1.25 2.0 train CF. To provide context for these numbers: tripling the size of a dense model (Dense-L to Dense-XL) yields a +0.090 neg. log perp. boost. Therefore, these CF boosts are of that magnitude. But this comes at a cost. Increasing the capacity factor linearly increases the einsums costs, memory for activations, all2all communication costs, and model-parallelism allreduce communication costs for expert layers101010all2all and allreduce costs depend on the number of devices, batch size, and capacity factor, but not on the number of experts..
Second, there are small gains of top- over top- routing given a fixed capacity factor (Table 8). For instance, top-2 routing improves +0.004 over top-1 at train CF of 1.25 or about the boost of a dense model tripling. This revises an earlier recommendation from fedus2021switch. The primary difference between these experimental setups was scale of compute. fedus2021switch trained 220M-FLOP matched models for 50B tokens. We find at an 8x larger scale of training (1B-FLOP matched models for 100B tokens) there is instead a small gain to route to more than one expert. Furthermore, at the larger experimental scale, the speed difference of top- versus top- routing is negligible. Speed differences were observed in fedus2021switch because the router computation was a larger fraction of the total model computation.
|Algorithm||Train CF||Eval CF||Neg. Log Perp. ()|
The specific hardware-software system will determine the optimal and capacity factor. For instance, if the system supports fast all2all and allreduce communications, larger capacity factors and larger in top- routing may be optimal. However, if the all2all and/or allreduce communications are slow, smaller capacity factors may dominate. In our case, the hardware-software stack is the TPU and Mesh Tensorflow. We record the training speed of both our ST-MoE-L and ST-MoE-32B model in Table 9 as we increase the train capacity factor. As the models scale, a higher capacity factor makes the models increasingly slower. The ST-MoE-L does not require model parallelism (it fits within accelerators memory, which implies no additional allreduce communications) making it better suited for high capacity factors than our ST-MoE-32B model. For our largest model, we therefore continue to use the smaller train capacity factor of 1.25 advocated by fedus2021switch for Pareto efficiency, differing from other work which use a larger and more expensive 2.0 capacity factor (lepikhin2020gshard; du2021glam).
|Model||Train CF||Step Time (s) ()|
Our results in this section focus on top- routing, but we also experimented with a variety of other routing techniques in Appendix J. We found most performed similarity or worse compared to top- routing. However we found Batch Prioritized Routing (BPR), introduced in riquelme2021scaling, significantly helps performance for capacity factors less than one (Appendix D). We recommend BPR for larger models where all2all and allreduce are more expensive and lower capacity factors are optimal.
6 Experimental Results
Given our improvements to training stability, fine-tuning and model design, we start by validating a sparse model approximately FLOP-matched to T5-Large (raffel2019exploring). We conclude this section by designing and training a 269B sparse parameter model (FLOP matched to a 32B dense model) which achieves state-of-the-art quality across a wide set of NLP tasks.
We studied the SuperGLUE (wang2019superglue)
benchmark throughout this work which consists of tasks including sentiment analysis (SST-2), word sense disambiguation (WIC), sentence similarity (MRPC, STS-B, QQP), natural language inference (MNLI, QNLI, RTE, CB), question answering (MultiRC, RECORD, BoolQ), coreference resolution (WNLI, WSC) and sentence completion (COPA) and sentence acceptability (CoLA). We often observe good performance on SuperGLUE to correlate with (but not guarantee) performance across many NLP tasks. We also include a divers set of additional benchmarks. The CNN-DM(cnn2015moritz) and BBC XSum (narayan2018don)
datasets are used to measure the ability to summarize articles. Question answering is probed with the SQuAD dataset(rajpurkar2016squad) as well as on grade-school science questions in ARC Easy and ARC Reasoning Challenge (clark2018think). And as in roberts2020much, we evaluate the knowledge of our models by fine-tuning on three closed-book question answer datasets: Natural Questions (kwiatkowski2019natural), Web Questions (berant2013semantic) and Trivia QA (joshi2017triviaqa). Closed-book simply refers to questions posed with no supplemental reference or context material. To gauge the model’s common sense reasoning we evaluate it on the Winogrande Schema Challenge (sakaguchi2020winogrande). And finally, we test our model’s natural language inference capabilities on the Adversarial NLI Benchmark (nie2019adversarial).
For simplicity and to cover dozens of tasks easily, we train on mixtures of the tasks listed rather than separately fine-tuning a model on each task. However, because the tasks vary in size considerably, equally sampling per the number of examples would over-sample large tasks and under-sample small ones. We therefore mix each task in proportion to the number of examples in its ‘train’ split (up to some max_num_examples=65536) as in raffel2019exploring. This means that tasks containing more than 65536 training examples are weighted as if they only contain max_num_examples.
|Name||Metric||Split||Dense-L ()||ST-MoE-L ()||Gain (%)|
|Closed Book TriviaQA||acc||dev||28.1||33.8|
|Closed Book NatQA||acc||dev||27.2||29.5|
|Closed Book WebQA||acc||dev||30.5||33.2|
Table 10 summarizes the quality of a dense T5-Large (L) model and sparse model with approximately the same number of FLOPs pre-trained for 500k steps with a 1M batch size (524B tokens) on the C4 dataset (raffel2019exploring). The sequence length for the encoder was 512 and 114 for the decoder. We observe improvements on the validation (dev) sets across a wide array of tasks examining natural language understanding, question answering, and summarization. As seen in fedus2021switch, striking gains are observed in closed book question answering (roberts2020much).
Also, in support of the overfitting hypothesis presented in Section 4.1, we observe two of the smallest tasks CB and WSC (250 and 259 training examples, respectively), are the only ones where the sparse model does not yield gains over its dense counterpart. This again suggests that improved forms of regularization for sparse models may unleash greater performance.
With quality validated at the scale of T5-Large, we seek to push the capabilities of sparse models through the ST-MoE-32B. When designing this, we sought a balance between FLOPs and parameters. High-FLOP sparse models were previously unstable in fedus2021switch in our setting (i.e. encoder-decoder models, Adafactor optimizer), but the router z-loss enabled us to proceed. For computational efficiency, we expanded the hidden size of the experts ( in Table 11 below)111111allreduce activation communications introduced through model parallelism are independent of the hidden size, but not the model dimension, making it a good choice to increase.. Finally, we increased the to 128 for better performance on our hardware. The most salient changes are fewer overall parameters and more FLOPs per token relative to both Switch-C and Switch-XXL. Our ST-MoE-32B has “only” 269B parameters and is approximately FLOP-matched to a dense Transformer with 32B parameters. The reduced parameter count from Switch-C and Switch-XXL eases the burden for both serving and fine-tuning. Finally, we use the sparse-dense stacking described in Appendix C.
|Model||Num. Heads||Num. Layers||Num. Experts||Expert Layer Freq.||Sparse-Dense|
We pre-train for 1.5T tokens on a mixture of English-only C4 dataset (raffel2019exploring) and the dataset from GLaM (du2021glam) summarized in Appendix E. We use 1M tokens per batch, the Adafactor optimizer with default hyperparameters, and a learning rate warm-up of 10k steps followed by inverse square root decay. Our model follows the initialization scheme proposed in fedus2021switch.
|Previous Best ()||Ours ()|
Table 12 evaluates our ST-MoE-32B model against previous state-of-the-art approaches using inference-only (zero-shot, one-shot) as well as fine-tuning. On SuperGLUE, our model improves upon the prior state-of-the-art model, achieving an average score of 91.2 on the test server (93.2 validation accuracy) which is over one percentage point beyond estimated human capability. For both summarization datasets, XSum and CNN-DM, our model achieves state-of-the-art without additional changes to training or fine-tuning (raffel2019exploring; liang2021rdrop). ST-MoE-32B improves the current state-of-the-art on the test server submissions for both ARC Easy (92.7 94.8) and ARC Challenge (81.4 86.5). On two of the three closed book QA tasks, we improve over the prior state-of-the-art. Closed book WebQA achieves a 47.4 accuracy (prior best of 42.8 from roberts2020much and exceeds results from the zero-shot performance of the ERNIE 3.0 Titan 260B dense parameter model (wang2021ernie)). Closed book NatQA improves to 41.9 accuracy (prior best of 41.5 from karpukhin2020dense). We find significant improvements on adversarially constructed datasets (ANLI R3 and WinoGrande XL). ANLI R3 (nie2019adversarial) improves the state-of-the-art to 74.7 (prior best of 53.4).
We note some weaknesses in our model. ST-MoE-32B has lackluster performance on the small SQuAD dataset, with an exact match score of 90.8 which falls short of the older benchmark set by the T5-XXL of 91.3. Furthermore, while setting a new state-of-the-art for SuperGLUE in aggregate, certain tasks, including small ones like CB, WSC, fail to improve. Finally, on closed book Trivia QA, our model improves over the fine-tuned baseline with SSM from roberts2020much, but fails to produce gains over both GPT-3 and GLAM.
While not the focus of this paper, we present the quality differential between recent advances in inference-only techniques like few-shot learning and fine-tuning on these tasks (GPT-3 (brown2020language), GLAM (du2021glam) and Gopher (rae2021scaling)). As expected and observed previously, fine-tuning outperforms zero/one-shot learning, but has the disadvantage of requiring additional training and different models for each task.
7 Tracing Tokens Through the Model
Thus far we have presented quantitative measures and performance metrics. We change tack to explore qualitative features by visualizing how tokens are routed among the experts. We do so by passing a batch of tokens to the model and manually inspecting token assignment at each layer. We consider our ST-MoE-L model pre-trained either on the monolingual C4 corpus (raffel2019exploring) or on the multilingual mC4 corpus (xue2020mt5). On both the encoder and the decoder, the model has six sparse layers, each with 32 experts.
[enhanced,attach boxed title to top center=yshift=-3mm,yshifttext=-1mm, title=Preliminaries, colback=white, colframe=white!75!blue, coltitle=black, colbacktitle=white]
The span corruption objective is to recover spans of variable-length contiguous segments masked out in the inputs.
This is formatted as:
Inputs: I went to <extra_id_0> to buy <extra_id_1>
Targets: <extra_id_0> the store <extra_id_1> milk
In our encoder-decoder architecture, the inputs will be passed to the encoder and targets to the decoder.
Each group of tokens is routed jointly with load balancing across experts incentivized by an auxiliary loss as proposed in shazeer2017outrageously (see Appendix A for details). Tokens compete for expert assignment against other tokens in their group, rather than the entire batch, and expert specialization is heavily influenced by the distribution of tokens in each group. The notion of groups is introduced to limit the cost of dispatching and gathering the correct tokens to the correct experts.
7.1 Encoder Experts Exhibit Semantic Specialization
Our first observation is that, at each layer, at least one expert specializes in sentinel tokens (mask tokens that represent blanks to fill-in). Additionally, some encoder experts exhibit clear semantic specialization, with some experts primarily operating on punctuation, verbs, proper names, counting, etc. Table 13 presents a few notable example of semantic specialization across encoder experts. And while we find many instances of specialization, these have been specifically extracted from many examples without a clear semantic specialization.
|Semantic specialization||Expert position||Routed tokens|
|Sentinel tokens||Layer 1||been <extra_id_4><extra_id_7>floral to|
|Punctuation||Layer 2||, , , , , , , , , - , , , , , ). )|
|Layer 6||, , , , , : . : , & , & & ? & - , , ? , , , . <extra_id_27>|
|Conjunctions and articles||Layer 3||The the the the the the the the the The the the|
|the the the The the the the|
|Layer 6||a and and and and and and and or and a and .|
|the the if ? a designed does been is not|
|Verbs||Layer 1||died falling identified fell closed left posted lost felt|
|left said read miss place struggling falling signed died|
|falling designed based disagree submitted develop|
|Visual descriptions||Layer 0||her over her know dark upper dark outer|
|color, spatial position||center upper blue inner yellow raw mama|
|bright bright over open your dark blue|
|Proper names||Layer 1||A Mart Gr Mart Kent Med Cor Tri Ca Mart|
|R Mart Lorraine Colin Ken Sam Ken Gr Angel A|
|Dou Now Ga GT Q Ga C Ko C Ko Ga G|
|Counting and numbers||Layer 1||after 37 19. 6. 27 I I Seven 25 4, 54 I two dead we|
|written and numerical forms||Some 2012 who we few lower each|
7.2 Decoder Experts Lack Semantic Specialization
In contrast, expert specialization is far less noticeable in the decoder. Not only are sentinel tokens routed somewhat uniformly across decoder experts (see Table 14), but we also do not observe semantically meaningful specialization in decoder experts.
We hypothesize that this lack of semantically meaningful expert specialization is caused by the distribution of target tokens induced by the span corruption objective. In particular, (a) a smaller number of tokens are routed jointly in the decoder due to longer sequence lengths in the encoder (e.g. group size is 2048 in the encoder vs 456 in the decoder in our setup) and (b) a higher proportion of tokens are sentinel tokens in the decoder. As a result, target tokens in each group typically cover a smaller semantic space (compared to the encoder), perhaps explaining the lack of semantic expert specialization in the decoder. This intricate interplay between the architecture and the training objective invites further research on better leveraging sparsity and expert specialization in the decoder. Alternatively, future work could study simply removing the experts in the decoder layer, which also confers benefits during autoregressive decoding (kudugunta2020exploring).
|Layer 1||Layer 2||Layer 3||Layer 4||Layer 5||Layer 6||Uniform (32-experts)|
We support our qualitative observation that encoder experts specialize, but decoder expert don’t by computing the entropy over the routing for sentinel tokens. The encoder routing entropy is low, but the decoder router is high entropy, and nearly equal to uniform routing. Because each layer has 32-experts, a completely uniform distribution has entropy of 3.5.
7.3 Multilingual Experts Specialize Semantically, But Not by Language
We next consider a multilingual sparse model pretrained on a mixture of different languages and inspect the expert specialization in the encoder. As in the monolingual case, we find strong evidence of semantically meaningful expert specialization. Table 15 presents some examples of experts specializing in sentinel tokens, numbers, conjunctions & articles and proper names.
|Semantic specialization||Routed tokens|
|Sentinel tokens||to <extra_id_6>to til <extra_id_9>|
|Numbers||$50 comment .10.2016 ! 20 20 3 ! 5 1. ! 91 ? né ?|
|2 17 4 17 11 17 8 & 11 & 22:30 02 2016. ) iOS|
|Conjunctions & Articles||of of of their their of any this this your your am von|
|this of Do of of This these our 的 的 于 的 在 的 在 的|
|le les Le la di la sur sur 136 sur の の する の という の し|
|Prepositions & Conjunctions||For for or for for or for from because https during https|
|并 与 和 par c Pour à a par trè pour pour pour pour pour c と や のに|
|で で で なので - and and + c between and and|
|Proper names||Life Apple iOS A IGT 众 莫 HB|
|F HB A K A OPP OK HB A Gia C Gia C P Scand Wi|
|G H Z PC G Z ハイ PC G Ti CPU PC PC A キット OS|
One might expect experts to specialize in languages, which appears as a natural criterion for divvying up batches of data among experts. However, we find no evidence of language specialization (see Table 15). Routers instead pass tokens from English, Japanese, French and Chinese indiscriminately and the experts appear to be multilingual. But this lack of language specialization is less surprising when considering the mechanism of token routing and load balancing. Since each group of tokens may only contain one, to at most a few, languages (a group usually consists of 2-4 sequences in our setup), then all experts are encouraged to handle tokens from all languages. We experimented with a global load balance loss, however, this usually results in worse load-balance and worse model performance, so we leave further improving multilingual expert models as an area of open work (Section 9).
Our visualization reveals apparent specialization learned in our models (Tables 13, 15) for the encoder layers. Other expert specializations were also observed in the appendix of shazeer2017outrageously. However, this leads to an interesting question of how architectures that eliminate learned routing roller2021hash; zuo2021taming appear to perform well. An extensive study of the scaling properties of learned versus random routing could prove helpful as future work and help guide us to a better understanding of routing behavior.
8 Related Work
Mixture-of-Experts (MoE) date back at least three decade history to the work of jacobs1991adaptive; jordan1994hierarchical. In initial concepts, the MoE defined the entire neural network akin to ensemble methods. But later eigen2013learning extended the idea of including MoE as a component as part of deeper networks. shazeer2017outrageously then scaled this idea to a 137B parameter model to achieve state-of-the-art in machine translation. Most of the later work (including ours) follows this MoE as a component approach.
Scale in natural language processing. The remarkable success of scale in natural language processing (kaplan2020scaling; brown2020language) has reinvigorated MoE research evidenced by a surge of recent work (lepikhin2020gshard; fedus2021switch; yang2021m6t; kim2021scalable; du2021glam; artetxe2021efficient; zuo2021taming; clark2022unified). Sparse expert models have been proposed as a method to achieve the results of large-scale dense models, more efficiently. fedus2021switch showed a 4x pre-train speed-up over T5-XXL (raffel2019exploring) and du2021glam matched the quality of GPT-3 (brown2020language) using only of the energy. And in the span of the last twelve months, a milestone of efficiently training trillion parameter deep neural networks has been achieved by multiple groups (fedus2021switch; yang2021m6t; du2021glam), and most recently, lin2021m610t introduced techniques to train a 10T parameter model. One side note is that the recent significant successes of sparse expert models have often been in settings with a lot of data and no distribution shift – two examples being language modeling/span corruption and machine translation (shazeer2017outrageously; lepikhin2020gshard; kim2021scalable; fedus2021switch). In contrast, discrepancies between strong pre-training quality and poor fine-tuning quality for sparse models have been observed in fedus2021switch; narang2021transformer; artetxe2021efficient, but we expect advances in regularization techniques to continue to improve downstream quality.
Towards better routing algorithms. BASE layers (lewis2021base) recasts token routing as a linear assignment problem – removing the need for load balancing auxiliary losses. This work also demonstrated the efficacy of a single expert layer. clark2022unified studies in depth the scaling properties of a few different routing algorithms and propose their own variant of BASE layers that uses an optimal transport formulation. yang2021m6t introduces the M6-T architecture and expert prototyping which splits experts into different groups and applies top-1 routing procedures (contrasting with the top- routing commonly used elsewhere). hazimeh2021dselectk proposed a continuously differentiable sparse gate with demonstrated improvements over vanilla top- gating. Other work (bengio2016conditional) considered casting the routing selection as a reinforcement learning problem. More radical versions remove learning the routing entirely. Hash layers (roller2021hash) shows random fixed routing (per hash functions) led to competitive performance with learned routing. zuo2021taming also proposed an algorithm which randomly selects experts during training and inference and found gains of 2 BLEU points over Switch Transformers and competitive scores with the larger models of kim2021scalable. Finally, fan2021beyond designs an architecture with explicit language-specific sublayers (rather than allowing arbitrary routing as done in lepikhin2020gshard) to yield gains of +1 BLEU.
Sparse expert models in other modalities. MoE and sparse experts model have also advanced results in modalities aside from language. riquelme2021scaling
designed a 15B parameter V-MoE to match state-of-the-art ImageNet(deng2009imagenet) models with fewer computational resources. lou2021sparsemlp
similarly showed a benefit over dense vision models by using MoE layers across both image patch and channel dimensions. Additionally, Automatic Speech Recognition has been improved by the SpeechMoE variants(you2021speechmoe; you2021speechmoe2). kumatani2021building reduced word error rates using MoE models in Sequence-to-Sequence Transformer and Transformer Transducer.
Improving deployment of sparse models. Initial expert designs (including this work) route each token separately to experts at that layer. One issue is that these type of architectures may be burdensome to serve since it requires sufficient memory for storing the parameters. Distillation was shown in fedus2021switch to be moderately effective, but recent approaches modified the routing to instead route full sentences or tasks (kudugunta2021beyond; zuo2021taming) which then permits extraction of sub-networks at time of serving (e.g. deploy only the network associated with the new task). As an alternative to distillation, kim2021scalable considers directly pruning away experts not essential to the task of interest.
Multitask learning with MoE. We conclude our tour of recent MoE research with successes in multitask settings. ma2018modeling recommended using a separate gating or router network for each task, an idea that may soon be revisited for Transformer architectures. Finally, gururangan2021demix recommends even greater modularity of language models and conditionally activates experts based on the domain/task label or by an inferred label.
While this work is on sparse models, these models intersect with many other interesting topics in machine learning such as adaptive computation, low-precision training, scaling principles, and neural network architecture advances. Our discussion therefore covers a broader range of topics surfaced during this research.
Unpredictable dynamics when pre-training on multilingual data.
We often observe that the same model pre-trained on multilingual data will yield smaller pre-training speed-ups and be more unstable. One hypothesis is that this is due to the variance of sequences per group across batches. As a reminder, we encourage tokensin a group to be load-balanced. There are usually only 2-8 sequences per group (higher becomes expensive) where each sequence is written in a single language. Therefore, at most 2-8 languages must be balanced across experts – even when training with over 100 languages. This leads to high variance across groups and batches, resulting in chaotic and unpredictable routing. In a follow-up experiment (just highlighted for brevity), we pre-trained on a mixture of English C4 plus a small fraction of a fine-tuning task which similarly resulted in an unstable model.
The robustness of sparse models.
Despite a paper focused on the details of sparse model-particulars, zooming out we find them to be robust to a wide set of hyperparameters and architectural changes. Sparse models obtain great performance under a variety of routing algorithms, dropping high fractions of tokens, and different hyperparameters. While we did point out the importance of tuning the batch size and learning rate for fine-tuning, our intuition, in-line with kaplan2020scaling, is that the real winner is scale. For instance, Table 8 shows larger gains to be had by simply increasing the capacity factor (i.e. FLOPs) rather than by more sophisticated routing (i.e. algorithms).
Sparse models are a subclass of adaptive computation models since each input gets different computation applied to it. In sparse models a token is routed to the expert(s) of its choosing. When capacity factors are less than one, the model learns to not apply computation to certain tokens. This has shown promise in computer vision(riquelme2021scaling) and our language experiments (Appendix D). We envision future models expanding this through heterogeneous experts (e.g. each expert applies differing computation). Intuitively, different input examples will likely require different amounts of processing depending on difficulty. Future models in this direction will be efficiently enabled through emerging computing infrastructures (dean2021pathways).
Generalizing findings from small to large scale.
A key issue we faced throughout our work was identifying small scale models and training setups that reflect larger scale experiments. This was evident in our stability studies in Section 3 where experiments had to be run with XL sized models to surface relevant dynamics. For our architecture and routing algorithm experiments, we often find improvements vanish, or even reverse, when models are trained for longer or made larger. As one example, the top- findings of fedus2021switch were reversed in our 8x larger-scale experiments presented here, which revealed small boosts of top- routing over top- routing (see Table 8).
Training models with even lower precision.
The best method we found to stabilize our models without hurting (and sometimes improving) quality was the router z-loss. This is an auxiliary loss that encourages the model logits to have values smaller in absolute magnitude. Given the max range of numbers float32 and bfloat16 can support (), this leads us to believe most of this range is not needed, and compressing it actually might improve model training dynamics. Therefore, future precision formats might take into account more compressed exponential ranges to train certain classes of models.
Designing new operations with more multiplicative interactions.
Section 3.1 shows that operations with more multiplicative interactions than additions, or those that don’t accumulate over many numbers, improve model performance. We test this further by injecting more multiplicative interactions into expert layers which speedup pre-training by 4% without any change to step-time (Appendix C). We think this hints at promising architectural improvements for models and could be a good design principle. Recently depthwise convolutions, which only accumulate 3-5 elements, have also been shown to greatly improve Transformer performance (so2021primer). These operations are especially exciting as elementwise multiplications typically do not introduce any communication overhead when using model parallelism (which makes operations like depthwise convolutions and our multiplicative interactions very efficient). While we did note these methods to increase model instabilities in Section 3.1, using the router z-loss in our models prevented any further instabilities.
Constrain activations to alleviate other undesirable model scaling dynamics.
We observed two additional sources of training instability. (1) Encoder-decoder models are more unstable than decoder only models (for fixed amount of FLOPs). Encoder-decoder models have a higher ratio of attention layers (e.g. more exponential functions) due to having both self-attention and enc-dec attention layers for each FFN on the decoder. (2) Deeper models are more unstable than shallower models for a fixed amount of FLOPs. Deeper models also introduce more exponential functions through additional attention layers. We hypothesize that a contributing factor to both of these observations is simply the increased number of exponential functions found in the network. Future work could look at resolving these training dynamics by adding z-loss penalties to the attention softmaxes for non-sparse models, especially since we observed adding them didn’t change model quality.
Dense and sparse models depend differently on hyperparameters.
Our fine-tuning analysis in Section 4.3 shows optimal fine-tuning hyperparameters differ significantly between dense and sparse models. In certain settings, fine-tuning hyperparamters that worked well for the dense model masked any improvements from the sparse model (despite large pre-training speedups). For new model classes, we recommend researchers and practitioners to extensively test key hyperparameters before prematurely abandoning a method.
We temper the over-exuberance for scale in fedus2021switch by showing how a model with 1/5th the size, but with a better balance of computation (FLOPs) to parameters – is a more effective sparse learner. Furthermore, this improves the usability of sparse models since it can be deployed with less memory overhead. Using our sparse model variant, we achieve SOTA across a wide range of the most competitive public benchmarks. We hope this work shows the power of model sparsity and accelerates the adoption of such models.
We would like to thank Alex Passos, Ekin Cubuk, Margaret Li, Noah Constant, Oriol Vinyals, Basil Mustafa, Joan Puigcerver, Diego de Las Casas, Mike Lewis, and Ryan Sepassi for detailed comments and feedback on early versions of the draft. We also thank the Google Brain Team for useful discussions throughout the course of this work.
Appendix A Token Load Balance Description
The auxiliary load balancing loss from shazeer2017outrageously is also used to here to balance tokens across experts. Assume we have experts indexed by to and a batch with tokens. The auxiliary loss is computed as the scaled dot-product between vectors f and P,
where is the fraction of tokens dispatched to expert ,
and is the fraction of the router probability allocated for expert , 222A potential source of confusion: is the probability of routing token to expert . is the probability fraction to expert across all tokens in the batch .
Since we seek uniform routing of the batch of tokens across the experts, we desire both vectors to have values of . The auxiliary loss of Equation 7 encourages uniform routing since it is minimized under a uniform distribution. The objective can also be differentiated as the -vector is differentiable, but the -vector is not. The final loss is multiplied by expert count to keep the loss constant as the number of experts varies since under uniform routing . Finally, a hyperparameter is a multiplicative coefficient for these auxiliary losses; throughout this work we use an which was sufficiently large to ensure load balancing while small enough to not to overwhelm the primary cross-entropy objective.
Appendix B Router Z-Loss Training Dynamics
Appendix C Improved Architectural Modifications
We consider a few small architecture variations here. The first modification was adding additional FFN layers (feed-forward network, see Table 1 for more details) immediately before or after each MoE layer (referred to as Sparse-Dense). Table 16 reveals the effectiveness of an FFN layer immediately preceding or following each sparse layer and that these extra FFN layers help less when added elsewhere in the network. Guaranteeing all tokens have at least one FFN applied to them between each attention layer appears useful.
|Model||Neg. Log Perp. ()|
|Dense model (baseline)||-1.474||-|
|Dense model w/ extra FFN layers||-1.452||0.022|
|Sparse model (baseline)||-1.383||-|
|Sparse model w/ extra FFN layer after each sparse layer||-1.369||0.014|
|Sparse model w/ extra FFN layer before each sparse layer||-1.369||0.014|
|Sparse model w/ extra FNN layers placed randomly in the network||-1.376||0.007|
Second, we introduce an additional bias in the expert layers. All our models use the GELU-Linear FFN [shazeer2020glu], rather than the ReLU FFN:
The additive bias is a learned weight () added after the first matrix multiplication in the FFN layer of shape . The multiplicative bias (also referred to as a scale parameter) is a learned weight of the same shape, but does an elementwise multiplication. We initialize the additive bias to zeros and the multiplicative bias to ones.
Table 17 shows the results of our different methods. Both the additive and multiplicative biases are essentially free: cheap to compute, adds few new parameters, and incurs no additional communication costs with model and expert parallelism. When using our router z-loss from Section 3.1, we observe no instabilities from the multiplicative bias. We do see that the multiplicative interactions improve performance, achieving a 4% speedup in convergence time over our strong sparse baseline. This hints that a promising avenue for future architectural research is finding new ways of adding more multiplicative interactions into networks.
|Model||Neg. Log. Perp. ()|
|Sparse Additive Bias||-1.371||-0.002|
|Sparse Multiplicative Bias||-1.361||0.008|
Finally, motivated by the work of roller2021hash, we explored similar methods, but did not find improvements in our setting. We tried routing using the word embedding exclusively, as well as an additional input to the layer embedding for routing decisions. We toggled stopping the gradient through the word embedding or allowing it to have gradients propagated from the router. Using only the word embedding hurt quality, while using it in addition to the normal layer hidden activation was initially positive, but after pre-training for 50B+ tokens on models of scale 1B+ dense parameters it had a neutral effect. Appendix J has further details on the experiments with negative results.
Appendix D Batch Prioritized Routing for Lower Capacity Factors
Surprisingly, top-1 and top-2 routing work well with CF less than 1.0 despite token routing being done in a left to right order over the sequence. If tokens are sent to an expert with only spaces then tokens will dropped. The ordering of the dropping is important: we drop tokens going left to right (e.g. tokens earlier in the sentence will be routed first over the end tokens). This is done to avoid the model cheating. If we dropped tokens in another ordering, the model gets information on what tokens are occurring later in the sequence based on if tokens are being dropped or not.
Batch Prioritized Routing (BPR) from riquelme2021scaling was introduced in Vision Transformers [dosovitskiy2020image] for image classification. Our work explores BPR with top-1 routing in the context of language modeling. BPR aims to have a global view of all tokens to determine which tokens should be dropped instead of the left-to-right ordering. The algorithm works by looking at all N tokens getting sent to Expert and then only routing the ones with the highest probabilities from the router. Table 18 shows that BPR top-1 routing improves performance over top-2 routing, especially when capacity factors are less than 1.0. We leave it to future work to try top- BPR routing, which will hopefully yield larger improvments for higher capacity factors.
Importantly, BPR routing can only be done on the encoder side of the encoder-decoder model. On the encoder side there are not autoregressive predictions and all tokens can see each other. If you use BPR on the decoder, it learns to cheat by using future token information to improve current token predictions.
|Algorithm||Train CF||Eval CF||Neg. Log. Perp. ()|
Appendix E Pre-Training Dataset Details
The pre-training dataset used to train our Sparse 32B model is a mix of C4 [raffel2019exploring] and the dataset introduced in GLaM [du2021glam].
|Dataset||Tokens (B)||Weight in Mixture|
Appendix F Full Fine-tuning Sensitivity Data
Table 20 contains the raw data for Figure 6 measuring the fine-tuning protocol sensitivity. Dense and Sparse are encoder-decoder models FLOP matched to T5-Large that were pre-trained for 500k steps with a batch size of 1M tokens on the C4 corpus.
|Model||Learning Rate||Batch Size||Reset Optimizer Slot Vars||SuperGLUE ()|
Appendix G Optimally Setting the Routing Threshold
[enhanced,attach boxed title to top center=yshift=-3mm,yshifttext=-1mm, title=Top- Routing Algorithm, colback=white, colframe=white!75!blue, coltitle=black, colbacktitle=white]
Route each token to the expert with the highest router probability ().
Normalize the top- expert router scores for each token , so .
Route the token to the other -1 experts (indexed by ) with probability . Threshold is a predefined hyperparameter that is typically set to 0.2.
We describe the MoE hyperparameters and how they should change as the routing algorithm changes. The MoE top-2 routing algorithm [shazeer2017outrageously, shazeer2018mesh, lepikhin2020gshard] works as follows: first the router finds the expert that is assigned the higher router score () and always sends the token to that expert. The token is also sent to its second highest expert with probability . The threshold is a hyperparameter that is typically set to 0.2, and is the token’s router probability for the second highest expert. Note that and get normalized by the sum of their two scores, so they sum to one.
We trivially extend the top-2 algorithm to work for top- routing here. Take the scores of the top- experts per token and sum them, then renormalize each expert router score based on that sum. If the specific renormalized expert score has a higher value than the threshold (e.g. 0.2), then the token will be routed, otherwise it will be routed with probability . At a high level this only routes the token to the next -1 experts if their scores are not too much lower than the highest scored expert.
For top-3 routing vs top-2, the sum that the expert scores are normalized by is larger, therefore we experimented with decreasing the threshold. Our experimental results are shown in Table 21. Interestingly, we do observe the top-3 routing to slightly benefit from the lower threshold, while the opposite is true for top-2 routing.
We also experimented with an absolute threshold policy instead of a relative one. This is where the next -1 tokens will be routed only if their router score is great than some pre-defined value (e.g. 0.2). We found it can achieve as good of performance if the threshold value is tuned.
|Algorithm||Train CF||Threshold||Neg. Log. Perp. ()|
Appendix H Mesh Layout for Data, Model and Expert Parallelism with Few Experts
We use data and model parallelism partitioning with Mesh-Tensorflow [shazeer2018mesh]. The partitioning strategy works by first forming a logical 2D mesh of size x , with the rows corresponding to the data dimension () and the columns as the model dimension () and the product equal to the total number of cores, = x . This mesh is only an abstraction. Each logical core must be mapped to a physical core, which is optimized through performance tuning.
As a refresher, each row in the mesh will have its own unique slice of the data and each column will have a unique slice of the model weights. The final gradient allreduce communication occurs across each individual column. The model parallelism allreduce communications occur across each row in the mesh. One constraint from this approach is that the number of rows must evenly divide the number of data sequences and the number of columns must evenly divide the model dimensions being partitioned.
But if we have fewer than experts then this layout will not work. To allow for fewer experts than data parallelism rows in our mesh, we factorize the data dimension into two new dimensions: inner () and outer () where x = and the number of experts equals . This transforms the logical 2D mesh of shape x into a 3D mesh of shape x x . See Figure 8 for a visualization of both meshes 121212See Mesh Tensorflow for more details on the inner and outer batch: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py.
Appendix I Note on Communication Costs for Distributed Models
Communication operations (allreduce and all2all) can significantly impact sparse model training throughput (see Table 1 for a description of the communication operations). allreduce calls are executed along model and batch dimensions, typically dominated by the model dimension allreduce calls that sum results of partial matrix multiplication operations from the workers. These calls are needed when matrix multiplications are partitioned across multiple cores (e.g. model parallelism). The gradient summation allreduce calls can be amortized away by training models with larger batch sizes since the gradient accumulation allreduce communication cost is independent of the batch size. To alleviate the memory issues of larger batch sizes, microbatches can be used. Microbatches do this by splitting the batch into evenly divisible chunks and computing gradients on each sequentially, then summing.
To increase the allreduce throughput, more workers may need to be assigned to the model dimension (instead of batch dimension). However, increasing the number of workers may reduce compute per worker resulting in higher communication overheads that cancel some of the gains from higher communication throughput from allreduce. For the results in this paper, first we explored various model partitioning strategies. Next the shapes of the pre-training jobs were allocated based on performance benchmarking which showed the lowest cumulative communication overheads in allreduce and all2all.
Appendix J Negative Results
We conclude with some ideas that yielded negative results in our setting.
Adding information if tokens were dropped to the router.
We experimented with having the expert layer have information of whether the token was routed or dropped in the previous expert layers. We implemented this through counting the number of times a token was routed in all previous expert layers, having embeddings for each possible value and then adding this to the router embedding. We found that this made no difference in performance.
Adding explicit expert positional information.
We experimented with adding explicit positional information into the outputs of the expert layer. We wanted to see if it either improved performance or sped up convergence during the beginning of training when expert layers were drastically changing. We did this through adding an embedding corresponding to what expert each token was sent (including an embedding if the token was dropped), but this did not improve performance.
Adding pre-training noise to fix pre-training and fine-tuning discrepancies.
To help fix the pre-training perplexity and fine-tuning gap we tried pre-training the sparse models with a variety of different types of noise. The goal was to help pre-training match the fine-tuning conditions where dropout is used and more tokens can be dropped. Some of the noise types we tried adding during pre-training were dropout, dropping out full experts for a batch of tokens, and adding an entropy maximization auxiliary loss to the router. Unfortunately, all of the methods either hurt the pre-training quality too much or didn’t end up helping the fine-tuning.
Load balancing in top-n routing over lower n-1 experts.
In the standard top- MoE formalization there is only loading balancing over the top expert a token is sent to. We experimented with adding an auxiliary load balancing term to the other experts in top- routing, but found this to provide minimal benefits.
Mixing pre-training and fine-tuning data to prevent overfitting.
To help combat the overfitting of sparse models during fine-tuning, we tried mixing in pre-training span corruption data at varying amounts (e.g. 1%, 5%, 25%, …) during fine-tuning. This ended up not helping the fine-tuning performance, but did increase the training loss.