BA^2M: A Batch Aware Attention Module for Image Classification

03/28/2021 ∙ by Qishang Cheng, et al. ∙ 0

The attention mechanisms have been employed in Convolutional Neural Network (CNN) to enhance the feature representation. However, existing attention mechanisms only concentrate on refining the features inside each sample and neglect the discrimination between different samples. In this paper, we propose a batch aware attention module (BA2M) for feature enrichment from a distinctive perspective. More specifically, we first get the sample-wise attention representation (SAR) by fusing the channel, local spatial and global spatial attention maps within each sample. Then, we feed the SARs of the whole batch to a normalization function to get the weights for each sample. The weights serve to distinguish the features' importance between samples in a training batch with different complexity of content. The BA2M could be embedded into different parts of CNN and optimized with the network in an end-to-end manner. The design of BA2M is lightweight with few extra parameters and calculations. We validate BA2M through extensive experiments on CIFAR-100 and ImageNet-1K for the image recognition task. The results show that BA2M can boost the performance of various network architectures and outperforms many classical attention methods. Besides, BA2M exceeds traditional methods of re-weighting samples based on the loss value.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Due to the powerful ability of feature enhancement, the attention mechanism has been widely used in designing various network architectures for the image classification task. More specifically, the most popular attention modules could be divided into three types: Channel Attention [13, 18, 12], Local Spatial Attention [20, 19, 25, 38] and Global Spatial Attention [2, 1]. Despite their great successes, these efforts mainly focus on exploring sample-wise attention, whose enhancements are limited within each sample, including the spatial and channel dimensions. However, The optimization of networks is based on batch of images, and all the images contribute equally during the optimization process. When the features of two samples from separated categories are approximately similar, the existing attention mechanism cannot effectively separate them in the feature space. As shown in Fig. 1, we show t-SNE [24] visualization of features refined by different attention mechanisms. The images are selected from 14 similar categories in the validation set of ImageNet. When the features of two samples from separated categories are approximately similar, the existing attention mechanism (SE [13], BAM [25], CBAM [38], GC [2]) cannot effectively separate them in the feature space and suffers from the overlapping of feature clusters.. For example, the red rectangles in the Fig. 1 cover the feature clusters of ”Greater Swiss Mountain Dog” and ”Bernese Mountain Dog”.

Fig. 1: t-SNE [24] visualization of generated features of 700 images refined by different attention mechanisms. The feature clusters refined by ResNet-50+BM are more discriminated than ResNet-50 [9], SE [13], BAM [25], CBAM [38], and GE [12].
Fig. 2: Illustration of the proposed batch aware attention module (BM). The top-row illustrates BNet. The bottom-left is a sub-module of BM which is used to calculate sample-wise attention representation of each image (SAR). The bottom-right shows the detailed components of three attention modules (, , ).

To address the above issues, in this paper, we propose a batch aware attention module (BM) to adaptively re-weight the CNN features of different samples within a batch to improve their discrimination in the image classification task. As shown in Fig. 1, it is seen that the feature clusters of ResNet-50+BM are more separated than ResNet-50 [9], SE [13], BAM [25], CBAM [38], and GE [12], and the feature clusters are more compact. It demonstrates that BM could help CNN to obtain more discriminative features. Hence, we first construct a sample-wise attention representation (SAR) for each sample based on three kinds of attention mechanisms including the channel attention (CA), local spatial attention (LSA), and global spatial attention (GSA). The CA captures inter-channel information of the given feature map, LSA generates local information on a limited neighbourhood in the spatial dimension of the given feature map, and GSA explores long-range interactions of pixels of the feature map. Then we feed the SARs of the whole batch to softmax layer to obtain normalized weights (Normalized-SAR) which is going to assigned to each sample. The weights serve to increase the features’ difference between samples with different categories.

We embed our BM into various network architectures and test their performances in the image classification task, including ResNet (ResNet-18/34/50/101 [9, 11], Wide-ResNet-28 [40], ResNeXt-29 [39]), DenseNet-100 [14], ShuffleNetv2 [23], MobileNetv2 [30], and EfficientsNetB0 [35]. Extensive experiments on CIFAR-100 and ImageNet-1K show that BM yields significant improvements with negligible complexity increase. It achieves +2.93% Top-1 accuracy improvement on ImageNet-1K with ResNet-50. In particular, BM significantly outperforms many popular sample-wise attention works such as SE [13], SRM [18], SK [20], SGE [19], GE [12], BAM [25], CBAM [38], GC [2], AA [1] and SASA [26] in our experiments.

Ii Related work

Deep feature learning

. Convolutional Neural Networks are potent tools to extract features for computer vision. With the rise of Convolutional Neural Networks (CNN), many sub-fields of computer vision rely on the design of new network structures to be greatly promoted. From VGGNets 

[33] to Inception models [34], ResNet [9]

, the powerful feature extraction capabilities of CNN are demonstrated. On the one hand, many works constantly increased the depth 

[11], width [40], and cardinality [39] of the network, and reformulate the information flow between network [14]

. On the other hand, many works are also concerned with designing new modular components and activation functions, such as depth-wise convolution 

[30], octave convolution [4], ELU [5, 31], and PReLU [8]. These designs can further improve learning and representational properties of CNN.

Attention for Feature Enhancement. In view of the fact that the importance of all extracted features are different, multiple attention mechanisms are developed to enhance the deep feature learning in image recognition task [36, 13, 12, 18, 20, 19, 25, 38, 2, 1], which focus on highlighting the semantic-aware features within each individual sample. In 2017, Wang et al. [36] proposed an attention residual module to boost the ability of residual unit to generate attention-aware features. Furthermore, Li et al. [20]

designed a dynamic selection mechanism that allows each neuron pay attention to different size of kernels adaptively. GE 

[12] redistributed the gathered information from long-range feature interactions to local spatial extent. Hu et al. [13] proposed SE-module to focus on the channel-wise attention by re-calibrating feature maps on channels using signals aggregated from entire feature maps. Lee et al. [18] extract statistic information from each channels and re-weight per-channel vial channel-independent information.  [25] and CBAM [38] refine features beyond the channel with introduced spatial dimensions. Beyond the above attention-works, He et al. [2] explored the advantages of combination of SE [13] and Non-Local [37] which further obtains global context information. The above attention mechanisms focus on the intra-features enhancement and ignore the inter-features enhancement.

