mixtureofexperts
A Pytorch implementation of SparselyGated Mixture of Experts, for massively increasing the parameter count of language models
view repo
Neural network scaling has been critical for improving the model quality in many realworld machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a surefire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. GShard is a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler. It provides an elegant way to express a wide range of parallel computation patterns with minimal changes to the existing model code. GShard enabled us to scale up multilingual neural machine translation Transformer model with SparselyGated MixtureofExperts beyond 600 billion parameters using automatic sharding. We demonstrate that such a giant model can efficiently be trained on 2048 TPU v3 accelerators in 4 days to achieve far superior quality for translation from 100 languages to English compared to the prior art.
READ FULL TEXT VIEW PDFA Pytorch implementation of SparselyGated Mixture of Experts, for massively increasing the parameter count of language models
Scaling neural networks brings dramatic quality gains over a wide array of machine learning problems [arora2018optimization, frankle2018lottery, kaplan2020scaling, devlin2018bert, mahajan2018exploring, gpt32020]
. For computer vision, increasing the model capacity has led to better image classification and detection accuracy for various computer vision architectures
[he2016deep, he2016identity, ghiasi2019fpn]. Similarly in natural language processing, scaling Transformers
[vaswani2017attention] yielded consistent gains on language understanding tasks [devlin2018bert, raffel2019exploring, brown2020language], crosslingual downstream transfer [devlin2018bert, conneau2019unsupervised] and (massively)multilingual neural machine translation [arivazhagan2019massively, gpipe19, shazeer2017outrageously]. This general tendency motivated recent studies to scrutinize the factors playing a critical role in the success of scaling [advani2017highdimensional, hestness2017deep, Hestness_2019, Geiger_2020, kaplan2020scaling], including the amounts of training data, the model size, and the computation being utilized as found by past studies. While the final model quality was found to have a powerlaw relationship with the amount of data, compute and model size [hestness2017deep, kaplan2020scaling], the significant quality gains brought by larger models also come with various practical challenges. Training efficiency among the most important ones, which we define as the amount of compute and training time being used to achieve a superior model quality against the best system existed, is oftentimes left out.Here we enumerate major practical challenges faced especially when training massivescale models that are orders of magnitude larger than the capacity limit of a single accelerator memory (e.g., GPUs or TPUs).
There is a lack of support for efficient model parallelism algorithms under commonly used deep learning frameworks such as TensorFlow
[abadi2016tensorflow]and PyTorch
[pytorch2017]. Naive model parallelism with graph partition is supported but it would lead to severe underutilization due to the sequential dependency of the network and gradient based optimization. In order to scale up the existing models efficiently, users typically need to invest a lot of engineering work, for example, migrating the model code to special frameworks [shazeer2018mesh, gpipe19].Straightforward scaling of the mode size by increasing the depth or width [gpt32020, gpipe19] generally results in at least linear increase of training step time. Model parallelism by splitting layer weights and computation across multiple devices generally becomes necessary, leading to network communication overhead and device underutilization. Device underutilization stems from imbalanced assignment and sequential dependencies of the underlying neural network. This superlinear relationship between the computation cost and the model size can not be resolved by simply using more devices, making training massive models impractical.
A naive graph representation for the massivescale model distributed across thousands of devices may become a bottleneck for both deep learning frameworks and their optimizing compilers. For example, adding times more layers with interop partitioning or increasing model dimensions with intraop partitioning across devices may result in a graph with nodes. Communication channels between devices could further increase the graph size by up to (e.g., partitioning gather or transpose). Such increase in the graph size would result in an infeasible amount of graph building and compilation time for massivescale models.
Partitioning a model to run on many devices efficiently is challenging, as it requires coordinating communications across devices. For graphlevel partitioning, sophisticated algorithms [gpipe19, harlap2018pipedream] are needed to reduce the overhead introduced by the sequential dependencies between different partitions of graphs allocated on different devices. For operatorlevel parallelism, there are different communication patterns for different partitioned operators, depending on the semantics, e.g., whether it needs to accumulate partial results, or to rearrange data shards. According to our experience, manually handling these issues in the model requires substantial amount of effort, given the fact that the frameworks like TensorFlow have a large sets of operators with adhoc semantics. In all cases, implementing model partitioning would particularly be a burden for practitioners, as changing model architecture would require changing the underlying device communications, causing a ripple effect.
In this paper, we demonstrate how to overcome these challenges by building a billion parameters sequencetosequence Transformer model with SparselyGated MixtureofExperts layers, which enjoys sublinear computation cost and compilation time. We trained this model with TPU v3 devices for days on a multilingual machine translation task and achieved far superior translation quality compared to prior art when translating languages to English with a single nonensemble model. We conducted experiments with various model sizes and found that the translation quality increases as the model gets bigger, yet the total walltime to train only increases sublinearly with respect to the model size, as illustrated in Figure 1. To build such an extremely large model, we made the following key design choices.
First, model architecture should be designed to keep the computation and communication requirements sublinear in the model capacity. Conditional computation [bengio2015conditional, shazeer2017outrageously, Elbayad2020DepthAdaptiveT, bapna2020controlling] enables us to satisfy training and inference efficiency by having a subnetwork activated on the perinput basis. Scaling capacity of RNNbased machine translation and language models by adding Positionwise Sparsely Gated MixtureofExperts (MoE) layers [shazeer2017outrageously] allowed to achieve stateoftheart results with sublinear computation cost. We therefore present our approach to extend Transformer architecture with MoE layers in Section 2.
Second, the model description should be separated from the partitioning implementation and optimization. This separation of concerns let model developers focus on the network architecture and flexibly change the partitioning strategy, while the underlying system applies semanticpreserving transformations and implements efficient parallel execution. To this end we propose a module, GShard, which only requires the user to annotate a few critical tensors in the model with partitioning policies. It consists of a set of simple APIs for annotations, and a compiler extension in XLA
[xla]for automatic parallelization. Model developers write models as if there is a single device with huge memory and computation capacity, and the compiler automatically partitions the computation for the target based on the annotations and their own heuristics. We provide more annotation examples in Section
3.2.Third, the system infrastructure, including the computation representation and compilation, must scale with thousands of devices for parallel execution. For example, Figure 2 illustrates two different ways of partitioning a dotproduct operation across 4 devices (colorcoded). Notice that with the usual MPMD (Multiple Program Multiple Data) approach in Figure (a)a scaling becomes more challenging since the number of nodes in the graph increases linearly with the number of devices. Instead, we developed a compiler technique for SPMD (Single Program Multiple Data) transformation that generates a single program to run on all devices, keeping the compilation time constant independent of the number of devices, as illustrated in Figure (b)b. We will discuss our SPMD framework in more details in Section 3.3.
The rest of the paper is organized as the following. Section 2 describes our Transformer architecture with SparselyGated MoE layer in more details. Section 3 introduces our development module GShard. Section 4 demonstrates the application of our mixture of expert models on the multilingual machine translation task over language pairs. Section 5 has performance and memory measurements of our implementation. Section 6 discusses related work.
The Transformer [vaswani2017attention] architecture has been widely used for natural language processing. It has become the defacto standard for many sequencetosequence tasks, such as machine translation. Transformer makes use of two computational blocks, an encoder and a decoder, both implemented by stacking multiple Transformer layers. Transformer encoder layer consists of two consecutive layers, namely a selfattention layer followed by a positionwise feedforward layer. Decoder adds third crossattention layer, which attends over encoder output. We sparsely scale Transformer with conditional computation by replacing every other feedforward layer with a Positionwise Mixture of Experts (MoE) layer [shazeer2017outrageously] with a variant of top2 gating in both the encoder and the decoder (Figure 3). We vary the number of Transformer layers and the number of experts per MoE layer in order to scale the model capacity.
Each training example consists of a pair of sequences of subword tokens. Each token activates a subnetwork of the MoE Transformer during both training and inference. The size of the subnetwork is roughly independent of the number of experts per MoE Layer, allowing sublinear scaling of the computation cost as described in the previous section. Computation complexity is further analyzed in Section 3.1 and training performance in Section 5.
The MixtureofExperts (MoE) layer used in our model is based on [shazeer2017outrageously] with variations in the sparse gating function and the auxiliary loss being used. A MoE layer for Transformer consists of feedforward networks :
(1)  
(2)  
(3) 
where
is the input token to the MoE layer, wiand wobeing the input and output projection matrices for the feedforward layer (an expert). Vector
is computed by a gating network. has one nonnegative for each expert, most of which are zeros meaning the token is not dispatched to that expert. The token is dispatched to a very small number of experts. We choose to let each token dispatched to at most two experts. The corresponding entries in are nonzeros, representing how much an expert contributes to the final network output. Every expert applies to a fullyconnected 2layer network using ReLU [Nair2010RectifiedLU]activation function. The output of the MoE layer, , is the weighted average of outputs from all the selected experts.The gating function is critical to the MoE layer, which is modeled by a softmax activation function to indicate the weights of each expert in processing incoming tokens. In other words, to indicate how good an expert is at processing the incoming token. Furthermore, the gating function must satisfy two goals:
Balanced load It is desirable that the MoE layer to sparsely activate the experts for a given token. A naive solution would be just to choose the top
experts according to the softmax probability distribution. However, it is known that this approach leads to load imbalance problem for training
[shazeer2017outrageously]: most tokens seen during training would have been dispatched to a small number of experts, amassing a very large input buffer for only a few (busy) experts leaving other experts untrained, slowing down the training. Meanwhile many other experts do not get sufficiently trained at all. A better design of the gating function would distribute processing burden more evenly across all experts.Efficiency at scale It would be rather trivial to achieve a balanced load if the gating function is done sequentially. The computation cost for the gating function alone is at least ) for all tokens in the input batch given experts. However, in our study, is in the order of millions and is in the order of thousands, a sequential implementation of the gating function would keep most of the computational resources idle most of the time. Therefore, we need an efficient parallel implementation of the gating function to leverage many devices.
We designed the following mechanisms in the gating function to meet the above requirements (details illustrated in Algorithm 1):
Expert capacity To ensure the load is balanced, we enforce that the number of tokens processed by one expert is below some uniform threshold, which we define as expert capacity. Assuming that the total number of tokens in a training batch is , and each token is dispatched to at most two experts, then the expert capacity is set to be . keeps a running counter for how many tokens are dispatched to an expert. When both experts selected by a token already exceed their capacity, the token is considered as an overflowed token, where degenerates into a zero vector. Such tokens have their representation passed on to the next layer via residual connections.
Local group dispatching partitions all tokens in a training batch evenly into groups, i.e., each group contains tokens. All groups are processed independently in parallel. Each group is given a fractional capacity of each expert, . Each group ensures that at most this many tokens are dispatched to an expert. In this way, we can ensure that expert capacity is still enforced and the overall load is balanced.
Auxiliary loss It is important that the gating function does not always choose the same few experts, as this would lead to a capacity overflow for only a few experts and underutilization for the remaining ones. Following [shazeer2017outrageously], we define an auxiliary loss term
to enforce this constraint. It is added to the overall loss function of the model
with a constant multiplier . The particular form of the auxiliary loss term in line (13) of algorithm 1 is motivated by the following consideration: the term represents the fraction of input routed to each expert, and we want to minimize mean square of . But because is derived from top2 operation and is not differentiable, we use the mean gates per expert as a differentiable approximation and replace with , which can now be optimized with gradient descent.Random routing Intuitively, because is a weighted average of what selected experts return, if the weight for the 2nd expert is very small, we can simply ignore the 2nd expert to conserve the overall expert capacity. Hence, in addition to respecting the expert capacity constraint, dispatches to the 2ndbest expert with the probability proportional to its weight .
This section describes the implementation of the model in Section 2 that runs efficiently on a cluster of TPU devices.
The first step is to express the model in terms of linear algebra operations, in which our software stack (TensorFlow [abadi2016tensorflow]) and the hardware platform (TPU) are highly tailored and optimized. It is readily easy to code up most of the model in terms of linear algebra in the same way as the original Transformer. However, it requires some effort to express the MoE Layer, in particular function presented in Algorithm 1 due to its sequential nature, and we describe the details in Section 3.1.
Next, we annotate the linear algebra computation to express parallelism. Each tensor in the computation can be annotated for replication or distribution across a cluster of devices using sharding APIs in Section 3.2. Using sharding annotations enables separation of concerns between the model description and the efficient parallel implementation, and allows users to flexibly express diverse parallelization strategies. For example, (1) the attention layer is parallelized by splitting along the batch dimension and replicating its weights to all devices. On the other hand, (2) experts in the MoE layer are infeasible to be replicated in all the devices due to its sheer size and the only viable strategy is to shard experts into many devices. Furthermore, the whole model alternates between these two modes (1)(2). Using annotations frees model developers from the system optimization efforts and avoids baking the parallel implementation and lowlevel details into the model code.
Finally, the compiler infrastructure takes a (partially) annotated linear algebra computation and produces an efficient parallel program that scales to thousands of devices. As will be described in Section 3.3, the compiler applies SPMD (Single Program Multiple Data) partitioning transformation to express perdevice computation, inserts necessary crossdevice communication, handles irregular patterns such as uneven partitions, and finally generates a single program to be launched on all devices for parallel execution.
Our model implementation (Algorithm 2) views the whole accelerator cluster as a single device and expresses its core mathematical algorithm in a few tensor operations independent of the concrete setup of the cluster. Einstein summation notation [einstein1923grundlage] (i.e., tf.einsum) is a powerful construct to concisely express the model and we use it extensively in our implementation. The softmax gates computation is trivially expressed by one einsum followed by the softmax function. Dispatching of inputs to selected experts is expressed by a single einsum between the dispatching mask and the input. All weights are combined into single 3D tensors wi amd wo and the computation by is expressed using 3 operators (two einsum and one relu). Finally, taking weighted average of all experts output into the final output is expressed in another einsum.
Top2Gating in Algorithm 2 computes the union of all grouplocal described in Algorithm 1. combine_weights is a 4D tensor with shape [G, S, E, C]. The value combine_weights[g, s, e, c] is nonzero when the input token in group is sent to the input buffer of expert at buffer position . For a specific g and s, a slice combine_weight[g, s, :, :] contains at most two nonzero vaules. Binary dispatch_mask is produced from combine_weights by simply setting all nonzero values to 1.
We need to choose the number of groups and the number of experts properly so that the algorithm can scale to a cluster with devices. It is worthwhile to analyze its overall computation complexity (the total number of floating point operations) for a training step given a training batch of tokens.
We analyze Algorithm 2 computation complexity scaling with number the of devices with the following assumptions: a) number of tokens per device is constant^{1}^{1}1This is oftentimes necessary in practice to avoid overflowing device memory.; b) , and ; c) , ; d) ; and e) and is a positive integer^{2}^{2}2Scaling would require different use of fractional expert capacity. .
The total number of floating point operations in Algorithm 2:
and consequently perdevice . Perdevice softmax complexity is linear in number of devices, but in practice is dominated by other terms since and . As a result could be considered , satisfying sublinear scaling design requirements. Section 5 verifies this analysis empirically.
In addition to the computation cost, we have nonconstant crossdevice communication cost, but it grows at a modest rate when we increase (Section 5).
Due to the daunting size and computation demand of tensors in Algorithm 1, we have to parallelize the algorithm over many devices. An immediate solution of how to shard each tensor in the algorithm is illustrated by underscored letters in Algorithm 2. The sharding API in GShard allows us to annotate tensors in the program to selectively specify how they should be partitioned. This information is propagated to the compiler so that the compiler can automatically apply transformations for parallel execution. We use the following APIs in TensorFlow/Lingvo [shen2019lingvo] in our work.
replicate(tensor) annotates tensor to be replicated across partitions, and returns the annotated tensor. This is often used for the nonMoE layers in our model to replicate the weights.
split(tensor, split_dimension, num_partitions) annotates tensor to be partitioned along split_dimension, and returns the annotated tensor. Partition is placed on the ’th device, and num_partitions must not exceed the number of devices on the system.
shard(tensor, device_assignment) generalizes split() to allow partitioning multiple dimensions and specifying the placement of each partition. Appendix A.3 describes this API with more details.
Note that the invocations to split or shard only adds annotations and does not change the logical shape in the user program. The user still works with full shapes and does not need to worry about issues like uneven partitioning.
GShard is general in the sense that the simple APIs apply to all dimensions in the same way. The sharded dimensions could include batch (dataparallelism), feature, expert, and even spatial dimensions in image models, depending on the use cases. Also, since the sharding annotation is per tensor, different parts of the model can be partitioned in different ways. This flexibility enables us to partition the giant MoE weights and switch partition modes between MoE and nonMoE layers, as well as uses cases beyond this paper, e.g., spatial partitioning of large images [spatialpartitioning] (Appendix A.4).
With the above sharding APIs, we can express the sharding strategy shown in Algorithm 2 as below. The input tensor is split along the first dimension and the gating weight tensor is replicated. After computing the dispatched expert inputs, we apply split to change the sharding from the group () dimension to the expert () dimension. is device count.
As shown in the example above, users are not required to annotate every tensor in the program. Annotations are typically only required on a few important operators like Einsums in our model and the compiler uses its own heuristics to infer sharding for the rest of the tensors ^{3}^{3}3
It is also important for the compiler to infer missing shardings since the backpropagation computation is often automatically generated by the frontend framework and users don’t have access to those tensors.
. For example, since the input tensor is partitioned along and the weight tensor is replicated, the compiler chooses to partition the einsum output along the same dimension (Line 5). Similarly, since both inputs are partitioned along the dimension for the input dispatch einsum (Line 7), the output sharding is inferred to be split along the dimension, and then we add the split annotation on the output to reshard along the dimension. Some annotations in the above example could also be determined by the compiler (e.g., replicate(wg)) but it is recommended to annotate the initial input and final output tensors of the computation.The compiler currently uses an iterative dataflow analysis to propagate sharding information from an operator to its neighbors (operands and users), starting from the userannotated operators. The analysis tries to minimize the chance of resharding by aligning the sharding decisions of adjacent operators. There could be other approaches such as integer programming or machinelearning methods, but improving the automatic sharding assignment is not the focus of this paper and we leave it as future work.
Automatic partitioning with sharding annotations is often enough for common cases, but GShard also has the flexibility to allow mixing manually partitioned operators with autopartitioned operators. This provides users with more controls on how operators are partitioned, and one example is that the user has more runtime knowledge beyond the operators’ semantics. For example, neither XLA’s nor TensorFlow’s Gather operator definition conveys information about the index bounds for different ranges in the input, but the user might know that a specific Gather operator shuffles data only within each partition. In this case, the user can trivially partition the operator by simply shrinking the dimension size and performing a local Gather; otherwise, the compiler would need to be conservative about the index range and add unnecessary communication overhead. For example, the dispatching Einsum (Line 3) in Algorithm 2 in Algorithm 2, which uses an onehot matrix to dispatch inputs, can be alternatively implemented with a Gather operator using trivial manual partitioning, while the rest of the model is partitioned automatically. Below is the pseudocode illustrating this use case.
This section describes the compiler infrastructure that automatically partitions a computation graph based on sharding annotations. Sharding annotations inform the compiler about how each tensor should be distributed across devices. The SPMD (Single Program Multiple Data) partitioner (or “partitioner” for simplicity) is a compiler component that transforms a computation graph into a single program to be executed on all devices in parallel. This makes the compilation time near constant regardless of the number of partitions, which allows us to scale to thousands of partitions. ^{4}^{4}4An alternative is MPMD (Multiple Program Multiple Data), which does not scale as shown in Figure 2.
We implemented the partitioner in the XLA compiler [xla]. Multiple frontend frameworks including TensorFlow, JAX, PyTorch and Julia already have lowering logic to transform their graph representation to XLA HLO graph. XLA also has a much smaller set of operators compared to popular frontend frameworks like TensorFlow, which reduces the burden of implementing a partitioner without harming generality, because the existing lowering from frontends performs the heavylifting to make it expressive. Although we developed the infrastructure in XLA, the techniques we describe here can be applied to intermediate representations in other machine learning frameworks (e.g., ONNX [onnx], TVM Relay [roesch2018relay], Glow IR [rotem2018glow]).
XLA models a computation as a dataflow graph where nodes are operators and edges are tensors flowing between operators. The core of the partitioner is peroperation handling that transforms a fullsized operator into a partitionsized operator according to the sharding specified on the input and output. When a computation is partitioned, various patterns of crossdevice data transfers are introduced. In order to maximize the performance at large scale, it is essential to define a core set of communication primitives and optimize those for the target platform.
Since the partitioner forces all the devices to run the same program, the communication patterns are also regular and XLA defines a set of collective operators that perform MPIstyle communications [mpi2.2]. We list the common communication primitives we use in the SPMD partitioner below.
This operator specifies a list of sourcedestination pairs, and the input data of a source is sent to the corresponding destination. It is used in two places: changing a sharded tensor’s device order among partitions, and halo exchange as discussed later in this section.
This operator concatenates tensors from all participants following a specified order. It is used to change a sharded tensor to a replicated tensor.
This operator performs elementwise reduction (e.g., summation) over the inputs from all participants. It is used to combine partially reduced intermediate tensors from different partitions. In a TPU device network, AllReduce has a constant cost when the number of partition grows (Section 5.2). It is also a commonly used primitive with efficient implementation in other types of network topology [cho2019blueconnect].
This operator logically splits the input of each participant along one dimension, then sends each piece to a different participant. On receiving data pieces from others, each participant concatenates the pieces to produce its result. It is used to reshard a sharded tensor from one dimension to another dimension. AllToAll is an efficient way for such resharding in a TPU device network, where its cost increases sublinearly when the number of partitions grows (Section 5.2).
The core of the partitioner is the peroperator transformation from a fullsized operator into a partitionsized operator according to the specified sharding. While some operators (e.g., elementwise) are trivial to support, we discuss several common cases where crosspartition communications are required.
There are a few important technical challenges in general cases, which we will cover in Section 3.3.3. To keep the discussion more relevant to the MoE model, this section focuses on Einsum partitioning to illustrate a few communication patterns. And to keep it simple for now, we assume that all tensors are evenly partitioned, which means the size of the dimension to partitition is a multiple of the partition count.
Einsum is the most critical operator in implementing the MoE model. They are represented as a Dot operation in XLA HLO, where each operand (LHS or RHS) consists of three types of dimensions:
Batch dimensions are the embarrassingly parallel dimensions. The same set of batch dimensions must exist in all of LHS, RHS and the output, and each element in the output only depends on the corresponding batch in LHS and RHS.
Contracting dimensions only exist in the operands. LHS and RHS must have the same set of contracting dimensions, and they are summed up and collapsed in the output.
Noncontracting dimensions are also parallel dimensions that exist in one of the operands and the output. Each of LHS and RHS has its own set of noncontracting dimensions, which are inherited by the output.
Sharding propagation prioritizes choosing the same sharding on batch dimensions of LHS, RHS and output, because that would avoid any crosspartition communication. However, that is not always possible, and we need crosspartition communication in the following three cases.
Resharding. In the MoE model we built, the expert dispatching logic (Line 3 in Algorithm 2) requires switching the partitioned dimension after an Einsum. Since resharding is efficient (Section 5.2) with AllToAll, we first execute the Einsum locally, then reshard it to the desired dimension, as shown in Figure (a)a.
Accumulating partial results. If the inputs are partitioned along contracting dimensions, the local result is partial and we need to use an AllReduce to combine them and produce the final result, as shown in Figure (b)b.
Slicing in a loop. For certain scenarios, we also implemented an algorithm similar to Cannon’s algorithm [cannon1969], in order to limit the size of tensors on each partition. For example, if both operands are partitioned on a noncontracting dimension, we cannot compute the local Einsum directly since operands have different noncontracting dimensions. Replicating one of the operands would not cause redundant computation, but it requires the replicated operand to fit in device memory. Therefore, if the size of the operand is too large, we instead keep both operands partitioned and use a loop to iterate over each slice of the result, and use CollectivePermute to communicate the input slices (Figure (c)c).
We solved several additional challenges to enable the SPMD partitioner to support a complete set of operators without extra constraints of tensor shapes or operator configurations. These challenges often involve asymmetric compute or communication patterns between partitions, which are particularly hard to express in SPMD, since the single program needs to be general enough for all partitions. We cannot simply create many branches in the single program based on the runtime device ID, because that would lead to an explosion in program size.
XLA requires tensor shapes to be static. ^{5}^{5}5The limited dynamism in the intermediate representation is often necessary to efficiently target accelerators.
However, when a computation is partitioned, it’s not always the case that all partitions have the same input/output shapes, because dimensions may not be evenly divisible by the number of partitions. In those cases, the size of the shape is rounded up to the next multiple of partition count, and the data in that padded region can be arbitrary.
When computing an operator, we may need to fill in a known value to the padded region for correctness. For example, if we need to partition an ReduceAdd operator, the identity value of zero needs to be used. Consider an example where the partitioned dimension (15) cannot be divided into 2 (partition count), so Partition 1 has one more column than needed. We create an Iota operator of range [0, 8), add the partition offset (calculated from ), and compare with the full shape offset (15). Based on the predicate value, we select either from the operand or from zero, and the result is the masked operand.
XLA operators have static configurations, like the padding, stride, and dilation defined in
Convolution. However, different partitions may not execute with the same operator configuration. E.g., for a Convolution, the leftmost partition applies padding to its left while the rightmost partition applies padding to its right. In such cases, the partitioner may choose configurations that make some partitions to produce slightly more data than needed, then slice out the the irrelevant parts. Appendix A.4 discusses examples for Convolution and similar operators.Certain operators have a communication pattern which involves partial data exchange with neighboring partitions, which we call halo exchange. We use the CollectivePermute operator to exchange halo data between partitions.
The most typical use case of halo exchange is for partitinoning windowbased operators (e.g., Convolution, ReduceWindow), because neighboring partitions may require overlapping input data (Figure (a)a). In practice, haloexchange for these operator often needs to be coupled with proper padding, slicing, and masking due to advanced use of window configurations (dilation, stride, and padding), as well as uneven halo sizes. We describe various scenarios in Appendix A.4.
Another use of halo exchange is for data formatting operators that change the size of the shape. For example, after a Slice or Pad operator, the shape of the tensor changes, and so do the boundaries between partitions. This requires us to realign the data on different partitions, which can be handled as a form of halo exchange (Figure (b)b).
Other data formatting operators, although logically not changing the size of the shape, may also need halo exchange, specifically due to the static shape constraint and uneven partitioning. For example, the Reverse operator reverses the order of elements in a tensor, but if it is partitioned unevenly, we need to shift data across partitions to keep the padding logically to the right of the result tensor. Another example is Reshape. Consider reshaping a tensor from [3, 2] to [6], where the input is unevenly partitioned in 2 ways on the first dimension (partition shape [2, 2]), and the output is also partitioned in 2 ways (partition shape [3]). There is padding on the input due to uneven partitioning, but after Reshape, the output tensor no longer has padding; as a result, halo exchange is required in a similar way to Slice (Figure (c)c).
The SPMD partitioner creates various data formatting operators in order to perform slicing, padding, concatenation, masking and halo exchange. To address the issue, we leverage XLA’s fusion capabilities on TPU, as well as code motion optimizations for slicing and padding, to largely hide the overhead of data formatting. As a result, the runtime overhead is typically negligible, even for convolutional networks where masking and padding are heavily used.
We chose multilingual neural machine translation (MT) [Firat_2016, Johnson_2017, DBLP:journals/corr/abs190300089] to validate our design for efficient training with GShard. Multilingual MT, which is an inherently multitask learning problem, aims at building a single neural network for the goal of translating multiple language pairs simultaneously. This extends our line of work [gpipe19, arivazhagan2019massively, shazeer2017outrageously] towards a universal machine translation model [translate2019m4], i.e. a single model that can translate between more than hundred languages, in all domains. Such massively multilingual translation models are not only convenient for stress testing models at scale, but also shown to be practically impactful in realworld production systems [translate2020quality].
In massively multilingual MT, there are two criteria that define success in terms of the model quality, 1) improvements attained on languages that have large amounts of training data (high resourced), and 2) improvements for languages with limited data (lowresource). As the number of language pairs (tasks) to be modeled within a single translation model increases, positive language transfer [baldwin1988transfer] starts to deliver large gains for lowresource languages. Given the number of languages considered, M4 has a clear advantage on improving the lowresource tasks. On the contrary, for highresource languages the increased number of tasks limits pertask capacity within the model, resulting in lower translation quality compared to a models trained on a single language pair. This capacity bottleneck for high resourced languages can be relaxed by increasing the model size to massive scale in order to satisfy the need for additional capacity [arivazhagan2019massively, gpipe19].
Massively multilingual, massive MT consequently aims at striking a balance between increasing positive transfer by massive multilinguality and mitigating the capacity bottleneck by massive scaling. While doing so, scaling the model size and the number of languages considered have to be coupled with a convenient neural network architecture. In order to amplify the positive transfer and reduce the negative transfer^{6}^{6}6Negative transfer is the notion of sharing the model capacity by unrelated tasks which in return hurts the quality of such interfering tasks., one can naturally design a model architecture that harbours shared components across languages (shared subnetworks), along with some language specific ones (unshared, language specific subnetworks). However, the search space in model design (deciding on what to share) grows rapidly as the number of languages increase, making heuristicbased search for a suitable architecture impractical. Thereupon the need for approaches based on learning the wiring pattern of the neural networks from the data emerge as scalable and practical way forward.
In this section, we advocate how conditional computation [bengio2013estimating, davis2013lowrank] with sparsely gated mixture of experts [shazeer2017outrageously] fits into the above detailed desiderata and show its efficacy by scaling neural machine translation models beyond 1 trillion parameters, while keeping the training time of such massive networks practical. E.g. a 600B GShard model for M4 can process 1T tokens^{7}^{7}7Source side tokens after subword segmentation. in 250k training steps in under 4 days. We experiment with increasing the model capacity by adding more and more experts into the model and study the factors playing role in convergence, model quality and training efficiency. Further, we demonstrate how conditional computation can speed up the training [bengio2015conditional] and how sparsely gating/routing each token through the network can efficiently be learned without any prior knowledge on task or language relatedness, exemplifying the capability of learning the routing decision directly from the data.
The premise of progressively larger models to attain greater quality necessitates large amounts of training data to begin with [kaplan2020scaling]. Following the prior work on dense scaling for multilingual machine translation [gpipe19, arivazhagan2019massively], we committed to the realistic test bed of MT in the wild, and use a webscale inhouse dataset. The training corpus, mined from the web [10.5555/1873781.1873905], contains parallel documents for 100 languages, to and from English, adding up to a total of 25 billion training examples. A few characteristics of the training set is worth mentioning. Having mined from the web, the joint corpus is considerably noisy while covering a diverse set of domains and languages. Such large coverage comes with a heavy imbalance between languages in terms of the amount of examples per language pair. This imbalance follows a sharp power law, ranging from billions of examples for highresourced languages to tens of thousands examples for lowresourced ones. While the above mentioned characteristics constitute a challenge for our study, it also makes the overall attempt as realistic as possible. We refer reader to [gpipe19, arivazhagan2019massively] for the additional details of the dataset being used.
We focus on improving the translation quality (measured in terms of BLEU score [papineni2002bleu]) from all 100 languages to English. This resulted in approximately 13 billion training examples to be used for model training^{8}^{8}8Compared to prior work using the same dataset, Kazakh and Latin to English language pairs were excluded from evaluation.. In order to form our baselines, we trained separate bilingual Neural Machine Translation models for each language pair (e.g. a single model for GermantoEnglish), tuned depending on the available training data perlanguage^{9}^{9}9We tuned batchsize and different values of regularization methods (e.g. dropout) in a TransformerBig or TransformerBase layout, for high or lowresourced languages respectively.. Rather than displaying individual BLEU scores for each language pair, we follow the convention of placing the baselines along the axis at zero, and report the BLEU trendline of each massively multilingual model trained with GShard (see Figure 6). The axis in Figure 6 is sorted from lefttoright in the decreasing order of amount of available training data, where the leftmost side corresponds to highresourced languages, and lowresourced languages on the rightmost side respectively. To reiterate, our ultimate goal in universal machine translation is to amass the BLEU trendline of a single multilingual model above the baselines for all languages considered. We also include a variant of dense 96 layer Transformer EncoderDecoder network T(96L) trained with GPipe pipeline parallelism on the same dataset as another baseline (dashed trendline in Figure 6). Training to convergence took over 6 weeks on 2048 TPU v3 cores ^{10}^{10}10T(96L) measured to be processing 1+ trillion tokens at 300k steps, processing around 4M tokens/step, total budget of 235.5 TPU v3 core years, outperforming the original GPipe T(128L)^{11}^{11}1164 encoder + 64 decoder layers, 16384 hidden dim, 32 attention heads [gpipe19] and is the strongest single dense model baseline we use in our comparisons.
Scaling Transformer architecture has been an exploratory research track recently [Bapna_2018, Irie_2019, Wang_2019]. Without loss of generality, emerging approaches follow scaling Transformer by stacking more and more layers [Bapna_2018, gpipe19], widening the governing dimensions of the network (i.e. model dimension, hidden dimension or number of attention heads) [devlin2018bert, raffel2019exploring] and more recently learning the wiring structure with architecture search [so2019evolved] ^{12}^{12}12Since the approaches utilizing architecture search are compute intensive, they are not considered within the scope of this work.. For massively multilingual machine translation, [gpipe19] demonstrated the best practices of scaling using GPipe pipeline parallelism; in which a 128 layer Transformer model with 6 billion parameters is shown to be effective at improving highresource languages while exhibiting the highest positive transfer towards lowresource languages. Although very promising, and satisfying our desiderata for universal translation, dense scaling of Transformer architecture has practical limitations which we referred in Section 1 under training efficiency.
We aim for practical training time and seek for architectures that warrant training efficiency. Our strategy has three pillars; increase the depth of the network by stacking more layers similar to GPipe [gpipe19], increase the width of the network by introducing multiple replicas of the feedforward networks (experts) as described in Section 2.2 and make use of learned routing modules to (sparsely) assign tokens to experts as described in Section 2.1. With this three constituents, we obtain an easy to scale, efficient to train and highly expressive architecture, which we call SparselyGated MixtureofExperts Transformer or MoE Transformer in short.
To detail the model specifics, each expert is designed to have the same shape of a regular Transformer feedforward network, and experts (MoE layers) are distributed once in every other Transformer layer. We tied the number of devices used for training to the number of experts per MoE layer for simplicity, although this is not a requirement. During training, we use float32 for both model weights and activations in order to ensure training stability. We ran additional scalability experiments with MoE(2048E, 60L) with bfloat16 [bfloat16] activations with total of 1 trillion model weights. Although trainable by careful and manual diagnostics, with deep 1 trillion model we encountered several trainability issues with numerical stability, hence did not include the results for the sake of reproducibility. For more model and training details, please see Appendix A.2.
Id  Model 


