Towards Stabilizing Batch Statistics in Backward Propagation of Batch Normalization

01/19/2020 ∙ by Junjie Yan, et al. ∙ FUDAN University 20

Batch Normalization (BN) is one of the most widely used techniques in Deep Learning field. But its performance can awfully degrade with insufficient batch size. This weakness limits the usage of BN on many computer vision tasks like detection or segmentation, where batch size is usually small due to the constraint of memory consumption. Therefore many modified normalization techniques have been proposed, which either fail to restore the performance of BN completely, or have to introduce additional nonlinear operations in inference procedure and increase huge consumption. In this paper, we reveal that there are two extra batch statistics involved in backward propagation of BN, on which has never been well discussed before. The extra batch statistics associated with gradients also can severely affect the training of deep neural network. Based on our analysis, we propose a novel normalization method, named Moving Average Batch Normalization (MABN). MABN can completely restore the performance of vanilla BN in small batch cases, without introducing any additional nonlinear operations in inference procedure. We prove the benefits of MABN by both theoretical analysis and experiments. Our experiments demonstrate the effectiveness of MABN in multiple computer vision tasks including ImageNet and COCO. The code has been released in https://github.com/megvii-model/MABN.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Batch Normalization (BN) (Ioffe and Szegedy, 2015) is one of the most popular techniques for training neural networks. It has been widely proven effective in many applications, and become the indispensable part of many state of the art deep models.

Despite the success of BN, it’s still challenging to utilize BN when batch size is extremely small111In the context of this paper, we use ”batch size/normalization batch size” to refer the number of samples used to compute statistics unless otherwise stated. We use ”gradient batch size” to refer the number of samples used to update weights.. The batch statistics with small batch size are highly unstable, leading to slow convergence during training and bad performance during inference. For example, in detection or segmentation tasks, the batch size is often limited to or per GPU due to the requirement of high resolution inputs or complex structure of the model. Directly computing batch statistics without any modification on each GPU will make performance of the model severely degrade.

To address such issues, many modified normalization methods have been proposed. They can be roughly divided into two categories: some of them try to improve vanilla BN by correcting batch statistics (Ioffe, 2017; Singh and Shrivastava, 2019), but they all fail to completely restore the performance of vanilla BN; Other methods get over the instability of BN by using instance-level normalization (Ulyanov et al., 2016; Ba et al., 2016; Wu and He, 2018), therefore models can avoid the affect of batch statistics. This type of methods can restore the performance in small batch cases to some extent. However, instance-level normalization hardly meet industrial or commercial needs so far, for this type of methods have to compute instance-level statistics both in training and inference, which will introduce additional nonlinear operations in inference procedure and dramatically increase consumption Shao et al. (2019). While vanilla BN uses the statistics computed over the whole training data instead of batch of samples when training finished. Thus BN is a linear operator and can be merged with convolution layer during inference procedure. Figure 1 shows with ResNet-50 (He et al., 2016), instance-level normalization almost double the inference time compared with vanilla BN. Therefore, it’s a tough but necessary task to restore the performance of BN in small batch training without introducing any nonlinear operations in inference procedure.

In this paper, we first analysis the formulation of vanilla BN, revealing there are actually not only but batch statistics involved in normalization during forward propagation (FP) as well as backward propagation (BP). The additional batch statistics involved in BP are associated with gradients of the model, and have never been well discussed before. They play an important role in regularizing gradients of the model during BP. In our experiments (see Figure 2

), variance of the batch statistics associated with gradients in BP, due to small batch size, is even larger than that of the widely-known batch statistics (mean, variance of feature maps). We believe the instability of batch statistics associated with gradients is one of the key reason why BN performs poorly in small batch cases.