Iii Approach

Fig. 3: Illustration of some feature maps in the second block of ResNet-50.

Iii-a Motivation

Network optimization is carried out with batch mode. In the batch domain with images, we define the input feature with the label (). The original softmax loss for the batch could be described as:

(1)
(2)

where is the number of images in a batch. for the original is . is the loss for each sample . where denotes the

elements of the vector of class score

. is the number of classes. is the number of images in the batch. Equ. 1 means that all the images contribute equally during the optimization process. However, due to the different complexity of the image content, images should have diverse importance while calculating the loss. Focal Loss (FL) [21] and online hard example mining (OHEM) [32] show that adaptive adjustment

can effectively improve the optimization process. OHEM ranks the loss values of each sample in descending order and adds weight to the hard sample with high loss. FL changes the Cross-Entropy loss function and pays more attention with hard ones, so that the samples with large loss values are assigned larger weights, and vice versa. However, these works do not take the image content into account when determining the importance, which acts on the loss directly.

Combine with Equ. 2, the Equ. 1 could be modified as follows:

(3)

where . is the weights of the last fully connected layer. is the -th column of . The bias is omit for simplify analysis.

is the generated feature to the classifier. Scale the loss in Equ. 

3, we could have the following relationship:

(4)

where inequalities 32 is true if . The proof of inequalities 32 is in the supplementary material. weights the features directly instead of the loss value of each sample. Besides, is the upper bound of original softmax loss . Therefore, if goes to 0, then L has to go to 0. That is, the method of feature weighting can get better results than the direct weighting of loss value. The is also very important, and it should be related to the content of features. The existing attention mechanisms generate weights based on the content of features. Thus, we try to generate based on attention mechanisms.

Iii-B Batch Aware Attention Module (BM)

The attention mechanism is to strengthens information related to optimization goals and suppresses irrelevant information. Existing sample-wise attention are limited to conduct feature enrichment in the spatial or channel dimensions of the feature map. We will extend the attention to batch dimension to generate the weight which is related to the content through the sample-wise attention and collaboratively rectify the features of all samples.

(5)

where is the feature map. The content of the feature is arranged by three-dimension. Therefore, we decompose the feature from the channel, local spatial and global spatial. denotes the channel attention, represents the local spatial attention, expresses the global spatial attention. The three modules are calculated in parallel branches, which are shown in the bottom-left of Fig. 2. We will introduce each module in detail in the following sections.

Iii-B1 Channel attention module ().

First, we introduce the channel attention module. The channel contains rich feature information because the feature map is arranged by feature planes. As shown in Fig. 3, each plane has different semantic information [13]. In order to exploit the inter-channel information. We follow the practices in SE [13] to build a channel attention module. First, we aggregate the feature map of each channel by taking global average pooling (GAP) on the feature map and produce a channel vector

. To estimate attention across channels from channel vector

, we use two fully connected layers ( and ). In order to save parameter overhead, we set the output dimension of as , where control the overhead of parameters. After

, we use a batch normalization (BN) layer 

[15] to regulate the output scale. In summary then, we computed the channel attention () as:

(6)

where , are trainable parameters and BN denotes a batch normalization operation.

Iii-B2 Local Spatial attention module ().

Then, we introduce the local spatial attention module. The spatial dimension usually reflect details of image content. Constructing a local spatial attention module can highlight the area of interest effectively. We cascade one convolution layer between two convolution layers to perform local semantic spatial signal aggregation operations. We also utilize a batch normalization layer to adjust the output scale. In short, we compute local spatial attention response () as:

(7)

where denotes a convolution operation, and the superscripts denote convolutional filter sizes. We divided in channel for groups for saving memory. Through this spindle-like structural design, we can effectively aggregate local spatial information with small parameter overhead.

Iii-B3 Global Spatial attention module ().

Finally, we introduce the global spatial attention module. The global spatial emphasizes the long-range relationship between every pixel in the spatial extent. It can be used as a supplement to local spatial attention. Many efforts [1, 41, 2] claim that the long-range interaction can help the feature to be more power. Inspired by the manners of extracting long-range interaction in [1]. We simply generate global spatial attention map () as following:

(8)

where , and are convolutional operation with kernel. After reshaping, , and . The reshaping operations are not shown in Eqn. 8 explicitly. However, the specific structure can refer to the bottom-right in Fig. 2. denotes matrix multiplication. is processed by groups for saving memory. Finally, could capture long range interactions to compensate the locality of .

Iii-B4 Combination and Excitation for Batch

In this section, we combine the results generated by the above sample-wised attention modules. First, we normalize the scale of the result spectrum. Then, we choose the result with the strongest activation. Finally, we perform dimension reduction operation on the output of the previous step to one dimension along the channel. In general, is formulated as:

(9)

where and . and . is the maximum operation, mean is the operation of dimension reduction along channel (C). Thus, the . The value for each is related to the complexity of contents.

Based on the above operations, we get sample-wise attention representation (SAR) of each sample, and then we use the SARs of the batch to rectify each sample. The output of batch-wise representation is:

(10)

where represents the number of images in the current batch. To make the inequalities 32 be true, we normalize the vector of SAR in a batch and make SARs more distinguishable. We use softmax function, as shown below. After applying softmax, each value will be in the interval . Then, we multiply the SARs with the batch of sample feature maps.

(11)
(12)

where is the vector of weights in current training batch. The batch of refined output could be computed as:

(13)

where denotes element-wise multiplication. It means that each sample multiplies with a normalized batch-wise attention representation .

Iii-C Batch Aware Attention Network (BNet)

Given the feature of variations across different layers in a CNN, we try to embed BM into different blocks in existing CNNs to build BNet. A list of composed blocks could represent most classical ConvNet:

(14)

where represents basic or residual block (, is the number of blocks) which is shown in Fig. 4 (a) and (b). is the -th image in the batch. represents the classifier. Activation units (e.g

. ReLU

[17]) are omit in Equ. 14. We place BM between blocks to build BNet:

(15)

where represents BNet, represents the residual block, is the element of W in the block. is four for ResNet-50.

