Automap: Towards Ergonomic Automated Parallelism for ML Models

by   Michael Schaarschmidt, et al.

The rapid rise in demand for training large neural network architectures has brought into focus the need for partitioning strategies, for example by using data, model, or pipeline parallelism. Implementing these methods is increasingly supported through program primitives, but identifying efficient partitioning strategies requires expensive experimentation and expertise. We present the prototype of an automated partitioner that seamlessly integrates into existing compilers and existing user workflows. Our partitioner enables SPMD-style parallelism that encompasses data parallelism and parameter/activation sharding. Through a combination of inductive tactics and search in a platform-independent partitioning IR, automap can recover expert partitioning strategies such as Megatron sharding for transformer layers.



There are no comments yet.


page 1

page 2

page 3

page 4


Automatic Graph Partitioning for Very Large-scale Deep Learning

This work proposes RaNNC (Rapid Neural Network Connector) as middleware ...

PipeTransformer: Automated Elastic Pipelining for Distributed Training of Transformers

The size of Transformer models is growing at an unprecedented pace. It h...

Whale: A Unified Distributed Training Framework

Data parallelism (DP) has been a common practice to speed up the trainin...

DistIR: An Intermediate Representation and Simulator for Efficient Neural Network Distribution

The rapidly growing size of deep neural network (DNN) models and dataset...

LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation

Deep Learning (DL) models are becoming larger, because the increase in m...

Amazon SageMaker Model Parallelism: A General and Flexible Framework for Large Model Training

With deep learning models rapidly growing in size, systems-level solutio...

torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models

We design and implement a ready-to-use library in PyTorch for performing...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Driven by recent progress in the design and evaluation of large language models

gpt3_2020, gshard, scaling_laws_2020, codex2021

, training techniques for large deep neural networks have become critically important for progress in machine learning systems. These networks are trained by combining multiple parallelism strategies and executing them across many accelerator devices. Identifying an effective combination of approaches such as data, model

megatron2019 or pipeline parallelism gpipe, pipedream, Narayanan2021MemoryEfficientPD, pipemare21 depends on the specific model architecture, accelerator characteristics, and distributed device topology. Selecting from a growing array of techniques such as micro-batching, rematerialisation, or parameter offloading DBLP:journals/corr/abs-2101-06840 is further complicated by expensive experimentation, with large models requiring up to thousands of accelerators. Beyond training larger models, organisations managing diverse accelerator fleets can improve hardware utilisation by partitioning models to fit onto older accelerators with less memory.

Tensor programming frameworks like TensorFlow tensorflow2015-whitepaper

, PyTorch

pytorch2019 and JAX jax2018github increasingly provide end-user primitives or add-in libraries to help define the parallelisation strategy. JAX exposes several function transformations to control parallelism such as pmap (typically, but not exclusively, used for batch parallelism), pjit or xmap for fine-grained control of model parallelism and interfacing tools such as the XLA SPMD partitioner gshard, gspmd2021. While these are powerful tools which enable experts to compose advanced parallelism strategies, they still require specialised skill sets to achieve the desired hardware utilisation. For instance, users must come up with expert sharding specifications for parameters and possibly intermediate values of a model to productively use the pjit API, which in turn drives the XLA SPMD partitioning infrastructure.

Motivation and challenges. Though automated partitioning and distribution of ML workloads has been explored before in the research community deviceplacement2017, optgraphcomp, mirhoseini2018a, flexflow2019, roc2020, tofu2019, mirhoseini2018a, we present here the unique challenges motivating our work:

  • We need integration into existing backend compilers/runtimes widely used in production and already targeting different accelerator platforms, most notably XLA xla. We explicitly want to avoid having to re-implement kernels for a specific architecture to be partition-aware flexflow2019.

  • We need integration into existing user workflows, in our case arbitrary computations specified in JAX without any

    user rewriting. Our partitioner may have to deal with XLA programs consisting of hundreds of thousands of operators, with hundreds or thousands of parameters. We cannot use a predefined library of partitioning-aware layers (e.g. Keras layers, as in 

    flexflow2019) because we wish to allow researchers (our prevalent group of users) to freely innovate with arbitrary tensor computations and experimental layers.

  • We aim for a fast solution that allows an effective research development cycle, i.e. a solution comparable to the overhead to schedule an experiment, perhaps minutes but not hours.

