1 Introduction
Scale matters. We are in an unprecedented era in AI research history in which the increasing data and model scale is rapidly improving accuracy in computer vision
[22, 41, 34, 35, 36, 16], speech [17, 40], and natural language processing
[7, 38]. Take the profound impact in computer vision as an example: visual representations learned by deep convolutional neural networks
[23, 22] show excellent performance on previously challenging tasks like ImageNet classification [33] and can be transferred to difficult perception problems such as object detection and segmentation [8, 10, 28]. Moreover, this pattern generalizes: larger datasets and neural network architectures consistently yield improved accuracy across all tasks that benefit from pretraining [22, 41, 34, 35, 36, 16]. But as model and data scale grow, so does training time; discovering the potential and limits of largescale deep learning requires developing novel techniques to keep training time manageable.The goal of this report is to demonstrate the feasibility of, and to communicate a practical guide to, largescale training with distributed synchronousstochastic gradient descent (SGD). As an example, we scale ResNet50 [16] training, originally performed with a minibatch size of 256 images (using 8 Tesla P100 GPUs, training time is 29 hours), to larger minibatches (see Figure 1). In particular, we show that with a large minibatch size of 8192, we can train ResNet50 in 1 hour using 256 GPUs while maintaining the same level of accuracy as the 256 minibatch baseline. While distributed synchronous SGD is now commonplace, no existing results show that generalization accuracy can be maintained with minibatches as large as 8192 or that such highaccuracy models can be trained in such short time.
To tackle this unusually large minibatch size, we employ a simple and hyperparameterfree linear scaling rule to adjust the learning rate. While this guideline is found in earlier work [21, 4], its empirical limits are not well understood and informally we have found that it is not widely known to the research community. To successfully apply this rule, we present a new warmup strategy, i.e., a strategy of using lower learning rates at the start of training [16], to overcome early optimization difficulties. Importantly, not only does our approach match the baseline validation error, but also yields training error curves that closely match the small minibatch baseline. Details are presented in §2.
Our comprehensive experiments in §5 show that optimization difficulty is the main issue with large minibatches, rather than poor generalization (at least on ImageNet), in contrast to some recent studies [20]. Additionally, we show that the linear scaling rule and warmup generalize to more complex tasks including object detection and instance segmentation [9, 31, 14, 28], which we demonstrate via the recently developed Mask RCNN [14]. We note that a robust and successful guideline for addressing a wide range of minibatch sizes has not been presented in previous work.
While the strategy we deliver is simple, its successful application requires correct implementation with respect to seemingly minor and often not well understood implementation details within deep learning libraries. Subtleties in the implementation of SGD can lead to incorrect solutions that are difficult to discover. To provide more helpful guidance we describe common pitfalls and the relevant implementation details that can trigger these traps in §3.
Our strategy applies regardless of framework, but achieving efficient linear scaling requires nontrivial communication algorithms. We use the opensource Caffe2^{1}^{1}1http://www.caffe2.ai deep learning framework and Big Basin GPU servers [24], which operates efficiently using standard Ethernet networking (as opposed to specialized network interfaces). We describe the systems algorithms that enable our approach to operate near its full potential in §4.
The practical advances described in this report are helpful across a range of domains. In an industrial domain, our system unleashes the potential of training visual models from internetscale data, enabling training with billions of images per day. Of equal importance, in a research domain, we have found it to simplify migrating algorithms from a singleGPU to a multiGPU implementation without requiring hyperparameter search, e.g. in our experience migrating Faster RCNN [31] and ResNets [16] from 1 to 8 GPUs.
2 Large Minibatch SGD
We start by reviewing the formulation of Stochastic Gradient Descent (SGD), which will be the foundation of our discussions in the following sections. We consider supervised learning by minimizing a loss
of the form:(1) 
Here are the weights of a network, is a labeled training set, and is the loss computed from samples and their labels . Typically is the sum of a classification loss (e.g., crossentropy) and a regularization loss on .
Minibatch Stochastic Gradient Descent [32], usually referred to as simply as SGD in recent literature even though it operates on minibatches, performs the following update:
(2) 
Here is a minibatch sampled from and is the minibatch size, is the learning rate, and is the iteration index. Note that in practice we use momentum SGD; we return to a discussion of momentum in §3.
2.1 Learning Rates for Large Minibatches
Our goal is to use large minibatches in place of small minibatches while maintaining training and generalization accuracy. This is of particular interest in distributed learning, because it can allow us to scale to multiple workers^{2}^{2}2We use the terms ‘worker’ and ‘GPU’ interchangeably in this work, although other implementations of a ‘worker’ are possible. ‘Server’ denotes a set of 8 GPUs that does not require communication over a network. using simple data parallelism without reducing the perworker workload and without sacrificing model accuracy.
As we will show in comprehensive experiments, we found that the following learning rate scaling rule is surprisingly effective for a broad range of minibatch sizes:
Linear Scaling Rule: When the minibatch size is multiplied by , multiply the learning rate by .
All other hyperparameters (weight decay, etc.) are kept unchanged. As we will show in §5, the linear scaling rule can help us to not only match the accuracy between using small and large minibatches, but equally importantly, to largely match their training curves, which enables rapid debugging and comparison of experiments prior to convergence.
Interpretation.
We present an informal discussion of the linear scaling rule and why it may be effective. Consider a network at iteration with weights , and a sequence of minibatches for each of size . We compare the effect of executing SGD iterations with small minibatches and learning rate versus a single iteration with a large minibatch of size and learning rate .
According to (2), after iterations of SGD with learning rate and a minibatch size of we have:
(3) 
On the other hand, taking a single step with the large minibatch of size and learning rate yields:
(4) 
As expected, the updates differ, and it is unlikely that . However, if we could assume for , then setting would yield , and the updates from small and large minibatch SGD would be similar. Although this is a strong assumption, we emphasize that if it were true the two updates are similar only if we set .
The above interpretation gives intuition for one case where we may hope the linear scaling rule to apply. In our experiments with (and warmup), small and large minibatch SGD not only result in models with the same final accuracy, but also, the training curves match closely. Our empirical results suggest that the above approximation might be valid in largescale, realworld data.
However, there are at least two cases when the condition will clearly not hold. First, in initial training when the network is changing rapidly, it does not hold. We address this by using a warmup phase, discussed in §2.2. Second, minibatch size cannot be scaled indefinitely: while results are stable for a large range of sizes, beyond a certain point accuracy degrades rapidly. Interestingly, this point is as large as 8k in ImageNet experiments.
Discussion.
The above linear scaling rule was adopted by Krizhevsky [21], if not earlier. However, Krizhevsky reported a 1% increase of error when increasing the minibatch size from 128 to 1024, whereas we show how to maintain accuracy across a much broader regime of minibatch sizes. Chen et al. [5] presented a comparison of numerous distributed SGD variants, and although their work also employed the linear scaling rule, it did not establish a small minibatch baseline. Li [25] (§4.6) showed distributed ImageNet training with minibatches up to 5120 without a loss in accuracy after convergence. However, their work did not demonstrate a hyperparameter searchfree rule for adjusting the learning rate as a function of minibatch size, which is a central contribution of our work.
In recent work, Bottou et al. [4] (§4.2) review theoretical tradeoffs of minibatching and show that with the linear scaling rule, solvers follow the same training curve as a function of number of examples seen, and suggest the learning rate should not exceed a maximum rate independent of minibatch size (which justifies warmup). Our work empirically tests these theories with unprecedented minibatch sizes.
2.2 Warmup
As we discussed, for large minibatches (e.g., 8k) the linear scaling rule breaks down when the network is changing rapidly, which commonly occurs in early stages of training. We find that this issue can be alleviated by a properly designed warmup [16], namely, a strategy of using less aggressive learning rates at the start of training.
Constant warmup.
The warmup strategy presented in [16] uses a low constant learning rate for the first few epochs of training. As we will show in §5, we have found constant warmup particularly helpful for prototyping object detection and segmentation methods [9, 31, 26, 14] that finetune pretrained layers together with newly initialized layers.
In our ImageNet experiments with a large minibatch of size , we have tried to train with the low learning rate of for the first 5 epochs and then return to the target learning rate of . However, given a large , we find that this constant warmup is not sufficient to solve the optimization problem, and a transition out of the low learning rate warmup phase can cause the training error to spike. This leads us to propose the following gradual warmup.
Gradual warmup.
We present an alternative warmup that gradually ramps up the learning rate from a small to a large value. This ramp avoids a sudden increase of the learning rate, allowing healthy convergence at the start of training. In practice, with a large minibatch of size , we start from a learning rate of and increment it by a constant amount at each iteration such that it reaches after 5 epochs (results are robust to the exact duration of warmup). After the warmup, we go back to the original learning rate schedule.
2.3 Batch Normalization with Large Minibatches
Batch Normalization (BN) [19]
computes statistics along the minibatch dimension: this breaks the independence of each sample’s loss, and changes in minibatch size change the underlying definition of the loss function being optimized. In the following we will show that a commonly used ‘shortcut’, which may appear to be a practical consideration to avoid communication overhead, is actually necessary for preserving the loss function when changing minibatch size.
We note that (1) and (2) assume the persample loss is independent of all other samples. This is not the case when BN is performed and activations are computed across samples. We write to denote that the loss of a single sample depends on the statistics of all samples in its minibatch . We denote the loss over a single minibatch of size as . With BN, the training set can be thought of as containing all distinct subsets of size drawn from the original training set , which we denote as . The training loss then becomes:
(5) 
If we view as a ‘single sample’ in , then the loss of each single sample is computed independently.
Note that the minibatch size over which the BN statistics are computed is a key component of the loss: if the perworker minibatch sample size is changed, it changes the underlying loss function that is optimized
. More specifically, the mean/variance statistics computed by BN with different
exhibit different levels of random variation.In the case of distributed (and multiGPU) training, if the perworker sample size is kept fixed and the total minibatch size is , it can be viewed a minibatch of samples with each sample independently selected from , so the underlying loss function is unchanged and is still defined in . Under this point of view, in the BN setting after seeing minibatches , (3) and (4) become:
(6) 
(7) 
Following similar logic as in §2.1, we set and we keep the perworker sample size constant when we change the number of workers .
In this work, we use which has performed well for a wide range of datasets and networks [19, 16]. If is adjusted, it should be viewed as a hyperparameter of BN, not of distributed training. We also note that the BN statistics should not be computed across all workers, not only for the sake of reducing communication, but also for maintaining the same underlying loss function being optimized.
3 Subtleties and Pitfalls of Distributed SGD
In practice a distributed implementation has many subtleties. Many common implementation errors change the definitions of hyperparameters, leading to models that train but whose error may be higher than expected, and such issues can be difficult to discover. While the remarks below are straightforward, they are important to consider explicitly to faithfully implement the underlying solver.
Weight decay.
Weight decay is actually the outcome of the gradient of an L2regularization term in the loss function. More formally, the persample loss in (1) can be written as . Here is the sampleindependent L2 regularization on the weights and is a sampledependent term such as the crossentropy loss. The SGD update in (2) can be written as:
(8) 
In practice, usually only the sampledependent term is computed by backprop; the term is computed separately and added to the aggregated gradients contributed by . If there is no weight decay term, there are many equivalent ways of scaling the learning rate, including scaling the term . However, as can be seen from (8), in general this is not the case. We summarize these observations in the following remark:
Remark 1: Scaling the crossentropy loss is not equivalent to scaling the learning rate.Momentum correction.
Momentum SGD is a commonly adopted modification to the vanilla SGD in (2). A reference implementation of momentum SGD has the following form:
(9) 
Here is the momentum decay factor and
is the update tensor. A popular variant absorbs the learning rate
into the update tensor. Substituting for in (9) yields:(10) 
For a fixed , the two are equivalent. However, we note that while only depends on the gradients and is independent of , is entangled with . When changes, to maintain equivalence with the reference variant in (9), the update for should be: . We refer to the factor as the momentum correction. We found that this is especially important for stabilizing training when , otherwise the history term is too small which leads to instability (for momentum correction is less critical). This leads to our second remark:
Remark 2: Apply momentum correction after changing learning rate if using (10).Gradient aggregation.
For workers each with a perworker minibatch of size , following (4), gradient aggregation must be performed over the entire set of examples according to . Loss layers are typically implemented to compute an average loss over their local input, which amounts to computing a perworker loss of . Given this, correct aggregation requires averaging the gradients in order to recover the missing factor. However, standard communication primitives like allreduce [11] perform summing, not averaging. Therefore, it is more efficient to absorb the
scaling into the loss, in which case only the loss’s gradient with respect to its input needs to be scaled, removing the need to scale the entire gradient vector. We summarize this as follows:
Remark 3: Normalize the perworker loss by total minibatch size , not perworker size .We also note that it may be incorrect to ‘cancel ’ by setting (not ) and normalizing the loss by (not ), which can lead to incorrect weight decay (see Remark 1).
Data shuffling.
SGD is typically analyzed as a process that samples data randomly with replacement. In practice, common SGD implementations apply random shuffling of the training set during each SGD epoch, which can give better results [3, 13]. To provide fair comparisons with baselines that use shuffling (e.g., [16]), we ensure the samples in one epoch done by workers are from a single consistent random shuffling of the training set. To achieve this, for each epoch we use a random shuffling that is partitioned into parts, each of which is processed by one of the workers. Failing to correctly implement random shuffling in multiple workers may lead to noticeably different behavior, which may contaminate results and conclusions. In summary:
Remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all workers.4 Communication
In order to scale beyond the 8 GPUs in a single Big Basin server [24], gradient aggregation has to span across servers on a network. To allow for near perfect linear scaling, the aggregation must be performed in parallel with backprop. This is possible because there is no data dependency between gradients across layers. Therefore, as soon as the gradient for a layer is computed, it is aggregated across workers, while gradient computation for the next layer continues (as discussed in [5]). We give full details next.
4.1 Gradient Aggregation
For every gradient, aggregation is done using an allreduce operation (similar to the MPI collective operation MPI_Allreduce [11]). Before allreduce starts every GPU has its locally computed gradients and after allreduce completes every GPU has the sum of all gradients. As the number of parameters grows and compute performance of GPUs increases, it becomes harder to hide the cost of aggregation in the backprop phase. Training techniques to overcome these effects are beyond the scope of this work (e.g., quantized gradients [18], BlockMomentum SGD [6]). However, at the scale of this work, collective communication was not a bottleneck, as we were able to achieve nearlinear SGD scaling by using an optimized allreduce implementation.
Our implementation of allreduce consists of three phases for communication within and across servers: (1) buffers from the 8 GPUs within a server are summed into a single buffer for each server, (2) the results buffers are shared and summed across all servers, and finally (3) the results are broadcast onto each GPU. For the local reduction and broadcast in phases (1) and (3) we used NVIDIA Collective Communication Library (NCCL)^{3}^{3}3https://developer.nvidia.com/nccl for buffers of size 256 KB or more and a simple implementation consisting of a number of GPUtohost memory copies and a CPU reduction otherwise. NCCL uses GPU kernels to accelerate intraserver collectives, so this approach dedicates more time on the GPU to backprop while using the CPU resources that would otherwise have been idle to improve throughput.
For interserver allreduce, we implemented two of the best algorithms for bandwidthlimited scenarios: the recursive halving and doubling algorithm [30, 37] and the bucket algorithm (also known as the ring algorithm) [2]. For both, each server sends and receives bytes of data, where is the buffer size in bytes and is the number of servers. While the halving/doubling algorithm consists of communication steps, the ring algorithm consists of steps. This generally makes the halving/doubling algorithm faster in latencylimited scenarios (i.e., for small buffer sizes and/or large server counts). In practice, we found the halving/doubling algorithm to perform much better than the ring algorithm for buffer sizes up to a million elements (and even higher on large server counts). On 32 servers (256 GPUs), using halving/doubling led to a speedup of 3 over the ring algorithm.
The halving/doubling algorithm consists of a reducescatter collective followed by an allgather. In the first step of reducescatter, servers communicate in pairs (rank 0 with 1, 2 with 3, etc.), sending and receiving for different halves of their input buffers. For example, rank 0 sends the second half of its buffer to 1 and receives the first half of the buffer from 1. A reduction over the received data is performed before proceeding to the next step, where the distance to the destination rank is doubled while the data sent and received is halved. After the reducescatter phase is finished, each server has a portion of the final reduced vector.
This is followed by the allgather phase, which retraces the communication pattern from the reducescatter in reverse, this time simply concatenating portions of the final reduced vector. At each server, the portion of the buffer that was being sent in the reducescatter is received in the allgather, and the portion that was being received is now sent.
To support nonpoweroftwo number of servers, we used the binary blocks algorithm [30]. This is a generalized version of the halving/doubling algorithm where servers are partitioned into poweroftwo blocks and two additional communication steps are used, one immediately after the intrablock reducescatter and one before the intrablock allgather. Nonpoweroftwo cases have some degree of load imbalance compared to poweroftwo, though in our runs we did not see significant performance degradation.
4.2 Software
The allreduce algorithms described are implemented in Gloo^{4}^{4}4https://github.com/facebookincubator/gloo, a library for collective communication. It supports multiple communication contexts, which means no additional synchronization is needed to execute multiple allreduce instances in parallel. Local reduction and broadcast (described as phases (1) and (3)) are pipelined with interserver allreduce where possible.
Caffe2 supports multithreaded execution of the compute graph that represents a training iteration. Whenever there is no data dependency between subgraphs, multiple threads can execute those subgraphs in parallel. Applying this to backprop, local gradients can be computed in sequence, without dealing with allreduce or weight updates. This means that during backprop, the set of runnable subgraphs may grow faster than we can execute them. For subgraphs that contain an allreduce run, all servers must choose to execute the same subgraph from the set of runnable subgraphs. Otherwise, we risk distributed deadlock where servers are attempting to execute nonintersecting sets of subgraphs. With allreduce being a collective operation, servers would time out waiting. To ensure correct execution we impose a partial order on these subgraphs. This is implemented using a cyclical control input, where completion of the th allreduce unblocks execution of the th allreduce, with being the maximum number of concurrent allreduce runs. Note that this number should be chosen to be lower than the number of threads used to execute the full compute graph.
4.3 Hardware
We used Facebook’s Big Basin [24] GPU servers for our experiments. Each server contains 8 NVIDIA Tesla P100 GPUs that are interconnected with NVIDIA NVLink. For local storage, each server has 3.2TB of NVMe SSDs. For network connectivity, the servers have a Mellanox ConnectX4 50Gbit Ethernet network card and are connected to Wedge100 [1] Ethernet switches.
We have found 50Gbit of network bandwidth sufficient for distributed synchronous SGD for ResNet50, per the following analysis. ResNet50 has approximately 25 million parameters. This means the total size of parameters is . Backprop for ResNet50 on a single NVIDIA Tesla P100 GPU takes 120 ms. Given that allreduce requires 2 bytes on the network compared to the value it operates on, this leads to a peak bandwidth requirement of , or 12.8 Gbit/s, not taking into account communication overhead. When we add a smudge factor for network overhead, we reach a peak bandwidth requirement for ResNet50 of 15 Gbit/s.
As this peak bandwidth requirement only holds during backprop, the network is free to be used for different tasks that are less latency sensitive then aggregation (e.g. reading data or saving network snapshots) during the forward pass.
5 Main Results and Analysis
Our main result is that we can train ResNet50 [16] on ImageNet [33] using 256 workers in one hour, while matching the accuracy of small minibatch training. Applying the linear scaling rule along with a warmup strategy allows us to seamlessly scale between small and large minibatches (up to 8k images) without tuning additional hyperparameters or impacting accuracy. In the following subsections we: (1) describe experimental settings, (2) establish the effectiveness of large minibatch training, (3) perform a deeper experimental analysis, (4) show our findings generalize to object detection/segmentation, and (5) provide timings.
5.1 Experimental Settings
The 1000way ImageNet classification task [33] serves as our main experimental benchmark. Models are trained on the 1.28 million training images and evaluated by top1 error on the 50,000 validation images.
We use the ResNet50 [16] variant from [12]
, noting that the stride2 convolutions are on 3
3 layers instead of on 11 layers as in [16]. We use Nesterov momentum
[29] with of 0.9 following [12] but note that standard momentum as was used in [16] is equally effective. We use a weight decay of 0.0001 and following [16] we do not apply weight decay on the learnable BN coefficients (namely, and in [19]). In order to keep the training objective fixed, which depends on the BN batch size as described in §2.3, we use throughout, regardless of the overall minibatch size. As in [12], we compute the BN statistics using running average (with momentum 0.9).All models are trained for 90 epochs regardless of minibatch sizes. We apply the linear scaling rule from §2.1 and use a learning rate of that is linear in the minibatch size . With workers (GPUs) and samples per worker, as in [16]. We call this number () the reference learning rate, and reduce it by at the 30th, 60th, and 80th epoch, similar to [16].
We adopt the initialization of [15] for all convolutional layers. The 1000way fullyconnected layer is initialized by drawing weights from a zeromean Gaussian with standard deviation of 0.01. We have found that although SGD with a small minibatch is not sensitive to initialization due to BN, this is not the case for a substantially large minibatch. Additionally we require an appropriate warmup strategy to avoid optimization difficulties in early training.
For BN layers, the learnable scaling coefficient is initialized to be 1, except for each residual block’s last BN where is initialized to be 0. Setting in the last BN of each residual block causes the forward/backward signal initially to propagate through the identity shortcut of ResNets, which we found to ease optimization at the start of training. This initialization improves all models but is particularly helpful for large minibatch training as we will show.
We use scale and aspect ratio data augmentation [36] as in [12]. The network input image is a 224224 pixel random crop from an augmented image or its horizontal flip. The input image is normalized by the percolor mean and standard deviation, as in [12].
Handling random variation.
As models are subject to random variation in training, we compute a model’s error rate as the median error of the final 5 epochs. Moreover, we report the mean and standard deviation (std) of the error from 5 independent runs. This gives us more confidence in our results and also provides a measure of model stability.
The random variation of ImageNet models has generally not been reported in previous work (largely due to resource limitations). We emphasize that ignoring random variation may cause unreliable conclusions, especially if results are from a single trial, or the best of many.
Baseline.
Under these settings, we establish a ResNet50 baseline using (8 GPUs in one server) and images per worker (minibatch size of ), as in [16]. Our baseline has a top1 validation error of 23.60% 0.12. As a reference, ResNet50 from
fb.resnet.torch
[12] has 24.01% error, and that of the original ResNet paper [16] has 24.7% under weaker data augmentation.5.2 Optimization or Generalization Issues?
We establish our main results on large minibatch training by exploring optimization and generalization behaviors. We will demonstrate that with a proper warmup strategy, large minibatch SGD can both match the training curves of small minibatch SGD and also match the validation error. In other words, in our experiments both optimization and generalization of large minibatch training matches that of small minibatch training. Moreover, in §5.4 we will show that these models exhibit good generalization behavior to the object detection/segmentation transfer tasks, matching the transfer quality of small minibatch models.
For the following results, we use and , which results in a minibatch size k (we use ‘1k’ to denote 1024). As discussed, our baseline has a minibatch size of and a reference learning rate of . Applying the linear scaling rule gives as the reference learning rate for our large minibatch runs. We test three warmup strategies as discussed in §2.2: no warmup, constant warmup with for 5 epochs, and gradual warmup which starts with and is linearly increased to over 5 epochs. All models are trained from scratch and all other hyperparameters are kept fixed. We emphasize that while better results for any particular minibatch size could be obtained by optimizing hyperparameters for that case; our goal is to match errors across minibatch sizes by using a general strategy that avoids hyperparameter tuning for each minibatch size.
Training error.
Training curves are shown in Figure 2. With no warmup ((a)a), the training curve for large minibatch of k is inferior to training with a small minibatch of across all epochs. A constant warmup strategy ((b)b) actually degrades results: although the small constant learning rate can decrease error during warmup, the error spikes immediately after and training never fully recovers.
Our main result is that with gradual warmup, large minibatch training error matches the baseline training curve obtained with small minibatches, see Figure (c)c. Although the large minibatch curve starts higher due to the low in the warmup phase, it catches up shortly thereafter. After about 20 epochs, the small and large minibatch training curves match closely. The comparison between no warmup and gradual warmup suggests that large minibatch sizes are challenged by optimization difficulties in early training and if these difficulties are addressed, the training error and its curve can match a small minibatch baseline closely.
top1 error (%)  