Fig. 4: Illustration of the proposed block in BNet, basic block and residual block.

Iii-D Inference and Complexity analysis

In this section, we will discuss the inference and complexity of BM. BM is an adjunct to standard training processes. During the training stage, BM equally scales the value of the output vector by weighting samples, thus influencing the loss value. However, during the inference stage, the batch aware weight only affects the absolute value of the output vector, and the rank of elements in the vector does not change. Thus it does not change the final prediction. To verify that BM does not affect the results in the inference stage, we activate the BM in the ResNet50 on the ImageNet. The size of test batch is . All the results are . Thus, during the inference phase, we test one image a time as the common practice. BM generates SAR value for each test image using Eqn. 9. The feature is refined by multiplying with the corresponding SAR. Besides, we deactivate BM in the inference phase because we usually infer one image at a time in the common practice where the size of the batch is 1. At this time, excitation BM for the batch will lead to over-calculation.

The complexity of BM could be divided into three parts, and we use FLOPs and Params to measure them, respectively. We have following equations based on [7]:

(16)
(17)

where the input feature map is . We can see that based on Eqn. 17 and 16, R plays an important role in controlling the complexity of the BM. We will discuss the influence of R in ablation experiments. The Flops of in Eqn. 17 have two parts. The first part is the time-complexity of three convolutional operations, and the second part is related to two matrix multiplication operations. However, there are no extra parameters in matrix multiplication, Thus the parameter measure in is .

Iv Experiments

In order to evaluate the effect of proposed BM, we conduct several experiments on two widely-used image classification datasets: CIFAR-100 [16] and ImageNet-1k [29]

. PyTorch-0.11.0 library 

[27] is utilized to implement all experiments on NVIDIA TITANX GPU graphics cards. CIFAR-related and ImageNet-related experiments are on a single card and eight cards, respectively. For each configuration, we repeat the experiments for five times with different random seeds and report the median result and the best (per run) validation performance of the architectures over time. The setting of hyper-parameters is the same as [25].

Iv-a Experiments on CIFAR-100

CIFAR-100 [16] is a tiny natural image dataset, which contains 100 different classes. Each image is in size of

. There are 500 images for training and 100 images validation per class. We adopt some simple data augmentation strategies, such as random crops, random horizontal flips, and mean-variance normalization.We performed image classification with a range of ResNet architecture and its variants: ResNet-50 

[9], ResNet-101 [9], PreResNet-110 [10], Wide-ResNet-28 (WRN-28) [40] and ResNeXt-29 [39]. We reported classification error on validation set as well as model size (Params) and computational complexity (Flops).

Results are shown in Table. XI. It could be concluded from the results that BM could consistently improve the classification performance of the network regardless of the network structure without increasing the amount of calculation. From the above results, we can conclude that BM is an effective method.

Architecture Params(M) FLOPs(G) Error (%)
ResNet-50 [9] 23.71 1.22 21.49
+BM (ours) 24.21 1.33 17.60(-3.89)
ResNet-101 [9] 42.70 2.44 20.00
+BM (ours) 42.32 2.55 17.88(-2.12)
PreResNet-110 [10] 1.73 0.25 22.22
+BM (ours) 1.86 0.25 21.41(-0.81)
WRN-28(w=8) [40] 23.40 3.36 20.40
+BM (ours) 23.44 3.46 17.95(-2.45)
WRN-28(w=10) [40] 36.54 5.24 18.89
+BM (ours) 36.59 5.39 17.30(-1.59)
ResNeXt-29 [39] 34.52 4.99 18.18
+BM (ours) 34.78 5.53 16.10(-2.08)
TABLE I: Image classification results on CIFAR-100  [16] across a range of ResNet architectures and its variants. The results of compared methods could be find in corresponding papers.

Iv-B Experiments on ImageNet-1K

In order to further validate the effectiveness of BM, in this section, we perform image classification experiments on more challenging 1000-class ImageNet dataset [29], which contains about 1.3 million training color images as well as 50,000 validation images. We use the same data augmentation strategy as [9, 11] for training and a single-crop evaluation with the size of in the testing phase. We report FLOPs and Params for each model, as well as the top-1 and top-5 classification errors on the validation set. We use a range of ResNet architectures and their variants: ResNet-18 [9], ResNet-34 [9], Wide-ResNet-18 (WRN-18) [40], ResNeXt-50 [39] and ResNeXt-101 [39].

The results are shown in Table. XII. The networks with BM outperform the baseline network in performance, which is an excellent proof that BM could improve the discrimination of features in such a challenging dataset. It is worth noting that it could be negligible in the overhead of both parameters and computation, which shows that BM could significantly improve the model capacity efficiently. When BM is combined with WRN, the increment of the parameter is a little bigger. The abnormal increment is mainly due to the wide of significantly increases the number of channel in the convolutional layer.

Architecture Params.(M) FLOPs(G) Top1-Error(%)
ResNet-18 [9] 11.69 1.81 29.60
+BM (ours) 12.28 1.92 28.62(-0.98)
ResNet-34 [9] 21.80 3.66 26.69
+BM (ours) 21.92 3.77 25.15(-1.54)
WRN-18(w=1.5) [40] 25.88 3.87 26.85
+BM (ours) 28.43 4.69 23.60(-3.25)
WRN-18(w=2) [40] 45.62 6.70 25.63
+BM (ours) 47.88 7.56 23.60(-2.03)
ResNeXt-50 [39] 25.03 3.77 22.85
+BM (ours) 27.25 4.47 20.98(-1.87)
ResNeXt-101 [39] 44.18 7.51 21.54
+BM (ours) 45.51 8.23 20.05(-1.49)
TABLE II: Image classification results on ImageNet-1K dataset [29] across a range of ResNet architectures and its variants.

Iv-C Comparison with Focal Loss and OHEM

In this section, we compare BM with Focal Loss [21] and online hard example mining (OHEM) [32]. B

M re-weights samples in a batch by calculating the attention value of the samples, this could also solve the problem of class imbalance in the object detection field to some extent. Focal loss and OHEM are specifically designed to solve the problem of class imbalance. OHEM selects a candidate ROI with a massive loss to solve the category imbalance problem. Focal loss achieves the effect of sample selection by reducing the weight of the easily categorized samples so that the model more focus on difficult samples during training. We perform experiments on the union set of PASCAL VOC 2007 trainval and PASCAL VOC 2012 trainval (VOC0712) and evaluate on the PASCAL VOC 2007 test set.