Weights  

(1)  MoE(2048E, 36L)  44.3  13.5  600B  
(2)  MoE(2048E, 12L)  41.3  10.5  200B  
(3)  MoE(512E, 36L)  43.7  12.9  150B  
(4)  MoE(512E, 12L)  40.0  9.2  50B  
(5)  MoE(128E, 36L)  39.0  8.2  37B  
(6)  MoE(128E, 12L)  36.7  5.9  12.5B  
*  T(96L)  36.9  6.1  2.3B  
*  Baselines  30.8    1000.4B 
Before going into the details of training efficiency, we first investigate the effect of various design choices on building MoE Transformer. In order to prune the search space, we explored varying two variables, number of layers in the Transformer encoderdecoder stack (L) and the total number of experts used for every other MoE layer (E). For depth, we tested three different options, 12 (original Transformer depth, which consists of 6 encoder and 6 decoder layers), 36 and 60 layers. For the number of experts that replaces every other feedforward layer, we also tested three options, namely 128, 512 and 2048 experts. Note that, the number of devices used for training, is fixed to be equal to the number of experts perlayer, using 128, 512 and 2048 cores respectively independent of the depth being experimented. Please also see the detailed description in Table 1 for model configurations.
For each experiment (rows of the Table 1), we trained the corresponding MoE Transformer model until it has seen 1 trillion () tokens. The model checkpoint at this point is used in the model evaluation. We did not observe any overfitting patterns by this point in any experiment. Instead, we observed that the training loss continued to improve if we kept training longer. We evaluated BLEU scores that the models achieved for all language pairs on a heldout test set. Figure 6 reports all our results.
Here we share a qualitative analysis for each experiment and discuss the implication of each setup on high and lowresource languages in order to track our progress towards universal translation. To ground the forthcoming analysis, it is worth restating the expected behavior of the underlying quality gains. In order to improve the quality for both high and lowresource languages simultaneously within a single model, scaled models must mitigate capacity bottleneck issue by allocating enough capacity to highresource tasks, while amplifying the positive transfer towards lowresource tasks by facilitating sufficient parameter sharing. We loosely relate the expected learning dynamics of such systems with the longstanding memorization and generalization dilemma, which is recently studied along the lines of width vs depth scaling efforts [Cheng_2016]. Not only do we expect our models to generalize better to the heldout test sets, we also expect them to exhibit high transfer capability across languages as another manifestation of generalization performance [lampinen2018analytic].
Id  Model 