Based on our analysis, we propose a novel normalization method named Moving Average Batch Normalization (MABN). MABN can completely get over small batch issues without introducing any nonlinear manipulation in inference procedure. The core idea of MABN is to replace batch statistics with moving average statistics. We substitute batch statistics involved in BP and FP with different type of moving average statistics respectively, and theoretical analysis is given to prove the benefits. However, we observed directly using moving average statistics as substitutes for batch statistics can’t make training converge in practice. We think the failure takes place due to the occasional large gradients during training, which has been mentioned in Ioffe (2017). To avoid training collapse, we modified the vanilla normalization form by reducing the number of batch statistics, centralizing the weights of convolution kernels, and utilizing renormalizing strategy. We also theoretically prove the modified normalization form is more stable than vanilla form.

MABN shows its effectiveness in multiple vision public datasets and tasks, including ImageNet (Russakovsky et al., 2015), COCO (Lin et al., 2014). All results of experiments show MABN with small batch size ( or ) can achieve comparable performance as BN with regular batch size (see Figure 1). Besides, it has same inference consumption as vanilla BN (see Figure 1). We also conducted sufficient ablation experiments to verify the effectiveness of MABN further.

Figure 1: (a) Throughout (iterations per second) in inference procedure using different Normalization methods. The implementation details can be seen in appendix B.2. (b)ImageNet classification validation error vs. batch sizes.

2 Related Work

Batch normalization (BN) (Ioffe and Szegedy, 2015)

normalizes the internal feature maps of deep neural network using channel-wise statistics (mean, standard deviation) along batch dimension. It has been widely proven effectively in most of tasks. But the vanilla BN heavily relies on sufficient batch size in practice. To restore the performance of BN in small batch cases, many normalization techniques have been proposed: Batch Renormalization (BRN) 

(Ioffe, 2017) introduces renormalizing parameters in BN to correct the batch statistics during training, where the renormalizing parameters are computed using moving average statistics; Unlike BRN, EvalNorm (Singh and Shrivastava, 2019) corrects the batch statistics during inference procedure. Both BRN and EvalNorm can restore the performance of BN to some extent, but they all fail to get over small batch issues completely. Instance Normalization (IN) (Ulyanov et al., 2016), Layer Normalization (LN) (Ba et al., 2016), and Group normalization (GN) (Wu and He, 2018) all try to avoid the effect of batch size by utilizing instance level statistics. IN uses channel-wise statistics per instance instead of per batch, while LN uses instance-level statistics along channel dimension. But IN and LN shows no superiority to vanilla BN in most of cases. GN divides all channels in predefined groups, and uses group-wise statistics per instance. It can restore the performance of vanilla BN very well in classification and detection tasks. But it have to introduce extra nonlinear manipulations in inference procedure and severely increase inference consumption, as we have pointed out in Section 1. SyncBN (Peng et al., 2018) handle the small batch issues by computing the mean and variance across multiple GPUs. This method doesn’t essentially solve the problem, and requires a lot of resource.

Apart from operating on feature maps, some works exploit to normalize the weights of convolution: Weight Standardization (Qiao et al., 2019) centralizes weight at first before divides weights by its standard deviation. It still has to combine with GN to handle small batch cases.

3 Statistics in Batch Normalization

3.1 Review of Batch Normalization

First of all, let’s review the formulation of batch Normalization (Ioffe and Szegedy, 2015): assume the input of a BN layer is denoted as , where denotes the batch size, denotes number of features. In training procedure, the normalized feature maps at iteration is computed as:

(1)

where batch statistics and are the sample mean and sample variance computed over the batch of samples at iteration :

(2)

Besides, a pair of parameters , are used to scale and shift normalized value :

(3)

The scaling and shifting part is added in all normalization form by default, and will be omitted in the following discussion for simplicity.

As Ioffe and Szegedy (2015) demonstrated, the batch statistics are both involved in backward propagation (BP). We can derive the formulation of BP in BN as follows: let denote the loss, denote the set of the whole learnable parameters of the model at iteration . Given the partial gradients , the partial gradients is computed as

(4)

where denotes element-wise production, and are computed as

(5)

It can be seen from (5) that and are also batch statistics involved in BN during BP. But they have never been well discussed before.

3.2 Instability of batch statistics

According to Ioffe and Szegedy (2015), the ideal normalization is to normalize feature maps using expectation and variance computed over the whole training data set:

(6)

