Better plain ViT baselines for ImageNet-1k

05/03/2022
by   Lucas Beyer, et al.
0

It is commonly accepted that the Vision Transformer model requires sophisticated regularization techniques to excel at ImageNet-1k scale data. Surprisingly, we find this is not the case and standard data augmentation is sufficient. This note presents a few minor modifications to the original Vision Transformer (ViT) vanilla training setting that dramatically improve the performance of plain ViT models. Notably, 90 epochs of training surpass 76 top-1 accuracy in under seven hours on a TPUv3-8, similar to the classic ResNet50 baseline, and 300 epochs of training reach 80

READ FULL TEXT VIEW PDF
02/23/2021

Do Transformer Modifications Transfer Across Implementations and Applications?

The research community has proposed copious modifications to the Transfo...
06/18/2021

How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers

Vision Transformers (ViT) have been shown to attain highly competitive p...
06/09/2021

Grounding inductive biases in natural images:invariance stems from variations in data

To perform well on unseen and potentially out-of-distribution samples, i...
11/30/2021

Pyramid Adversarial Training Improves ViT Performance

Aggressive data augmentation is a key component of the strong generaliza...
02/05/2021

PipeTransformer: Automated Elastic Pipelining for Distributed Training of Transformers

The size of Transformer models is growing at an unprecedented pace. It h...
10/01/2021

ResNet strikes back: An improved training procedure in timm

The influential Residual Networks designed by He et al. remain the gold-...

1 Introduction

The ViT paper [4] focused solely on the aspect of large-scale pre-training, where ViT models outshine well tuned ResNet [6] (BiT [8]

) models. The addition of results when pre-training only on ImageNet-1k was an afterthought, mostly to ablate the effect of data scale. Nevertheless, ImageNet-1k remains a key testbed in the computer vision research and it is highly beneficial to have as simple and effective a baseline as possible.

Thus, coupled with the release of the big vision codebase used to develop ViT [4], MLP-Mixer [14], ViT-G [19], LiT [20], and a variety of other research projects, we now provide a new baseline that stays true to the original ViT’s simplicity while reaching results competitive with similar approaches [15, 17] and concurrent [16], which also strives for simplification.

Figure 1: Comparison of ViT model for this note to state-of-the-art ViT and ResNet models. Left plot demonstrates how performance depends on the total number of epochs, while the right plot uses TPUv3-8 wallclock time to measure compute. We observe that our simple setting is highly competitive, even to the canonical ResNet-50 setups.

2 Experimental setup

We focus entirely on the ImageNet-1k dataset (ILSVRC-2012) for both (pre)training and evaluation. We stick to the original ViT model architecture due to its widespread acceptance [15, 2, 5, 1, 9], simplicity and scalability, and revisit only few very minor details, none of which are novel. We choose to focus on the smaller ViT-S/16 variant introduced by [15] as we believe it provides a good tradeoff between iteration velocity with commonly available hardware and final accuracy. However, when more compute and data is available, we highly recommend iterating with ViT-B/32 or ViT-B/16 instead [12, 19], and note that increasing patch-size is almost equivalent to reducing image resolution.

All experiments use “inception crop” [13] at 224px² resolution, random horizontal flips, RandAugment [3], and Mixup augmentations. We train on the first 99% of the training data, and keep 1% for minival to encourage the community to stop selecting design choices on the validation (de-facto test) set. The full setup is shown in Appendix A.

3 Results

The results for our improved setup are shown in Figure 1, along with a few related important baselines. It is clear that a simple, standard ViT trained this way can match both the seminal ResNet50 at 90 epochs baseline, as well as more modern ResNet [17] and ViT [16] training setups. Furthermore, on a small TPUv3-8 node, the 90 epoch run takes only 6h30, and one can reach 80% accuracy in less than a day when training for 300 epochs.

The main differences from [4, 12] are a batch-size of 1024 instead of 4096, the use of global average-pooling (GAP) instead of a class token [2, 11], fixed 2D sin-cos position embeddings [2], and the introduction of a small amount of RandAugment [3] and Mixup [21]

(level 10 and probability 0.2 respectively, which is less than 

[12]). These small changes lead to significantly better performance than that originally reported in [4].

Notably absent from this baseline are further architectural changes, regularizers such as dropout or stochastic depth [7], advanced optimization schemes such as SAM [10], extra augmentations such as CutMix [18], repeated augmentations [15], or blurring, “tricks” such as high-resolution fine-tuning or checkpoint averaging, as well as supervision from a strong teacher via knowledge distillation.

