MetaFormer is Actually What You Need for Vision

11/22/2021 ∙ by Weihao Yu, et al. ∙ National University of Singapore 20

Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1 vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3 with 35 PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design. Code is available at https://github.com/sail-sg/poolformer

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Transformers have gained much interest and success in the computer vision field [double_attention, stand_alone_attention, vaswani2021scaling, detr]. Since the seminal work of vision transformer (ViT) [vit] that adapts pure transformers to image classification tasks, many follow-up models are developed to make further improvements and achieve promising performance in various computer vision tasks [deit, t2t, swin].

The transformer encoder, as shown in Figure 1(a), consists of two components. One is the attention module for mixing information among tokens and we term it as token mixer

. The other component contains the remaining modules, such as channel MLPs and residual connections. By regarding the attention module as a specific token mixer, we further abstract the overall transformer into a general architecture

MetaFormer where the token mixer is not specified, as shown in Figure 1(a).

The success of transformers has been long attributed to the attention-based token mixer [transformer]. Based on this common belief, many variants of the attention modules [convit, pvt, refiner, tnt] have been developed to improve the vision transformer. However, a very recent work [mlp-mixer] replaces the attention module completely with spatial MLPs as token mixers, and finds the derived MLP-like model can readily attain competitive performance on image classification benchmarks. The follow-up works [resmlp, gmlp, vip] further improve MLP-like models by data-efficient training and specific MLP module design, gradually narrowing the performance gap to ViT and challenging the dominance of attention as token mixers.

Some recent approaches [fnet, infinite_former, gfnet, continus_attention] explore other types of token mixers within the MetaFormer architecture, and have demonstrated encouraging performance. For example, [fnet]

replaces attention with Fourier Transform and still achieves around 97% of the accuracy of vanilla transformers. Taking all these results together, it seems as long as a model adopts MetaFormer as the general architecture, promising results could be attained. We thus hypothesize

compared with specific token mixers, MetaFormer is more essential for the model to achieve competitive performance.

To verify this hypothesis, we apply an extremely simple non-parametric operator, pooling, as the token mixer to conduct the most basic token mixing. Astonishingly, this derived model, termed PoolFormer, achieves competitive performance, and even consistently outperforms well-tuned transformer and MLP-like models, including DeiT [deit] and ResMLP [resmlp], as shown in Figure 1(b). More specifically, PoolFormer-M36 achieves 82.1% top-1 accuracy on ImageNet-1K classification benchmark, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. These results demonstrate that MetaFormer, even with a naive token mixer, can still deliver promising performance. We thus argue that MetaFormer is our de facto need for vision models that deserved more future research.

The contributions of our paper are two-fold. Firstly, we abstract transformers into a general architecture MetaFormer, and empirically demonstrate that the success of transformer/MLP-like models is largely attributed to the MetaFormer architecture. Specifically, by only employing a simple non-parametric operator, pooling, as an extremely weak token mixer for MetaFormer, we build a simple model named PoolFormer and find it can still achieves highly competitive performance. We hope our findings inspire more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Secondly, we evaluate the proposed PoolFormer on multiple vision tasks including image classification [imagenet], object detection [coco], instance segmentation [coco], and semantic segmentation [ade20k], and find it achieves competitive performance compared with the SOTA models using sophistic design of token mixers. The PoolFormer can readily serve as a good starting baseline for future MetaFormer architecture design.

2 Related work

Transformers are first proposed by [transformer] for translation tasks and then rapidly become popular in various NLP tasks. In language pre-training tasks, transformers are trained on large-scale unlabeled text corpus and achieve amazing performance [bert, gpt3]. Inspired by the success of transformers in NLP, many researchers apply attention mechanism and transformers to vision tasks [double_attention, stand_alone_attention, vaswani2021scaling, detr]. Notably, Chen et al. introduce iGPT [igpt]

where the transformer is trained to auto-regressively predict pixels on images for self-supervised learning. Dosovitskiy

