Exploring the Limits of Large Scale Pre-training

10/05/2021
by   Samira Abnar, et al.
Google
5

Recent developments in large-scale machine learning suggest that by scaling up data, model size and training time properly, one might observe that improvements in pre-training would transfer favorably to most downstream tasks. In this work, we systematically study this phenomena and establish that, as we increase the upstream accuracy, the performance of downstream tasks saturates. In particular, we investigate more than 4800 experiments on Vision Transformers, MLP-Mixers and ResNets with number of parameters ranging from ten million to ten billion, trained on the largest scale of available image data (JFT, ImageNet21K) and evaluated on more than 20 downstream image recognition tasks. We propose a model for downstream performance that reflects the saturation phenomena and captures the nonlinear relationship in performance of upstream and downstream tasks. Delving deeper to understand the reasons that give rise to these phenomena, we show that the saturation behavior we observe is closely related to the way that representations evolve through the layers of the models. We showcase an even more extreme scenario where performance on upstream and downstream are at odds with each other. That is, to have a better downstream performance, we need to hurt upstream accuracy.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 28

page 29

page 34

page 35

03/10/2022

Knowledge Distillation as Efficient Pre-training: Faster Convergence, Higher Data-efficiency, and Better Transferability

Large-scale pre-training has been proven to be crucial for various compu...
11/30/2021

Task2Sim : Towards Effective Pre-training and Transfer from Synthetic Data

Pre-training models on Imagenet or other massive datasets of real images...
10/15/2021

Don't speak too fast: The impact of data bias on self-supervised speech models

Self-supervised Speech Models (S3Ms) have been proven successful in many...
03/22/2022

WuDaoMM: A large-scale Multi-Modal Dataset for Pre-training models

Compared with the domain-specific model, the vision-language pre-trainin...
03/17/2022

POLARIS: A Geographic Pre-trained Model and its Applications in Baidu Maps

Pre-trained models (PTMs) have become a fundamental backbone for downstr...
05/31/2021

Effect of large-scale pre-training on full and few-shot transfer learning for natural and medical images

Transfer learning aims to exploit pre-trained models for more efficient ...
03/15/2022

Bamboo: Building Mega-Scale Vision Dataset Continually with Human-Machine Synergy

Large-scale datasets play a vital role in computer vision. Existing data...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Recent impressive progress on transfer and few-shot learning suggests an emerging direction that scaling up models and training them on a huge corpus of data is the main obstacle towards better performance on downstream tasks with less or no data. One prominent example is brown2020language

where they show that GPT-3 

(brown2020language), which is a large transformer model (dosovitskiy2020)

trained on a large corpus of data, achieves substantial performance on many natural language processing (NLP) tasks and benchmarks in few-shot settings. On image recognition tasks, training on Instagram images and JFT-300 

(sun2017-jft) has been proven to be very effective in transfer and few-shot settings (mahajan2018exploring; goyal2021self; kolesnikov2019big; pham2020meta; dosovitskiy2020). Even when no example is provided (zero-shot), CLIP (radford2021learning) which is a deep model trained with a contrastive loss on 400 million image-text pairs from the internet can achieve remarkable performance (radford2021learning).

All above developments implicitly encourage two consistent views: 1) scaling up the model and data size improves the performance significantly; 2) the performance improvement transfers to downstream tasks in a desirable way. A more focused empirical study in support of the first view (kaplan2020scaling) shows that scaling up the model size, data, and compute appropriately in the language modeling task results in a non-saturating return in performance. (bello2021revisiting; tan2019efficientnet) show that favorable scaling can be achieved in image recognition tasks as well. The second view has also been a subject of recent focused studies. (hernandez2021scaling) showed that favorable scaling laws similar to that of (kaplan2020scaling) holds in transfer and few-shot settings in NLP tasks. In, perhaps, closest prior work to ours, kornblith2019better shows linear relationship111

The linear relationship can be achieved after proper logit scaling of accuracy values.

between performance on ImageNet 

(russakovsky2015imagenet) and downstream image recognition tasks.

Adopting the above views has major implications moving forward. These views suggest that spending compute, money and research effort on scaling up one corpus would pay off because that would enable us to solve many downstream tasks almost for free. It also means while improving our upstream performance, we do not need to be worried about downstream tasks as their improvement is predictable based on the linear trend. The above works provide a compelling story and there are tempting practical motivations to adopt these views. However, the aforementioned studies all suffer from a major shortcoming: due to compute limitations, performance for different choices of hyper-parameter values are not reported. Scaling plots seem more favorable if the hyper-parameter chosen for each scale is fixed or determined by a simple scaling function. Moreover, such plots show more promising scaling if most of the effort in hyper-parameter selection has been on a higher scale. This might naturally happen because most researchers are focused on improving state-of-the-art results and computational budget is limited. However, when studying scaling, we are concerned about the best performance of models given all possible values for the hyper-parameters.

One also needs to be careful that the aforementioned works study a scaling behavior within a limited range, and simply extrapolating that scaling without further understanding of the dynamics of scaling can be detrimental as there is no reason, a priori, for the scaling to hold outside of the studied range.

In this paper, we investigate the transferability of improvements on a large-scale upstream task to a large number of downstream tasks. To attempt to address the above shortcomings, part of our work is a meta-study of more than 5500 Vision Transformer(ViT) (dosovitskiy2020) models trained on either JFT (sun2017-jft) with 303M images and 18k classes or ImageNet21k (

deng2009-imagenet

)

with 14M images and 21k classes on a variety of downstream datasets for few-shot and transfer learning settings. Our downstream tasks cover a wide range of standard datasets that are included in benchmarks like VTAB 

(zhai2019large), MetaDataset (triantafillou2019meta), Wilds (koh2020wilds) and medical imaging benchmark.

Figure 1: Performance of upstream vs downstream (8 different tasks) based on more than 3K different ViT with different configurations, pre-trained on JFT and evaluated on few-shot (25 shots). We connect the Pareto frontiers via a line to highlight the models with the best downstream performance compared to all the others with similar upstream performance (maximum positive transfer).
Contributions