For Focal Loss, We adopt RetinaNet [21] as our detection method and ImageNet pre-trained ResNet-50 as our baseline networks. Then we replace Focal Loss with SoftmaxLoss and change the ResNet-50 to ResNet-50 with BM. Results are shown in Table III, the Focal loss is lower than BM on both Recall and mAP.

Detector Method Recall (%) mAP (%)
RetinaNet Focal Loss 96.27 79.10
RetinaNet BM(ours) 97.12(+0.85) 79.60(+0.50)
TABLE III: Object detection results compared with Focal Loss. Baseline detector is RetinaNet. Backbone is ResNet-50 pre-trained on ImageNet.

For OHEM, We adopt Faster-RCNN [28] as our detection method and ImageNet pre-trained ResNet-50 as our baseline network. From the results summarized in Table IV. We can see that the Recall and mAP of BM are both optimal. Compared with the traditional methods of re-weight samples based on the loss value, the weights generated by BM are better.

Detector Method Recall (%) mAP (%)
Faster R-CNN - 92.97 80.10
Faster R-CNN OHEM [32] 90.57 80.50
Faster R-CNN BM (ours) 93.99(+1.02) 81.00(+0.90)
TABLE IV: Object detection results compared with OHEM. Baseline detector is Faster R-CNN. Backbone is ResNet-50 pre-trained on ImageNet.

Iv-D Ablation Experiments

Iv-D1 Comparison with Sample-wise Attention

In this section, we systematically compare BM with some attention works through image classification on ImageNet-1K [29]. We choose SE [13] and SRM [18], SK [20], SGE [19] and GE [12]), BAM [25], CBAM [38], GC [2], AA [1] and SASA [26] as compared methods. These works mainly focus on exploring sample-wise attention, whose enhancements are limited within each individual sample including the spatial and channel dimensions. We choose ResNet-50 and ResNet-101 as baseline networks and replace attention modules at each block. All models maintain the same parameter settings during training.

ResNet-50 ResNet-101
Params.(M) FLOPs(G) Top1-Error(%) Top5-Error(%) Params.(M) FLOPs(G) Top1-Error(%) Top5-Error(%)
Baseline 25.56 3.86 24.56 7.50 44.55 7.57 23.38 6.88
+SE [13] 28.09 3.86 23.14 6.70 49.33 7.58 22.35 6.19
+SRM [18] 25.62 3.88 22.87 6.49 44.68 7.62 21.53 5.80
+SK [20] 26.15 4.19 22.46 6.30 45.68 7.98 21.21 5.73
+SGE [19] 25.56 4.13 22.42 6.34 44.55 7.86 21.20 5.63
+GE [12] 31.20 3.87 22.00 5.87 33.70 3.87 21.88 5.80
+BAM [25] 25.92 3.94 24.02 7.18 44.91 7.65 22.44 6.29
+CBAM [38] 28.09 3.86 22.66 6.31 49.33 7.58 21.51 5.69
+GC [2] 28.08 3.87 22.30 6.34 49.36 7.86 25.36 7.93
+AA [1] 25.80 4.15 22.30 6.20 45.40 8.05 21.30 5.60
+SASA [26] 18.00 7.20 22.40 - - - - -
+BM (ours) 26.21(+0.65) 4.32(+0.46) 21.63(-2.93) 5.80(-1.70) 45.87(+1.32) 8.05(+0.48) 20.85(-2.53) 5.58(-1.30)
TABLE V: Image classification results on ImageNet-1K dataset  [29] with ResNet-50 and ResNet-101 across a range of attention mechanisms:SE [13], SRM [18], GE [12], SK [20], SGE [19], BAM [25], CBAM [38], GC [2], AA [1] and SASA [26].

The results are shown in Table. V, It could be inferred from the results that BM could significantly improve the performance of baseline and BM is superior to sample-wise attention mechanisms in improving the performance of baseline with little overhead on parameters and computation. Besides, In Table. XII and Table. V, we have performed experiments with ResNet18/34/50/101 on ImageNet. The gains are significant improved from 0.98% for ResNet18 to 2.53% for ResNet101 as the network architecture becoming deeper.

Iv-D2 Cooperation with Sample-wise Attention

In this section, we perform experiments with a combination of batch-with and other sample-wise attention methods. We assume that the combination with other sample-with attention methods should improve model performance even more. We choose SE [13], CBAM [38], BAM [25] to be the instance methods. We use ResNet-50 as the baseline networks. The results are shown in Table. VI. The combination of BM with other attention methods could improve performance even more. Especially, the combination of BAM and BM even gets the classification error of .

Architecture Params.(M) FLOPs(G) Error (%)
ResNet-50 [9] 23.71 1.22 21.49
+SE [13] 26.24 1.23 20.72
+CBAM [38] 27.44 1.34 21.01
+BAM [25] 24.07 1.25 20.00
+BM 24.21 1.33 17.60
+SE+BM 26.74 1.34 17.52
+CBAM+BM 27.94 1.45 17.43
+BAM+BM 24.57 1.36 17.00
TABLE VI: Image classification results with other attention modules on CIFAR-100 [16].

Iv-D3 The position of BM

The location for BM in the residual-like network has two choices. The first is to embed BM inside the block, as shown in Fig. 4 (c ). The other is to embed BM between blocks, as shown in Fig. 4 (d). In this section, we conduct experiments to determine the embed mode. The results are exhibited in Table. VII

. The average error and standard deviation of

five random runs are recorded and the best results are in bold. We found that placing BM between blocks is better than inside blocks in improving the performance (17.60% Vs 17.71%) and stability of the model (0.15 Vs 0.40). Thus, we assign BM between blocks to build different BNet.

Position Error(%) std
Inside (Fig. 4(c)) 17.71 0.40
Between (Fig. 4(d)) 17.60 0.15
TABLE VII: The results of the position of BM.

Iv-D4 Sensitivity of Hyper-Parameters