et al. propose vision transformer (ViT) with hard patch embedding as input[vit]. They show that on supervised image classification tasks, a ViT pre-trained on a large propriety dataset (JFT dataset with 300 million images) can achieve excellent performance. DeiT [deit] and T2T-ViT [t2t] further demonstrate that the ViT pre-trained on only ImageNet-1K ( 1.3 million images) from scratch can achieve promising performance. A lot of works have been focusing on improving the token mixing approach of transformers by shifted windows [swin], relative position encoding [wu2021rethinking], refining attention map [refiner], or incorporating convolution [guo2021cmt, wu2021cvt, d2021convit], etc. In addition to attention-like token mixers, [mlp-mixer, resmlp] surprisingly find that merely adopting MLPs as token mixers can still achieve competitive performance. This discovery challenges the dominance of attention-based token mixers and triggers a heated discussion in the research community about which token mixer is better [vip, chen2021cyclemlp]. However, the target of this work is neither to be engaged in this debate nor to design new complicated token mixers to achieve new state of the art. Instead, we examine a fundamental question: What is truly responsible for the success of the transformers and their variants? Our answer is the general architecture i.e., MetaFormer. We simply utilize pooling as basic token mixers to probe the power of MetaFormer.

Contemporarily, some works contribute to answering the same question. Dong et al. prove that without residual connections or MLPs, the output converges doubly exponentially to a rank-1 matrix [dong2021attention]. Raghu et al. [raghu2021vision] compare the feature difference between ViT and CNNs, finding that self-attention enables early aggregation of global information while residual connections strongly propagate features from lower to higher layers. Unfortunately, they neither abstract transformers into a general architecture nor give an explicit explanation of the origin of transformers’ power.

3 Method

Figure 2: (a) The overall framework of PoolFormer. Similar to [resnet, pvt, swin], PoolFormer adopts hierarchical architecture with 4 stages. For a model with L PoolFormer blocks, stage [1, 2, 3, 4] have [L/6, L/6, L/2, L/6] blocks, respectively. The feature dimension of stage is shown in the figure. (b) The architecture of PoolFormer block. Compared with transformer block, it replaces attention with extremely simple non-parametric operator, pooling, to conduct only basic token mixing.
import torch.nn as nn
class Pooling(nn.Module):
    def __init__(self, pool_size=3):
        super().__init__()
        self.pool = nn.AvgPool2d(
            pool_size, stride=1,
            padding=pool_size//2,
            count_include_pad=False,
        )
    def forward(self, x):
        # [B, C, H, W] = x.shape
        return self.pool(x) - x
Algorithm 1

Pooling for PoolFormer, PyTorch-like Code

3.1 MetaFormer

We present the core concept “MetaFormer” for this work at first. As shown in Figure 1, abstracted from transformers [transformer], MetaFormer is a general architecture where the token mixer is not specified while the other components are kept the same as transformers. The input is first processed by input embedding, such as patch embedding for ViTs [vit],

(1)

where denotes the embedding tokens with sequence length and embedding dimension .

Then, embedding tokens are fed to repeated MetaFormer blocks, each of which includes two residual sub-blocks. Specifically, the first sub-block mainly contains a token mixer to communicate information among tokens and this sub-block can be expressed as

(2)

where denotes the normalization such as Layer Normalization [layer_norm]

or Batch Normalization

[batch_norm]; means a module mainly working for mixing token information. It is implemented by various attention mechanism in recent vision transformer models [vit, refiner, t2t] or spatial MLP in MLP-like models [mlp-mixer, resmlp]. Note that the main function of the token mixer is to propagate token information although some token mixers can also mix channels, like attention.

The second sub-block primarily consists of a two-layered MLP with non-linear activation,

(3)

where and are learnable parameters with MLP expansion ratio ;

is a non-linear activation function, such as GELU

[gelu]

or ReLU

[relu].