But it’s impractical when using stochastic optimization. Therefore, Ioffe and Szegedy (2015)

uses mini-batches in stochastic gradient training, each mini-batch produces estimates the mean and variance of each activation. Such simplification makes it possible to involve mean and variance in BP. From the derivation in section 3.1, we can see batch statistics

, are the Monte Carlo (MC) estimators of population statistics , respectively at iteration . Similarly, batch statistics , are MC estimators of population statistics , at iteration . , are computed over the whole data set. They contain the information how the mean and the variance of population will change as model updates, so they play an important role to make trade off between the change of individual sample and population. Therefore, it’s crucial to estimate the population statistics precisely, in order to regularize the gradients of the model properly as weights update.

It’s well known the variance of MC estimator is inversely proportional to the number of samples, hence the variance of batch statistics dramatically increases when batch size is small. Figure 2 shows the change of batch statistics from a specific normalization layer of ResNet-50 during training on ImageNet. Regular batch statistics (orange line) are regarded as a good approximation for population statistics. We can see small batch statistics (blue line) are highly unstable, and contains notable error compared with regular batch statistics during training. In fact, the bias of and in BP is more serious than that of and  (see Figure 2(c), 2(d)). The instability of small batch statistics can worsen the capacity of the models in two aspects: firstly the instability of small batch statistics will make training unstable, resulting in slow convergence; Secondly the instability of small batch can produce huge difference between batch statistics and population statistics. Since the model is trained using batch statistics while evaluated using population statistics, the difference between batch statistics and population statistics will cause inconsistency between training and inference procedure, leading to bad performance of the model on evaluation data.

(a)
(b)
(c)
(d)
Figure 2: Plot of batch statistics from layer1.0.bn1 in ResNet-50 during training. The formulation of these batch statistics (, , , ) have been shown in Section 3.1. Blue line represents the small batch statistic () to compute, while orange line represents the regular batch statistics(). The x-axis represents the iterations, while the y-axis represents the norm of these statistics in each figures. Notice the mean of and is close to zero, hence norm of and essentially represent their standard deviation.

4 Moving Average Batch Normalization

Based on the discussion in Section 3.2, the key to restore the performance of BN is to solve the instability of small batch statistics. Therefore we considered two ways to handle the instability of small batch statistics: using moving average statistics to estimate population statistics, and reducing the number of statistics by modifying the formulation of normalization.

4.1 Substitute batch statistics by Moving Average Statistics.

Moving average statistics seem to be a suitable substitute for batch statistics to estimate population statistics when batch is small. We consider two types of moving average statistics: simple moving average statistics (SMAS)222The exponential moving average (EMA) for a series is calculated as: . and exponential moving average statistics (EMAS)333The simple moving average (SMA) for a series is calculated as: .. The following theorem shows under mild conditions, SMAS and EMAS are more stable than batch statistics:

Theorem 1

Assume there exists a sequence of random variable (r.v.)

, which are independent, uniformly bounded, i.e. , and have uniformly bounded density. Define:

(7)

where . If the sequence satisfies

(8)

then we have

(9)

If the sequence satisfies

(10)

then we have

(11)

The proof of theorem 1 can be seen in appendix A.1. Theorem 1 not only proves moving average statistics have lower variance compared with batch statistics, but also reveals that with large momentum , EMAS is better than SMAS with lower variance. However, using SMAS and EMAS request different conditions: Condition (8) means the sequence of the given statistics need to weakly converge to a specific random variable. For , , they converge to the ”final” batch statistics , (when training finished), hence condition (8) is satisfied, EMAS can be applied to replace , ; Unfortunately , don’t share the same property, EMAS is not suitable to take replace of , . However, under the assumption that learning rate is extremely small, the difference between the distribution of and is tiny, thus condition (10) is satisfied, we can use SMAS to replace , . In a word, we can use EMAS , to replace , , and use SMAS , to replace , in (1) and (4), where

(12)
(13)