In this section, we empirically show the effectiveness of design choice. For this ablation study, we use ResNet-50 as the baseline architecture and train it on the CIFAR-100 dataset.

The size of Batch (N) . We perform experiments with under the memory limitation of GPU (12G). The learning rate (lr) for is . For other , the lr changes linearly. Each setting runs for five times, and we reported mean test error. As shown in Fig. 5, classification performance continuously improves with larger N, which indicates that BM is more efficient when more samples are involved in re-weighting. In experiments, we set N to be 256 to balance computing efficiency and hardware overhead.

Fig. 5: Test error vs The size of Batch. All models are trained on one GPU and learning rate changed linearly.

Reduction (R).We conduct experiments to determine hyper-parameters in BM, which is the reduction ratio (R). R is used to control the number of channels in three modules, which enable us to control the capacity and the overhead of BM. The minimum number of channels in the ablation experiment is 32. In Table. VIII, we compare mean performance for five random runs of and corresponding model size (Params) and computational complexity (Flops, Madds). Interestingly, as R increases, the overhead of computation and model size continue to decrease. However, the corresponding performance first drops and then rises. When is 32, the corresponding performance is the best. To balance model complexity and performance, we choose R to be 32.

R Params(MB) Flops(GB) Error(%) std
2 70.80 2.86 17.71 0.40
4 40.65 1.78 18.04 0.19
8 30.12 1.47 18.13 0.10
16 25.99 1.37 18.01 0.13
32 24.21 1.33 17.60 0.15
TABLE VIII: Results with different reduction in . We train ResNet-50 as the base network on CIFAR-100.

Design Choices of BM. We perform experiments to determine the design choices of BM. There are three attention modules in BM. Thus, we have seven combinations of attention mechanisms. The results are shown in Table. IX. From the second row to the fourth row, we can find that LSA ( ) plays the most crucial role in BM. Besides, the combination of LSA and GSA gets performance which is better than single LSA. The above results demonstrate that the long-range interactions could compensate the local spatial attention. Furthermore, the full version of BM achieves the best performance.

CA LSA GSA Error (%) Std
18.06 0.20
17.86 0.30
18.18 0.29
17.71 0.30
18.00 0.23
17.78 0.28
17.60 0.15
TABLE IX: The results of different design choices of BM. The last row is the result of BM.

V Conclusion

We have presented a batch-wise attention module(BM). It provides new insight into how the attention mechanism can enhance the discrimination power of features. BM rectifies the features by sample-wise attention across a batch. It is a lightweight module and can be embed between blocks of a CNN. To verify BM’s efficacy, we conducted extensive experiments with various state-of-the-art models. The results proofed that BM can boost the performance of the baseline and outperform other methods.

-a Supplementary material

In this supplementary material, we give the proof of inequalities.4 and present more results and conduct some object detection and instance segmentation experiments to verify the effectiveness of BM.

-B Proof of Equation

Lemma 1. For variable and hold

(18)

where

(19)

proof. Let’s define a function as follow:

(20)

where

(21)

The derivative of is:

(22)

The last line with inequality sign of Equ. 22 is because the function is monotonic decreasing function when .

Therefore, is also a monotonic decreasing function. For each element , we have the following relation:

(23)

Therefore,

(24)

where

(25)

Let , where , . We will get the following conclusion:

(26)

Due to and , thus, . We use multiply the both sides of Equ. 26. Therefore, we get

(27)

Lemma 2. For variable hold

(28)

where

(29)

proof. Let’s employ mathematical induction with the number .

STEP1

, When , The Equ. 28 gets the equal sign.

STEP2

, When , The Equ. 28 could be hold by Lemma 1.

STEP3

, if ( and ), The Equ. 28 is true. We get

(30)

When ,

(31)

Therefore, Equ. 28 holds when .

CONCLUTION

,

The Equ. 28 is true, When .

Therefore, When we apply Lemma 2. in the loss function , we could got the following relationship.

(32)

From Equ. 32, we could know that we could re-weight the feature map instead of loss value.

-C Experiments with Lightweight Networks

We present more results with classical lightweight networks (DensNet-100 [14], ShuffleNetv2 [23], MobileNetv2 [30] and EfficientNetB0 [35]) on CIFAR-100. Especially, the EfficientnetB0 is obtained by Neural Architecture Search (NAS). Results are shown in the Table. X.

Architecture Params(M) FLOPs(G) Err (%)
DensNet-100 [14] 0.76 0.29 21.95
+BM (ours) 0.78 0.30 20.63(-1.32)
ShuffleNetv2 [23] 1.3 0.05 26.8
+BM (ours) 1.92 0.06 25.49(-1.31)
MobileNetv2  [30] 2.26 0.07 30.17
+BM (ours) 2.31 0.07 26.77(-3.40)
EfficientNetB0 [35] 2.80 0.03 36.01
+BM (ours) 4.18 0.03 33.98(-2.03)
TABLE X: Image classification results on CIFAR-100 with lightweight networks and NAS network.

-D Experiments on CIFAR-100

We presents more results with other attention mechanisms on CIFAR-100 in Table. XI.

Architecture Params(M) FLOPs(G) Err (%)
ResNet-50 [9] 23.71 1.22 21.49
+SE [13] 26.24 1.23 20.72
+CBAM [38] 27.44 1.34 21.01
+BAM  [25] 24.07 1.25 20.00
+BM (ours) 24.21 1.33 17.60(-3.89)
ResNet-101 [9] 42.70 2.44 20.00
+SE [13] 45.56 2.54 20.89
+CBAM [38] 49.84 2.54 20.30
+BAM [25] 43.06 2.46 19.61
+BM (ours) 42.32 2.55 17.88(-2.12)
PreResNet-110 [10] 1.73 0.25 22.22
+BAM [25] 1.73 0.25 21.96
+SE [13] 1.93 0.25 21.85
+BM (ours) 1.86 0.25 21.41(-0.81)
WRN-28(w=8) [40] 23.40 3.36 20.40
+GE [12] - - 19.74
+SE [13] 23.58 3.36 19.85
+BAM [25] 23.42 3.37 19.06
+BM (ours) 23.44 3.46 17.95(-2.45)
WRN-28(w=10) [40] 36.54 5.24 18.89
+GE [12] 36.30 5.20 20.20
+SE [13] 36.50 5.20 19.00
+BAM [25] 36.57 5.25 18.56
+AA [1] 36.20 5.45 18.40
+BM (ours) 36.59 5.39 17.30(-1.59)
ResNeXt-29 [39] 34.52 4.99 18.18
+BAM [25] 34.61 5.00 16.71
+BM (ours) 34.78 5.53 16.10(-2.08)
TABLE XI: Image classification results on CIFAR-100  [16] across a range of ResNet architectures and its variants. The results of compared methods could be find in corresponding papers.