Instantiations of MetaFormer. MetaFormer describes a general architecture with which different models can be obtained immediately by specifying the concrete design of the token mixers. As shown in Figure 1(a), if the token mixer is specified as attention or spatial MLP, MetaFormer then becomes a transformer or MLP-like model respectively.

3.2 PoolFormer

From the introduction of transformers [transformer], lots of works attach much importance to the attention and focus on designing various attention-based token mixer components. In contrast, these works pay little attention to the general architecture, i.e., the MetaFormer.

In this work, we argue that this MetaFormer general architecture contributes mostly to the success of the recent transformer and MLP-like models. To demonstrate it, we deliberately employ an embarrassingly simple operator, pooling, as the token mixer. This operator has no learnable parameters and it just makes each token averagely aggregate its nearby token features.

Since this work is targeted at vision tasks, we assume the input is in channel-first data format, i.e., . The pooling operator can be expressed as

(4)

where is the pooling size. Since the MetaFormer block already has a residual connection, subtraction of the input itself is added in Equation (4). The PyTorch-like code of the pooling is shown in Algorithm 1.

As well known, self-attention and spatial MLP has computational complexity quadratic to the number of tokens to mix. Even worse, spatial MLPs bring much more parameters when handling longer sequences. As a result, self-attention and spatial MLPs usually can only process hundreds of tokens. In contrast, the pooling needs a computational complexity linear to the sequence length without any learnable parameters. Thus, we take advantage of pooling by adopting a hierarchical structure similar to traditional CNNs [alexnet, vgg, resnet] and recent hierarchical transformer variants [swin, pvt]. Figure 2 shows the overall framework of PoolFormer. Specifically, PoolFormer has 4 stages with , , , and tokens respectively, where and represent the width and height of the input image. There are two groups of embedding size: 1) small-sized models with embedding dimensions of 64, 128, 320, and 512 responding to the four stages; 2) medium-sized models with embedding dimensions 96, 192, 384, and 768. Assuming there are PoolFormer blocks in total, the stages 1, 2, 3, and 4 will contain , , , and PoolFormer blocks respectively. The MLP expansion ratio is set as 4. According to the above simple model scaling rule, we obtain 5 different model sizes of PoolFormer and their hyper-parameters are shown in Table 1.

Stage # Tokens Layer Specification PoolFormer
S12 S24 S36 M36 M48
1

Patch
Embedding
Patch Size , stride
Embed. Dim.
PoolFormer
Block
Pooling Size , stride 1
MLP Ratio 4
# Block 2 4 6 6 8
2

Patch
Embedding
Patch Size , stride
Embed. Dim.
PoolFormer
Block
Pooling Size , stride 1
MLP Ratio 4
# Block 2 4 6 6 8
3

Patch
Embedding
Patch Size , stride
Embed. Dim.
PoolFormer
Block
Pooling Size , stride 1
MLP Ratio 4
# Block 6 12 18 18 24
4

Patch
Embedding
Patch Size , stride
Embed. Dim.
PoolFormer
Block
Pooling Size , stride 1
MLP Ratio 4
# Block 2 4 6 6 8
Parameters (M) 11.9 21.4 30.8 56.1 73.4
MACs (G) 2.0 3.6 5.2 9.1 11.9
Table 1: Configurations of different PoolFormer models. There are two groups of embedding dimensions, i.e., small size with [64, 128, 320, 512] dimensions and medium size with [96, 196, 384, 768]. Notation “S24” means the model is in small size of embedding dimensions with 24 PoolFormer blocks in total.

4 Experiments