Our main contributions in this paper are as follows:

  • Our main observation is that as we improve the performance of upstream (US) task either by scaling up or hyper-parameter and architectural choices, the performance of many downstream (DS) tasks starts to saturate and the saturating behaviour is typical among DS tasks (Section 2).

  • We study how scaling up model size, data size, and compute affects the relationship between US and DS performance (Section 2.1).

  • We investigate reasons behind the DS performance saturation and show that this behavior is highly related to the usefulness of feature representation in different layers (Section 2.2).

  • To further understand the discrepancy between upstream and downstream tasks, we showcase how the optimal head hyper-parameters are different for US and DS and uncover the reason behind this discrepancy (Section 2.2).

  • Finally, we show how our observations are robust to several choices such as size of DS, choice of common scalings of accuracy, etc (Section 4).

1.1 Related Work

Large scale transfer learning by pre-training on JFT (kolesnikov2019big; dosovitskiy2020; mustafa2021supervised; puigcerver2020scalable; ngiam2018domain; tay2021omninet) or ImageNet21k (dosovitskiy2020; kolesnikov2019big; mustafa2021supervised; puigcerver2020scalable; zhai2019large; arnab2021vivit) has been done extensively.mensink2021factors

considers a two-step transfer chain, where the model is pre-trained on ImageNet, fine-tuned on the source task and then transferred to the target task. Then they look into effect of different hyperparameters on this transfer chain. They conclude that the effect of transfer learning vanishes as the target domain size increases.

raghu2019transfusion investigates the performance of models pre-trained on ImageNet when they are used to transfer to medical images. neyshabur2020being also studies transfer learning from models trained on ImageNet and note that to achieve improved accuracy from pre-training, one does not need to train too long on ImageNet. The closest work to ours is that of (kornblith2019better). They claim that performance on ImageNet (russakovsky2015imagenet) linearly translates to performance on DS. We emphasize that in most of these studies, the conclusions were made based on experiments that are not extensive enough to capture the big picture needed to understand the scaling. Without a complete sweep of all the hyper-parameters, it is easy to miss the Pareto frontier and focus on a limited range of accuracies.

1.2 Setup

Discussions and analyses of this paper are based on a study on an exhaustive number of large-scale experiments on image recognition tasks, as well as a set of controlled experiments we conducted to ablate our setup and deepen our understanding of the studied phenomena. We investigate more than 5500 experiments with Vision Transformer (ViT), when pre-trained on a large amount of data in a supervised fashion and evaluated on several downstream image recognition tasks through few-shot learning and fine tuning. These experiments vary in terms of the upstream dataset (either JFT-300M with 300M images (sun2017-jft) or ImageNet21k (deng2009-imagenet)

with 14M images), model size and shape (different hyper-parameters of the architecture), optimization (e.g. different learning rate values and scheduling, different weight decays, different optimizers), compute (e.g. number of epochs) and other knobs that researchers changed during development of models for chasing state-of-the-art results on vision tasks.

We emphasize that the large set of experiments we investigate are not trained for the purpose of this paper, rather, we have aggregated different ViT models trained by different researchers for different purposes and we do a meta-study on it. This, in fact, positions this meta-study in a unique spot, as first of all, it may not be feasible to run such a number of large-scale trials for the purpose of studying particular phenomena, neither financially, nor in terms of environmental impacts. Moreover, no implicit or explicit assumption was made in these experiments with respect to the type of analysis we conducted on them afterward, hence minimizing the systematic biases of the analysis process in the findings.

In the experiments we run ourselves, we mainly use ViT-B/32, which is the base model with patch size

, (we also have tiny, small, and large models for the controlled scaling experiments). We pre-train our models on JFT for 7 epochs and evaluate on more than 20 tasks. For the downstream evaluation, we mainly focus on few-shot learning setup (1, 5, 10, and 20 shots) as well as fine-tuning for some of the ablations. In both aggregated and controlled experiments, in the few-shot setup, a linear classifier is trained on top of the representations from the frozen pre-trained model, given only a fixed number of training examples per class. In the fine-tuning setup, we use the whole training set of the downstream task and update all the parameters of the model besides the downstream head. The details on dataset and training appear in the Appendix 

D.

In the paper, in favor of space, we report the results over eight datasets and provide results and plots that include more than 20 tasks in the supplementary material.

2 The diminishing benefit of scaling up in transfer learning

First, we examine the transferability of improvements on US task to a variety of DS tasks. To do so, we consider the plot of DS vs US performance for the large set of experiments we discussed earlier. Next, we look into the effects of scaling up in the three axes of model size, US data size, and compute, as well as varying the number of shots on DS performance. Finally, we investigate the observed phenomena and provide insights.

2.1 Observations

Figure 2: Effect of controlled scale up with respect to the model size (number of parameters), data size (portion of the pretrained data), and (compute epochs) on different downstream tasks.

Figure 1 shows DS vs US performance for more than experiments where Vision Transformer (ViT) models are trained on JFT and evaluated on a set of DS tasks in the few-shot setting ().222A similar plot with all the 5500 experiments (trained on JFT or ImageNet21K), for both or shots, can be found in the Appendix. In each plot in Figure 1, we draw a line and connect the set of Pareto efficient points forming the frontier in the plot. Pareto frontier is a widely used tool to investigate trade-offs between two competing performance metrics. Here, we refer to the set of experiments with the highest DS accuracy given all the experiments with similar US accuracy as the Pareto efficient experiments.

In Figure 1, we observe that the DS vs US accuracy plots saturate for many DS tasks and better US performance does not transfer well to better DS performance in higher US accuracies. We note that unlike the common belief, the saturating behaviour is not an exception, but typical among DS tasks. Essentially, all DS tasks display saturating behaviour and whether or not it happens earlier depends on similarity of DS task to that of US. To understand this, it is important to look at the Pareto frontier, otherwise, we might miss the saturation behaviour. As mentioned above, despite the common belief, improvements in US performance does not always lead to performance improvement on DS. Essentially, the goal in transfer learning is to reach the best possible performance on DS for a given US accuracy. Since we are subjecting DS performance to US performance, the Pareto frontier is an appropriate tool for this analysis.

If we only look at the areas with higher densities in the DS vs US accuracy plots in Figure 1, they seem to follow linear trends. Whereas the areas with lower densities above those lines tell us that it is possible to achieve better DS performance with the same US performance. This means if we perfectly tune our models and training setup for a given DS task, we will either saturate or the slope of the curve becomes much less than we would expect otherwise.