Weights  

(1)  MoE(2048E, 36L)  2048  36684  2048  36  600B  
(2)  MoE(2048E, 12L)  2048  12228  2048  12  200B  
(3)  MoE(512E, 36L)  512  9216  512  36  150B  
(4)  MoE(512E, 12L)  512  3072  512  12  50B  
(5)  MoE(128E, 36L)  128  2304  128  36  37B  
(6)  MoE(128E, 12L)  128  768  128  12  12.5B  
*  MoE(2048E, 60L)  2048  61440  2048  60  1T 
We first investigate the relationship between the model depth and the model quality for both high and lowresource languages. Three different experiments are conducted in order to test the generalization performance, while keeping the number of experts perlayer fixed. With an increasing number of perlayer experts for each experiment (128, 512 and 2048), we tripled the depth of the network for each expert size, from 12 to 36. This resulted in three groups where experts perlayer fixed but three times the depth within each group:
For each configuration shown in Fig. 6, we observed that increasing the depth (L) while keeping the experts perlayer (E) fixed, brings consistent gains for both low and high resourced languages (upwards shift along the axis), almost with a constant additive factor every time we scale the depth from 12L to 36L (2to3 BLEU points on average as shown in the last column of Table 3).
Earlier in Section 4.1 we highlighted the influence of the capacity bottleneck on task interference, resulting in degraded quality especially for high resourced languages. Later we alleviated this complication by increasing the number of experts perlayer, which in return resulted in a dramatic increase in the number of parameters (weight) of the models studied. Here we investigate whether this so called capacity bottleneck is distinctly observable and explore the impact on model quality and efficiency once it is relaxed. To that end, we first consider three models with identical depths (12L), with increasing number of experts perlayer: 128, 512 and 2048. As we increase the number of experts perlayer from 128 to 512 by a factor of four, we notice a large jump in model quality, +3.3 average BLEU score across 100 languages. However again by four folds scaling of the number of experts perlayer, from 512 to 2048, yields only +1.3 average BLEU scores. Despite the significant quality improvement, this drop in gains hints the emergence of diminishing returns.
Speculatively, the capacity bottleneck is expected to be residing between 128 to 512 experts, for the particular parametrization, number of languages and the amount of training data used in our experimental setup. Once the bottleneck is relaxed, models enjoy successive scaling of the depth, which can be seen by comparing 12 versus 36 layer models both with 128 experts. Interestingly increasing the depth does not help as much if the capacity bottleneck is not relaxed.
Another dimension that could shed light on the quality gains of scaling in multitask models is the contrast between high and low resource language improvements. As mentioned before, low resourced languages benefit from transfer while high resource languages seek for added capacity. Next we examine the effect of increasing the experts perlayer while fixing the depth.
As can be seen in Figure 6, for 12 layer models increase in the expert number yields larger gains for high resourced languages as opposed to earlier revealed diminishing returns for lowresourced languages. A similar pattern is observed also for 36 layer models. While adding more experts relaxes the capacity bottleneck, at the same time it reduces the amount of transfer due to a reduction of the shared subnetworks.
Lastly we look into the impact of the depth on lowresourced tasks as a loose corollary to our previous experiment. In order to do so, we include a dense model with 96 layers T(96L) trained with GPipe on the same data into our analysis. We compare T(96L) with the shallow MoE(128E, 12L) model. While the gap between the two models measured to be almost constant for the majority of the hightomid resourced languages, the gap grows in favor of the densedeep T(96L) model as we get into the lowresourced regime. Following our previous statement, as the proportion of the shared subnetworks across tasks increase, which is 100% for dense T(96L), the bandwidth for transfer gets maximized and results in a comparably better quality against its shallow counterpart. Also notice that, the same transfer quality to the lowresourced languages can be achieved with MoE(36E, 128L) which contains 37 billion parameters.
We conjecture that, increasing the depth might potentially increase the extent of transfer to lowresource tasks hence generalize better along that axis. But we also want to highlight that the models in comparison have a disproportionate training resource requirements. We again want to promote the importance of training efficiency, which is the very topic we studied next.
In this section we focus on the training efficiency of MoE Transformer models. So far, we have seen empirical evidence how scaling the models along various axes bring dramatic quality gains, and studied the factors affecting the extent of the improvements. In order to measure the training efficiency, we first keep track of the number of tokens being processed to reach a certain training loss and second we keep track of the wallclock time for a model to process certain number of tokens. Note that, we focus on the training time and training loss^{13}^{13}13Training loss reported in this section corresponds to crossentropy loss and excludes the auxiliary loss term introduced in Section 2.2 while varying other factors, as opposed to test error, which we analyzed in the previous section.
It has been shown that, deeper models are better at sample efficiency, reaching better training/test error given the same amount of training examples [gpipe19, shoeybi2019megatron], commonly attributed to the acceleration effect of overparametrization [arora2018optimization]. We empirically test the hypothesis again using GShard with MoE Transformers and share tradeoffs for models that are not only deep, but also sparsely activated.
For this purpose, we compare number of tokens being processed by each model to reach a preset training loss. A general trend we observe from Table 2 is that, MoE Transformer models with 3 times the depth need 2 to 3 times fewer tokens to reach the preset training loss thresholds. For example MoE(128E, 12L) takes 3 times the number of tokens to reach 0.7 training crossentropy compared to MoE(128E, 36L), (6) vs (5). We observe a similar trend for models with 512 and 2048 experts, (4) vs (3) and (2) vs (1).
Id  Model  Cores 