Notice neither of SMAS and EMAS is the unbiased substitute for batch statistics, but the bias can be extremely small comparing with expectation and variance of batch statistics, which is proven by equation 11 in theorem 1, our experiments also prove the effectiveness of moving average statistics as substitutes for small batch statistics (see Figure 3, 4 in appendix B.1).

Relation to Batch Renormalization

Essentially, Batch Renormalization (BRN) (Ioffe, 2017) replaces batch statistics , with EMAS , both in FP (1) and BP (4). The formulation of BRN during training is written as:

(14)

where , . Based on our analysis, BRN successfully eliminates the effect of small batch statistics and by EMAS, but the small batch statistics associated with gradients and remains during backward propagation, preventing BRN from completely restoring the performance of vanilla BN.

4.2 Stabilizing Normalization by reducing the number of Statistics

To further stabilize training procedure in small batch cases, we consider normalizing feature maps using instead of and . The formulation of normalization is modified as:

(15)

where . Given , the backward propagation is:

(16)

The benefits of the modification seems obvious: there’s only two batch statistics left during FP and BP, which will introduce less instability into the normalization layer compared with vanilla normalizing form. In fact we can theoretically prove the benefits of the modification by following theorem:

Theorem 2

If the following assumptions hold:

  1. , ;

  2. ;

  3. ;

Then we have:

(17)

The proof can be seen in appendix A.2. According to (17), is larger than that of , the gap is at least , which mainly caused by the variance of . So the modification essentially reduces the variance of the gradient by eliminating the batch statistics during BP. Since is a Monte Carlo estimator, the gap is inversely proportional to batch size. This can also explain why the improvement of modification is significant in small batch cases, but modified BN shows no superiority to vanilla BN within sufficient batch size (see ablation study in section 5.1).

Centralizing weights of convolution kernel

Notice theorem 2 relies on assumption 3. The vanilla normalization naturally satisfies by centralizing feature maps, but the modified normalization doesn’t necessarily satisfy assumption 3. To deal with that, inspired by Qiao et al. (2019), we find centralizing weights of convolution kernels, named as Weight Centralization (WC) can be a compensation for the absence of centralizing feature maps in practice:

(18)

where , are the input and output of the convolution layer respectively. We conduct further ablation study to clarify the effectiveness of WC (see Table 4 in appendix B.2). It shows that WC has little benefits to vanilla normalization, but it can significantly improve the performance of modified normalization. We emphasize that weight centralization is only a practical remedy for the absence of centralizing feature maps. The theoretical analysis remains as a future work.

Clipping and renormalizing strategy.

In practice, we find directly substituting batch statistics by moving average statistics in normalization layer will meet collapse during training. Therefore we take use of the clipping and renormalizing strategy from BRN (Ioffe, 2017).

All in all, the formulation of proposed method MABN is:

(19)
(20)

where the EMAS is computed as , SMAS is defined as (13). The renormalizing parameter is set as .

5 Experiments

This section presents main results of MABN on ImageNet (Russakovsky et al., 2015), COCO (Lin et al., 2014). Further experiment results on ImangeNet, COCO and Cityscapes (Cordts et al., 2016) can be seen in appendix B.2, B.3, B.4 resepectively. We also evaluate the computational overhead and memory footprint of MABN, the results is shown in appendix B.5.

5.1 Image Classification in Imagenet

We evaluate the proposed method on ImageNet (Russakovsky et al., 2015) classification datatsets with 1000 classes. All classification experiments are conducted with ResNet-50 (He et al., 2016). More implementation details can be found in the appendix B.2.

BN
(Regular)
BN
(Small)
BRN
(Small)
MABN
(Small, )
val error
(vs BN(Regular))
-
Table 1: Comparison of top-1 error rate (%) of ResNet-50 on ImageNet Classification. The gradient batch size is per GPU. Regular means normalization batch size is , while Small means normalization batch size is .

Comparison with other normalization methods.