-E Experiments on ImageNet

We presents more results with other attention mechanisms on ImageNet in Table. XII.

Architecture Params.(M) FLOPs(G) Top1Err(%) Top5Err(%)
ResNet-18 [9] 11.69 1.81 29.60 10.55
+SE [13] 11.78 1.81 29.41 10.22
+BAM [25] 11.71 1.82 28.88 10.10
+CBAM [38] 11.78 1.82 29.27 10.09
+BM (ours) 12.28 1.92 28.62(-0.98) 10.04(-0.51)
ResNet-34 [9] 21.80 3.66 26.69 8.60
+SE [13] 21.96 3.66 26.13 8.35
+BAM [25] - - 25.71 8.21
+CBAM [38] 21.96 3.67 25.99 8.24
+BM (ours) 21.92 3.77 25.15(-1.54) 8.13(-0.47)
WRN-18(w=1.5) [40] 25.88 3.87 26.85 8.88
+BAM [25] 25.93 3.88 26.67 8.69
+SE [13] 26.07 3.87 26.21 8.47
+CBAM [38] 26.08 3.87 26.10 8.43
+BM (ours) 28.43 4.69 23.60(-3.25) 7.60(-1.28)
WRN-18(w=2) [40] 45.62 6.70 25.63 8.20
+BAM [25] 45.71 6.72 25.00 7.81
+SE [13] 45.97 6.70 24.93 7.65
+CBAM [38] 45.97 6.70 24.84 7.63
+BM (ours) 47.88 7.56 23.60(-2.03) 7.23(-0.97)
ResNeXt-50 [39] 25.03 3.77 22.85 6.48
+BAM [25] 25.39 3.85 22.56 6.40
+SE [13] 27.56 3.77 21.91 6.04
+CBAM [38] 27.56 3.77 21.92 5.91
SKNet-50 [20] 27.50 4.47 20.79 -
+BM (ours) 27.25 4.47 20.98(-1.87) 5.85(-0.63)
ResNeXt-101 [39] 44.18 7.51 21.54 5.75
+SE [13] 48.96 7.51 21.17 5.66
+CBAM [38] 48.96 7.52 21.07 5.59
SKNet-101 [20] 48.90 8.46 20.19 -
+BM (ours) 45.51 8.23 20.05(-1.49) 5.22(-0.53)
TABLE XII: Image classification results on ImageNet-1K dataset [29] across a range of ResNet architectures and its variants.

-F BM in Network

We perform experiments to determine the choice of BM within a network. We use four blocks in the ResNet-50. Thus, we have 15 combinations. The results are shown in Table. XIII. When inserting one BM into the ResNet-50, the performance at the second block is the best, and the performance of the first block is the worst. Then, embedding BM at the fourth block can put the performance forward further. Furthermore, ResNet-50 has the highest performance () when all four blocks are embedded in BM. The result reveals the rationality of calculating SARs at variant layers.

No BM One BM Two BMs Three BMs Four BMs
Block1
Block2
Block3
Block4
Mean Err (%) 21.49 17.96 17.87 17.94 17.88 18.05 18.33 17.99 17.91 17.78 17.95 17.70 18.02 17.90 17.93 17.60
Std 0.13 0.14 0.12 0.24 0.22 0.36 0.30 0.36 0.17 0.21 0.32 0.11 0.18 0.17 0.21 0.15
TABLE XIII: The result of Combination of BM within a network. We use ResNet-50 as an instance network. The average error and standard deviation of five random runs are reported and the best results are in bold.

-G Object Detection and Segmentation

In order to further verify the generalization performance of B

M, we further carried out object detection and instance-level segmentation on the MS COCO

[22]

. The dataset has 80 classes. We train with the union of 118k train images and report ablations on the remaining 5k val images (minival). We train all models for 24 epochs (Lr schd=2x) using synchronized SGD with a weight decay of 0.0001 and a momentum of 0.9. The learning rate is initialized as 0.01 and decays by a factor of 10 at the 16th and 22th epochs. We report bounding box AP of different IoU thresholds form 0.5 to 0.95, and object size (small (S), medium (M), large (L)).We adopt Faster-RCNN

[28]and Mask R-CNN framework [6] as our detection method and ImageNet pre-trained ResNet-50/101 as our baseline networks. As shown in Table XIV and Table XV, we can see that the detectors embedded in BM are better than the baseline method under different evaluation indicators.

For the instance-level segmentation experiments, we used the classic Mask R-CNN framework [6] as our instance-level segmentation method and ImageNet pre-trained ResNet-50/101 as baseline networks. We report the standard COCO metrics, including AP (averaged over IoU thresholds), , , and , , (AP at different scales). Unless noted, AP is evaluating using mask IoU. The results are shown in Table XVI. We can see that regardless of how the evaluation indicators change, the Mask R-CNN embedded in BM is even much better than the baseline method. We can conclude that BM can boost the performance of the detector based on the improvement of the ability in feature extraction under the condition of the small size of the batch.

For experiments of object detection and instance segmentation, we utilize MMDetection [3] as the platform. All the related hyper-parameters are not changed. Here we study the performance improvement after inserting BM into the backbone network. Because we use the same detection method, the performance gain can only be attributed to the feature enhancement capabilities, given by BM.