Figure 3: ImageNet-1K validation accuracy vs. MACs/Model Size.
General Arch. Token Mixer Outcome Model Image Size Params (M) MACs (G) Top-1 (%)
Convolutional
Neural Netowrks
ResNet-50 [resnet] 224 26 4.1 76.2
ResNet-101 [resnet] 224 45 7.9 77.4
ResNet-152 [resnet] 224 60 11.6 78.3
RegNetY-4GF [regnet] 224 21 4.0 80.0
RegNetY-8GF [regnet] 224 39 8.0 81.7
MetaFormer Attention ViT-B/16 [vit] 224 86 17.6 79.7
ViT-L/16 [vit] 224 307 63.6 76.1
DeiT-S [deit] 224 22 4.6 79.8
DeiT-B [deit] 224 86 17.5 81.8
PVT-Tiny [pvt] 224 13 1.9 75.1
PVT-Small [pvt] 224 25 3.8 79.8
PVT-Medium [pvt] 224 44 6.7 81.2
PVT-Large [pvt] 224 61 9.8 81.7
Spatial MLP MLP-Mixer-B/16 [mlp-mixer] 224 59 12.7 76.4
ResMLP-S12 [resmlp] 224 15 3.0 76.6
ResMLP-S24 [resmlp] 224 30 6.0 79.4
ResMLP-B24 [resmlp] 224 116 23.0 81.0
Swin-Mixer-T/D24 [swin] 256 20 4.0 79.4
Swin-Mixer-T/D6 [swin] 256 23 4.0 79.7
Swin-Mixer-B/D24 [swin] 224 61 10.4 81.3
gMLP-S [gmlp] 224 20 4.5 79.6
gMLP-B [gmlp] 224 73 15.8 81.6
Pooling PoolFormer-S12 224 12 2.0 77.2
PoolFormer-S24 224 21 3.6 80.3
PoolFormer-S36 224 31 5.2 81.4
PoolFormer-M36 224 56 9.1 82.1
PoolFormer-M48 224 73 11.9 82.5
Table 2: Performance of different types of models on ImageNet-1K classification. All these models are only trained on the ImageNet-1K training set and the accuracy on the validation set is reported. * denotes results of ViT trained with with extra regularization from [mlp-mixer].

4.1 Image classification

Setup. ImageNet-1K [imagenet] is one of the most widely used datasets in computer vision. It contains about 1.3M training images and 50K validation images, covering common 1K classes. Our training scheme mainly follows [deit] and [cait]. Specifically, MixUp [mixup], CutMix [cutmix], CutOut [cutout] and RandAugment [randaugment]

are used for data augmentation. The models are trained by 300 epochs using AdamW optimizer

[adam, adamw] with weight decay and peak learning rate (batch size 4096 and learning rate are used in this paper). The number of warmup epochs is 5 and cosine schedule is used to decay the learning rate. Label Smoothing [label_smoothing] is set as 0.1. Dropout is disabled but stochastic depth [stochastic_depth] and LayerScale [cait] are used to help train deep models. Group Normalization (group number is set as 1 for simplicity) is adopted since it is preferred by PoolFormer as shown in Section 4.4. See the appendix for more details of hyper-parameters. Our implementation is based on the Timm codebase [timm] and the experiments are run on TPUs.

Results. Table 2 shows the performance of PoolFormers on ImageNet classification. Surprisingly, despite the simple pooling token mixer, PoolFormers can still achieve highly competitive performance compared with CNNs and other MetaFormer-like models. For example, PoolFormer-S24 reaches the top-1 accuracy of more than 80 while only requiring 21M parameters and 3.6G MACs. Comparatively, the well-established ViT baseline DeiT-S [deit], attains slightly worse accuracy of 79.8 but requires 28% more MACs (4.6G). To obtain similar accuracy, MLP-like model ResMLP-S24 [resmlp] needs 43% more parameters (30M) as well as 67% more computation (6.0G) while only 79.4 accuracy is attained. Even compared with more improved ViT and MLP-like variants [pvt, gmlp]

, PoolFormer still shows better performance. Specifically, the pyramid transformer PVT-Medium obtains 81.2 top-1 accuracy with 44M parameters and 6.7 MACs while PoolFormer-S36 reaches 81.5 with 30% fewer parameters (31M) and 22% fewer MACs (5.2G) than those of PVT-Medium. Besides, compared with strong convolutional neural networks RegNet