Our baseline is BN using small () or regular () batch size, and BRN (Ioffe, 2017) with small batch size. We don’t present the performance of instance-level normalization counterpart on ImageNet, because they are not linear-type method during inference time, and they also failed to restore the performance of BN (over ), according to Wu and He (2018). Table 1 shows vanilla BN with small batch size can severely worsen the performance of the model(); BRN (Ioffe, 2017) alleviates the issue to some extent, but there’s still remaining far from complete recovery(); While MABN almost completely restore the performance of vanilla BN().

We also compared the performance of BN, BRN and MABN when varying the batch size (see Figure 1). BN and BRN are heavily relies on the batch size of training, though BRN performs better than vanilla BN. MABN can always retain the best capacity of ResNet-50, regardless of batch size during training.

Experiment
Number
Vanilla
Normalization
Modified
Normalization
EMAS in FP SMAS in BP Top-1 Error ()
1⃝ (BN, regular)
2⃝ (regular)
3⃝ (BN)
4⃝ (BRN)
5⃝ -
6⃝
7⃝
8⃝ (MABN)
Table 2: Ablation study on ImageNet Classification with ResNet-50. The normalization batch size is 2 in all experiments otherwise stated. The memory size is and momentum is when using SMAS, otherwise the momentum is . ”-” means the training can’t converge.

Ablation study on ImageNet.

We conduct ablation experiments on ImageNet to clarify the contribution of each part of MABN (see table 2). With vanilla normalization form, replacing batch statistics in FP with EMAS (as BRN) will restore the performance to some extents(, comparing 3⃝ and 4⃝), but there’s still a huge gap (

, comparing 1⃝ and 4⃝) from complete restore. Directly using SMAS in BP with BRN will meet collapse during training (5⃝), no matter how we tuned hyperparameters. We think it’s due to the instability of vanilla normalization structure in small cases, so we modify the formulation of normalization shown in section 4.2. The modified normalization even slightly outperforms BRN in small batch cases (comparing 4⃝ and 6⃝). However, modified normalization shows no superiority to vanilla form (comparing 1⃝ and 2⃝), which can be interpreted by the result of theorem

2. With EMAS in FP, modified normalization significantly reduces the error rate further (comparing 6⃝ and 7⃝), but still fail to restore the performance completely (, comparing 1⃝ and 7⃝). Applying SMAS in BP finally fills the rest of gap, almost completely restore the performance of vanilla BN in small batch cases ( ,comparing 1⃝ and 8⃝).

5.2 Detection and Segmentation in COCO from scratch

We conduct experiments on Mask R-CNN (He et al., 2017) benchmark using a Feature Pyramid Network(FPN) (Lin et al., 2017a) following the basic setting in He et al. (2017). We train the networks from scratch (He et al., 2018) for times. Only the backbone contains normalization layers. More implementation details and experiment results can be seen in the appendix B.3.

BN
BRN
SyncBN
MABN
Table 3: Comparison of Average Precision(AP) of Mask-RCNN on COCO Detection and Segmentation. The gradients batch size is 16. The normalization batch size of SyncBN is 16, while that of BN, BRN and MABN are both 2. The momentum of BRN and MABN are both 0.98, while the momentum of BN and SyncBN are both 0.9. The buffer size () is 16).

Table 3 shows the result of MABN compared with vanilla BN, BRN and SyncBN (Peng et al., 2018). It can be seen that MABN outperforms vanilla BN and BRN by a clear margin and get comparable performance with SyncBN. Quite different from Imagenet experiments, we update the parameters every single batch (with ). With such a complex pipeline, MABN still achieves a comparable performance as SyncBN.

6 Conclusion

This paper reveals the existence of the batch statistics and involved in backward propagation of BN, and analysis their influence to training process. This discovery provides a new perspective to understand why BN always fails in small batch cases. Based on our analysis, we propose MABN to deal with small batch training problem. MABN can completely restore the performance of vanilla BN in small batch cases, and is extraordinarily efficient compared with its counterpart like GN. Our experiments on multiple computer vision tasks (classification, detection, segmentation) have shown the remarkable performance of MABN.

Acknowledgement

This research was partially supported by National Key RD Program of China (No. 2017YFA0700800), Beijing Academy of Artificial Intelligence (BAAI), and NSFC under Grant No. 61473091.