Key ideas we offer for discussion. In this paper, we present our design prototype and preliminary results of a data-driven automatic partitioner, automap. The challenging setting we face has implications (a) on the design of our partitioning infrastructure to make it compatible with existing compilers and user workflows, and (b) on the design of our automation.

To address problem (a) we offer a novel IR (“dialect” in MLIR terms mlir2020) that is layered on top of another tensor dialect (XLA HLO in our prototype, but this could be adapted to any dialect) and allows us to express distribution decisions of computations from the base dialect. The distribution decisions are expressed as rewrite rules controlled by an agent, providing maximal flexibility to incorporate new rewrites (Section 2.1).

Problem (b) is particularly challenging: The most important implication of our setup is the need to manage large, unstructured programs. Contrary to coarse-grained layer-based approaches, it would be impractical to explore decisions111A decision may be to determine whether a value (e.g. an argument or an intermediate) should be replicated or sharded, and if the latter, upon which dimension and along which devices (mesh axis). for all operations in a program via search. On the other hand, ML models can be trained to make parallel decisions for all operations optgraphcomp at this scale, but such models typically rely on fine-tuning for new types of programs, and this makes the approach harder to integrate with a researcher’s workflow. Moreover, evaluating the “goodness” of a partitioning solution, e.g., the reduction in peak working memory, requires at least a static analysis (e.g. a liveness analysis) over the result of lowering and optimising a large (50-100k ops) program to an accelerator-local program. This reinforces the need for solutions that get to good performance within few trials and rewriting steps. Hence, we started our investigations by exploring (i) hybrid search-learning techniques, and (ii) incremental partitioning decisions and the use of compiler rewriting passes that propagate these decisions across the program where possible. A key idea we use is to imitate the behaviour of expert users partitioning their models (e.g. users of GSPMD gspmd2021) to design inductive biases that reduce the number of decisions taken by an agent. Early results show how this approach can recover expert partitioning such as the Megatron transformer layer sharding megatron2019. We identify important challenges for scaling search up to many layers and improving its robustness to the details of the model architecture.

Figure 1: End-to-end workflow overview.

Further related work. There is further related work, not discussed in detail in this paper – for instance DistIR distir2021 is also an MLIR-based representation for distributed tensor programs focusing on MPMD (for arguments for/against MPMD see gspmd2021). ROC roc2020 is an extension of flexflow2019 specifically for GraphNets using a learned cost model. Tofu tofu2019 is a hierarchical graph partitioner based on dynamic programming. Other plausible techniques for partitioning such as constraint-based encoding of cost models, often combined with equality saturation equality_saturation21, or BRKGA aditya2020-based options are all insufficient – at least out of the box – for our purposes due to the difficulties in obtaining cost models. Another promising avenue is to design progressively accurate sequences of cost models that do not all require a fully materialised program, as in TVM tvm2018.

2 System design

Our rewrite engine is implemented in MLIR mlir2020 with an XLA backend, and a Python API in JAX. Users interact with our system by designing normal JAX models and then pass their main update function to our partitioner. The JAX functions can already include user-managed parallelism (e.g. batch parallelism), and our system will further partition them along additional mesh axes. JAX functions are converted to XLA computations, and then lowered to our rewriting dialect PartIR partir2020 in MLIR (overview in Figure 1). Rewrites and inductive tactics are exposed to the partitioner via a compiler environment (§2.2

). Rewrite sequences executed by the partitioner are lowered to an SPMD variant of our partitioning IR, and evaluated through compiler-internal cost models (estimating peak memory, runtime, and communication). PartIR is platform-independent but we have implemented an XLA backend to seamlessly support CPU, TPU