Additionally, we compared the DS vs US accuracies for the same set of DS tasks, for a model trained with different US datasets (Imagenet21k vs JFT), and when different number of shots are used for transfer. We find that the DS accuracy at saturation can depend on the US dataset. We also observe that the DS accuracy saturation happens on the same US accuracy independent of the number of shots Another observation is that more data is not always better. For example, JFT has more images than ImageNet21K but it does not transfer better on most DS tasks. (see Appendix C)

Figure 2 depicts how DS vs US accuracy changes as we increase US dataset size (from to of JFT), number of parameters of the model (ViT-Tiny, ViT-Small, Vit-Base, ViT-Large) and number of epochs (7 and 14 epochs). Since we are in the under-parametrized regime and far from saturating on JFT dataset, the effect of increasing data size is equivalent to increasing training time and the performance on the US keeps increasing as we increase the training time (nakkiran2020deep).

We note that the DS vs US accuracy has different trends in different DS tasks when scaling up dataset size, model size and number of epochs (Figure 2). For some DS tasks the performance saturates quicker and beyond that, improving performance on the US does not lead to significant improvement on DS, for instance, colorectal histology (col_hist) dataset333https://www.kaggle.com/kmader/colorectal-histology-mnist/ and UC-Merced land use dataset 444https://usdahsi.ucmerced.edudatasets/landuse.html.

Furthermore, similar to what we saw in Figure 1, for some of the DS tasks, the benefit of scaling up diminishes gradually, for instance for Cars KrauseStarkDengFei-Fei_3DRR2013 or Caltech101 fei2004learning datasets.

Figure 3: Relation between the quality of representations from different layers on the downstream tasks with the effect of scaling (model, data, and compute) on downstream performance. The red triangles in the plots are performance on downstream task when representation used in the few-shot learning is from different layers of the model. The green circles in the plots overlay the US versus DS performance of different experiments from Figure 2 on each task.

2.2 Investigations

Performance saturation can be caused by non-overlapping distributions where US distribution does not cover the DS distribution and increasing data size on US does not ensure diversity. As a simple case, consider the scenario where US task has a uniform distribution in the interval

, and DS task has a uniform distribution in

. In this scenario, no matter how much we increase data size (or compute and model capacity) for US case, there are parts of DS distribution that will not be covered and some part of US data does not help DS performance at all. Therefore, we expect the DS vs US performance plot to saturate. On the other hand, if US and DS distributions were two Gaussian distributions with means that are close to each other, both distributions cover the whole real line and one would expect that increasing US samples leads to performance improvement on DS. Similarly, when DS distribution support is a subset of US distribution, we expect the same trend. We would not expect perfect performance on DS due to loss caused by distribution shift, however, we would expect a linear relationship in the DS vs US performance plot.

To empirically investigate if early saturation is an indicator of DS distribution not being fully covered by US (or simply put as distribution differences), we measure the performance of few-shot classifiers when applied on top of representation from different layers of the pre-trained model. We hypothesise that the depth of the earliest layer that leads to the best performance for a given DS task is a proxy of the difference between US and DS, and an indicator of how much the DS task will benefit from scaling up the compute.

We observe that, for those DS tasks that are similar to US, such as ImageNet, the higher the representation layer the better performance on DS. On the contrary, For those DS tasks that saturate fast, i.e., do not follow performance improvement on US such as UC-Merced land use dataset and colorectal histology (col_hist), the optimal layer is not the last one. That is, e.g., for col_hist if we choose the head at layers 5 or 6 we achieve better performance compared to the pre-logit layer. Figure 3 presents this result.

Bringing the two discussions together, performance saturates on DS happens when the pre-trained network lacks the fine-grained features required to perform well on DS due to non-overlapping distributions. As discussed in (yosinski2014transferable; neyshabur2020being), lower layers capture lower level features that are more common across different dataset and tasks, whereas fine-grained features reside at top layers in the network. In addition, examples that are learned in higher layers are learned later in training with lower confidence and higher uncertainty(exampledifficultytemp). Therefore, one can get similar performance on such DS task when cutting the top layers of the pre-trained model, as seen in Figure 3. The most interesting point about the plots in Figure 3 is that when we overlay the DS vs US accuracy curves on DS accuracy vs layer depth curves, they follow almost exactly the same pattern, which could mean they are both good proxies for capturing the relation between US and DS datasets and both confirm our hypothesize from completely different angles.

Figure 4: The effect of increasing head weight decay in performance of upstream versus performance of downstream (all shots). Note that not only the optimum value of head WD for upstream and downstream is different, but also the optimum value changes for different downstream tasks.

3 Discrepancies between US and DS performances: a case study

In the last section, we observed that there exist cases where increase in the US accuracy does not translate to performance improvement in DS.

It is important to understand the hyper-parameters that drive the models towards the Pareto frontier (the dashed lines in Figure 1), i.e., higher DS performance for the same or worse US performance.

In this section, we take a closer look at the effect of the head (the projection layer). We present cases where there are discrepancies between US and DS performances when we change head hyper-parameters. Specifically, we show that by changing weight decay and learning rate of the head at pre-training, we can drive the models towards the Pareto frontier line by trading US accuracy for a better DS accuracy on some tasks.

Figure 5: The effect of increasing head learning rate in performance of upstream versus performance of downstream (20 shots).
Figure 6: Optimal head weight decay for each DS task for different number of shots. For different DS tasks, the optimum value of head weight decay is different.

3.1 Effect of head weight decay

Figure 4 shows the performance for DS when we increase US head weight decay. In this experiment, the weight decay for the rest of the network is kept at 0.01. We observe that

  1. [leftmargin=*,label=,noitemsep,partopsep=0pt,topsep=0pt,parsep=0pt]

  2. For US, increasing the head weight decay up to a threshold (optimum head WD) improves the performance on US and increasing it beyond that threshold leads to over regularization and worse performance.

  3. The optimum value for head WD is different for US and different DS tasks. This means, there are cases where increasing WD on US head, results in deteriorating performance on US but improves on some DS tasks. Therefore, head weight decay is an important hyper-parameter and should be optimized for DS.

  4. The optimal head weight decay for different DS tasks can be very different, i.e., If we take different DS tasks into account when tuning the value for this hyper-parameter we will end up with different optimal values. This is illustrated in Figure 6. In other words, there are cases where increasing or decreasing US head WD results in improved performance for a DS task and degraded performance for another DS task. Therefore, one cannot simply save a checkpoint of a model pre-trained on an upstream task and use it for all downstream tasks.

  5. The optimal weight decay for DS is usually higher than the optimal one for US.

  6. The impact of increasing weight decay on the head is more prominent when the number of shots is lower. For example, we observe that the effect is more prominent on 1-shot performance on all DS datasets than on 20-shot performance.