References

  • J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §1, §2.
  • L. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille (2017) Deeplab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE transactions on pattern analysis and machine intelligence 40 (4), pp. 834–848. Cited by: §B.4.
  • M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson, U. Franke, S. Roth, and B. Schiele (2016)

    The cityscapes dataset for semantic urban scene understanding

    .
    In

    Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    ,
    Cited by: §B.4, §5.
  • S. Gross and M. Wilber (2016) External Links: Link Cited by: §B.2.
  • K. He, R. Girshick, and P. Dollár (2018) Rethinking imagenet pre-training. arXiv preprint arXiv:1811.08883. Cited by: §5.2.
  • K. He, G. Gkioxari, P. Dollár, and R. Girshick (2017) Mask r-cnn. In Proceedings of the IEEE international conference on computer vision, pp. 2961–2969. Cited by: §B.3, §5.2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2015) Delving deep into rectifiers: surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision, pp. 1026–1034. Cited by: §B.2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §1, §5.1.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In

    International Conference on Machine Learning

    ,
    pp. 448–456. Cited by: §1, §2, §3.1, §3.1, §3.2.
  • S. Ioffe (2017) Batch renormalization: towards reducing minibatch dependence in batch-normalized models. In Advances in neural information processing systems, pp. 1945–1953. Cited by: §1, §1, §2, §4.1, §4.2, §5.1.
  • T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, and S. Belongie (2017a) Feature pyramid networks for object detection. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2117–2125. Cited by: §B.3, §5.2.
  • T. Lin, P. Goyal, R. Girshick, K. He, and P. Dollár (2017b) Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision, pp. 2980–2988. Cited by: §B.3.
  • T. Lin, M. Maire, S. Belongie, J. Hays, P. Perona, D. Ramanan, P. Dollár, and C. L. Zitnick (2014) Microsoft coco: common objects in context. In European conference on computer vision, pp. 740–755. Cited by: §1, §5.
  • C. Peng, T. Xiao, Z. Li, Y. Jiang, X. Zhang, K. Jia, G. Yu, and J. Sun (2018) Megdet: a large mini-batch object detector. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 6181–6189. Cited by: §2, §5.2.
  • S. Qiao, H. Wang, C. Liu, W. Shen, and A. Yuille (2019) Weight standardization. arXiv preprint arXiv:1903.10520. Cited by: §2, §4.2.
  • O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, et al. (2015) Imagenet large scale visual recognition challenge. International journal of computer vision 115 (3), pp. 211–252. Cited by: §1, §5.1, §5.
  • W. Shao, T. Meng, J. Li, R. Zhang, Y. Li, X. Wang, and P. Luo (2019) Ssn: learning sparse switchable normalization via sparsestmax. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 443–451. Cited by: §1.
  • S. Singh and A. Shrivastava (2019) EvalNorm: estimating batch normalization statistics for evaluation. arXiv preprint arXiv:1904.06031. Cited by: §1, §2.
  • D. Ulyanov, A. Vedaldi, and V. Lempitsky (2016) Instance normalization: the missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022. Cited by: §1, §2.
  • Y. Wu and K. He (2018) Group normalization. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 3–19. Cited by: §B.2, §1, §2, §5.1.
  • H. Zhao, J. Shi, X. Qi, X. Wang, and J. Jia (2017) Pyramid scene parsing network. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2881–2890. Cited by: §B.4, §B.4.

Appendix A Sketch of proof

a.1 Proof of theorem 1

If the condition (8) is satisfied, i.e. weakly converge to . Since has uniformly bounded density, we have:

(21)
(22)

Since are independently, hence we have:

(23)

as . Hence (9) has been proven.

If the condition (10) is satisfied. Since is uniformly bounded, then , , . As , We have

(24)

Similarly, we have

(25)

Therefore combining (24) and (25), we have

(26)

For a fixed memory size , as , we have

(27)

Therefore, (11) has been proven.

a.2 Proof of theorem 2

Without loss of generality, given the backward propagation of two normalizing form of a single input with batch :

(28)