baseline (single server)  8  32  256  0.1  23.60 0.12 
no warmup, Figure (a)a  256  32  8k  3.2  24.84 0.37 
constant warmup, Figure (b)b  256  32  8k  3.2  25.88 0.56 
gradual warmup, Figure (c)c  256  32  8k  3.2  23.74 0.09 
Validation error.
Table 1 shows the validation error for the three warmup strategies. The nowarmup variant has 1.2% higher validation error than the baseline which is likely caused by the increase in training error (Figure (a)a), rather than overfitting or other causes for poor generalization. This argument is further supported by our gradual warmup experiment. The gradual warmup variant has a validation
error within 0.14% of the baseline (noting that std of these estimates is
0.1%). Given that the final training errors (Figure (c)c) match nicely in this case, it shows that if the optimization issues are addressed, there is no apparent generalization degradation observed using large minibatch training, even if the minibatch size goes from 256 to 8k.Finally, Figure 4 shows both the training and validation curves for the large minibatch training with gradual warmup. As can be seen, validation error starts to match the baseline closely after the second learning rate drop; actually, the validation curves can match earlier if BN statistics are recomputed prior to evaluating the error instead of using the running average (see also caption in Figure 4).
5.3 Analysis Experiments
Minibatch size vs. error.
Figure 1 (page 1) shows top1 validation error for models trained with minibatch sizes ranging from of 64 to 65536 (64k). For all models we used the linear scaling rule and set the reference learning rate as . For models with , we used the gradual warmup strategy always starting with and increasing linearly to the reference learning rate after 5 epochs. Figure 1 illustrates that validation error remains stable across a broad range of minibatch sizes, from 64 to 8k, after which it begins to increase. Beyond 64k training diverges when using the linear learning rate scaling rule.^{5}^{5}5We note that because of the availability of hardware, we simulated distributed training of very large minibatches (12k) on a single server by using multiple gradient accumulation steps between SGD updates. We have thoroughly verified that gradient accumulation on a single server yields equivalent results relative to distributed training.
Training curves for various minibatch sizes.
Each of the nine plots in Figure 3 shows the top1 training error curve for the 256 minibatch baseline (orange) and a second curve corresponding to different size minibatch (blue). Validation errors are shown in the plot legends. As minibatch size increases, all training curves show some divergence from the baseline at the start of training. However, in the cases where the final validation error closely matches the baseline (k), the training curves also closely match after the initial epochs. When the validation errors do not match (k), there is a noticeable gap in the training curves for all epochs. This suggests that when comparing a new setting, the training curves can be used as a reliable proxy for success well before training finishes.
Alternative learning rate rules.
Table (a)a shows results for multiple learning rates. For small minibatches (), gives best error but slightly smaller or larger also work well. When applying the linear scaling rule with a minibatch of 8k images, the optimum error is also achieved with , showing the successful application of the linear scaling rule. However, in this case results are more sensitive to changing . In practice we suggest to use a minibatch size that is not close to the breaking point.
Figure 5 shows the training curves of a 256 minibatch using or . It shows that changing the learning rate in general changes the overall shapes of the training curves, even if the final error is similar. Contrasting this result with the success of the linear scaling rule (that can match both the final error and the training curves when minibatch sizes change) may reveal some underlying invariance maintained between small and large minibatches.
We also show two alternative strategies: keeping fixed at 0.1 or using according to the square root scaling rule that was justified theoretically in [21] on grounds that it scales by the inverse amount of the reduction in the gradient estimator’s standard deviation. For fair comparisons we also use gradual warmup for . Both policies work poorly in practice as the results show.
Batch Normalization initialization.
Table (b)b controls for the impact of the new BN initialization introduced in §5.1. We show results for minibatch sizes 256 and 8k with the standard BN initialization ( for all BN layers) and with our initialization ( for the final BN layer of each residual block). The results show improved performance with for both minibatch sizes, and the improvement is slightly larger for the 8k minibatch size. This behavior also suggests that large minibatches are more easily affected by optimization difficulties. We expect that improved optimization and initialization methods will help push the boundary of large minibatch training.
ResNet101.
Results for ResNet101 [16] are shown in Table (c)c. Training ResNet101 with a batchsize of k and a linearly scaled results in an error of 22.36% vs. the baseline which achieves 22.08% with . In other words, ResNet101 trained with minibatch 8k has a small 0.28% increase in error vs. the baseline. It is likely that the minibatch size of 8k lies on the edge of the useful minibatch training regime for ResNet101, similarly to ResNet50 (see Figure 1).
The training time of ResNet101 is 92.5 minutes in our implementation using 256 Tesla P100 GPUs and a minibatch size of 8k. We believe this is a compelling result if the speedaccuracy tradeoff of ResNet101 is preferred.
ImageNet5k.
Observing the sharp increase in validation error between minibatch sizes of 8k and 16k on ImageNet1k (Figure 1), a natural question is if the position of this ‘elbow’ in the error curve is a function of dataset information content. To investigate this question, we adopt the ImageNet5k dataset suggested by Xie et al. [39] that extends ImageNet1k to 6.8 million images (roughly 5 larger) by adding 4k additional categories from ImageNet22k [33]. We evaluate the 1kway classification error on the original ImageNet1k validation set as in [39].
The minibatch size vs. validation error curve for ImageNet5k is shown in Figure 6. Qualitatively, the curve is very similar to the ImageNet1k curve, showing that for practitioners it is unlikely that even a 5 increase in dataset size will automatically lead to a meaningful increase in useable minibatch size. Quantitatively, using an 8k minibatch increases the validation error by 0.26% from 25.83% for a 256 minibatch to 26.09%. An understanding of the precise relationship between generalization error, minibatch size, and dataset information content is open for future work.
5.4 Generalization to Detection and Segmentation
A low error rate on ImageNet is not typically an end goal. Instead, the utility of ImageNet training lies in learning good features that transfer, or generalize well, to related tasks. A question of key importance is if the features learned with large minibatches generalize as well as the features learned with small minibatches?
To test this, we adopt the object detection and instance segmentation tasks on COCO [27] as these advanced perception tasks benefit substantially from ImageNet pretraining [10]. We use the recently developed Mask RCNN [14] system that is capable of learning to detect and segment object instances. We follow all of the hyperparameter settings used in [14] and only change the ResNet50 model used to initialize Mask RCNN training. We train Mask RCNN on the COCO trainval35k split and report results on the 5k image minival split used in [14].
It is interesting to note that the concept of minibatch size in Mask RCNN is different from the classification setting. As an extension of the imagecentric Fast/Faster RCNN [9, 31], Mask RCNN exhibits different minibatch sizes for different layers: the network backbone uses two images (per GPU), but each image contributes 512 RegionsofInterest for computing classification (multinomial crossentropy), boundingbox regression (smoothL1/Huber), and pixelwise mask ( binomial crossentropy) losses. This diverse set of minibatch sizes and loss functions provides a good test case to the robustness of our approach.
Transfer learning from large minibatch pretraining.
To test how large minibatch pretraining effects Mask RCNN, we take ResNet50 models trained on ImageNet1k with 256 to 16k minibatches and use them to initialize Mask RCNN training. For each minibatch size we pretrain 5 models and then train Mask RCNN using all 5 models on COCO (35 models total). We report the mean box and mask APs, averaged over the 5 trials, in Table (a)a. The results show that as long as ImageNet validation error is kept low, which is true up to 8k batch size, generalization to object detection matches the AP of the small minibatch baseline. We emphasize that we observed no generalization issues when transferring across datasets (from ImageNet to COCO) and across tasks (from classification to detection/segmentation) using models trained with large minibatches.
Linear scaling rule applied to Mask RCNN.
We also show evidence of the generality of the linear scaling rule using Mask RCNN. In fact, this rule was already used without explicit discussion in [16] and was applied effectively as the default Mask RCNN training scheme when using 8 GPUs. Table (b)b provides experimental results showing that when training with 1, 2, 4, or 8 GPUs the linear learning rate rule results in constant box and mask AP. For these experiments, we initialize Mask RCNN from the released MSRA ResNet50 model, as was done in [14].
5.5 Run Time
Figure 7 shows two visualizations of the run time characteristics of our system. The blue curve is the time per iteration as minibatch size varies from 256 to 11264 (11k). Notably this curve is relatively flat and the time per iteration increases only 12% while scaling the minibatch size by 44. Visualized another way, the orange curve shows the approximately linear decrease in time per epoch from over 16 minutes to just 30 seconds. Run time performance can also be viewed in terms of throughput (images / second), as shown in Figure 8. Relative to a perfectly efficient extrapolation of the 8 GPU baseline, our implementation achieves 90% scaling efficiency.
Acknowledgements.
We would like to thank Leon Bottou for helpful discussions on theoretical background, Jerry Pan and Christian Puhrsch for discussions on efficient data loading, Andrew Dye for help with debugging distributed training, and Kevin Lee, Brian Dodds, Jia Ning, Koh Yew Thoon, Micah Harris, and John Volk for Big Basin and hardware support.
References
 [1] J. Bagga, H. Morsy, and Z. Yao. Opening designs for 6pack and Wedge 100. https://code.facebook.com/posts/203733993317833/openingdesignsfor6packandwedge100, 2016.
 [2] M. Barnett, L. Shuler, R. van De Geijn, S. Gupta, D. G. Payne, and J. Watts. Interprocessor collective communication library (intercom). In Scalable HighPerformance Computing Conference, 1994.
 [3] L. Bottou. Curiously fast convergence of some stochastic gradient descent algorithms. Unpublished open problem offered to the attendance of the SLDS 2009 conference, 2009.
 [4] L. Bottou, F. E. Curtis, and J. Nocedal. Opt. methods for largescale machine learning. arXiv:1606.04838, 2016.
 [5] J. Chen, X. Pan, R. Monga, S. Bengio, and R. Jozefowicz. Revisiting Distributed Synchronous SGD. arXiv:1604.00981, 2016.
 [6] K. Chen and Q. Huo. Scalable training of deep learning machines by incremental block training with intrablock parallel optimization and blockwise modelupdate filtering. In ICASSP, 2016.
 [7] R. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu, and P. Kuksa. Natural language processing (almost) from scratch. JMLR, 2011.
 [8] J. Donahue, Y. Jia, O. Vinyals, J. Hoffman, N. Zhang, E. Tzeng, and T. Darrell. Decaf: A deep convolutional activation feature for generic visual recognition. In ICML, 2014.
 [9] R. Girshick. Fast RCNN. In ICCV, 2015.
 [10] R. Girshick, J. Donahue, T. Darrell, and J. Malik. Rich feature hierarchies for accurate object detection and semantic segmentation. In CVPR, 2014.
 [11] W. Gropp, E. Lusk, and A. Skjellum. Using MPI: Portable Parallel Programming with the MessagePassing Interface. MIT Press, Cambridge, MA, 1999.
 [12] S. Gross and M. Wilber. Training and investigating Residual Nets. https://github.com/facebook/fb.resnet.torch, 2016.
 [13] M. Gürbüzbalaban, A. Ozdaglar, and P. Parrilo. Why random reshuffling beats stochastic gradient descent. arXiv:1510.08560, 2015.
 [14] K. He, G. Gkioxari, P. Dollár, and R. Girshick. Mask RCNN. arXiv:1703.06870, 2017.
 [15] K. He, X. Zhang, S. Ren, and J. Sun. Delving deep into rectifiers: Surpassing humanlevel performance on imagenet classification. In ICCV, 2015.
 [16] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
 [17] G. Hinton, L. Deng, D. Yu, G. E. Dahl, A.r. Mohamed, N. Jaitly, A. Senior, V. Vanhoucke, P. Nguyen, T. N. Sainath, et al. Deep neural networks for acoustic modeling in speech recognition: The shared views of four research groups. IEEE Signal Processing Magazine, 2012.
 [18] I. Hubara, M. Courbariaux, D. Soudry, R. ElYaniv, and Y. Bengio. Quantized neural networks: Training neural networks with low precision weights and activations. arXiv:1510.08560, 2016.
 [19] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In ICML, 2015.
 [20] N. S. Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang. On largebatch training for deep learning: Generalization gap and sharp minima. ICLR, 2017.
 [21] A. Krizhevsky. One weird trick for parallelizing convolutional neural networks. arXiv:1404.5997, 2014.
 [22] A. Krizhevsky, I. Sutskever, and G. Hinton. ImageNet classification with deep convolutional neural nets. In NIPS, 2012.
 [23] Y. LeCun, B. Boser, J. S. Denker, D. Henderson, R. E. Howard, W. Hubbard, and L. D. Jackel. Backpropagation applied to handwritten zip code recognition. Neural computation, 1989.
 [24] K. Lee. Introducing Big Basin: Our nextgeneration AI hardware. https://code.facebook.com/posts/1835166200089399/introducingbigbasin, 2017.