Note that, this phenomenon is robust to the number of training steps in US. If we train for more epochs, the trend of US, DS w.r.t. increasing WD of the head remains the same. See Appendix C.

3.2 Effect of learning rate

Next, we look into the effect of learning rate instead of weight decay. We change the learning rate of the head relative to the learning rate of the rest of the network. In this experiment, the learning rate for the rest of the network is kept at 0.008. We expect to have similar patterns when decreasing head learning rate to that of increasing head weight decay. The observations confirm this intuition. Figure 33 shows the discrepancy between DS (Imagenet and Caltech) and US (JFT) when we change head learning rate.

3.3 Investigations

We investigate the -norm of layers as a proxy of the amount of information stored in them, as we change the head WD. In this experiment, the WD for the rest of the network is kept at 0.01. We observe that as we increase the WD on the upstream task, the norm of the weights in the higher layers increases while it does not change much in lower layers. (See Appendix C.) Figure 8 shows the sum of the norm of all layers before the head as we increase head weight decay. We observed a similar pattern in distance to initialization as we increase head WD, we do not see a change in lower layers, but the distance to initialization increases for higher layers as we increase head WD.

It has been widely discussed that a network’s margin of error (also called prediction margin) can predict its generalization performance well (neyshabur2017exploring; bartlett2017spectrally; jiang2018predicting). We refer to the margin for a single data point as the difference between the score of the correct label and the maximum score of other labels. We report average margin value over train data. The classical notion of margin refers to the scores at the head. More recently, (jiang2018predicting) proposed the notion of margin at different layers which normalizes the score difference by the norm of gradient differences at that layer. Due to the correlation of margin to how well the model separates the data at each layer, in order to investigate this phenomena we look into how head margin and pre-logit (penultimate) layer margin changes as we increase head WD. We observe that as we increase the weight decay, the pre-logit layer margin increases, while the head layer margin decreases; See Figure 8.

Figure 7: Layer norm and layer margin for US as a function of head weight decay. As we increase the head weight decay, the sum of norms of all the layers up to the head, as well as the pre-logit layer margin increases, while the head’s norm and margin decrease. These are both indicators that by increasing head weight decay we push more information down to the layers below, similar to the effect of decreasing the head learning rate.
Figure 8: Optimal weight decay as a function of rank correlation between the performance of US and DS for different DS tasks.

For US, although the head margin decreases with increasing head weight decay, which is also reflected in the performance drop on US (see Figure 4), the margin for pre-logit improves. This shows that the information is being pushed down from head to pre-logit layer.

The above two investigations show that as we increase head weight decay the information is pushed down to layers below the head. Moreover, these are still top layers in the network and the effect does not propagate and affects early layers in the network.

Next, we look into the margin on DS datasets. We note that the margin (calculated on training data) trend completely reflects the accuracy trend on DS test data. Although this is expected in classical machine learning, it is still intriguing that we observe this pattern for a large-scale deep learning model where the margin has occasionally failed to capture generalization performance. We note that for datasets such as ImageNet that are close to JFT, the margin increases, and for datasets that saturate fast, such as Caltech101 and Cars, the margin does not change. See Appendix 

C for DS margin plots.

We observe that as we decrease the head learning rate, the norm of the head decreases while the sum of the norm of other layers increases. See Appendix C. We also observe similar pattern in US margin and norm plots when decreasing head learning rate as to increasing head weight decay. See Appendix C. We note that the effects of these two interventions (increasing head WD, decreasing head LR) are the same. When we increase the head weight decay, as discussed above, we are pushing the key information compressed in the network down to lower layers. On the other hand, when we decrease the head learning rate, we encourage lower layers to be more active and learn more. Both lead to the same impact.

Figure 8 shows the optimal WD as a function of the rank correlation between the performance of US and DS. i.e., we have a list of checkpoints and make two rank lists based on US and DS performance and then calculate the correlation between the two lists. We note that for a DS task optimal WD is high when we have a high correlation between performance on US and DS. The reason is that when the correlation is high, one would want to move all information in the head to the lower layers (since the head is discarded for DS) and not lose any information and layers. Since the head is removed for few-shot transfer, storing more information in the rest of the network leads to better performance in the DS. But when US and DS are different and hence uncorrelated, we do not need a high WD as there is not any information in the head that will help in the DS performance and one can even remove the head and some of the top layers. There is no DS-related information in them.

4 On the generalization of observed phenomena

The phenomena we described in the paper is not limited to the setting reported above. In this section, we discuss that the observations are robust to several changes in the setting.

Number of shots:

The Pareto frontier phenomena and effect of head are robust to the number of shots in the downstream task. This can be seen for example in Figure 6. For additional plots, see Appendix C.

transfer vs. few shot:

In addition to robustness to number of shots in few-shot setting, the phenomena reported in Section 22.2 is consistent across both few-shot and fine-tuning setting. Note that this is not direct implication of the previous assertion. In few-shot setting we are keeping the network the same and only replace the head and train it for the downstream task. In fine-tuning setting however, weights from the rest of the network get also updated give the training data for the downstream task. See Appendix C for additional observations in the finetuning learning setting.

Scaling of plots:

Many of the works that consider transfer performance accuracy or how model accuracy changes by scaling up (kornblith2019better; kaplan2020scaling; hernandez2021scaling), scale the accuracy by passing it through a logit transformation (logit

, i.e., instead of plotting accuracy, plot logit of accuracy. Logit function (which is the inverse of sigmoid function) has the drawback of being sensitive to low values. Meaning that if we plot a range of values that admits values close to zero, the logit plot is mainly influence by values between 0 and 0.15 and the bigger values are collapsed mostly on top of each other. To mitigate this sensitivity, one can instead plot the the second term

. We considered both these scaling options as well as not scaling the accuracies and observed that both phenomena presented in the paper are robust to choice of scaling. The only difference between different scaling functions is that for some datasets the trend changes while for many others, the trend remains the same and the overall phenomena is robust. For corresponding plots to logit and see Appendix C.

Architecture:

It has been widely inspected that in large data regime the role of inductive biases and architecture-specific parameters diminish. This is also observed in (kornblith2019better) that effect of architecture is only observe through performance on US. Therefore, we expect that our results generalizes to other large scale architectures such as ResNet 151 and EfficientNet (tan2019efficientnet) (that is made of CNN blocks).

5 Discussion and Conclusion

We investigated the role of scale in few-shot and transfer learning performance in image recognition task and provided strong empirical evidence that scaling does not lead to a one-model-fits-all solution. One cannot hope to find one pre-trained checkpoint that performs well on all possible downstream datasets. We emphasize the importance of the observation that when extensively considering downstream performance for a given upstream model, we hit a Pareto frontier. This helped in correcting the earlier belief in the community that was missing the big picture.

We assert that we should refrain from focusing on the performance of only one downstream task, which usually ends up being close to the upstream task. Instead, we should make design choices that improve performance on a breadth of downstream tasks. Moreover, as we know, scaling has both monetary and environmental costs (patterson2021carbon). Here we provide guidelines on what to consider in this regard. When investing in terms of scaling in terms of data, model parameters and compute, we should think of an additional axis which is data diversity should be more observant about covering different possible distributions. One cannot guarantee that mindlessly increasing upstream data leads to upstream data covering the downstream distribution. We need to have a closed-loop when working on improving the performance of the downstream and when we get close to saturating focus on our data gathering practices and how to improve data diversity.

We emphasize that our paper focuses on the image recognition task. Extending our results to the natural language domain is the subject of future work. Moreover, we investigate supervised pre-training. We hope to investigate unsupervised pre-training and the role of scale and possible discrepancies between upstream and downstream tasks in that scenario in the future.

References

Checklist

  1. For all authors…

    1. Do the main claims made in the abstract and introduction accurately reflect the paper’s contributions and scope?

    2. Did you describe the limitations of your work?

    3. Did you discuss any potential negative societal impacts of your work?

    4. Have you read the ethics review guidelines and ensured that your paper conforms to them?

  2. If you are including theoretical results…

    1. Did you state the full set of assumptions of all theoretical results?

    2. Did you include complete proofs of all theoretical results?

  3. If you ran experiments…

    1. Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)?

    2. Did you specify all the training details (e.g., data splits, hyperparameters, how they were chosen)?

    3. Did you report error bars (e.g., with respect to the random seed after running experiments multiple times)?

    4. Did you include the total amount of compute and the type of resources used (e.g., type of GPUs, internal cluster, or cloud provider)?

  4. If you are using existing assets (e.g., code, data, models) or curating/releasing new assets…

    1. If your work uses existing assets, did you cite the creators?

    2. Did you mention the license of the assets?

    3. Did you include any new assets either in the supplemental material or as a URL?

    4. Did you discuss whether and how consent was obtained from people whose data you’re using/curating?

    5. Did you discuss whether the data you are using/curating contains personally identifiable information or offensive content?

  5. If you used crowdsourcing or conducted research with human subjects…

    1. Did you include the full text of instructions given to participants and screenshots, if applicable? We did not use crowdsourcing and we did not conduct research with human subjects.

    2. Did you describe any potential participant risks, with links to Institutional Review Board (IRB) approvals, if applicable?

    3. Did you include the estimated hourly wage paid to participants and the total amount spent on participant compensation?

Appendix A Additional Related Work

Large scale transfer learning by pre-training on JFT [kolesnikov2019big, dosovitskiy2020, ryoo2021tokenlearner, mustafa2021supervised, tay2021omninet, puigcerver2020scalable, ngiam2018domain] or ImageNet21K [dosovitskiy2020, kolesnikov2019big, mustafa2021supervised, arnab2021vivit, puigcerver2020scalable, zhai2019large] has been done extensively. mensink2021factors considers a two-step transfer chain, where the model is pre-trained on ImageNet, fine-tuned on the source task and then transferred to the target task. Then they look into the effect of different hyper-parameters on this transfer chain. They conclude that the effect of transfer learning vanishes as the target domain size increases. This is very different from the setting we consider, that is when the size of the target domain is very small (the few-shot setting).

raghu2019transfusion investigate the performance of models pre-trained on ImageNet when they are used to transfer to medical images. They conclude that the family of smaller lightweight convolutional networks performs comparably to standard ImageNet models, despite having significantly worse accuracy on ImageNet. Hence, ImageNet performance is not predictive of medical performance. neyshabur2020being also studies transfer learning from models trained on ImageNet. They note that improved accuracy from pre-training can be achieved in fewer steps of fine-tuning than what is done in practice.

Appendix B Proof of Lemma LABEL:lemma:convex_hull

Proof.

Since

are probability values we have

for all j, . The proof follows the definition of accuracy and simple counting, as follows. Accuracy captures total number of correct predictions over total number of predictions. Let refer to accuracy of , i.e., , let refer to total number of predictions for upstream and downstream respectively. That is

where (1), (4) are due to the definition of accuracy, (2) is achieved by the construction of the randomized classifier and (3) is due to commutative property of addition. Similarly,

Putting these two together gives us

Note that, this is the definition of convex hull of , .

Appendix C Additional Figures

c.1 Additional Figures for Section LABEL:sec:powerlaw

Figure 9 presents a scaled version of Figure LABEL:fig:convex_hull, given the scaling of downstream accuracies, discussed in Section 4.