where , are the batch statistics, and , are the EMAS, defined as before. We omitted the subscript for simplicity. Then the variance of partial gradients w.r.t. inputs is written as

(29)
(30)
(31)
(32)
(33)

where (30) is satisfied due to assumption 1. The variance of is so small that can be regarded as a fixed number; (31) is satisfied because

(34)
(35)

Due to assumption 2, the correlation between individual sample and batch statistics is close to , hence we have

(36)
(37)
(38)

Besides, is close to 0 according to assumption 3, hence

(39)

(32) is satisfied due to the definition of and , we have

(40)

Similar to , the variance of is also too small that can be regarded as a fixed number due to assumption 1, so (33) is satisfied.

Appendix B Experiments

b.1 Statistics Analysis

We analyze the difference between small batch statistics () and regular batch statistics () with the modified formulation of normalization (15) shown in Section 4.2.

(a)
(b)
Figure 3: Plot of batch statistics from layer1.0.bn1 in ResNet-50 with a modified structure during training. The formulation of these batch statistics ( , ) is shown in section 4.2, 3.1 respectively. Blue line represents the small batch statistic (), while orange line represents the regular batch statistics (). We use the small batch statistics to update the network parameters.
(a)
(b)
Figure 4: Plot of batch statistics from layer1.0.bn1 in ResNet-50 with MABN. The formulation of these batch statistics ( , ) is shown in section 4.2, 3.1 respectively. Blue line represents the SMA batch statistic(2+30), while orange line represents the regular batch statistics(32). We use the moving average batch statistics to update the network parameters.

Figure 3 illustrates the change of small batch statistics and regular batch statistics in FP and BP respectively. The variance of small batch statistics is much higher than the regular one. However, when we use SMAS as a approximation for regular batch statistics, the gap between SMAS and regular batch statistics is not obvious as shown in Figure 4.

b.2 Experiments on ImageNet

Implementation details.

All experiments on ImageNet are conducted across 8 GPUs. We train models with a gradient batch size of images per GPU. To simulate small batch training, we split the samples on each GPU into groups where denotes the normalization batch size. The batch statistics are computed within each group individually.

All weights from convolutions are initialized as He et al. (2015). We use to initialize all and to initialize all in normalization layers. We use a weight decay of for all weight layers including and (following Wu and He (2018)). We train iterations (approximately equal to epoch when gradient batch size is ) for all models, and divide the learning rate by at , and iterations. The data augmentation follows Gross and Wilber (2016). The models are evaluated by top-1 classification error on center crops of pixels in the validation set. In vanilla BN or BRN, the momentum , in MABN, the momentum .

Additional ablation studies.

Table 4 shows the additional ablation results. We test all possible combination of all three kinds of statistics (SMAS, EMAS, BS) in FP and BP. The experiments results strongly prove our theoretical analysis in section 4.3. Besides, we verify the necessity of centralizing weights with modified normalization form.

Experiment
number
w/o centralizing
feature maps
Centralizing
weights
FP statistics BP statistics Top-1 Error ()
1⃝ EMAS SMAS 23.58(MABN)
2⃝ SMAS SMAS 26.63
3⃝ EMAS EMAS 24.83
4⃝ EMAS BS 27.03
5⃝ BS BS 29.68
6⃝ EMAS SMAS 25.45
7⃝ EMAS BS 29.57
8⃝ BS BS 32.95
9⃝ BS BS 35.22
1⃝0 BS BS 34.27
1⃝1 BS BS 23.35(regular)
Table 4: Further ablation study on ImageNet with ResNet-50. The normalization batch size is 2 in all experiments. The buffer size () is and momentum is when using SMA statistics, otherwise the momentum is . BS means vanilla batch statistics.

b.3 Experiments on COCO

Implementation details.

We train the Mask-RCNN pipeline from scratch with MABN. We train the model on 8 GPUs, with 2 images per GPU. We train our model using COCO 2014 train and trainval35k dataset. We evaluate the model on COCO 2014 minival dataset.We set the momentum for all MABN layers. We report the standard COCO merics , , for bounding box detection and , , for instance segmentation. Other basic settings follow He et al. (2017).