[25]
M. Li.
Scaling Distributed Machine Learning with System and Algorithm Codesign
. PhD thesis, Carnegie Mellon University, 2017.  [26] T.Y. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, and S. Belongie. Feature pyramid networks for object detection. In CVPR, 2017.
 [27] T.Y. Lin, M. Maire, S. Belongie, J. Hays, P. Perona, D. Ramanan, P. Dollár, and C. L. Zitnick. Microsoft COCO: Common objects in context. In ECCV. 2014.
 [28] J. Long, E. Shelhamer, and T. Darrell. Fully convolutional networks for semantic segmentation. In CVPR, 2015.
 [29] Y. Nesterov. Introductory lectures on convex optimization: A basic course. Springer, 2004.
 [30] R. Rabenseifner. Optimization of collective reduction operations. In ICCS. Springer, 2004.
 [31] S. Ren, K. He, R. Girshick, and J. Sun. Faster RCNN: Towards realtime object detection with region proposal networks. In NIPS, 2015.
 [32] H. Robbins and S. Monro. A stochastic approximation method. The annals of mathematical statistics, 1951.
 [33] O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, A. C. Berg, and L. FeiFei. ImageNet Large Scale Visual Recognition Challenge. IJCV, 2015.
 [34] P. Sermanet, D. Eigen, X. Zhang, M. Mathieu, R. Fergus, and Y. LeCun. Overfeat: Integrated recognition, localization and detection using convolutional networks. In ICLR, 2014.
 [35] K. Simonyan and A. Zisserman. Very deep convolutional networks for largescale image recognition. In ICLR, 2015.
 [36] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Rabinovich. Going deeper with convolutions. In CVPR, 2015.
 [37] R. Thakur, R. Rabenseifner, and W. Gropp. Optimization of collective comm. operations in MPICH. IJHPCA, 2005.
 [38] Y. Wu, M. Schuster, Z. Chen, Q. V. Le, M. Norouzi, W. Macherey, M. Krikun, Y. Cao, Q. Gao, K. Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv:1609.08144, 2016.
 [39] S. Xie, R. Girshick, P. Dollár, Z. Tu, and K. He. Aggregated residual transformations for deep neural networks. In CVPR, 2017.
 [40] W. Xiong, J. Droppo, X. Huang, F. Seide, M. Seltzer, A. Stolcke, D. Yu, and G. Zweig. The Microsoft 2016 Conversational Speech Recognition System. arXiv:1609.03528, 2016.
 [41] M. D. Zeiler and R. Fergus. Visualizing and understanding convolutional neural networks. In ECCV, 2014.
Comments
There are no comments yet.