[regnet], PoolFormer is still not undefeated. With about the same 21M parameters, RegNetY-4GF [regnet] gets 80.0 accuracy with 4.0G MACs while PoolFormer-S24 can obtain 80.3 accuracy with only 3.6G MACs.

With the pooling operator, each token aggregates the features from its nearby tokens averagely. Thus it is the most basic token mixing operation. However, the experiment results show that even with this extremely simple token mixer, MetaFormer still obtains highly competitive performance. Figure 3 clearly shows that PoolFormer surpasses other models with fewer MACs and parameters. This finding conveys that the general architecture MetaFormer is actually what we need when designing vision models. By adopting MetaFormer, it is guaranteed that the derived models would have the potential to achieve reasonable performance.

Model Params (M) AP AP AP AP AP AP
ResNet-18 [resnet] 21.3 31.8 49.6 33.6 16.3 34.3 43.2
PoolFormer-S12 21.7 36.2 56.2 38.2 20.8 39.1 48.0
ResNet-50 [resnet] 37.7 36.3 55.3 38.6 19.3 40.0 48.8
PoolFormer-S24 31.1 38.9 59.7 41.3 23.3 42.1 51.8
ResNet-101 [resnet] 56.7 38.5 57.8 41.2 21.4 42.6 51.1
PoolFormer-S36 40.6 39.5 60.5 41.8 22.5 42.9 52.4
Table 3: Performance of object detection on COCO val2017 [coco]. All models are based on RetinaNet and training schedule (i.e.12 epochs) is used for training detection models.
Model Params (M) AP AP AP AP AP AP
ResNet-18 [resnet] 31.2 34.0 54.0 36.7 31.2 51.0 32.7
PoolFormer-S12 31.6 37.3 59.0 40.1 34.6 55.8 36.9
ResNet-50 [resnet] 44.2 38.0 58.6 41.4 34.4 55.1 36.7
PoolFormer-S24 41.0 40.1 62.2 43.4 37.0 59.1 39.6
ResNet-101 [resnet] 63.2 40.4 61.1 44.2 36.4 57.7 38.8
PoolFormer-S36 50.5 41.0 63.1 44.8 37.7 60.1 40.0
Table 4: Performance of object detection and instance segmentation on COCO val2017 [coco]. and represent bounding box AP and mask AP, respectively. All models are based on Mask R-CNN and trained by training schedule (i.e.12 epochs).
Model Params (M) mIoU (%)
ResNet-18 [resnet] 15.5 32.9
PVT-Tiny [pvt] 17.0 35.7
PoolFormer-S12 15.7 37.2
ResNet-50 [resnet] 28.5 36.7
PVT-Small [pvt] 28.2 39.8
PoolFormer-S24 23.2 40.3
ResNet-101 [resnet] 47.5 38.8
ResNeXt-101-32x4d [xie2017aggregated] 47.1 39.7
PVT-Medium [pvt] 48.0 41.6
PoolFormer-S36 34.6 42.0
PVT-Large [pvt] 65.1 42.1
PoolFormer-M36 59.8 42.4
ResNeXt-101-64x4d [xie2017aggregated] 86.4 40.2
PoolFormer-M48 77.1 42.7
Table 5: Performance of Semantic segmentation on ADE20K [ade20k] validation set. All models are equipped with Semantic FPN [fpn].

4.2 Object detection and instance segmentation

Setup. We evaluate PoolFormer on the challenging COCO benchmark [coco] that includes 118K training images (train2017) and 5K validation images (val2017). The models are trained on training set and the performance on validation set is reported. PoolFormer is employed as the backbone for two standard detectors, i.e., RetinaNet [retinanet] and Mask R-CNN [mask_rcnn]. ImageNet pre-trained weights are utilized to initialize the backbones and Xavier [glorot2010understanding] to initialize the added layers. AdamW [adam, adamw] is adopted for training with an initial learning rate of and batch size of 16. Following [retinanet, mask_rcnn], we employ 1 training schedule, i.e., training the detection models for 12 epochs. The training images are resized into shorter side of 800 pixels and longer side of no more than 1,333 pixels. For testing, the shorter side of the images is also resized to 800 pixels. The implementation is based on the mmdetection [mmdetection] codebase and the experiments are run on 8 NVIDIA A100 GPUs.