Figure 9: Performance of upstream vs downstream (8 different tasks) based on more than 3K different ViT models with different configurations, pre-trained on JFT and evaluated on few-shot (25 shots), where downstream accuracies are scaled using logit.
Figure 10: The performance of downstream (8 different tasks) vs upstream based on more than 1.4k different Vision Transformers, 90 MLP mixers and 233 ResNets, with different configurations. The models are pre-trained on ImageNet21K and evaluated in few-shot settings (25 shots). As the upstream performance improves, the downstream performance saturates. Even if US accuracy reaches 100 accuracy, the DS accuracy may not reach the 100 accuracy and saturates at a lower value. We observe a non-linear relationship between upstream and downstream accuracy and model the relationship with a power law function to predict the downstream performance given the upstream performance. The plot also shows a horizontal line which is the predicted downstream accuracy if upstream accuracy reaches 100.
Figure 11: Effect of the number shots and the DS task on the value of parameters of the power law curves, when the upstream task is JFT. We note that the DS task affects all parameters, while the number of shots mostly impacts and .
Figure 12: Effect of the number shots and the DS task on the value of parameters of the power law curves, when the upstream task is ImageNet 21k. We note that the DS task affects all parameters, while the number of shots mostly impacts and .
c.1.1 Details the experimental setup for fitting Equation LABEL:eqn:power_law

Figures 13 and 14 illustrate the fitted curves to the convex hull and all data points in the US-vs-DS accuracy plots respectively. We use the points from the lower US accuracies (0.0, 0.45) as fitting data and higher US accuracies (0.45-0.50) as held out data to fit equation LABEL:eqn:power_law. For the convex hull fit, we first compute the convex hull of the given data points and find the fit to the convex hull. In Figure 15 and 16, we compare the fitted curves when we fit equation LABEL:eqn:power_law to all data points or the convex hull of all data points for 1 shot and 25 shot.

To measure the sensitivity of the predictive power of the fitted equation to the number of samples, we conduct the experiment with different numbers of data points sampled randomly (uniform distribution across all data points), and for each sample size, we repeat the experiment 10 times (where we take a new sample for each trial). We use the points from the higher US accuracies as held out data. Prediction error captures the difference between power law prediction and the observed value of the DS accuracy. Fitting error captures the difference of power law values from the points that are used in calculating power law parameters. We plot fitting error and prediction error as the number of samples changes. Figures 17, 18, 19, 20, 21 and 22

depict the mean prediction error and mean fitting error for each sample size as well as their standard deviation across the 10 trial.

