Deep learning yields strong performance in many vision tasks. However, success usually requires a large amount of task-specific labeled training data. Training also requires a large amount of compute. These per-task data and compute requirements can make solving new tasks prohibitively expensive.
Transfer learning offers a solution: some task-specific data is replaced by a large amount of generic data. To perform transfer, one first pre-trains a network on a large generic dataset. Transfer of knowledge from pre-training to new tasks then usually entails initializing networks for subsequent tasks using the weights learned during pre-training. Good performance can then be attained with fewer downstream labeled examples. Transfer learning can also shift computational requirements upstream. Although the pre-training phase is expensive, downstream tasks often require fewer training iterations [raghu2019, he2019rethinking].
In this paper, we revisit this simple paradigm: pre-train on a large supervised source domain, and fine-tune the weights on the target domain. Many new techniques have been introduced to improve deep network training. Recent examples include: architectures such as EfficientNet [tan2019efficientnet] or ResNeXt [xie2017aggregated]; optimization strategies such as Adam [kingma2014adam] and Cosine Annealing [loshchilov2016sgdr]; Stochastic Weight Averaging [izmailov2018averaging, fast_swa]; regularizers such as CutMix [yun2019cutmix], MixUp [mixup], stochastic depth, and label smoothing [szegedy2016rethinking]; and new normalization layers, such as Instance Normalization [ulyanov2016instance], Layer Normalization [ba2016layer], and Group Normalization [wu2018group]. We distill from these the most effective techniques required to create a single model that can transfer to many tasks. We call such models “Big Transfer” (BiT).
We pre-train BiT-Large (BiT-L) on the JFT-300M dataset [sun2017revisiting], which contains 300 M noisily labelled images. BiT-L attains state-of-the-art performance on many visual classification tasks with training sets ranging from 10 to 1M examples (Figure 1
). These tasks include ImageNet’s ILSVRC-2012[deng2009imagenet], CIFAR-10/100 [cifar10], Oxford-IIIT Pet [parkhi12a], Oxford Flowers-102 [Nilsback08] (including low-data variants), and the 1000-sample VTAB-1k benchmark [zhai2019visual], which itself is composed of 19 diverse datasets. We also train BiT-M on the public ImageNet-21k dataset, and demonstrate large improvements compared to the conventional ILSVRC-2012 pre-training.
Importantly, BiT only needs to be pre-trained once. Subsequent fine-tuning to downstream tasks is cheap. This is in contrast to other state-of-the-art methods that require extensive training on support data conditioned on the task at hand [dat, noisystudent, yalniz2019billion]. Not only does BiT require a short fine-tuning protocol for each new task, but also BiT does not require extensive hyperparameter tuning on new tasks. We present a heuristic for setting the hyperparameters for transfer, which works well on many new tasks.
The goal of this paper is not to introduce a new component, or add further complexity to state-of-the-art deep learning pipelines. Instead, we aim to simplify, and create a pre-train/transfer pipeline which uses a minimal number of necessary tricks to attain very strong performance across a broad spectrum of popular classification tasks. We also study the effect of different components, and provide insight into the interplay between scale, architecture, and training hyperparameters. For practitioners, we plan to release a performant BiT-M model trained on the public ImageNet-21k dataset.
2 The Components of BiT
We review the components that we found necessary to build a pre-trained network that attains good performance across many tasks. These are divided into two groups: upstream — those used during pre-training, and downstream — those used for fine-tuning on a new task.
2.1 Upstream Pre-Training
The first component is scale. It is widely known that large architectures require large datasets to exhibit their benefits, and vice versa. We find that long training schedules are crucial for training with very large data. We train three BiT models on three large datasets: ILSVRC-2012 [imagenet] which contains 1.3M images (BiT-S), ImageNet-21k [deng2009imagenet] which contains 14M images (BiT-M), and JFT [sun2017revisiting] which contains 300M images (BiT-L). We show that a long schedule is crucial to harness the benefits of larger datasets and models, see Section 4.2.
The second component is Group Normalization (GN) [wu2018group]
. Batch Normalization (BN)[ioffe2015batch] is a popular technique used to stabilize training, and is used in most state-of-the-art vision models. However, BN can hurt in transfer learning, likely due to the requirement to update running statistics. We study this empirically in Section 4.3. BN can also be detrimental when training with few images per chip, since batch statistics become too noisy. In that regime, GN, when combined with Weight Standardization (WS), is shown to improve performance on ImageNet and COCO [lin2014microsoft]. We demonstrate that both GN and WS are effective at a larger batch size, and have a significant impact on transfer learning.
2.2 Transfer to Downstream Tasks
|Optimizer||SGD, learning rate: 0.003, momentum: 0.9, batch size: 512|
|Resolution||<: resize(160),crop(128)||>: resize(448),crop(384)|
|MixUp||< samples: NO||> samples: YES|
|Training Steps||< samples: 500||< samples:||> samples:|
We propose a cheap fine-tuning protocol that applies to many diverse tasks, with training set sizes spanning many orders of magnitude. In particular, we avoid expensive hyperparameter search for every new task and dataset size. Our heuristic hyperparameter configuration—which we call BiT-hyperparam—selects the resolution, the use of MixUp [mixup], and the training schedule based on dataset characteristics; see Table 1 for details. With this strategy, BiT attains strong performance across over 20 tasks and training regimes ranging from 1 example per class to large datasets. We give a high-level overview of our choices here and more detailed exposition in Section 3.3.
During fine-tuning, we use limited data pre-processing: we resize the image to a fixed square size, crop out a smaller random square, and randomly horizontally flip the image at training time. At test time, we only resize the image to a fixed size. The only per-task heuristic we apply is that we do not perform random horizontal flipping or cropping for tasks where doing so would destroy the label semantics, such as when predicting object orientations or coordinates in pixel space.
Recent work has shown that existing augmentation methods induce inconsistency between training and test resolutions for ConvNets [fixres]. A common heuristic is scaling up the resolution by a small factor at test time. A better solution, proposed by [fixres], is to introduce an additional step at which the trained model is fine-tuned to the test resolution. This fits well with transfer learning: we include the resolution change during our fine-tuning step.
linearly interpolates between two image samples. The ground truth label of the new sample is given by the linear interpolation of one-hot labels. This technique has been used in many recent works, and similarly we found it can help during transfer.
Finally, we note that we do not use weight decay, neither towards zero, nor towards the pre-trained weights. We also do not use dropout. Despite the fact that the network is very large—BiT-L has about 1 billion parameters—the performance is surprisingly good without needing these techniques and their respective hyperparameters, even when transferring to very small datasets. We find that an appropriately chosen schedule is sufficient. The schedule is automatically chosen based on the number of samples, with larger datasets having longer schedules (see Table 1).
We train three upstream models using three datasets of different scale: BiT-S, BiT-M, BiT-L. We evaluate these models on a wide range of downstream tasks that span high and low data regimes.
3.1 Data for upstream pre-training
BiT-S is trained on the popular ILSVRC-2012 variant of the ImageNet dataset. This dataset contains 1.28 million images and 1000 classes. Each image has a single label, and the labels are organized according to the WordNet hierarchy. BiT-M is trained on the full ImageNet-21k dataset [deng2009imagenet], a publicly available dataset with 14.2 million images and 21k classes also organized using WordNet.
BiT-L is trained on the JFT-300M dataset, as in [sun2017revisiting, dat, noisystudent]. This dataset is a new version of that used in [hinton2015distilling, chollet2017xception]. JFT-300M consists of around 300 million images with 1.26 labels per image on average. The labels are organized into a hierarchy of classes. Annotation is performed using an automatic pipeline, and are therefore imperfect; approximately 20% of the labels are noisy.
3.2 Downstream tasks
We evaluate BiT on standard computer vision benchmarks: ILSVRC-2012[deng2009imagenet], CIFAR-10/100 [cifar10], Oxford-IIIT Pet [parkhi12a] and Oxford Flowers-102 [Nilsback08]. These datasets have a long history and differ in the total number of images, input resolution and nature of their categories, from general object categories in ImageNet and CIFAR to fine-grained ones in Pets and Flowers. We always fine-tune BiT on the official training split and report results on the official test split if publicly available. Otherwise, we use the val split.
To further assess the generality of representations learned by BiT models, we leverage the recently introduced Visual Task Adaptation Benchmark (VTAB) [zhai2019visual] that consists of 19 visual tasks. For each task in VTAB, we have access to 1000 training samples (VTAB-1k variant). These tasks are organized into three groups: natural, specialized and structured. The VTAB-1k score is computed as the top-1 recognition performance averaged over these 19 tasks. The natural group of tasks represents classical tasks that contain natural images captured using standard cameras. The specialized group also contains images captured in the real world, but through specialist equipment, such as satellite or medical images. Finally, the structured
tasks are mostly generated from simulated environments and assess understanding of the the structure of a scene. Example tasks are object counting and 3D depth estimation.
3.3 Hyperparameter Details
3.3.1 Upstream Pre-Training
For all of our models, we use a vanilla ResNet-v2 architecture [he2016identity], except that we replace all Batch Normalization [ioffe2015batch] layers with Group Normalization [wu2018group] and use Weight Standardization [qiao2019weight] in all convolutional layers. This change is analyzed in Section 4.3. The BiT-S and BiT-M models use the ResNet-101 architecture, where every hidden layer is widened by a factor of three (ResNet101x3). To benefit from the larger dataset, BiT-L uses a ResNet-152x4 model, which has 0.93 billion trainable parameters. We explore the coupling between the datasets and the size of the model in Section 4.1.
We train both BiT-S and BiT-M for 90 epochs and decay the learning rate by a factor of 10 at 30, 60 and 80 epochs. For BiT-L, we train for 40 epochs with an initial learning rate of 0.03, with 5000 linear warmup steps and decay the learning rate after 10, 23, 30 and 37 epochs. We use a global batch size of 4096 and train on a Cloud TPUv3-512[jouppi2017datacenter], resulting in 8 images per chip. For optimization with large batch sizes we employ recipes from [goyal2017accurate]. In particular, we use linear learning rate warmup for 5000 optimization steps and scale learning rate multiplicatively by .
|BiT-L||Generalist SOTA||Specialist SOTA|
|ILSVRC-2012||87.76 0.09||86.4 [fixres]||87.4 [noisystudent]|
|CIFAR-10||99.35 0.03||99.0 [gpipe]||-|
|CIFAR-100||93.60 0.18||91.7 [tan2019efficientnet]||-|
|Pets||96.76 0.18||95.9 [gpipe]||97.1 [dat]|
|Flowers||99.69 0.01||98.8 [tan2019efficientnet]||97.7 [dat]|
|VTAB (19 tasks)||76.65 0.11||71.7 [vivi]||-|
3.3.2 Downstream Fine-Tuning
We desire a low per-task adaptation cost. We therefore run a single hyperparameter setting for each downstream task. However, due to different resolutions and dataset sizes, identical hyperparameters will not work well across all tasks. To address this we provide a heuristic setting, BiT-hyperparam, to determine all downstream hyperparameters. Of the hundreds of hyperparameter choices, BiT-hyperparam selects the most important based on the task’s image resolution and training set size.
The logic for BiT-hyperparam is summarized in Table 1. For all the tasks, we set the initial learning rate to 0.003 and batch size to 512. We resize input images smaller than pixels to pixels, and then take a random crop of pixels. We resize larger images to and take a -sized crop for BiT-S and BiT-M. For BiT-L, we take a crop out of images. See Appendix 0.B for more details about how we apply horizontal flips and random crops.
We define three task regimes: we call small tasks those with fewer than 20 k labeled examples, medium those with fewer than 500 k, and any larger dataset is a large task. We fine-tune BiT for steps on small tasks, for 10 k steps on medium tasks, and for 20 k steps on large tasks. During fine-tuning, we decay the learning rate by a factor of 10 at 30%, 60% and 90% of the training steps. Finally, we use MixUp [mixup], with , for medium and large datasets.
|ILSVRC-2012||CIFAR-10||CIFAR-100||Pets||Flowers||VTAB-1k (19 tasks)|
3.4 Evaluation on Standard Computer Vision Benchmarks
We evaluate BiT-L on standard benchmarks and compare its performance to the current state-of-the-art results (Table 2). Here we separate models that perform task-independent large-scale pre-training (“general” representations), from those that perform task-dependent large-scale pre-training (“specialist” representations). The specialist methods condition on a particular task, for example ILSVRC-2012, then train using a large support dataset, such as JFT-300M [dat] or Instagram-1B [yalniz2019billion]. Further details are discussed in Section 5. Specialist representations are highly effective, but require a large training cost per task. By contrast, generalized representations require large-scale training only once, followed by a cheaper adaptation phase.
BiT-L outperforms previously reported generalist SOTA models, and in almost all cases, specialist models. BiT-L model achieves these results without extensive hyperparameter tuning for each individual dataset: we use BiT-hyperparam for all our runs (including 19 VTAB tasks) and do not perform costly tuning.
Inspired by strong results of BiT-L trained on the in-house JFT-300M dataset, we draw our attention towards the public ImageNet-21k dataset. This dataset is more than 10 times bigger than the widely used ILSVRC-2012, but is mostly overlooked by the research community. In Table 3 we demonstrate that BiT-M trained on ImageNet-21k leads to substantially improved visual representations compared to the same model trained on ILSVRC-2012 (BiT-S), as measured by all our benchmarks.
In our detailed analysis, in particular in Section 4.2, we discuss pitfalls that may have hindered wide adoption of ImageNet-21k as a dataset model for pre-training and highlight crucial components of BiT that enabled success on this large dataset.
3.5 Evaluation On Low-data Regime
In this section we study how many labeled samples are needed to effectively transfer BiT-L to a new downstream task. To this end, we fine-tune our model on small subsets of downstream training data. We test BiT-L in a low data regime on a broad set of tasks, including ILSVRC-2012, CIFAR-10, and CIFAR-100 datasets and the recently introduced VTAB-1k that consists of 19 different downstream tasks.
Note that the goal here is similar to that in semi-supervised learning — in both cases, we want to attain high performance using fewer examples per class. Importantly, we use extra labelled out-of-domain data, whereas many of these methods leverage extra unlabelled in-domain data, so the results are not directly comparable. Nevertheless, it is interesting to compare relative benefits of leveraging generic labelled out-of-domain data versus unlabelled in-domain data.
Figure 2 (left half) shows how performance of BiT-L on ILSVRC-2012, CIFAR-10, and CIFAR-100 depends on the number of available labelled samples per class. Multiple points with the same amount of training data correspond to different random data subsamples (we evaluate 5 random subsamples for each examples-per-class configuration). Surprisingly, even with very few samples per class, BiT-L demonstrates strong performance and quickly approaches performance of the full-data regime. In particular, with just 1 labeled sample per class it achieves top-1 accuracy of 74.3% on ILSVRC-2012 and with 25 samples the top-1 accuracy goes to 86.4%. On the CIFAR-100 dataset, we achieve 85.0% with just 10 samples per class.
For reference, on the right side of Figure 2, we show the results from the semi-supervised learning community. Even though the results are not directly comparable, one can assess relative benefits of both approaches.
Further, Figure 3 shows the performance of BiT-L on VTAB-1k that consists of 19 downstream tasks with only 1000 training samples for each task. Overall, BiT-L with BiT-hyperparam outperforms the previously reported state-of-the-art on VTAB-1k [vivi]. When looking into performance of VTAB-1k task subsets, our model is the best on natural and specialized tasks. However, the recently-proposed VIVI-Ex-100% [vivi] model that employs video-data during upstream training shows better results on the structured tasks.
4 Detailed Analysis
In this section we perform detailed analysis of various components of BiT. In particular, we demonstrate the importance of model capacity, discuss practical optimization caveats, choice of normalization layer, and hyperparameter selection. We also analyze the effect of potential image duplication between upstream datasets used for pre-training and downstream test sets used for evaluation.
4.1 Big Models and Big Data
The general consensus is that larger models result in better upstream and downstream performance [kornblith2018better]. To investigate the effects of model capacity and upstream dataset size on downstream performance, we train different ResNet architectures on three upstream datasets: ILSVRC-2012, ImageNet-21k and JFT-300M, and evaluate them on four downstream benchmarks (Figure 4). We opt for training the following models that have different capacity: ResNet-50x1, ResNet-50x3, ResNet-101x1, and ResNet-101x3. For the largest dataset, JFT-300M, we also train an extra large ResNet-152x4 model.
The gain from using large models is much more pronounced when pre-training on larger datasets. When pre-training on ILSVRC-2012, the benefit from larger models is significant, but quickly diminishes. However, improvements from scaling up the architecture are much more pronounced for larger datasets, such as Imagenet-21k or JFT-300 M. A similar effect is observed when training on Instagram hashtags [mahajan2018exploring].
There is a second effect: not only is there limited benefit of training a large model on a small dataset, but there is also limited (or even negative) benefit from training a small model on a large dataset. Interestingly, the ResNet-50x1 model trained on the JFT-300M dataset performs worse or similar to the same architecture trained on the much smaller ImageNet-21k dataset. Thus, if one employs only architectures with usual capacity, one may conclude that scaling up data size does not bring any additional benefits. However, with larger architectures, such as ResNet-101x3, models pre-trained on JFT-300M significantly outperform those pre-trained on ILSVRC-2012 or ImageNet-21k.
Crucially, we also observe that large pre-trained models lead to improved results even on small downstream datasets. On VTAB-1k, which measures the average accuracy across 19 tasks with only 1000 training samples per task (Figure 4, lower right corner), the largest models also result in the best performance. It seems remarkable that it is not only possible to fine-tune such large models using comparatively little data, but that it leads to strong results.
4.2 Optimization on Large Datasets
For standard computer vision datasets such as ILSVRC-2012, there are well-known training procedures that are robust and lead to good performance. Progress in high-performance computing has made it feasible to learn from much larger datasets, such as ImageNet-21k, which has 14.2M images compared to ILSVRC-2012’s 1.28M. However, there are no established procedures for training from such large datasets. This section aims to tackle this shortcoming and provide pointers for training models on large datasets.
We first note that sufficient computational budget is crucial for training performant models on large datasets. The standard ILSVRC-2012 training schedule processes roughly 100 million images (1.28M images 90 epochs). However, if the same computational budget is applied to ImageNet-21k, the resulting model leads to worse performance on the ILSVRC-2012 validation set (Figure 5, bottom-left section of the leftmost plot). Nevertheless, as shown in the same figure, by increasing the computational budget, we not only recover ILSVRC-2012 performance, but significantly outperform it (we increased computational budget by factor 3 and 10 in the plot). Training with very large datasets such as JFT-300M may require extra patience. The validation error may not improve over a long time (Figure 5 middle plot, “8 GPU weeks” zoom-in), even though the model is still improving as evidenced by looking at a 4x longer time window.
Another important aspect of training with large datasets is the weight decay parameter. Lower weight decay can result in an apparent acceleration of convergence (Figure 5 rightmost plot, weight decay 1e-5). However, this setting eventually results in an under-performing final model. This counter-intuitive behavior stems from the interaction of weight decay and normalization layers [laarhoven17b, li2019exponential]. Low weight decay results in growing weight norms, which in turn results in a diminishing effective learning rate. Initially this effect creates an impression of faster convergence, but it eventually prevents further progress. A sufficiently large weight decay is required to avoid this effect. Throughout the paper, for upstream training, we use the standard weight decay of [he2016deep, he2016identity].
Finally, we note that in all of our experiments we use stochastic gradient descent with momentum without any modifications. This is because, in our preliminary experiments, we did not observe clear benefits from using more involved adaptive gradient methods during upstream training.
4.3 Large Batches, Group Normalization and Weight Standardization
|Plain Conv||Weight Std.|
|Plain Conv||Weight Std.|
Currently, training on large datasets is only feasible using many hardware accelerators. Data parallelism is the most popular distributions strategy, and this naturally entails large batch sizes. Many known algorithms for training with large batch sizes use Batch Normalization (BN) [ioffe2015batch] as a component [Goyal2017AccurateLM] or even highlight it as the key instrument required for large batch training [de2019bn].
We also want to train large models, see Section 4.1. This significantly increases the memory requirement for any single accelerator chip, which necessitates small per-device batch sizes. However, it is known that models using BN perform worse when the number of images on each accelerator is too low [ioffe2017renorm]. An alternative strategy is to accumulate BN statistics across all of the accelerators. However, this strategy has two major drawbacks. First, computing BN statistics across large batches has been shown to harm generalization [de2019bn]. Second, using global BN requires many aggregations across accelerators which incurs significant latency.
We therefore investigated alternatives to BN in ResNets, specifically Group Normalization (GN) [wu2018group] and Weight Standardization (WS) [qiao2019weight]. In our experiments we observe that combining GN and WN standardization recovers BN generalization performance and is stable when used for training with large input batches.
We investigated how these methods scale when using 128 accelerator chips and a batch size of 4096. We find that GN alone does not scale well to such large batches. We observe a performance drop of on ILSVRC-2012 top-1 accuracy when using GN compared to using BN with a ResNet-50x1. However, the addition of WS enables GN to scale to such large batches, even outperforming BN. Table 5 summarizes these results.
We are not only interested in upstream performance, but also how models trained with GN and WS transfer. We transferred models with different combinations of BN, GN, and WS pre-trained on ILSVRC-2012 to the 19 tasks defined by VTAB-1k. Table 5 summarizes our results, which indicate that the GN and WS combination transfers significantly better than BN. We therefore use GN and WS in all of our BiT models.
4.4 Tuning hyperparameters for transfer
Throughout the paper we evaluate BiT using BiT-hyperparam. Here, we investigate whether BiT-L would benefit from additional computational budget for selecting fine-tuning hyperparameters.
For this investigation we use VTAB-1k as it contains a diverse set of 19 tasks. For each task we fine-tune BiT-L 50 times. Each trial uses randomly sampled hyperparameters, such as learning rate, number of updates and dropout rate for the penultimate layer. The full search space is provided in Appendix 0.A. We select the best model for each dataset using the validation set and report results on the test set.
Overall, our random search improves performance over BiT-hyperparam on VTAB-1k by 4.5%. Figure 6 shows how VTAB performance improves with the number of random hyperparameter trials. We show mean accuracy across all 19 tasks from VTAB and the mean accuracy across each group of VTAB tasks. Since the order in which we pick the hyperparameters matters, we generate 100 random orderings of our hyperparameter trials and report average performance over these orderings. The standard error is shown as a shaded blue region around the mean.
Surprisingly, we conclude that with 10 to 20 hyperparameter trials, nearly optimal results can already be achieved. This means that if a practitioner wants to achieve the best possible score on their dataset, they will likely require only a modest computational budget for hyperparameter search.
4.5 Duplicates and near-duplicates
In order to make sure that our results are not inflated due to overlap between upstream training and downstream test data, we run extensive de-duplication experiments. For each upstream training dataset (JFT-300M, ImageNet-21k, and ILSVRC-2012) we remove all near-duplicates from our evaluation test sets and re-evaluate the best model on the de-duplicated test sets. The results are shown in Table 6: “Full” is the accuracy on the original test set, “Dedup” is the accuracy on the test set without near-duplicates, and “Dups” is the number of near-duplicates that have been removed. We observe that near-duplicates barely affect our results, so we report the results on the full test sets throughout the paper for comparability to previously published results. Note that near-duplicates between training and test sets have previously been reported by [sun2017revisiting] for ILSVRC-2012, and by [barz2019duplicates] for CIFAR.
In Section 0.D in the appendix we present a few duplicates found between the ILSVRC-2012 training set and our downstream datasets.
|From JFT||From ImageNet21k||From ILSVRC-2012|
5 Related Work
5.0.1 Large-scale Weakly Supervised Learning of Representations
A number of prior works use large supervised datasets for pre-training visual representations [joulin2016learning, sun2017revisiting, li2017learning, mahajan2018exploring]. In [joulin2016learning, li2017learning] the authors use a dataset containing 100M Flickr images [thomee2015yfcc100m]. This dataset appears to transfer less well than JFT-300M, which could be due to the limited domain of the data, noise in the text annotations, or architectures used. In [sun2017revisiting], authors train on JFT-300M. This paper focuses on the effect of dataset size, and shows that transfer performance increases when using this large dataset, despite reporting a large degree of noise (20% precision errors) in the JFT-300M labels. An even larger labelled dataset of 3.5B Instagram images is used in [mahajan2018exploring]. The labels consist of noisy user-generated hashtags. A stronger ILSVRC-2012 performance of is reported in [mahajan2018exploring], compared to in [sun2017revisiting]. The authors claim that the improvement is due to the larger dataset and better architecture (ResNeXt [xie2017aggregated]). We show that we can attain better performance again with ResNet on the JFT-300M dataset using appropriate adjustments presented in Section 2. These papers focus on transfer to ImageNet classification, and COCO or VOC detection and segmentation. We show that transfer is also highly effective in the low data regime, and also works well on the broader set of 19 datasets in VTAB [zhai2019visual].
5.0.2 Specialized Representations
Rather than pre-train generic representations, recent works have shown strong performance by training task-specific representations [yalniz2019billion, dat, noisystudent]. These papers condition on a particular task when training on a large support dataset. [yalniz2019billion, noisystudent] train student networks on a large unlabelled support dataset using the predictions of a teacher network trained on the target task. [dat] compute importance weights on the a labelled support dataset by conditioning on the target dataset. They then train the representations on the re-weighted source data. Even though these approaches may lead to superior results, they require knowing the downstream dataset in advance and substantial computational resources for each downstream dataset.
5.0.3 Unsupervised and Semi-Supervised Representation learning
Self-supervised methods have shown the ability to leverage unsupervised datasets to transfer to labelled tasks. For example, [he2019momentum] show that unsupervised representations trained on 1B unlabelled Instagram images transfer comparably or better than supervised ImageNet features to COCO, VOC, LVIS, and Cityscapes detection and segmentation. Semi-supervised learning exploits unsupervised data drawn from the same domain as the labelled data. [berthelot2019remixmatch] used self-supervised learning to attain strong performance on CIFAR-10 and SVHN using only 250 labels. Recent works combine self-supervised and semi-supervised learning to attain good performance with fewer labels on ImageNet [zhai2019s4l, henaff2019data]. [zhai2019visual] study many representation learning algorithms — unsupervised, semi-supervised, and supervised — and evaluate their representation’s ability to generalize to novel tasks. This paper shows that a combination of supervised and self-supervised representations works best. However, all models evaluated in that paper were trained on ILSVRC-2012. We show that supervised pre-training on larger datasets continues to be effective on diverse tasks.
5.0.4 Few-shot Learning
Many strategies have been proposed to attain good performance when faced with novel classes, and that use only a few examples per class. Meta-learning or metric-learning techniques have been proposed to learn with few or no labels, such as [vinyals2016matching, snell2017prototypical, sung2018learning]
. However, recent papers have shown that a simple linear classifier on top of pre-trained representations or fine-tuning can attain similar or better performance[chen2019closerlook, nakamura2019revisiting]. The upstream pre-training and downstream few-shot learning are usually performed on the same domain, with disjoint class labels. Our goal is to find a generalist representation which works well when transferring to many downstream tasks. Thus we do not force classes to be disjoint during train and test, but rather focus on the effectiveness of transferring general representations to many downstream tasks from multiple domains.
We have revisited classical transfer learning, where a large pre-trained generalist model is fine-tuned to downstream tasks. We provide a simple recipe which exploits large scale pre-training to yield good performance on all of these tasks. BiT uses a clean training and fine-tuning setup, with a small number of carefully selected components, to balance complexity and performance.
Figure 7 shows all of BiT-L’s mistakes on CIFAR-10, and some examples from ILSVRC-2012. Visualizing these mistakes, we can see that many of these label/prediction mismatches are not true ‘mistakes’. In many cases, the model’s classification is valid — but it does not match the label. For example, the model may identify another prominent object when there are multiple objects in the image, or may provide an valid classification when the main entity has multiple attributes. There are some cases of label noise, where the model’s prediction is a better fit than the ground-truth label. In Figure 7 we can see that around half of the model’s mistakes on CIFAR-10 are due to ambiguity or label noise. We illustrate mistakes for more downstream datasets in Appendix 0.C. Overall, by inspecting the mistakes we observe that performance on the standard vision benchmarks seems to have approached a saturation point.
We therefore also explore the effectiveness of transfer to two classes of more challenging tasks: classical tasks, but with very few labels to adapt to the new domain, and VTAB, which contains more diverse tasks, such as spatial localization, tasks from simulated environments, and medical imaging tasks. These benchmarks are much further from saturation, and BiT-L yields strong performance on both.
In the future we plan to further investigate transfer learning in the low data regime and look deeper into non-standard computer vision tasks, which pose new challenges and require holistic image understanding.
Appendix 0.A Hyperparameters for random search
In Section 4.4 we use random hyperparameter search for analysis of performance headroom. Our random search includes following hyperparameters with the following ranges and sampling strategies:
Initial learning rate is sampled log-uniformly from the range .
Total number of updates is sampled from the set .
Dropout rate for the penultimate layer is uniformly sampled from the range .
Weight decay to the initial weight values is sampled log-uniformly from the range .
MixUp parameter is sampled from the set .
Input image resolution is sampled from the set .
Appendix 0.B Horizontal flipping and cropping for VTAB-1k tasks
When fine-tuning BiT models, we apply random horizontal flipping and cropping as image augmentations. However, these operations are not reasonable for certain VTAB tasks, where the semantic label (e.g. angle, location or object count) is not invariant to these operations.
Thus, we disable random horizontal flipping as preprocessing for dSprites/orientation, SmallNORB/azimuth and dSprites/location tasks. Random cropping preprocessing is disabled for Clevr/count, Clevr/distance, DMLab, KITTI/distance and dSprites/location tasks.
Appendix 0.C All of BiT-L’s Mistakes
Here we show all mistakes made by BiT-L for Pets, Flowers and CIFAR-10. As in the main paper, the upper word shows the model’s prediction, while the lower word shows the ground-truth label. CIFAR-10 mistakes are shown in the main paper and are thus omitted here. The larger panels are best viewed on screen, where they can be magnified.