Results. Equipped with RetinaNet for object detection, PoolFormer-based models consistently outperform their comparable ResNet counterparts as shown in Table 3. For instance, PoolFormer-S12 achieves 36.2 AP, largely surpassing that of ResNet-18 (31.8 AP). Similar results are observed for those models based on Mask R-CNN on object detection and instance segmentation. For example, PoolFormer-S12 largely surpasses ResNet-18 (bounding box AP 37.3 vs. 34.0, and mask AP 34.6 vs. 31.2). Overall, for COCO object detection and instance segmentation, PoolForemrs achieve competitive performance, consistently outperforming those counterparts of ResNet.

Ablation Variant Params (M) MACs (G) Top-1 (%)
Baseline None (PoolFormer-S12) 11.9 2.0 77.2
Polling Pooling Identity mapping 11.9 2.0 74.3
Pooling size 3 5 11.9 2.0 77.2
Pooling size 3 7 11.9 2.0 77.1
Pooling size 3 9 11.9 2.0 76.8
Normalization Group Normalization [group_norm] Layer Normalization [layer_norm] 11.9 2.0 76.5
Group Normalization [group_norm] Batch Normalization [batch_norm] 11.9 2.0 76.4
Activation GELU [gelu] ReLU [relu] 11.9 2.0 76.4
GELU SiLU [silu] 11.9 2.0 77.2
Hybrid Stages [Pool, Pool, Pool, Pool] [Pool, Pool, Pool, Attention] 14.0 2.1 78.3
[Pool, Pool, Pool, Pool] [Pool, Pool, Attention, Attention] 16.5 2.7 81.0
[Pool, Pool, Pool, Pool] [Pool, Pool, Pool, SpatialFC] 11.9 2.0 77.5
[Pool, Pool, Pool, Pool] [Pool, Pool, SpatialFC, SpatialFC] 12.2 2.1 77.9
Table 6: Ablation for PoolFormer on ImageNet-1K classification benchmark. PoolFormer-S12 is utilized as the baseline to conduct ablation study. The top-1 accuracy on the validation set is reported.

4.3 Semantic segmentation

Setup. ADE20K [ade20k], a challenging scene parsing benchmark, is selected to evaluate the models for semantic segmentation. The dataset includes 20K and 2K images in the training and validation set, respectively, covering 150 fine-grained semantic categories. PoolFormers are evaluated as backbones equipped with Semantic FPN [fpn]. ImageNet-1K trained checkpoints are used to initialize the backbones while Xavier [glorot2010understanding] is utilized to initialize other newly added layers. Common practices [fpn, chen2017deeplab] train models for 80K iterations with a batch size of 16. To speed up training, we double batch size to 32 and decrease the iteration number to 40K. The AdamW [adam, adamw] is employed with an initial learning rate of that will decay in the polynomial decay schedule with a power of 0.9. Images are resized and cropped into for training and are resized to shorter side of 512 pixels for testing. Our implementation is based on the mmsegmentation [mmseg2020] codebase and the experiments are conducted on 8 NVIDIA A100 GPUs.

Results. Table 5 shows the ADE20K semantic segmentation performance of different backbones using FPN [fpn]. PoolFormer-based models consistently outperform the models with backbones of CNN-based ResNet [resnet] and ResNeXt [xie2017aggregated] as well as transformer-based PVT. For instance, PoolFormer-12 achieves mIoU of 37.1, 4.3 and 1.5 better than ResNet-18 and PVT-Tiny, respectively.