Detector backbone
Faster R-CNN ResNet-50 37.60 58.90 40.90 21.80 41.50 48.80
Faster R-CNN ResNet-50 w BM 38.50(+0.9) 60.00(+1.1) 41.90 (+1.0) 22.20 (+0.4) 42.10(+0.6) 49.80(+1.0)
Faster R-CNN ResNet-101 39.40 60.60 43.00 22.10 43.60 52.10
Faster R-CNN ResNet-101 w BM 40.60(+1.2) 62.10(+1.5) 44.40 (+1.4) 22.70 (+0.6) 44.30(+0.7) 53.60(+1.5)
TABLE XIV: Object detection single-model results on COCO minival (bounding box AP).The baseline detector is Faster R-CNN with ResNet-50/101 .
Detector backbone
Mask R-CNN ResNet-50 38.30 59.70 41.30 22.50 41.70 50.40
Mask R-CNN ResNet-50 w BM 39.80(+1.50) 60.90(+1.20) 43.70(+2.40) 23.00(+0.50) 43.10(+1.40) 52.10(+1.70)
Mask R-CNN ResNet-101 40.40 61.50 44.10 22.20 44.80 52.90
Mask R-CNN ResNet-101 w BM 42.20(+1.80) 62.90(+1.40) 46.60(+2.50) 24.00(+1.80) 46.70(+1.90) 55.30(+2.40)
TABLE XV: Object detection single model results on COCO minival (bounding box AP). The baseline is Mask R-CNN with ResNet-50/101.
Method backbone AP
Mask R-CNN ResNet-50 34.80 56.10 36.90 16.10 37.20 52.90
Mask R-CNN ResNet-50 w BM 36.10(+1.30) 57.80(+1.70) 38.80(+1.90) 17.00(+0.90) 38.70(+1.50) 52.90(+0.00)
Mask R-CNN ResNet-101 36.50 58.10 39.10 18.40 40.20 50.40
Mask R-CNN ResNet-101 w BM 37.70(+1.20) 59.70(+1.60) 40.90(+1.80) 18.60(+0.20) 41.00(+0.80) 55.90(+5.50)
TABLE XVI: Instance segmentation single-model results on COCO minival (mask AP). The baseline is Mask R-CNN with ResNet-50/101.