DBLP:journals/corr/JouppiYPPABBBBB17, and GPU worfklows.

func @main(%arg0: tensor<8x16xf32>, %arg1: tensor<16x64xf32>, %arg2: tensor<64xf32>)   -> tensor<8x64xf32> {     %0 = %arg0, %arg1 : tensor<8x64xf32>     %1 = mhlo.broadcast_in_dim %arg2 {broadcast_dims = 1} : tensor<8x64xf32>     %2 = mhlo.add %0, %1 : tensor<8x64xf32>     return %2 : tensor<8x64xf32> }
func @main(%arg0: tensor<8x16xf32>, %arg1: tensor<16x64xf32>, %arg2: tensor<64xf32>)   -> tensor<8x64xf32> attributes {mesh_shape = #partir.mesh<"shard"=2>} {     %0 = partir.tile 1 "shard" (%rshard : !partir.range<2>) {       %4 = partir.slice 1 %arg1[%rshard] : tensor<16x32xf32>       partir.yield %4 : tensor<16x32xf32>     }     %1 = %arg0, %0  : tensor<8x64xf32>     %2 = mhlo.broadcast_in_dim %arg2 {broadcast_dims = 1} : tensor<8x64xf32>     %3 = mhlo.add %1, %2 : tensor<8x64xf32>     return %3 : tensor<8x64xf32> }
func @main(%arg0: tensor<8x16xf32>, %arg1: tensor<16x64xf32>, %arg2: tensor<64xf32>)   -> tensor<8x64xf32> attributes  {mesh_shape = #partir.mesh<"shard"=2>} {     %0 = partir.atomic "shard" {       partir.yield %arg0 : tensor<8x16xf32>     }     %1 = partir.tile 1 "shard" (%rshard : !partir.range<2>) {       %2 = partir.slice 1 %arg1[%rshard] : tensor<16x32xf32>       %3 = %0, %2 : tensor<8x32xf32>       %4 = partir.slice 0 %arg2[%rshard] : tensor<32xf32>       %5 = mhlo.broadcast_in_dim %4 {broadcast_dims = 1} : tensor<8x32xf32>       %6 = mhlo.add %3, %5 : tensor<8x32xf32>       partir.yield %6 : tensor<8x32xf32>     }     return %1 : tensor<8x64xf32> }
Figure 2: Top: A small MLIR MHLO program representing a single linear layer in slightly simplified notation. Middle: the same program where %arg1 has been expressed as a tiling loop on dimension 1. Bottom: the final PartIR program after propagation. Note that looping on dimension 1 of %arg1 (of size 64) means that we can also partition the dot product along dimension 1, and essentially pull the whole computation inside the tiling loop by operating on slices of %arg2. In the final program %arg0 automatically got wrapped inside an “atomic” region to signify that it will remain replicated.

2.1 Partitioning IR

To expose partitioning decisions, we represent tensor programs in PartIR which is an MLIR “dialect” layered on top of MHLO, an MLIR encoding of the XLA Higher Level Operations (HLO). At the core of our IR, which operates on statically shaped multi-dimensional arrays, are tiling loop-like operators which express parallel computations that compute a distributed tensor value. Instead of allowing unrestricted parallel loops, we force users to declare logical mesh axes with fixed sizes and make sure that every such loop in a program is associated with an axis and same-axes loops never occur nested. This guarantees that our programs can compile as a single SPMD kernel. An example of a mesh is given in the left of Figure 1, requiring a total of 8 devices for execution.

Rewriting actions include actions to express the distribution of an intermediate variable as well as several flavors of propagation of partitioning information (i) from operands to results; (ii) from results to operands, and (iii) from a subset of operands to the rest. These propagation rules are enabled by a registry containing a declarative specification of this behaviour for each operator in the underlying tensor dialect. Rewrites always preserve semantics, decoupling search policies from correctness.

Figure 2 illustrates a small MHLO program representing a dense layer, how tiling decisions are expressed in our IR, and the result of propagation. Finally, the tiling loops in our IR lower to a dialect suitable for expressing SPMD computations – Figure 3 shows the result of lowering the final program in Figure 2. Optimising data transfers and reasoning about cost happens at this level of the stack, before we eventually compile back to accelerator-specific HLO code and feed back into the XLA compiler/runtime. A detailed technical exposition of our IR stack is the subject of a different paper.

func @main(%arg0: f32[8,16], %arg1: f32[16,64{"shard"}]>, %arg2: f32[64{"shard"}]>)
    -> f32[8,64{"shard"}]>  attributes {mesh_shape = #partir.mesh<"shard"=2>} {
    %0 = partir.spmd(%arg1, %arg0, %arg2) ["shard"] (%rshard: !partir.range<2>,
                                                     %arg3: tensor<16x32xf32>,
                                                     %arg4: tensor<8x16xf32>,
                                                     %arg5: tensor<32xf32>) {
      %2 = %arg4, %arg3 : tensor<8x32xf32>
      %3 = mhlo.broadcast_in_dim %arg5 { broadcast_dims = 1 } : tensor<8x32xf32>
      %4 = mhlo.add %2, %3 : tensor<8x32xf32>
      partir.yield %4 : tensor<8x32xf32>
    return %0
Figure 3: An SPMD program after partitioning using axis “shard”. A distributed tensor type like f32[16, 64{"shard"}] means that the value has a global shape of [16, 64], but is nevertheless sharded in chunks of shape [16, 32] (since here axis "shard" is of size 2). 

2.2 Automated partitioner

The partitioner interacts with a rewriting environment exposing rewriting tactics to distribute/partition values, and tactics involving patterns that apply globally throughout the module. Prior work optgraphcomp, placeto has highlighted the cost of auto-regressive rewriting which scales unfavourably with model size. Our approach is guided by minimising the number of rewriting decisions. To improve robustness across a wide range of programs, we propose to combine search and learning based on several observations:

  • Users often do not need to solve heterogeneous partitioning problems but typically map programs to rectangular device meshes such as TPU slices. Faster automation is enabled by restricting partitioning to using pre-defined mesh axes as the structure of tiling loops and their allowable nesting is fixed ahead of time.

  • When developing a partitioning strategy, users can often assign some decisions themselves based on knowledge of model and devices, such as selecting a data parallel axis. This allows the partitioner to focus on difficult decisions such as only the model-parallel strategy.

  • Experts do not approach partitioning by investigating individual operations but consider key structural elements such as parameters, certain activations or inputs, optimiser/network state etc. to formulate a high level strategy. They often then hand-craft annotations for a handful of internal nodes that they deem important to guide tools like the XLA GSPMD partitioner.

We leverage all three observations to design our automated partitioner. For a physical set of devices, (e.g. 8), users explicitly specify the logical axes for different forms of parallelism, concretely by providing a set of axes names and sizes (e.g. (“batch”, 2), (“model”, 4). The partitioner only searches over tiling decisions involving axes which it is explicitly instructed to use – while users remain in control of the others.

python def update(params, opt_state, batch): loss, grads = jax.value_and_grad( loss_fn)(params, batch) grads = jax.lax.psum(grads, "batch") updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state, loss # Update calls now execute in SPMD # fashion. update = jax.pmap(update, axis_name="batch")
Figure 4: Using JAX’s pmap (parallel map) function transform enables the use of multi-host data parallelism over an axis "batch".
python def update(params, opt_state, batch): # Same as in Fig. 2. # Specify mapping of axes on devices. device_layout = np.reshape( np.array(jax.devices()), (2, 4)) mesh = Mesh(device_layout , ("batch", "model")) # Manual data-, automated model parallel. # Axis "batch" is specified # for batch, the third input argument. update, spec = automap( update, mesh, [’model’], (None, None, 0) (None, None, 0))(*args)
Figure 5: Automap allows combining manual and automated parallelism. Users specify a mesh layout and where manual axes apply.

2.3 Search and learning

The partitioner must select a sequence of rewrites and propagation tactics to optimise some desirable cost function such as the execution cost. Large models frequently reach 50-100k HLO operations, and even the set of interesting nodes, in our experience of operations, can be impractically large for search. Instead of opting for a fully learned solution which may require fine-tuning on unseen programs, we first experiment with a hybrid approach. A learner narrows down the most relevant subset of program nodes for partitioning, and search selects the final decisions.

Search. We implemented Monte Carlo Tree Search (MCTS) Browne12asurvey with upper confidence bound for trees (UCT). Instead of exposing all program operations for assigning mesh axes, we initialise a worklist of all ‘interesting operation nodes’ when traversing the program i.e., the function arguments to the MLIR representation consisting of weights and biases, optimiser state, and model inputs. The action space exposes actions to insert tiling loops that partition each tensor and dimension by each predefined axis. After applying an action, its consequences are propagated conservatively

backward and forward through the program – for instance if a pointwise operator receives equi-partitioned arguments it may also be executed to produce an equi-partitioned result. Propagation can get stuck in internal nodes for which insufficient information exists (e.g. not enough arguments are partitioned); and these internal nodes with a need for non-trivial decisions resurface back to our worklist. This is a key difference compared to the heuristics-based sharding propagation underlying XLA GSPMD 

gshard. Another global rewriting decision we expose is a pass that infers the tiling of the rest of the arguments from only some of them. This pass allows us with only a few tiling decisions for some parameters/inputs of a model to induce sharding for other parameters/inputs. If applied at every step, it quickly reduces the number of remaining decisions at the cost of increased wall clock time, and we are experimenting with different mechanics to expose this to an agent.

Learning. Using the worklist still exposes too many nodes for search, but careful analysis of the state-of-the-art sharding of transformers megatron2019 shows that in certain situations only a handful need to be selected to fully partition a model, so we apply learning to rank them. Our compiler featurises operation nodes as a concatenation of operation type, operand shapes, and existing partitioned axes. Edges encode program dataflow and MLIR program structure. We then compute a per-node relevance score using a node-embedding where a learned model predicts for each input to the MLIR program a ranking corresponding to the importance of this node to be partitioned, and the top- ( 25) most relevant nodes are then passed to MCTS to select the final rewriting sequence. In summary, we combine inductive propagation tactics, search and learning to deliver automated model parallelism with a relatively small number of explicit decisions.

End-to-end user example. We illustrate how end-users interact with automap in Figure 5 in comparison to using an existing JAX parallel primitive in Figure 5. Automap is instrumented using existing JAX tooling for describing positional axes and meshes (based on xmap/pjit). In addition to a partitioned callable, automap returns a specification of partitioning decisions for inputs and outputs. These specifications can then be used to partition function inputs such as parameters, optimiser state, or network state.

3 Towards an evaluation

Results. We investigate the performance of our prototype on a transformer model and compare to a well known reference strategy. We implemented a GPT-3 gpt3_2020 style 24-layer transformer model which requires 26 GB of memory at batch size 1 (not fit for a single TPU v3 device at 16 GB RAM), and which has just over 50k operations, and 1150 arguments. We then evaluated our prototype’s ability to discover Megatron-style megatron2019 sharding through search, and when combined with a learner. The search mechanism is guided by multiple cost statistics. First, a peak liveness analysis exposes an approximate memory estimate. This is a conservative estimate, and XLA compilation can further improve required memory through optimisations such as fusion. Second, we minimise the number of bytes communicated through reduction operations. The learned model was trained on a dataset of 20k transformer variants. To generate training data, we selected random model arguments (1000 per model), and exhaustively partitioned all argument dimensions. Our model was trained to imitate the highest scoring strategy. The model was implemented based on an Interaction Network DBLP:journals/corr/BattagliaPLRK16 using JAX with Haiku, Jraph for GraphNets, and Optax for training jax2018github, deepmind2020jax, jraph2020github, haiku2020github.

Megatron is a highly scalable large_megatron_21 training strategy for transformers DBLP:journals/corr/VaswaniSPUJGKP17 which exploits intra-layer model parallelism to minimise the number of required all-reduces. We view Megatron as a representative for a widely used expert strategy. Achieving Megatron is measured through gathering statistics on collectives in the partitioned model.

Figure 6: Comparing search progress using MCTS only and combined with a learned filter to discover Megatron-style sharding.
Figure 7: TPU v3 runtimes of the solutions found. Near Megatron solutions only incur a small performance penalty.

In Figure 7 we illustrate the success rate in discovering Megatron (over 50 search attempts) for a number of search budgets. Results show that several thousands of episodes are required to reliably discover expert-level sharding. We then evaluated the search result for each run on TPU v3 (Figure 7). A key insight is that our search at shorter budgets frequently discovers solutions near Megatron (i.e. few redundant collectives) which are in practice almost as fast (as highlighted by the runtimes of solutions using the learned filter, which are near Megatron from 500 episodes onwards, i.e. requiring few minutes of search). Solutions typically required 2-20 decisions. While these results are encouraging, more work is needed to support a learned system in interactive compiler workflows to be able to handle a variety of (generally unpredictable) user programs.

Scaling with compiler hints. Our results show that discovering semantically meaningful strategies is possible in principle. However, we found that discovering these strategies for larger models critically relies on propagating sharding information through subtly shared constants and other computations across layers. Such sharing is brittle and cannot be relied on for a general solution towards rewriting deeper networks. As machine learning models commonly consist of repeated blocks (such as attention blocks in Transformers, residual blocks in ConvNets, or message passing layers in Graph nets), search techniques scale unfavourably when having to explicitly rewrite each layer. We therefore also implemented the ability to exploit model structure in automap by allowing users to group repeated layers together and exposing only a single set of decisions per group. The mechanism used was that of named scopes which are commonly used in libraries such as Haiku deepmind2020jax. Figures 9 and 9 present the effect of grouping.

Figure 8: Searching Transformer sharding strategies with grouped attention blocks via compiler hints drastically improves results.
Figure 9: Impact of grouping when not relying on propagation of sharding information via shared constants across layers.

Notably, when allowing for compiler hints on layer groups, Megatron can be found reliably in a small number of episodes without requiring to propagate through shared dependencies across layers. Without grouping or shared dependency propagation, Megatron is not found for a 24 layer Transformer. As grouping only requires users to provide the name scope for any relevant group (e.g. "attention-block"), this provides an attractive path for initial real world use cases.

Other models. We were also able to partition other models such as GraphNets where no one-size-fits-all expert strategy exists. Here, the automap prototype in first experiments was able to discover simple manual strategies such as input edge sharding that allow practitioners to begin experimentation with larger graphs and models.

Discussion. Our results illustrate how effective partitioning strategies can be reached using a rewriting system through the combination of data-driven strategies and inductive tactics, and optionally exploiting high-level model structure. More work is required to understand the right combination of user-provided structure information, search, and learning in order to balance minimal user efforts and time to good solutions. The results presented here were also initially restricted to sharding within the devices of a single host while assuming data parallelism across hosts which simplifies the communication cost model. More advanced cost models will be required to model multi-host communication as well as down-stream changes to models during device-specific lowering (e.g. through fusion). Finally, further work is also needed to support additional automated partitioning strategies such as pipeline parallelism or ZeRO offloading.