These results demonstrate that our PoorFormer that serves as backbone can attain competitive performance on semantic segmentation although it only utilizes pooling for basically communicating information among tokens. This further indicates the great potential of MetaFormer and supports our claim that MetaFormer is actually what we need.

4.4 Ablation studies

The experiments of ablation studies are conducted on ImageNet-1K [imagenet]. Table 6 reports the ablation study of PoolFormer. We discuss the ablation below according to the following aspects.

Pooling. Compared with transformers, the main change made by PoolFormer is using simple pooling as a token mixer. We first conduct ablation for this operator by directly replacing pooling with identity mapping. Surprisingly, MetaFormer with identity mapping can still achieve 74.3% top-1 accuracy, supporting the claim that MetaFormer is actually what we need to guarantee reasonable performance.

We test the effects of pooling size to PoolFormer. We observe similar performance when pooling sizes are 3, 5 and 7. However, when the pooling size increases to 9, there is an obvious performance drop of 0.5%. Thus, we adopt the default pooing size of 3 for PoolFormer.

Normalization. Three types of normalization are employed for PoolFormer, i.e., Group Normalization [group_norm] (group number is set as 1 for simplicity), Layer Normalization [layer_norm] and Batch Normalization [batch_norm]. We find PoolFormer prefers Group Normalization with 0.7% or 0.8% higher than Layer Normalization or Batch Normalization. Thus, Group Normalization is set as default for PoolFormer.

Activation. We change GELU [gelu] to ReLU [relu] or SiLU [silu]. When ReLU is adopted for activation, an obvious performance drop of 0.8 % is observed. For SiLU, its performance is almost the same as that of GELU. Thus, we still adopt GELU as default activation.

Hybrid stages. Among token mixers based on pooling, attention, and spatial MLP, the pooling-based one can handle much longer input sequences while attention and spatial MLP are good at capturing global information. Therefore, it is intuitive to stack MetaFormers with pooling in the bottom stages to handle long sequences and use attention or spatial MLP-based mixer in the top stages, considering the sequences have been largely shortened. Thus, we replace the token mixer pooling with attention or spatial FC 111Following [resmlp], we use only one spatial fully connected layer as a token mixer, so we call it FC. in the top one or two stages in PoolFormer. From Table 6, the hybrid models perform quite well. The variant with pooling in the bottom two stages and attention in the top two stages delivers highly competitive performance. It achieves 81.0% accuracy with only 16.5M parameters and 2.7G MACs. As a comparison, ResMLP-B24 needs parameters (116M) and MACs (23.0G) to achieve the same accuracy. These results indicate that combining pooling with other token mixers for MetaFormer may be a promising direction to further improve the performance.

5 Conclusion and future work

In this work, we abstracted the attention in transformers as a token mixer, and the overall transformer as a general architecture termed MetaFormer where the token mixer is not specified. Instead of focusing on specific token mixers, we point out that MetaFormer is actually what we need to guarantee achieving reasonable performance. To verify this, we deliberately specify token mixer as extremely simple pooling for MetaFormer. It is found that the derived PoolFormer model can achieve competitive performance on different vision tasks, which well supports that “MetaFormer is actually what we need for vision”.

In the future, we will further evaluate PoolFormer under more different learning settings, such as self-supervised learning and transfer learning. Moreover, it is interesting to see whether PoolFormer still works on NLP tasks to further support the claim “MetaFormer is actually what you need” in the NLP domain. We hope that this work can inspire more future research devoted to improving the fundamental architecture MetaFormer instead of paying too much on the token mixer modules.

Acknowledgement

The authors would like to thank Quanhong Fu at Sea AI Lab for the help to improve the technical writing aspect of this paper. Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources.

References

Appendix A Detailed hyper-parameters on ImageNet-1K

PoolFormer. On ImageNet-1K classification benchmark, we utilize the hyper-parameters shown in Table 7 to train most models in our paper. According to the relation between batch size and learning rate in Table 7, we set the batch size as 4096 and learning rate as in this paper to run the experiments. For stochastic depth, following the original paper [stochastic_depth]