0.7  0.6  0.5  
(1)  MoE(2048E, 36L)  2048  82  175  542  
(2)  MoE(2048E, 12L)  2048  176  484  1780  
(3)  MoE(512E, 36L)  512  66  170  567  
(4)  MoE(512E, 12L)  512  141  486    
(5)  MoE(128E, 36L)  128  321  1074    
(6)  MoE(128E, 12L)  128  995     
Another intriguing observation from Table 2, is again related to the presence of capacity bottleneck. Comparing the models with same depth, (5), (3) and (1), we notice a significant drop in the number of tokens required to reach training loss of 0.7, as we transition from 128 to 512 number of experts. Practically that is where we observed the capacity bottleneck was residing, aligning with the hypothesis in Section 4.4. After this phase shift, models with ample capacity tend to exhibit similar sample efficiency characteristics, as in models (3) and (1).
Next we delve deeper into the interaction between model size and wallclock time spent for training. We monitor number of TPU cores being used, training steps persecond, total number of tokens per batch, TPU core years^{14}^{14}14TPU core years is simply measured by the product of number of cores and wallclock time in years., and actual wallclock time spent in days for training (see Table 3 columns respectively).
We start with investigating one of the largest models we trained, MoE(2048E, 36L) with 600 billion parameters, model with id (1). Having utilized 2048 TPU cores for 4 days, this model achieves the best translation quality in terms of average BLEU, but also takes a total of 22.4 TPU years to train. While we have not seen any signs that the quality improvements plateau as we scale up our models, we strive for finding costeffective solutions for scaling.
Results in Table 3 again validates scaling with conditional computation is way more practical compared to dense scaling. Given the same number of TPU cores used by (1), the dense scaling variant, T(96L), appears to be taking more than ten times to train (235 TPU core years), while trailing behind in terms of model quality compared to models trained with GShard.
Id  Model  Cores 