Figure 13: Power law curves that are fitted to the points on the convex hull corresponding to experiment results from Figure LABEL:fig:pareto_complete. We plot the predictions from the power law curve on the higher US accuracies to the ground truth (prediction target) and observe that the power law curve closely predicts the performance of DS.
Figure 14: Power law curves that are fitted to all point corresponding to experiment results from Figure LABEL:fig:pareto_complete. We plot the predictions from the power law curve on the higher US accuracies to the ground truth (prediction target) and observe that the power law curve closely predicts the performance of DS.
Figure 15: Comparing fitted curves when we use convex hull (Figure 13) vs when we use all samples (Figure 14 when the number of shots is 1.
Figure 16: Comparing fitted curves when we use convex hull (Figure 13) vs when we use all samples (Figure 14 when the number of shots is 25.
Figure 17: Effect of sample size when fitting the power law to the convex hull of the samples on the average fitting error.
Figure 18: Effect of sample size when fitting the power law to the convex hull of the samples on the average prediction error.
Figure 19: Effect of sample size when fitting the power law to the convex hull of the samples on the average fitting error.
Figure 20: Effect of sample size when fitting the power law to the convex hull of the samples on the average prediction error.
Figure 21: Effect of sample size when fitting the power law to all samples on the average fitting error.
Figure 22: Effect of sample size when fitting the power law to all samples on the average prediction error.

max width= DS US Parameter Correlation with Number of Shots caltech ImageNet21K K -0.777892 caltech ImageNet21K -0.582066 caltech ImageNet21K -0.845368 caltech JFT K -0.620526 caltech JFT 0.259305 caltech JFT -0.762856 cars ImageNet21K K 0.720391 cars ImageNet21K 0.960490 cars ImageNet21K -0.737273 cars JFT K -0.976599 cars JFT -0.034033 cars JFT -0.809016 cifar100 ImageNet21K K -0.918914 cifar100 ImageNet21K 0.683485 cifar100 ImageNet21K -0.587304 cifar100 JFT K -0.934455 cifar100 JFT 0.707966 cifar100 JFT -0.754030 col_hist ImageNet21K K -0.756297 col_hist ImageNet21K 0.947101 col_hist ImageNet21K -0.104776 col_hist JFT K -0.534724 col_hist JFT 0.466138 col_hist JFT -0.848960 dtd ImageNet21K K -0.892400 dtd ImageNet21K 0.810935 dtd ImageNet21K -0.532797 dtd JFT K 0.392218 dtd JFT -0.751290 dtd JFT -0.806674 imagenet ImageNet21K K -0.923350 imagenet ImageNet21K 0.464193 imagenet ImageNet21K -0.590325 imagenet JFT K 0.618935 imagenet JFT -0.866692 imagenet JFT -0.847294 pets ImageNet21K K -0.895292 pets ImageNet21K 0.707198 pets ImageNet21K 0.936508 pets JFT K 0.398171 pets JFT 0.937076 pets JFT -0.003738 uc_merced ImageNet21K K -0.986538 uc_merced ImageNet21K 0.942120 uc_merced ImageNet21K -0.724245 uc_merced JFT K -0.821492 uc_merced JFT 0.743757 uc_merced JFT 0.019906

Table 1: Correlation of each parameter with number of shots

c.2 Additional Figures for Section LABEL:sec:controlled_exp

Figure 23 shows the effect of scaling model, data, and compute on all downstream tasks. This is a complete version of Figure 2 in the main paper that includes all 25 different downstream tasks.

Figure 23: Effect of controlled scale up with respect to the model size (number of parameters), data size (the portion of the pre-trained data), and compute (epochs) on 25 different downstream tasks in the few-shot setup (20-shots).
Figure 24: Fitting the scaling law to points plotted Figure 23 and depicting the value of error incurred in predicting DS accuracy.
DS
birds 0.154270
caltech 0.102052
camelyon 0.402138
cars 0.197948
cifar10 0.235078
cifar100 0.242331
clevr_count 0.093481
clevr_distance 0.093481
col_hist 0.155221
dmlab 0.126028
dsprites_location 0.059326
dsprites_orientation 0.059326
dtd 0.088551
eurosat 0.258027
flowers 0.141492
imagenet 0.188222
kitti 0.438465
pets 0.141252
resisc45 0.188155
retinopathy 0.446441
smallnorb_azimuth 0.049473
smallnorb_elevation 0.049473
sun397 0.085017
svhn 0.082023
uc_merced 0.158118
Table 2: Root squared error of predicted DS accuracy when fitting the points in Figure C.2 with Equation LABEL:eqn:power_law (Table 2 provides the results for all downstream datasets).

c.3 Additional Figures for Section 2.2

Figure 25: The overlay of the convex hull of ImageNet DS-vs-US plot on the DS-vs-US plots of all DS tasks from Figure LABEL:fig:convex_hull. The US task is ImageNet21K. We observe that the best-performing ImageNet models perform very similarly to the best-performing models in several DS tasks but not all DS tasks. Moreover, as the US performance increases, the gap between best performing ImageNet models and best performing DS task models reduces significantly.

Figure 26 depicts Spearman correlation between accuracies on different downstream tasks. Figure 27 shows Spearman correlation between accuracies on different downstream tasks and the upstream task.

Figure 26: Spearman correlation between accuracies on different downstream tasks.
Figure 27: Spearman correlation between accuracies on different downstream tasks and the upstream task, based on more than 3K different ViT models with different configurations.

Figure 28 illustrates the quality of representations from different layers on all downstream tasks. This is a complete version of Figure 3 in Section 2.2 that includes all 25 different downstream tasks.

Figure 28: Investigating the effect of choosing representations from different layers on the downstream tasks performance overlay-ed with the effect of scaling (model, data, and compute) on downstream performance when the upstream task is JFT. The red triangles in the plots are the performance on the downstream task when representation used in the few-shot learning is from different layers of the model. The green circles in the plots overlay the US versus DS performance of different experiments from Figure 23 on each task. Here we sketch the plots for 25 downstream tasks.
Figure 29: Investigating the effect of choosing representations from different layers on the downstream tasks’ performance overlay-ed with US-vs-DS performance when the upstream task is ImageNet21K. The red triangles in the plots are the performance on downstream tasks when representation used in the few-shot learning is from different layers of the model. The green circles in the plots overlay the DS versus US performance of different experiments from Figure 10 on each task. Red triangles use the x-axis on the bottom and the green circles use the x-axis on the top. We note that for DS tasks similar to the US, such as ImageNet, the higher the representation layer the better performance on DS. On the contrary, for DS tasks that saturate fast, such as UC-Merced and col_hist, the optimal layer is not the last one and the model can be cut at lower layers leading to better performance.

c.4 Additional Figures for Section LABEL:sec:headWDLR

Figure 30 illustrates the effect of increasing head weight decay on all downstream tasks. It is the complete version of Figure 4 in the main paper that includes all downstream tasks.

Figure 30: The effect of increasing head weight decay in the performance of DS-vs-US (all shots, all datasets).

Figure 31 show the best head weight decay for all downstream tasks. This figure is a complete version of Figure 6.

Figure 31: Optimum head weight decay for different downstream tasks and for different number of shots.

Figure 32 illustrates the effect of changing head weight decay on all downstream tasks, when we train longer (for 14 epochs instead of 7 that is reported in Figure 4). The changes are consistent across different epochs as well as the different number of shots.

Figure 32: The effect of changing head weight decay when trained for 7 or 14 epochs for different number of shots.
Figure 33: The effect of increasing head learning rate in performance of upstream (JFT) versus performance of downstream (ImageNet1k and Caltech101).
Figure 34: Optimum head learning rate for different downstream tasks and for different number of shots.
Figure 35: L2 Norm of different layers of the ViT model for different values of head weight decay.

Figure 36 illustrates the effect of increasing head weight decay on pre-logit layer margin for all downstream tasks.

Figure 36: The effect of increasing head weight decay in the pre-logit layer margin for downstream (all shots, all datasets). In this plot, L2 term for downstream few-shot classifiers is set to 4096.
Figure 37: L2 Norm of different layers of the ViT model for different values of head learning rate.

Appendix D Experiment setup

d.1 Training details

For the controlled experiments, we train all models using Adam [kingma2014adam] with , . In all experiments, the batch size is set to . The default weight decay used in the experiments is , unless the changed value is mentioned in the description of the experiment. For the learning rate, we set the value to (unless for large models that we use ) and use a linear decay, with a warmup of steps.

d.2 datasets

Table 3 summarizes the datasets used in our experiments.

Dataset Description Reference
ImageNet 1.28M labelled natural images. [deng2009-imagenet]
Caltech101 The task consists in classifying pictures of objects (101 classes plus a background clutter class), including animals, airplanes, chairs, or scissors. The image size varies, but it typically ranges from 200-300 pixels per edge. http://www.vision.caltech.edu/Image_Datasets/Caltech101/
CIFAR-10 The task consists in classifying natural images (10 classes, with 6000 training images each). Some examples include apples, bottles, dinosaurs, and bicycles. The image size is 32x32. https://www.cs.toronto.edu/~kriz/cifar.html
CIFAR-100 The task consists in classifying natural images (100 classes, with 500 training images each). Some examples include apples, bottles, dinosaurs, and bicycles. The image size is 32x32. https://www.cs.toronto.edu/~kriz/cifar.html
DTD The task consists in classifying images of textural patterns (47 classes, with 120 training images each). Some of the textures are banded, bubbly, meshed, lined, or porous. The image size ranges between 300x300 and 640x640 pixels. [cimpoi2014describing]
Pets The task consists in classifying pictures of cat and dog breeds (37 classes with around 200 images each), including Persian cat, Chihuahua dog, English Setter dog, or Bengal cat. Images dimensions are typically 200 pixels or larger. https://www.robots.ox.ac.uk/~vgg/data/pets/
Sun397 The Sun397 task is a scenery benchmark with 397 classes and, at least, 100 images per class. Classes have a hierarchy structure and include cathedral, staircase, shelter, river, or archipelago. The images are (colour) 200x200 pixels or larger. https://vision.princeton.edu/projects/2010/SUN/
Flowers102 The task consists in classifying images of flowers present in the UK (102 classes, with between 40 and 248 training images per class). Azalea, Californian Poppy, Sunflower, or Petunia are some examples. Each image dimension has at least 500 pixels. https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
SVHN This task consists in classifying images of Google’s street-view house numbers (10 classes, with more than 1000 training images each). The image size is 32x32 pixels. http://ufldl.stanford.edu/housenumbers/
CLEVR/count CLEVR is a visual question and answer dataset designed to evaluate algorithmic visual reasoning. We use just the images from this dataset, and create a synthetic task by setting the label equal to the number of objects in the images. [johnson2017clevr]
CLEVR/distance Another synthetic task we create from CLEVR consists of predicting the depth of the closest object in the image from the camera. The depths are bucketed into size bins. [johnson2017clevr]
Retinopathy The Diabetic Retinopathy dataset consists of image-label pairs with high-resolution retina images, and labels that indicate the presence of Diabetic Retinopathy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, or Proliferative DR). https://www.kaggle.com/c/diabetic-retinopathy-detection/data
birds image dataset with photos of 200 bird species (mostly North American). http://www.vision.caltech.edu/visipedia/CUB-200.html
Table 3: Summary of datasets used in our experiments, part I
Dataset Description Reference
Patch Camelyon The Patch Camelyon dataset contains 327,680 images of histopathologic scans of lymph node sections. The classification task consists in predicting the presence of metastatic tissue in a given image (i.e., two classes). All images are 96x96 pixels. [teh2019metric]
Resisc45

The Remote Sensing Image Scene Classification (RESISC) dataset is a scene classification task from remote sensing images. There are 45 classes, containing 700 images each, including tennis court, ship, island, lake, parking lot, sparse residential, or stadium. The image size is RGB 256x256 pixels.

[cheng2017remote]
EuroSAT The task consists in classifying Sentinel-2 satellite images into 10 different types of land use (Residential, Industrial, River, Highway, etc). The spatial resolution corresponds to 10 meters per pixel, and the image size is 64x64 pixels. [helber2019eurosat]
dSprites/location

The dSprites dataset was originally designed to assess disentanglement properties of unsupervised learning algorithms. In particular, each image is a 2D shape where six factors are controlled: color, shape, scale, rotation, and (x,y) center coordinates. Images have 64x64 black-and-white pixels. This task consists in predicting the x (horizontal) coordinate of the object. The locations are bucketed into 16 bins

https://github.com/deepmind/dsprites-dataset/
dSprites/orientation We create another task from dSprites consisting in predicting the orientation of each object, bucketed into 16 bins. https://github.com/deepmind/dsprites-dataset/https://github.com/deepmind/dsprites-dataset/
SmallNORB/azimuth The Small NORB dataset contains images of 3D-toys from 50 classes, including animals, human figures, airplanes, trucks, and cars. The image size is 640x480 pixels. In this case, we define labels depending on the azimuth (angle of horizontal deviation), in intervals of 20 degrees (18 classes). [lecun2004learning]
SmallNORB/elevation Another synthetic task we create from Small NORB consists in predicting the elevation in the image. There are 9 classes, corresponding to 9 different elevations ranging from 30 to 70 degrees, in intervals of 5 degrees [lecun2004learning]
DMLab The DMLab (DeepMind Lab) is a set of control environments focused on 3D navigation and puzzle-solving tasks. The Dmlab dataset contains frames observed by the agent acting in the DeepMind Lab environment, which are annotated by the distance between the agent and various objects present in the environment. The goal is to evaluate the ability of a visual model to reason about distances from the visual input in 3D environments. The Dmlab dataset consists of 360x480 color images in 6 classes. The classes are close, far, very far × positive reward, negative reward respectively. [beattie2016deepmind]
KITTI The KITTI task consists in predicting the (binned) depth to the vehicle (car, van, or truck) in the image. There are 4 bins / classes. [geiger2013vision]
ColHist Classification of textures in colorectal cancer histology. Each example is a 150 x 150 x 3 RGB image of one of 8 classes. https://www.tensorflow.org/datasets/catalog/colorectal_histology
UC Merced 21 class land use image dataset https://usdahsi.ucmerced.edudatasets/landuse.html
cars The Cars dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images, where each class has been split roughly in a 50-50 split. Classes are typically at the level of Make, Model, Year, e.g. 2012 Tesla Model S or 2012 BMW M3 coupe. http://ai.stanford.edu/~jkrause/cars/car_dataset.html
Table 4: Summary of datasets used in our experiments, part II

Appendix E Transfer to VTAB

In this Section, we provide additional experiments for the transfer learning scenario and use VTAB as the downstream task. Figure 38 shows the effect of controlled experiments, scaling up the model size, data size and compute for transfer learning setting on VTAB dataset. Note that these experiments are based on the standard VTAB setup [zhai2019large] that uses only 1000 examples for each dataset to reflect the performance of transfer learning under a reasonable labelling budget in downstream tasks. We use the same objective function for both upstream and downstream (Sigmoid cross-entropy) and update all of the pre-trained parameters during fine-tuning. Table 5 presents results of models that are pre-trained with differed head weight decays in the transfer setup on the VTAB test set. In this setup, we use SGD momentum with batch size for fine-tuning all the parameters of the model using the training set of the downstream task.

Figure 38: Effect of controlled scale up with respect to the model size (number of parameters), data size (the portion of the pre-trained data), and compute (epochs) on tasks in VTAB-1K benchmark (1000 training example per task) in the transfer setup.
Dataset HWD=0.0 HWD=5.0
caltech101 0.89 0.91
cifar100 0.51 0.79
clevr-count 0.72 0.42
clevr-distance 0.65 0.49
diabetic-retinopathy-detection 0.74 0.72
dmlab 0.42 0.36
dsprites-location 0.68 0.56
dsprites-orientation 0.58 0.58
dtd 0.66 0.72
eurosat 0.94 0.95
kitti 0.76 0.70
oxford-flowers102 0.98 0.99
oxford-iiit-pet 0.93 0.94
patch-camelyon 0.78 0.77
resisc45 0.82 0.83
smallnorb-azimuth 0.27 0.22
smallnorb-elevation 0.47 0.36
sun397 0.42 0.65
svhn-cropped 0.72 0.60
VTAB-Natural 0.69 0.78
VTAB-Specialized 0.82 0.82
VTAB-Structured 0.57 0.46
VTAB-ALL 0.69 0.68
Table 5: Results of a ViT-B/32 on fine-tuning (transfer) setup on VITAB-1K benchmark, when pre-trained with different head weight decays. Note that the selected head WD for these experiments are set to and , which are rather extreme values, to highlight the effect on different datasets.