MABN used on heads.

We build mask-rcnn baseline using a Feature Pyramid Network(FPN)(Lin et al., 2017a) backbone. The base model is ResNet-50. We train the models for iterations. We use 4conv1fc instead of 2fc as the box head. Both backbone and heads contain normalization layers. We replace all normalization layers in each experiments. While training models with MABN, we use batch statistics in normalization layers on head during first 10,000 iterations. Table 5 shows the result. The momentum are set to 0.98 in BRN and MABN.

BN
BRN
SyncBN
MABN
Table 5: Comparision of Average Precision(AP) of Mask-RCNN on COCO Detection and Segmentation. The gradients batch size is 16. The normalization batch size of SyncBN is 16, while that of BN, BRN, MABN are both 2, the buffer size () of MABN is 32.

Training from pretrained model.

We compare the performance of MABN and SyncBN when training model based on ImageNet pretrained weights for 2x iterations. The results are shown in Table

SyncBN
MABN
Table 6: Comparision of Average Precision(AP) of Mask-RCNN on COCO Detection and Segmentation. The gradients batch size is 16. The normalization batch size of SyncBN is 16, while that of BN, BRN, MABN are both 2, the buffer size () of MABN is 32.

Training from scratch for one-stage model.

We also compare MABN and SyncBN based on one-stage pipeline. We build on retinanet(Lin et al., 2017b) benchmark. We train the model from scratch for iterations. The results are shown in Table 7.

SyncBN
MABN
Table 7: Comparison of Average Precision(AP) of retinanet on COCO Detection. The gradients batch size is 16. The normalization batch size of SyncBN is 16, while that of MABN is 2.

All experiment results shows MABN can get comparable as SyncBN, and significantly outperform BN on COCO.

b.4 Semantic Segmentation in Cityscapes

We evaluate semantic segmentation in Cityscapes(Cordts et al., 2016). It contains 5,000 high quality pixel-level finely annotated images collected from 50 cities in different seasons. We conduct experiments on PSPNET baseline and follow the basic settings mentioned in Zhao et al. (2017).

For fair comparison, our backbone network is ResNet-101 as in Chen et al. (2017). Since we centralize weights of convolutional kernel to use MABN, we have to re-pretrain our backbone model on Imagenet dataset. During fine-tuning process, we linearly increase the learning rate for 3 epoch (558 iterations) at first. Then we follow the ”poly” learning schedule as Zhao et al. (2017). Table 8 shows the result of MABN compared with vanilla BN, BRN and SyncBN. The buffer size () of MABN is 16, the modementum of MABN and BRN is 0.98.

pretrain Top-1 mIoU
BN
BRN
SyncBN
MABN
Table 8: Results on Cityscapes testing set.

Since the statistics (mean and variance) is more stable in a pre-trained model than a random initialized one, the gap between vanilla BN and SyncBN is not significant (+). However, MABN still outperforms vanilla BN by a clear margin.(+). Besides, BRN shows no obvious superiority to vanilla BN(+) on Cityscapes dataset.

b.5 Computational Overhead

We compare the computational overhead and memory footprint of BN, GN and MABN. We use maskrcnn with resnet50 and FPN as benchmark. We compute the theoretical FLOPS during inference and measure the inference speed when a single image () goes through the backbone (resnet50 + FPN). We assume BN and MABN can be absorbed in convolution layer during inference. GN can not be absorbed in convolution layer, so its FLOPS is larger than BN and MABN. Besides GN includes division and sqrt operation during inference, therefore it’s much slower than BN and MABN during inference time.

We also monitor the training process of maskrcnn on COCO (8 GPUs, 2 images per GPU), and show its memory footprint and training speed. Notice we have not optimized the implementation of MABN, so its training speed is a little slower than BN and GN.

FLOPS (M) Memory (GB) Training Speed (iter/s) Inference Speed (iter/s)
BN
GN
MABN
Table 9: Computational overhead and memory footprint of BN, GN and MABN.