(1)  MoE(2048E, 36L)  2048  0.72  4M  22.4  4.0  44.3  
(2)  MoE(2048E, 12L)  2048  2.15  4M  7.5  1.4  41.3  
(3)  MoE(512E, 36L)  512  1.05  1M  15.5  11.0  43.7  
(4)  MoE(512E, 12L)  512  3.28  1M  4.9  3.5  40.0  
(5)  MoE(128E, 36L)  128  0.67  1M  6.1  17.3  39.0  
(6)  MoE(128E, 12L)  128  2.16  1M  1.9  5.4  36.7  
*  T(96L)  2048    4M  235.5  42  36.9 
In this section, we benchmarked GShard with MoE Transformers applications to multilingual machine translation (in particular to M4). We identified variables that are affecting the end result, such as capacity bottleneck, positive transfer and training efficiency, and provided experimental results in order to reveal the interplay between them. Next we will delve deep into performance related topics of GShard, such as memory and runtime efficiency and communication benchmarks.
This section discusses how well GShard achieves computation and memory efficiency on the TPU platform. Our measurement and analysis show that the device memory consumption is roughly constant when we increase the number of devices and experts, and the step time grows sublinearly, i.e., 1.7x execution time increase when we scale the model by 16x from 128 devices to 2048 devices. We also provide microbenchmarks and analyses for a variety of partitioned operators, which could guide use cases beyond this paper.
In the GShard model, there are mainly three types of memory usage, all of which have constant perdevice sizes after SPMD partitioning, when the number of experts increases.
Replicated weights (e.g. transformer feedforward layers).
Distributed weights (MoE feedforward layers^{15}^{15}15Gate projection weights are in size and could be partitioned, but in practice they are small enough to be replicated and only have negligible effect on peak memory usage.).
Activations (output of each layer that is used in both forward and backward pass).
The memory scaling is demonstrated in Figure 7, which shows the perdevice memory usage distribution for different models. With a fixed number of layers, both weight memory and activation memory stay constant when the number of experts increases.
On this other hand, weight memory and activation memory both scale linearly with the number of layers. When the memory requirement exceeds available memory on each device, compilerbased rematerialization will automatically recompute part of the activations in the backward pass in order to reduce peak activation memory. This is why the activation size for MoE(2048E, 60L) is smaller than MoE(2048E, 36L). The overhead of rematerialization is also optimized, e.g. only 28% and 34% of the total cycles are spent on recomputation for 36L and 60L models respectively, and 0% for 12L and 24L since they fit in device memory without rematerialization.
Figure 8
shows the breakdown of execution time for an MoE layer and its adjacent Transformer layer. It also compares the achieved performance to a roofline, which is estimated by assuming compute, memory, or communicationbounded operations can achieve 100% of the peak FLOPS, memory bandwidth, or interconnect bandwidth. This is a very optimistic estimate as many operators are bounded by a mixed set of resources. At a smaller scale (128 experts), our model can achieve > 70% of the roofline performance. The device time increases by 1.7x when we scale the model to 16x larger (2048 experts), and can still achieve 48% of the roofline performance.
Before analyzing performance scalability, we recall the size scaling of relevant tensor dimensions as discussed in Section 3.1. With devices, the number of experts and the group count are both set to . The fractional pergroup expert capacity is set to . This setup cannot scale indefinitely, since needs to be at least 1, but it is good enough to scale to thousands of experts.
These are the dense parts of the model, which are designed to achieve peak TPU utilization. On each device, these computations also have a constant cost when we scale to more experts. Feedforward layers and Transformer projections are mainly large matrix multiplications that utilize the TPU’s matrix unit well. These operations have achieved > 85% peak FLOPS in our experiment. The attention operations are composed of mainly batch matmuls, which are bounded by memory bandwidth when sequence lengths are small. As a result, in our experiments attention operations only achieved > 30% peak FLOPS.
In Figure 8, “Gate Einsum” represents the first two and the last Einsums in Algorithm 2. The first Einsum is the projection that calculates perexpert input to softmax. It has an cost, but it is a very small part of the layer. The other two Einsums are dispatching tokens and combining expert results. They effectively implement Gather with onehot matrices, which are more expensive, but with constant cost that is independent from the number of experts. The execution time of these Einsums increases by around 2x when we scale from 128 to 2048 experts (16x).
The remaining perdevice gating computation involves many generalpurpose computations like ArgMax and Cumsum, which are either memorybound or even sequential in nature, thus not designed to utilize TPUs well. The majority of the time is spent on sequential Cumsum operations to invert onehot matrices that represent selected experts for each token to onehot matrices that represent selected tokens for each expert. The linear complexity of Cumsum is demonstrated in Figure 8. This part of the gating computation also has an cost, but fortunately, similar to the Einsum before softmax, it has a very small constant factor. It has negligible execution time with 128 experts, and takes less than 10% of the total time spent in the MoE and Transformer layers with 2048 experts.
The most significant part of gating is communication, shown as “MoE dispatch and combine” in Figure 8. These are AllToAll operators, and as we will discuss in Section 5.3, their cost is . When the number experts grows 16x from 128 to 2048, the execution time increases by about 3.75x, and their proportion of execution time in the MoE and Transformer increases from 16% to 36%.
In this section, we measure and analyze the performance scalability of the SPMD partitioner for basic operators, which can be used to guide use cases beyond the MoE model presented in this paper.
Two critical collective communication operators in the MoE model are AllReduce and AllToAll. AllReduce is used in accumulating partial results, and AllToAll is used in resharding (Section 3.3.2). Figure 9 shows their performance scalability from 16 to 2048 partitions. AllReduce
on TPU has an execution time independent from the number of devices. The variance in Figure
9 is due to specifics of each topology, e.g., whether it is a square or a rectangle, and whether it is a torus or a mesh.AllToAll, on the other hand, gets more expensive as the number of partitions grows, but in a sublinear manner. On our 2D TPU cluster, AllToAll cost is roughly , where is the number of partitions. This is because with a fixed amount of data each partition sends (8MB or 32MB in Figure 9), the total amount of data that all partitions send is . Meanwhile, each data piece needs to travel hops on average, and there are overall devicetodevice links in the network. Therefore, if it is bandwidthbound, the execution time of an AllToAll is
Even if it is latencybound, the execution time will still be . Comparing 2048 partitions and 16 partitions, while grows by 128 times, the execution time of AllToAll only increases by 9 times. This enables us to use resharding to efficiently implement crosspartition dispatching (Figure (a)a).
AllGather and CollectivePermute are easier to analyze. AllGather’s output is larger than the input, and if we fix input size, then its communication cost is . CollectivePermute has a onetoone communication pattern, and with reasonable device arrangement where the sourcedestination pairs are close, its cost is for a fixed input size.
Total  Perpartition  