References

  • [1] I. Bello, B. Zoph, A. Vaswani, J. Shlens, and Q. V. Le (2019) Attention augmented convolutional networks. CoRR abs/1904.09925. Cited by: TABLE XI, §I, §I, §II, §III-B3, §IV-D1, TABLE V.
  • [2] Y. Cao, J. Xu, S. Lin, F. Wei, and H. Hu (2019) GCNet: non-local networks meet squeeze-excitation networks and beyond. CoRR abs/1904.11492. Cited by: §I, §I, §II, §III-B3, §IV-D1, TABLE V.
  • [3] K. Chen, J. Wang, J. Pang, Y. Cao, Y. Xiong, X. Li, S. Sun, W. Feng, Z. Liu, J. Xu, Z. Zhang, D. Cheng, C. Zhu, T. Cheng, Q. Zhao, B. Li, X. Lu, R. Zhu, Y. Wu, J. Dai, J. Wang, J. Shi, W. Ouyang, C. C. Loy, and D. Lin (2019) MMDetection: open mmlab detection toolbox and benchmark. CoRR abs/1906.07155. Cited by: §-G.
  • [4] Y. Chen, H. Fan, B. Xu, Z. Yan, Y. Kalantidis, M. Rohrbach, S. Yan, and J. Feng (2019) Drop an octave: reducing spatial redundancy in convolutional neural networks with octave convolution. CoRR abs/1904.05049. Cited by: §II.
  • [5] D. Clevert, T. Unterthiner, and S. Hochreiter (2015) Fast and accurate deep network learning by exponential linear units (elus). CoRR abs/1511.07289. Cited by: §II.
  • [6] K. He, G. Gkioxari, P. Dollár, and R. B. Girshick (2017) Mask R-CNN. In IEEE International Conference on Computer Vision ICCV, pp. 2980–2988. Cited by: §-G, §-G.
  • [7] K. He and J. Sun (2015) Convolutional neural networks at constrained time cost. In

    IEEE Conference on Computer Vision and Pattern Recognition, CVPR

    ,
    pp. 5353–5360. Cited by: §III-D.
  • [8] K. He, X. Zhang, S. Ren, and J. Sun (2015) Delving deep into rectifiers: surpassing human-level performance on imagenet classification. In IEEE International Conference on Computer Vision (ICCV), pp. 1026–1034. Cited by: §II.
  • [9] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770–778. Cited by: TABLE XI, TABLE XII, Fig. 1, §I, §I, §II, §IV-A, §IV-B, TABLE I, TABLE II, TABLE VI.
  • [10] K. He, X. Zhang, S. Ren, and J. Sun (2016) Identity mappings in deep residual networks. In IEEE European Conference on Computer Vision (ECCV), pp. 630–645. Cited by: TABLE XI, §IV-A, TABLE I.
  • [11] K. He, X. Zhang, S. Ren, and J. Sun (2016) Identity mappings in deep residual networks. In IEEE European Conference on Computer Vision (ECCV), pp. 630–645. Cited by: §I, §II, §IV-B.
  • [12] J. Hu, L. Shen, S. Albanie, G. Sun, and A. Vedaldi (2018) Gather-excite: exploiting feature context in convolutional neural networks. In Advances in Neural Information Processing Systems (NeurIPS), pp. 9423–9433. Cited by: TABLE XI, Fig. 1, §I, §I, §I, §II, §IV-D1, TABLE V.
  • [13] J. Hu, L. Shen, and G. Sun (2018) Squeeze-and-excitation networks. In IEEE Conference on Computer Vision and Pattern Recognition CVPR, pp. 7132–7141. Cited by: TABLE XI, TABLE XII, Fig. 1, §I, §I, §I, §II, §III-B1, §IV-D1, §IV-D2, TABLE V, TABLE VI.
  • [14] G. Huang, Z. Liu, L. van der Maaten, and K. Q. Weinberger (2017) Densely connected convolutional networks. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2261–2269. Cited by: §-C, TABLE X, §I, §II.
  • [15] S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In

    Proceedings of the 32nd International Conference on Machine Learning (ICML)

    ,
    pp. 448–456. Cited by: §III-B1.
  • [16] A. Krizhevsky and G. Hinton (2009) Learning multiple layers of features from tiny images. Master’s thesis, Department of Computer Science, University of Toronto. Cited by: TABLE XI, §IV-A, TABLE I, TABLE VI, §IV.
  • [17] A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems (NIPS), pp. 1106–1114. Cited by: §III-C.
  • [18] H. Lee, H. Kim, and H. Nam (2019) SRM : A style-based recalibration module for convolutional neural networks. CoRR abs/1903.10829. Cited by: §I, §I, §II, §IV-D1, TABLE V.
  • [19] X. Li, X. Hu, and J. Yang (2019) Spatial group-wise enhance: improving semantic feature learning in convolutional networks. CoRR abs/1905.09646. Cited by: §I, §I, §II, §IV-D1, TABLE V.
  • [20] X. Li, W. Wang, X. Hu, and J. Yang (2019) Selective kernel networks. CoRR abs/1903.06586. Cited by: TABLE XII, §I, §I, §II, §IV-D1, TABLE V.
  • [21] T. Lin, P. Goyal, R. B. Girshick, K. He, and P. Dollár (2017) Focal loss for dense object detection. In IEEE International Conference on Computer Vision ICCV, pp. 2999–3007. Cited by: §III-A, §IV-C, §IV-C.
  • [22] T. Lin, M. Maire, S. J. Belongie, J. Hays, P. Perona, D. Ramanan, P. Dollár, and C. L. Zitnick (2014) Microsoft COCO: common objects in context. In IEEE European Conference on Computer Vision (ECCV), pp. 740–755. Cited by: §-G.
  • [23] N. Ma, X. Zhang, H. Zheng, and J. Sun (2018) ShuffleNet V2: practical guidelines for efficient CNN architecture design. In IEEE European Conference on Computer Vision (ECCV), pp. 122–138. Cited by: §-C, TABLE X, §I.
  • [24] L. v. d. Maaten and G. Hinton (2008) Visualizing data using t-sne. Journal of machine learning research, JMLR, pp. 2579–2605. Cited by: Fig. 1, §I.
  • [25] J. Park, S. Woo, J. Lee, and I. S. Kweon (2018) BAM: bottleneck attention module. In Proceedings of the British Machine Vision Conference (BMVC), pp. 147. Cited by: TABLE XI, TABLE XII, Fig. 1, §I, §I, §I, §II, §IV-D1, §IV-D2, TABLE V, TABLE VI, §IV.
  • [26] N. Parmar, P. Ramachandran, A. Vaswani, I. Bello, A. Levskaya, and J. Shlens (2019) Stand-alone self-attention in vision models. In Advances in Neural Information Processing Systems (NeurIPS), pp. 68–80. Cited by: §I, §IV-D1, TABLE V.
  • [27] A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer (2017) Automatic differentiation in pytorch. Cited by: §IV.
  • [28] S. Ren, K. He, R. B. Girshick, and J. Sun (2017) Faster R-CNN: towards real-time object detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell., pp. 1137–1149. Cited by: §-G, §IV-C.
  • [29] O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. S. Bernstein, A. C. Berg, and F. Li (2015) ImageNet large scale visual recognition challenge. International Journal of Computer Vision (IJCV), pp. 211–252. Cited by: TABLE XII, §IV-B, §IV-D1, TABLE II, TABLE V, §IV.
  • [30] M. Sandler, A. G. Howard, M. Zhu, A. Zhmoginov, and L. Chen (2018) MobileNetV2: inverted residuals and linear bottlenecks. In IEEE Conference on Computer Vision and Pattern Recognition CVPR, pp. 4510–4520. Cited by: §-C, TABLE X, §I, §II.
  • [31] A. Shah, E. Kadam, H. Shah, and S. Shinde (2016) Deep residual networks with exponential linear unit. CoRR abs/1604.04112. Cited by: §II.
  • [32] A. Shrivastava, A. Gupta, and R. B. Girshick (2016) Training region-based object detectors with online hard example mining. In IEEE Conference on Computer Vision and Pattern Recognition CVPR, pp. 761–769. Cited by: §III-A, §IV-C, TABLE IV.
  • [33] K. Simonyan and A. Zisserman (2015) Very deep convolutional networks for large-scale image recognition. In International Conference on Learning Representations, ICLR, Cited by: §II.
  • [34] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. E. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Rabinovich (2015) Going deeper with convolutions. In IEEE Conference on Computer Vision and Pattern Recognition, CVPR, Cited by: §II.
  • [35] M. Tan and Q. V. Le (2019) EfficientNet: rethinking model scaling for convolutional neural networks. In Proceedings of the 36th International Conference on Machine Learning ICML, pp. 6105–6114. Cited by: §-C, TABLE X, §I.
  • [36] F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, and X. Tang (2017) Residual attention network for image classification. In IEEE Conference on Computer Vision and Pattern Recognition, CVPR, pp. 6450–6458. Cited by: §II.
  • [37] X. Wang, R. B. Girshick, A. Gupta, and K. He (2018) Non-local neural networks. In IEEE Conference on Computer Vision and Pattern Recognition,CVPR, pp. 7794–7803. Cited by: §II.
  • [38] S. Woo, J. Park, J. Lee, and I. S. Kweon (2018) CBAM: convolutional block attention module. In IEEE European Conference on Computer Vision (ECCV), pp. 3–19. Cited by: TABLE XI, TABLE XII, Fig. 1, §I, §I, §I, §II, §IV-D1, §IV-D2, TABLE V, TABLE VI.
  • [39] S. Xie, R. B. Girshick, P. Dollár, Z. Tu, and K. He (2017) Aggregated residual transformations for deep neural networks. In IEEE Conference on Computer Vision and Pattern Recognition CVPR, pp. 5987–5995. Cited by: TABLE XI, TABLE XII, §I, §II, §IV-A, §IV-B, TABLE I, TABLE II.
  • [40] S. Zagoruyko and N. Komodakis (2016) Wide residual networks. In Proceedings of the British Machine Vision Conference (BMVC), Cited by: TABLE XI, TABLE XII, §I, §II, §IV-A, §IV-B, TABLE I, TABLE II.
  • [41] H. Zhang, I. J. Goodfellow, D. N. Metaxas, and A. Odena (2019)

    Self-attention generative adversarial networks

    .
    In Proceedings of the 36th International Conference on Machine Learning ICML, pp. 7354–7363. Cited by: §III-B3.