, we linearly increase the probability of dropping a layer from 0.0 for the bottom block to the peak drop

for the top block.

Hybrid Models. We use the hyper-parameters for most models except the hybrid models with token mixers of pooling and attention. For these hybrid models, we find it can achieve much better performance by setting batch size as 1024, learning rate as and normalization as Layer Normalization [layer_norm].

Appendix B Model definition in PyTorch

We provide the PyTorch-like code in Algorithm 2 associated with the modules used in the PoolFormer block. Algorithm 3 further shows the PoolFormer block built with these modules.

PoolFormer
S12 S24 S36 M36 M48
Peak drop rate of stoch. depth 0.1 0.1 0.2 0.3 0.4
LayerScale initialization
Data augmentation AutoAugment
Repeated Augmentation off
Input resolution 224
Epochs 300
Warmup epochs 5
Hidden dropout 0
GeLU dropout 0
Attention dropout (if applicable) 0
Classification dropout 0
Random erasing prob 0.25
EMA decay 0
Cutmix 1.0
Mixup 0.8
Cutmix-Mixup switch prob 0.5
Label smoothing 0.1
Relation between peak learning
  rate and batch size
Batch size used in the paper 4096
Peak learning rate used in the paper
Learning rate decay cosine
Optimizer AdamW
Adam 1e-8
Adam (0.9, 0.999)
Weight decay 0.05
Gradient clipping None
Table 7: Hyper-parameters for Image classification on ImageNet-1K
import torch.nn as nn
class GroupNorm(nn.GroupNorm):
    """
␣␣␣␣GroupNormalizationwith1group.
␣␣␣␣Input:tensorinshape[B,C,H,W]
␣␣␣␣"""
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module):
    """
␣␣␣␣ImplementationofpoolingforPooFormer
␣␣␣␣--pool_size:poolingsize
␣␣␣␣"""
    def __init__(self, pool_size=3):
        super().__init__()
        self.pool = nn.AvgPool2d(
            pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
    def forward(self, x):
        return self.pool(x) - x
class Mlp(nn.Module):
    """
␣␣␣␣ImplementationofMLPwith1*1convolutions.
␣␣␣␣Input:tensorwithshape[B,C,H,W]
␣␣␣␣"""
    def __init__(self, in_features, hidden_features=None,
                 out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
Algorithm 2 Modules for PoolFormer block, PyTorch-like Code
import torch.nn as nn
class PoolFormerBlock(nn.Module):
    """
␣␣␣␣ImplementationofonePoolFormerblock.
␣␣␣␣--dim:embeddingdim
␣␣␣␣--pool_size:poolingsize
␣␣␣␣--mlp_ratio:mlpexpansionratio
␣␣␣␣--act_layer:activation
␣␣␣␣--norm_layer:normalization
␣␣␣␣--drop:dropoutrate
␣␣␣␣--droppath:StochasticDepth,
␣␣␣␣␣␣␣␣refertohttps://arxiv.org/abs/1603.09382
␣␣␣␣--use_layer_scale,--layer_scale_init_value:LayerScale,
␣␣␣␣␣␣␣␣refertohttps://arxiv.org/abs/2012.12877
␣␣␣␣"""
    def __init__(self, dim, pool_size=3, mlp_ratio=4.,
                 act_layer=nn.GELU, norm_layer=GroupNorm,
                 drop=0., drop_path=0.,
                 use_layer_scale=True, layer_scale_init_value=1e-5):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.token_mixer = Pooling(pool_size=pool_size)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)
        # The following two techniques are useful to train deep PoolFormers.
        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.layer_scale_1 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.layer_scale_2 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim)), requires_grad=True)
    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(
                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
                * self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(
                self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
                * self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
Algorithm 3 PoolFormer block, PyTorch-like Code