Dimensions  Compute  Compute  Communication  
Add(A,A>A)  A  0  
Matmul(AB,BC>AC)  B  AR  
Matmul(AB,BC>AC)  A  0  
Matmul(AB,BC>AC)  A,B  AG or CP  
Matmul(AB,BC>AC)  A,C  AG or CP  
Reduce(AB>A)  A  0  
Reduce(AB>B)  A  AR  
Einsum(GSEC,GSM>EGCM)  G,E *  AA  
Convolution(BIXY,xyIO>BOXY)  X **  CP 
We summarize the performance scalability for common operators using GShard in Table 4. It contains the Einsum/Matmul examples in Section 3.3.2, and also other common operators like Convolution and Reduce. The table includes the local compute on each partition, as well as the required communication based on our analysis above.
Most operators in Table 4 have sublinear scalability in terms of both compute and communication, which is consistent with our performance measurement of the MoE model. The scaling of spatially partitioned convolutions also demonstrates the efficiency of GShard for image partitioning (Appendix A.4).
However, the last two Matmul operators in Table 4 have scaling of perpartition compute and communication, where they have unmatched sharding in the operands. This is not due to inefficiency in the partitioning algorithm, but because the total compute in the full operator is very large (). Different partitioning strategies can be used for these cases, producing different communication primitives: replicating one operand will result in AllGather (requiring the replicated operand to fit in device memory), while slicing in a loop (Figure (c)c) will result in CollectivePermute.
Neural networks
Deep learning models have been very successful in advancing subfields of artificial intelligence. For years, the fields have been continuously reporting new state of the art results using varieties of model architectures for computer vision tasks
[krizhevsky2012imagenet, szegedy2015going, he2016deep], for natural language understanding tasks [sutskever2014sequence, bahdanau2014neural, wu2016google], for speech recognition and synthesis tasks [hinton2012deep, chan2016listen, chiu2018state, oord2016wavenet, shen2018natural]. More recently, attentionbased Transformer models further advanced state of the art of these fields [vaswani2017attention, devlin2018bert].Model scaling Both academic research and industry applications observed that larger neural networks tend to perform better on large enough datasets and for complex tasks. Within a single model family, simply making the network wider or deeper often improves the model quality empirically. E.g., deeper ResNets performed better [he2016identity], bigger Transformer models achieved better translation quality [vaswani2017attention], models with larger vocabulary, or embedding or feature crosses work better, too [arivazhagan2019massively, conneau2019unsupervised]. Across different model families, it has also been observed that bigger models with larger model capacities not only fit the training data better but also generalize better on test time [45820, neyshabur2017exploring, gpipe19]. This observation motivated many research efforts to build much bigger neural networks than those typically used in deep learning research models or production models. Shazeer et al. showed that a recurrent language model with 69 billion parameters using mixtureofexpert layers achieved much lower test perplexity for the one billion words (LM1B) benchmark [shazeer2017outrageously]. Brown et al. showed that a nonsparse 175 billion parameters model is capable of exhibiting highly accurate fewshot performance on several downstream NLP tasks.
Hardware Neural networks demand nonnegligible amounts of computation power. To address such a demand, special hardware (chips and networked machines) built for neural network training and inference can be dated back to 25 years ago [ienne1996special]. Since late 2000s, researchers started to leverage GPUs to accelerate neural nets [raina2009large, krizhevsky2012imagenet, cirecsan2010deep]. More recently, the industry also invested heavily in building more dedicated hardware systems chasing for more costeffective neural network hardware [jouppi2017datacenter]. Because the core computation of neural networks (various forms of summation of multiplications: convolution, matrix multiplication, einsum) are highly parallelizable numerical calculations, these chips are equipped with huge number of floating processing units (FPUs). Hence, the compute power of these specially designed hardware grew dramatically. It is reported that GPU price per flops dropped a factor of ten in just the last 4 years [gpu2019price] and flops per watts increased by 2 magnitude over the past 12 years [sun2019summarizing]. The widely available lowcost computation power is a major enabler for the success of neural networks.
Software Software systems supporting neural networks evolved together with the advancement of the underlying hardware [dean2012large, bastien2012theano, abadi2016tensorflow, paszke2017automatic]. While the accelerators are highly parallel compute machines, they are significantly more difficult to program directly. The frameworks made building neural networks easier and abstracted away many hardware specific details from the practitioners. They in turn rely on lowerlevel libraries to drive special hardware (accelerators) efficiently. E.g., CUDA [nickolls2008scalable] for Nvidia’s GPUs, or XLA for Google’s TPUs [xla]. These lowerlevel libraries are critical for achieving high efficiency using these special hardware.
Parallelism in model training and inference Modern neural networks make extensive use of a cluster of machines for training and inference, each of which equiped with several accelerators. Data parallelism [krizhevsky2012imagenet] is the most commonly used approach and is supported by major frameworks (TensorFlow [abadi2016tensorflow], PyTorch [pytorch2017], JAX [jax2018github, frostig2018mlsys]), where devices run the same program with different input data and combine their local gradients before the weight updates. Model parallelism on the other hand, partitions computation beyond the input batch, which is needed to build very large models. For example, pipelining [gpipe19, harlap2018pipedream] splits a large model’s layers into multiple stages, while operatorlevel partitioning [shazeer2018mesh, jia2018beyond] splits individual operators into smaller parallel operators. GShard used a type of operatorlevel partitioning to scale our model to a large number of parallel experts.
Automated parallelism Because programming in a distributed heterogeneous environment is challenging, particularly for highlevel practitioners, deeplearning frameworks attempt to alleviate the burden of their users from specifying how the distributed computation is done. For example, TensorFlow [abadi2016tensorflow] has support for data parallelism, and basic model parallelism with graph partitioning by pernode device assignment. Mesh TensorFlow [shazeer2018mesh] helps the user to build large models with SPMDstyle peroperator partitioning, by rewriting the computation in a Python library on top of TensorFlow; in comparison, our approach partitions the graph in the compiler based on lightweight annotations without requiring the user to rewrite the model. FlexFlow [jia2018beyond] uses automated search to discover the optimal partition of operators in a graph for better performance; while it focuses on determining the partitioning policy, our SPMD partitioner focuses on the mechanisms to transform an annotated graph. Weightupdate sharding [xu2020automatic] is another automatic parallelization transformation based on XLA, which mostly focuses on performance optimizations for TPU clusters, and conceptually can be viewed as a special case for GShard. Zero [rajbhandari2019zero] presents a set of optimizations to reduce memory redundancy in parallel training devices, by partitioning weights, activations, and optimizer state separately, and it is able to scale models to 170 billion parameters; in comparison, GShard is more general in the sense that it does not distinguish these tensors, and all of those specific partitioning techniques can be supported by simply annotating the corresponding tensors, allowing us to scale to over 1 trillion parameters and explore more design choices.
Conditional Computation and Machine Translation Conditional computation [bengio2015conditional, shazeer2017outrageously, Elbayad2020DepthAdaptiveT, bapna2020controlling] premises that the examples should be routed within the network by activating an input dependent subnetwork. The routing depends (or conditions) on certain criterion and without the loss of generality, can be any of the following: estimated difficulty of the example [lugosch2020surprisaltriggered], available computation budget [Elbayad2020DepthAdaptiveT, bapna2020controlling], or more generally a learned criterion with sparsity induced mixture of experts [shazeer2017outrageously]. We extend sparsely gated mixture of experts [shazeer2017outrageously] due to its flexibility and ease of scaling to state of the art neural sequence models, Transformers [vaswani2017attention], to satisfy training efficiency.
In this paper, we introduced GShard, a deep learning module that partitions computation at scale automatically. GShard operates with lightweight sharding annotations required in the user model code only and delivers an easy to use and flexible API for scaling giant neural networks. We applied GShard to scale up Transformer architecture with SparselyGated MixtureofExperts layers (MoE Transformer) and demonstrated a 600B parameter multilingual neural machine translation model can efficiently be trained in 4 days achieving superior performance and quality compared to prior art when translating 100 languages to English with a single model. In addition to the far better translation quality, MoE Transformer models trained with GShard also excel at training efficiency, with a training cost of 22 TPU v3 core years compared to 29 TPU years used for training all 100 bilingual Transformer baseline models. Empirical results presented in this paper confirmed that scaling models by utilizing conditional computation not only improve the quality of realworld machine learning applications but also remained practical and sample efficient during training. Our proposed method presents a favorable scalability/cost tradeoff and alleviates the need for modelspecific frameworks or tools for scaling giant neural networks. Together, our results help to elucidate a realistic and practical way forward for neural network scaling to achieve better model quality.
We have learned several lessons from our study. Our results suggest that progressive scaling of neural networks yields consistent quality gains, validating that the quality improvements have not yet plateaued as we scale up our models. While the results in this paper consolidate that model scaling is a must in deep learning practitioners’ toolbox, we also urge practitioners to strive for training efficiency. To this end, we identified factors that affect the training efficiency and showed their implications on downstream task quality. We demonstrated how the neural networks built with conditional computation yield a favorable tradeoff between scale and computational cost. In practice such critical design decisions allowed us to enjoy experimental cycles of not months or weeks, but only days to train models in the order of magnitude of trillion parameters.
Further, having a proper abstraction layer that separates model description from parallelization implementation, allows model developer to focus on network implementation, leaving GShard to partition the computation graphs automatically and generate programs that run on all devices in parallel. We found that generating a single program that is general enough to express computation on all underlying parallel devices is the key to compile scalably. The traditional way of generating multiple dedicated programs for different partitions results in explosive compilation time when scaling to thousands of partitions. To address this complexity, we introduced various compiler renovations based on SPMD sharding that allows any tensor dimension to be partitioned. As a takeaway, we emphasize that model scaling and training efficiency should go handinhand; and algorithmic improvements such as conditional computation when coupled with easy to use interfaces can effectively utilize large computational power.
Lastly, our experimental results empirically support that, mere parameter counting does not always correlate with the effective capacity of the models at scale [li2018measuring, maddox2020rethinking]. Comparison of the models should also account in the nature of the problem, i.e. massively multitask setting with a heavy training data imbalance across tasks as in our case, and control the factors affecting different operation modes of the networks, i.e. capacity bottleneck vs positive transfer.
We would like to thank the Google Brain and Translate teams for their useful input and insightful discussions, entire XLA and Lingvo development teams for their foundational contributions to this project. In particular Youlong Cheng, Naveen Arivazhagan, Ankur Bapna, Ruoming Pang, Yonghui Wu, Yuan Cao, David Majnemer, James Molloy, Peter Hawkins, Blake Hechtman, Mark Heffernan, Dimitris Vardoulakis, Tamas Berghammer, Marco Cornero, Cong Liu, Tong Shen, Hongjun Choi, Jianwei Xie, Sneha Kudugunta, and Macduff Hughes.
During decoding, we use beam search with length normalization similar to [wu2016google]. Decoding is autoregressive and generates the target sequence one token at a time, so for an output of length the decoder layer stack is executed times, sequentially. In particular for each decoder MoE layer there are dispatch/combine operations, which require crossdevice communication. Inference utilizes same cluster with same number of devices as training.
During beam search we flatten the beam hypotheses into a single sequence which contains all underlying tokens interleaved, and we modify decoder selfattention mask so that each hypothesis only has attention to appropriate positions in the joint flat sequence. We apply the same transformation to key/value tensors maintained by each decoder selfattention layer. This allows us to avoid reordering previously computed attention key/values after each beam expansion. Instead, we only reorder the mask representing the current active hypotheses. However, attention becomes times longer.
This tradeoff can be positive or negative depending on implementation details. As explained in [shazeer2019fast], memory bandwidth limits are important for incremental decoding with Transformer models. From this point of view, by flattening the beam we replace two operations with low compute/memory ratio (attention dot product and key/value reordering) with a single operation with a slightly higher compute/memory ratio (attention dot product over a longer sequence with more keys), but with the same total amount of memory it has to access.
In our Machine Translation experiments MoE Transformer models shared a) 1024 Transformer model dimension b) 8192 Feed Forward and MoE hidden dimension; c) 16 heads in multihead attention; d) 128 attention key and value dimension; and e) 0.1 input, residual and attention dropout rate.
We used the Adafactor [Shazeer2018AdafactorAL] optimizer with a)
factored secondmoment estimation;
b) first moment decay ; c) second moment decay with schedule; d) update clipping threshold of 1.0; and e) 1.0 learning rate with square root decay after 10k training steps.We used SentencePiece [Kudo2018SentencePieceAS] subword tokenizer with a single multilingual vocabulary for sourceside spanning 102 languages of size 64000, and Englishonly targetside vocabulary of size 32000.
In addition to the two common APIs (replicate() and split()) for sharding listed in Section 3.2, users or the compiler may use a more advanced sharding strategy to minimize data transfers.
shard(tensor, device_assignment) annotates tensor to be partitioned with the provided device assignment, and returns the annotated tensor. We use device assignment, a multidimensional integer array, to represent how the split is done. device_assignment has the same rank as the data tensor; its element count is the total number of partitions, and each element is the ID of the device that occupies the corresponding data slice. For example, a 3D tensor with shape with device assignment shape will have partition shape , and the order of elements in device assignment determines which slice each partition occupies.
Since data movement across devices critically affects the parallel execution performance, it is important to consider the target device topology as well as the communication between partitions of the tensor when assigning device ids in the device assignment for maximum performance. Figure 10 shows two different device assignments based on the device topology and the rowwise communication pattern on the tensor.
GShard is able to partition spatial dimensions in convolutions, and general enough to support use cases like giant images [spatialpartitioning]. To spatially shard a convolutional layer, we can use the sharding API in the following way.
GShard will then propagate the sharding on the spatial dimension to other layers and the backward pass. The rest of section discusses the specific complexity to partition Convolution and similar operators. There are several windowbased operations (e.g., Convolution, ReduceWindow), and they all require some type of halo exchange since data may be shared between windows. We use the CollectivePermute operator to exchange halo data between partitions, but one complication is that the halo size may be different across partitions whereas CollectivePermute needs to be statically shaped.
We first introduce the window configurations that the SPMD partitioner has to consider. Each spatial dimension in the convolution has the following set of configurations.
Stride is the distance (in number of elements) that the window moves to produce the next output element.
Low/high padding is the number of elements padded to the low/high end of the dimension in LHS (base).
Base dilation is the dilation factor of the LHS, i.e., one plus the number of elements padded between every element (excluding low/high padding). No base dilation means the value is set to 1.
Window dilation is one plus the number of elements padded between every element in the RHS (window).
We demonstrate that nonconstant halo size is common using a simple example, which does not have dilation. Figure 11 shows a 4way partitioned convolution, where the right halo sizes for the partitions are (1, 2, 3, 4) and can be expressed as a linear function of the partition ID: . Partition 1 is in charge of generating 2 output elements (red cells), which means that the partition needs to get 0 elements from Partition 0, and 2 elements from Partition 2 (area covered by two dotted red windows).
Figure 12 describes the sequence of operations for a general halo exchange. First, we calculate the maximum size of left and right halo across partitions and perform the halo exchange of the maximum size (Steps 1 and 2). Since some partitions may have excessive halos than needed, we use DynamicSlice (based on the partition ID) to slice off the valid region for the current partition (Step 3). Finally, some partitions may include garbage values (e.g., halos from outofrange input data), so we apply masking as described in Section 3.3.3.
Base dilation adds additional complexities to halo exchange, since the offset of each partition may be positioned at the dilation holes, and also low/high padding is applied after dilation, which makes the edges have different behavior than the interior elements. We handle base dilation in 3 cases (Figure 13).
is divisible by , where is the number of windows to be processed by each partition (i.e., the number of output elements for each partition). This condition guarantees that all partitions start with the same number of (interior or low) padding elements before the first data element in the LHS, so that we can use the same low padding. Halo exchange occurs on the nondilated/nonpadded base region, and the limit index of required data for Partition can be calculated as below.
which determines the right halo size. Because is divisible by , it can be simplified as , where and are both constants.
but is not divisible by . In this case, the low padding on different partitions are different, but it is a static configuration in windowed operations, which can’t be specialized for each partition for SPMD execution. Using Pad and DynamicSlice on the operand also would not work, because those operators would be applied before dilation, so everything would be multiplied by the dilation factor. Fortunately, with , all positions on the padded and dilated base region are valid window starts, and we can use the maximum low padding on all partitions to ensure that each partition calculates all required windows, then do a DynamicSlice on the output of the partitioned windowed operator to remove unnecessary data. The limit index of required data on the nonpadded base region for Partition is same as before,
but cannot be simplified to .
and is not divisible by . If neither of the above conditions are true, different partitions could start with different number of padding elements, and not all offsets are valid window starts. Consider the last example in Figure 13. Whatever low padding we chose, some partition will be invalid, because the valid windows could be skipped since . A solution to this problem is to pad the window in addition to padding the base area. We can use the maximum low padding required by the partitions on the base area, then increase the window size by that low padding amount. However, the low and high padding amounts on the window vary on different partitions, which can be implemented by a Pad followed by a DynamicSlice. The window padding is used to mask off the unaligned elements in the base area, so that the start of the nonpadding window element will be aligned with the desired start in the base area for each partition.
If the RHS is replicated, window dilation only affects the effective window size when partitioning the operator based on its LHS. If the dilated RHS is also partitioned, which typically occurs in the gradient computation of strided convolutions, handling window dilation is still simpler than handling base dilation, because there is no low/high padding on the RHS. We skip the details of the implementation.
Comments
There are no comments yet.