Table 1 shows an ablation of the various minor changes we propose. It exemplifies how a collection of almost trivial changes can accumulate to an important overall improvement. The only change which makes no significant difference in classification accuracy is whether the classification head is a single linear layer, or an MLP with one hidden layer as in the original Transformer formulation.

90ep 150ep 300ep
Our improvements 76.5 78.5 80.0
no RandAug+MixUp 73.6 73.7 73.7
Posemb: sincos2d learned 75.0 78.0 79.6
Batch-size: 1024 4096 74.7 77.3 78.6
Global Avgpool [cls] token 75.0 76.9 78.2
Head: MLP linear 76.7 78.6 79.8
Original + RandAug + MixUp 71.6 74.8 76.1
Original 66.8 67.2 67.1
Table 1: Ablation of our trivial modifications.
Top-1 ReaL v2
Original (90ep) 66.8 72.8 52.2
Our improvements (90ep) 76.5 83.1 64.2
Our improvements (150ep) 78.5 84.5 66.4
Our improvements (300ep) 80.0 85.4 68.3
Table 2: A few more standard metrics.

4 Conclusion

It is always worth striving for simplicity.

Acknowledgements. We thank Daniel Suo and Naman Agarwal for nudging for 90 epochs and feedback on the report, as well as the Google Brain team for a supportive research environment.

References

Appendix A big_vision experiment configuration

1def get_config():
2  config = mlc.ConfigDict()
3
4  config.dataset = ’imagenet2012’
5  config.train_split = ’train[:99%]’
6  config.cache_raw = True
7  config.shuffle_buffer_size = 250_000
8  config.num_classes = 1000
9  config.loss = ’softmax_xent’
10  config.batch_size = 1024
11  config.num_epochs = 90
12
13  pp_common = (
14      ’|value_range(-1, 1)’
15      ’|onehot(1000, key="{lbl}", key_result="labels")’
16      ’|keep("image", "labels")’
17  )
18  config.pp_train = (
19      ’decode_jpeg_and_inception_crop(224)’ +
20      ’|flip_lr|randaug(2,10)’ +
21      pp_common.format(lbl=’label’)
22  )
23  pp_eval = ’decode|resize_small(256)|central_crop(224)’ + pp_common
24
25  config.log_training_steps = 50
26  config.log_eval_steps = 1000
27  config.checkpoint_steps = 1000
28
29  # Model section
30  config.model_name = ’vit’
31  config.model = dict(
32      variant=’S/16’,
33      rep_size=True,
34      pool_type=’gap’,
35      posemb=’sincos2d’,
36  )
37
38  # Optimizer section
39  config.grad_clip_norm = 1.0
40  config.optax_name = ’scale_by_adam’
41  config.optax = dict(mu_dtype=’bfloat16’)
42  config.lr = 0.001
43  config.wd = 0.0001
44  config.schedule = dict(warmup_steps=10_000, decay_type=’cosine’)
45  config.mixup = dict(p=0.2, fold_in=None)
46
47  # Eval section
48  config.evals = [
49      (’minival’, ’classification’),
50      (’val’, ’classification’),
51      (’real’, ’classification’),
52      (’v2’, ’classification’),
53  ]
54  eval_common = dict(
55      pp_fn=pp_eval.format(lbl=’label’),
56      loss_name=config.loss,
57      log_steps=1000,
58  )
59
60  config.minival = dict(**eval_common)
61  config.minival.dataset = ’imagenet2012’
62  config.minival.split = ’train[99%:]’
63  config.minival.prefix = ’minival_’
64
65  config.val = dict(**eval_common)
66  config.val.dataset = ’imagenet2012’
67  config.val.split = ’validation’
68  config.val.prefix = ’val_’
69
70  config.real = dict(**eval_common)
71  config.real.dataset = ’imagenet2012_real’
72  config.real.split = ’validation’
73  config.real.pp_fn = pp_eval.format(lbl=’real_label’)
74  config.real.prefix = ’real_’
75
76  config.v2 = dict(**eval_common)
77  config.v2.dataset = ’imagenet_v2’
78  config.v2.split = ’test’
79  config.v2.prefix = ’v2_’
80
81  return config
Listing 1: Full recommended config