FiT: Parameter Efficient Few-shot Transfer Learning for Personalized and Federated Image Classification

Modern deep learning systems are increasingly deployed in situations such as personalization and federated learning where it is necessary to support i) learning on small amounts of data, and ii) communication efficient distributed training protocols. In this work we develop FiLM Transfer (FiT) which fulfills these requirements in the image classification setting. FiT uses an automatically configured Naive Bayes classifier on top of a fixed backbone that has been pretrained on large image datasets. Parameter efficient FiLM layers are used to modulate the backbone, shaping the representation for the downstream task. The network is trained via an episodic fine-tuning protocol. The approach is parameter efficient which is key for enabling few-shot learning, inexpensive model updates for personalization, and communication efficient federated learning. We experiment with FiT on a wide range of downstream datasets and show that it achieves better classification accuracy than the state-of-the-art Big Transfer (BiT) algorithm at low-shot and on the challenging VTAB-1k benchmark, with fewer than 1 Finally, we demonstrate the parameter efficiency of FiT in distributed low-shot applications including model personalization and federated learning where model update size is an important performance metric.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

02/28/2019

One-Shot Federated Learning

We present one-shot federated learning, where a central server learns a ...
06/20/2022

Contextual Squeeze-and-Excitation for Efficient Few-Shot Image Classification

Recent years have seen a growth in user-centric applications that requir...
06/14/2020

The OARF Benchmark Suite: Characterization and Implications for Federated Learning Systems

This paper presents and characterizes an Open Application Repository for...
12/15/2020

CosSGD: Nonlinear Quantization for Communication-efficient Federated Learning

Federated learning facilitates learning across clients without transferr...
09/17/2020

Distilled One-Shot Federated Learning

Current federated learning algorithms take tens of communication rounds ...
05/11/2020

Pretraining Federated Text Models for Next Word Prediction

Federated learning is a decentralized approach for training models on di...
04/06/2022

Dimensionality Expansion and Transfer Learning for Next Generation Energy Management Systems

Electrical management systems (EMS) are playing a central role in enabli...

Code Repositories

1 Introduction

With the success of the commercial application of deep learning in many fields such as computer vision

(Schroff et al., 2015)

, natural language processing

(Brown et al., 2020), speech recognition (Xiong et al., 2018), and language translation (Wu et al., 2016), an increasing number of models are being trained on central servers and then deployed on remote devices, often to personalize a model to a specific user’s needs. Personalization requires models that can be updated inexpensively by minimizing the number of parameters that need to be stored and/or transmitted and frequently calls for few-shot learning methods as the amount of training data from an individual user may be small (Massiceti et al., 2021). At the same time, for privacy, security, and performance reasons, it can be advantageous to use federated learning where a model is trained on an array of remote devices, each with different data, and share gradient or parameter updates instead of training data with a central server (McMahan et al., 2017). In the federated learning setting, in order to minimize communication cost with the server, it is beneficial to have models with a small number of parameters that need to be updated for each training round conducted by remote clients. The amount of training data available to the clients is often small, again necessitating few-shot learning approaches.

Few-shot learning approaches fall broadly into two camps – meta-learning (Hospedales et al., 2020) and transfer learning (fine-tuning) (Yosinski et al., 2014). It is useful to characterise these approaches in terms of shared and updateable parameters. Shared parameters do not change as the model is retrained or updated, while updateable parameters are those that are either recomputed or learned as the model is updated or retrained. From a statistical perspective, shared parameters capture similarities between datasets, while updateable parameters efficiently capture the differences. In general, meta-learning approaches (Hospedales et al., 2020) to few-shot learning are trained in a multi-task manner and as a result share a large number of parameters and only update a small subset when adapting to a new task (Requeima et al., 2019). While meta-learners can achieve good accuracy on datasets that are similar to what they are meta-trained on, their accuracy suffers when tested on datasets that are significantly different (Dumoulin et al., 2021). Transfer learning algorithms can often outperform meta-learners, especially on diverse datasets and even at low-shot (Dumoulin et al., 2021; Tian et al., 2020). While some transfer learning algorithms are able to minimize the number of updateable parameters by only fine-tuning the final or a small subset of layers, the state of the art Big Transfer (BiT) (Dumoulin et al., 2021; Kolesnikov et al., 2019) algorithm requires every parameter in a large network to be updated.

In this work, we focus on designing deep learning network architectures and associated training protocols that allow image classification models to be updated with only a small subset of the total model parameters, without sacrificing prediction performance when there is only a small number of training examples available. This leads to reduced storage and transmission costs for updating personalized models on remote devices (Massiceti et al., 2021), distributed training in federated learning (McMahan et al., 2017), and efficient ensemble realization (Havasi et al., 2020), among other applications. To realize our goal of small model updates, we pursue a transfer learning approach that takes advantage of image classification backbones that have been pretrained on large upstream datasets (Kolesnikov et al., 2019). We freeze the backbone parameters such that they are shared when fine-tuning on a downstream dataset. Parameter efficient FiLM (Perez et al., 2018) adapter layers are used to modulate the backbone, shaping the representation for the downstream task. For a ResNet50 (He et al., 2016a), the FiLM layer parameter count is less than 0.05 of the overall model size, yet allows expressive adaptation. The last, novel piece of the proposed system is the use of an automatically configured Naive Bayes final layer classifier which outperforms the usual linear layer head. The system is trained end-to-end with an episodic fine-tuning protocol. We call this approach FiLM Transfer or FiT. We experiment with FiT on a wide range of downstream datasets and show that it achieves better classification accuracy at low-shot with two orders of magnitude fewer updateable parameters when compared to BiT (Kolesnikov et al., 2019) and competitive accuracy when more data is available. We also demonstrate the benefits of FiT on a low-shot real-world model personalization application and in a demanding few-shot federated learning scenario. Our contributions:

  • A parameter and data efficient network architecture for low-shot transfer learning consisting of an automatically configured Naive Bayes final layer classifier and parameter efficient FiLM layers that are used to adapt a fixed, pretrained backbone to a downstream dataset;

  • A meta-learning inspired episodic training protocol for low-shot fine-tuning requiring no data augmentation, no regularization, and a minimal set of hyper-parameters;

  • Superior classification accuracy at low-shot on standard downstream datasets and on the challenging VTAB-1k benchmark while using using of the updateable parameters when compared to the leading transfer learning method BiT;

  • Demonstration of superior parameter efficiency without sacrificing classification accuracy in distributed low-shot applications including model personalization and federated learning where model update size is important performance metric.

2 FiLM Transfer (FiT)

In this section we detail the FiT algorithm focusing on the few-shot image classification scenario.

Preliminaries

We denote input images where is the width, the height, the number of channels and image labels where is the number of image classes indexed by . Assume that we have access to a model that outputs for and is comprised of a feature extractor backbone with parameters

that has been pretrained on a large upstream dataset such as Imagenet

(Russakovsky et al., 2015) where is the output feature dimension and a final layer classifier or head with weights . Let be the downstream dataset that we wish to fine-tune the model to.

FiT Backbone

For the backbone, we freeze the parameters to the values learned during upstream pretraining and add Feature-wise Linear Modulation (FiLM) (Perez et al., 2018) layers with parameters at strategic points within . A FiLM layer scales and shifts the activations arising from the channel of a convolutional layer in the block of the backbone as , where and are scalars. The set of FiLM parameters are learned during fine-tuning. We add FiLM layers following the middle 3 3 convolutional layer in each ResNetV2 (He et al., 2016b) block and also one at the end of the backbone prior to the head. Fig. A.1 illustrates a FiLM layer operating on a convolutional layer, and Fig. A.1 illustrates how a FiLM layer can be added to a ResNetV2 network block. FiLM layers can be similarly added to EfficientNet based backbones. A key advantage of FiLM layers is that they enable expressive feature adaptation while adding only a small number of parameters (Perez et al., 2018). For example, in a ResNet50 with a FiLM layer in every block, the set of FiLM parameters account for only 11648 parameters which is fewer than 0.05% of the parameters in . We show in Section 4 that FiLM layers allow the model to adapt to a broad class of datasets.

FiT Head

For the head, we use a Gaussian Naive Bayes classifier which can be automatically configured directly from data. If the data is normally distributed, Naive Bayes is generally more effective than logistic regression

(Hastie et al., 2009; Efron, 1975), and especially so in the small data setting (Pohar et al., 2004). Also, when compared to the usual linear layer head, Naive Bayes has fewer free parameters to learn and in Section 4

we show that it yields superior results. The probability of classifying a test point

is:

(1)

where , ,

are the maximum likelihood estimates,

is the number of examples of class in , and is a multivariate Gaussian over with mean and covariance .

Estimating the mean for each class is straightforward and incurs a total storage cost of . However, estimating the covariance for each class is challenging when the number of examples per class is small and the embedding dimension of the backbone is large. In addition, the storage cost for the covariance matrices may be prohibitively high if is large. In this work, we use three different approximations to the covariance in place of in Eq. 1:

  • Quadratic Discriminant Analysis (QDA) (Fisher, 1936; Duda et al., 2012):

  • Linear Discriminant Analysis (LDA)(Fisher, 1936; Duda et al., 2012):

  • ProtoNets (Snell et al., 2017): ; Or equivalently, there is no covariance and the class representation is parameterized only by

    and the classifier logits are formed by computing the squared Euclidean distance between the feature representation of a test point

    and each of the class means.

In the above, is the computed covariance of the examples in class in , is the computed covariance of all the examples in assuming they arise from a single Gaussian with a single mean,

are weights learned during training, and the identity matrix

is used as a regularizer. The primary difference between QDA and LDA is that QDA computes and stores a covariance matrix for each class in the dataset, while LDA shares a single covariance matrix across all classes. The number of model shared and updateable parameters for the three FiT variants as well as the BiT algorithm are detailed in Table A.1. Using the BiT-M-R50x1 (Kolesnikov et al., 2019) backbone on a 10-way classification task, BiT shares no parameters and all 23.52M parameters are updateable. FiT shares the 25.50M backbone parameters and FiT-QDA, FiT-LDA, and FiT-ProtoNets have 21.01M, 32,140, and 32,128 updateable parameters, respectively. The parameters are known and fixed from pretraining, but we must learn the FiLM parameters and the covariance weights as outlined in the next section.

FiT Training

We learn and via fine-tuning, but employ an approach often used in meta-learning referred to as episodic training (Vinyals et al., 2016) to form the training batches. Note we require ‘training’ data to compute , , to configure the head, and a ‘test’ set to optimize and via gradient ascent. Thus, from the downstream dataset , we derive two sets – and . If is sufficiently large ( 1000), we set . Otherwise, in the few-shot scenario, we randomly split into and such that the number of examples or shots in each class are roughy equal in both partitions and that there is at least one example of each class in both. Refer to Algorithm A.1 for details. For each training iteration, we sample a task consisting of a support set drawn from with examples and a query set drawn from with examples. First, is formed by randomly choosing a subset of classes selected from the range of available classes in . Second, the number of shots to use for each selected class is randomly selected from the range of available examples in each class of with the goal of keeping the examples per class as equal as possible. Third, is formed by using the classes selected for and all available examples from in those classes up to a limit of 2000 examples. Refer to Algorithm A.2 for details. Episodic task sampling is crucial to achieving the best classification accuracy. If all of and are used for every iteration, overfitting occurs, limiting accuracy (see Table A.3 and Table A.4). The support set is then used to compute , , and and we then use to train and with maximum likelihood. We optimize the following objective:

(2)

FiT

training hyperparameters include a learning rate,

, and the number of training iterations. For the transfer learning experiments in Section 4 these are set to constant values across all datasets and do not need to be tuned based on a validation set. We do not augment the training data, even in the 1-shot case. When there is only one shot per class, we leave the FiLM layer parameters at their initial value of and and predict as described in the next paragraph. In Section 4 we show this can yield results that surpass those when augmentation and training steps are taken on the 1-shot data.

FiT Prediction

Once the FiLM parameters and covariance weights have been learned, we use for the support set to compute , , and for each class and then Eq. 1 can be used to make a prediction for any unseen test input.

3 Related Work

We take inspiration from residual adapters (Rebuffi et al., 2017, 2018) where parameter efficient adapters are inserted into a ResNet with frozen pretrained weights. The adapter parameters and the final layer linear classifier are then learned via fine-tuning. FiT differs from residual adapters by focusing on the few-shot scenario, and using FiLM layers that have fewer parameters than a residual adapter, a Naive Bayes head, and an episodic training protocol. CNAPs (Requeima et al., 2019) also uses a frozen, pretrained backbone with FiLM layers added. The FiLM parameters and a linear head are generated by meta-trained hypernetworks. Simple CNAPs (Bateni et al., 2020) improves on CNAPs by using a Mahalanobis distance based classifier that uses a blend of class specific and task specific covariance matrices. FiT-LDA greatly improves over Simple CNAPs (Bronskill et al., 2021) in terms of classification accuracy and updateable parameters on the VTAB-1k benchmark as a result of using fine-tuning instead of amortization networks and using a single dataset specific covariance matrix as opposed to class specific one.

FLUTE (Triantafillou et al., 2021) during multi-task training jointly learns backbone parameters along with a set of dataset specific FiLM layer parameters. When asked to classify a novel input at test-time, the backbone parameters are frozen and a Blender network generates initial values for the FiLM parameters using a combination of those that it has learned in the training phase and the test task. The initial FiLM parameters are then improved via fine-tuning through a nearest centroid final layer classifier. FiT differs from FLUTE in that (i) there is no initial multi-task learning phase; (ii) more fine-tuning iterations; and (iii) the use of a Naive Bayes head. The work of Mudrakarta et al. (2019)

has the same aim as this work. Instead of fine-tuning FiLM layers that are added to a pretrained network, batch normalization weights are tuned. While this work is similar in spirit, unlike

FiT, it does not focus on the few-shot regime, employs a linear head, and uses a more conventional fine-tuning protocol.

4 Experiments

In this section, we evaluate the classification accuracy and updateable parameter efficiency of FiT in a series of challenging benchmarks and application scenarios. First, we compare different variations of FiT to Big Transfer (BiT) (Kolesnikov et al., 2019), a state-of-the-art transfer learning algorithm, on several standard downstream datasets in the few-shot regime. Second, we evaluate FiT against BiT on the challenging VTAB-1k benchmark (Zhai et al., 2019), where BiT has been shown to outperform all meta-learners (Dumoulin et al., 2021; Bronskill et al., 2021). Third, we show how FiT can be used in a personalization scenario on the ORBIT (Massiceti et al., 2021)

dataset, where a smaller updateable model is an important evaluation metric. Finally, we apply

FiT to a few-shot federated learning scenario where minimizing the number of parameter updates and their size is a key requirement. In addition, we introduce a metric Relative Model Update Size or RMUS, which is the ratio of the number of updateable parameters in one model to another. Training and evaluation details are in Section A.5. Source code for experiments can be found at: https://github.com/cambridge-mlg/fit.

4.1 Few-shot Results

Figure 1: FiT-LDA outperforms BiT at low-shot. Classification accuracy as a function of Relative Model Update Size (RMUS – lower is better) and shots per class for FiT-LDA and BiT on four downstream datasets. Classification accuracy is on the vertical axis and is the average of 3 runs with different data sampling seeds. RMUS is relative to the number of updateable parameters for BiT and uses a base 10 log scale. The dot size from smallest to largest indicates the number of shots per class - 1, 2, 5, 10, and All. A tabular version that includes results for additional variants is in Table A.2.

Fig. 1 shows the classification accuracy as a function of RMUS for FiT-LDA and BiT on four downstream datasets (CIFAR10, CIFAR100 (Krizhevsky et al., 2009), Flowers (Nilsback and Zisserman, 2008), and Pets (Parkhi et al., 2012)) that were used to evaluate the performance of BiT (Kolesnikov et al., 2019). Table A.2 contains complete tabular results with additional variants of FiT and BiT. All methods use the BiT-M-R50x1 (Kolesnikov et al., 2019)

backbone that has been pretrained on the ImageNet-21K

(Russakovsky et al., 2015) dataset. The key observations from Fig. 1 are:

  • For 10 shots (except 1-shot on CIFAR100), FiT-LDA outperforms BiT, often by a large margin.

  • On 3 out of 4 datasets, FiT-LDA outperforms BiT even when all of is used for fine-tuning.

  • FiT-LDA outperforms BiT despite BiT having more than 100 times as many updateable parameters.

  • To avoid overfitting when is small, Table A.3 indicates that it is better to split into two disjoint partitions and and that and should be randomly sub-sampled as opposed to using all of the data in each training iteration.

We also evaluate BiT-FiLM, a variant of BiT that uses the same training protocol as the standard version of BiT, but the backbone weights are frozen and FiLM layers are added in the same manner as FiT. The FiLM parameters and the linear head weights are learned during training. The results are shown in Table A.2 and the key observations are:

  • In general, at low-shot, the standard version of BiT outperforms BiT-FiLM. However, as the shot increases, especially when training on all of , BiT-FiLM is equal in classification accuracy.

  • The above implies that FiLM layers have sufficient capacity to accurately fine-tune to downstream datasets, but the FiT head and training protocol are needed to achieve superior results.

  • While the accuracy of FiT-QDA and FiT-LDA is similar, the storage requirements for a covariance matrix for each class makes QDA impractical if model update size is an important consideration.

  • The accuracy of FiT-ProtoNets is slightly lower than FiT-LDA, but often betters BiT, despite BiT having more than 100 times as many updateable parameters.

The datasets used in this section were similar in content to the dataset used for pretraining and the performance of FiT-QDA and FiT-LDA was similar, indicating that the covariance per class was not that useful for these datasets. In the next section, we test on a wider variety of datasets, many of which differ greatly from the upstream data.

4.2 VTAB-1k Results

The VTAB-1k benchmark (Zhai et al., 2019) is a low to medium-shot transfer learning benchmark that consists of 19 datasets grouped into three distinct categories (natural, specialized, and structured). From each dataset, 1000 examples are drawn at random from the training split to use for the downstream dataset . After fine-tuning, the entire test split is used to evaluate classification performance. Table 1 shows the classification accuracy of the three variants of FiT and BiT using the BiT-M-R50x1 backbone for all variants. The key observations from our results are:

  • Both FiT-QDA and FiT-LDA outperform BiT on VTAB-1k.

  • The FiT-QDA variant has the best overall performance, showing that the class covariance is important to achieve superior results on datasets that differ from those used in upstream pretraining (e.g. the structured category of datasets). However, the updateable parameter cost is high.

  • FiT-LDA utilizes two orders of magnitude fewer updateable parameters compared to BiT, making it the preferred approach.

  • Table A.4 indicates that it is best to use all of for each of and (i.e. no split) and that and should be randomly sub-sampled as opposed to using all of the data in each iteration.

max width= BiT FiLM Transfer (FiT) QDA LDA ProtoNets Dataset Accuracy↑ RMUS↓ Accuracy↑ RMUS↓ Accuracy↑ RMUS↓ Accuracy↑ RMUS↓ Caltech101 (Fei-Fei et al., 2006) 102 88.0±0.2 1.0 90.3±0.8 9.04 90.4±0.8 0.0093 89.6±0.2 0.0093 CIFAR100 (Krizhevsky et al., 2009) 100 70.1±0.1 1.0 74.1±0.1 8.86 74.2±0.5 0.0091 73.9±0.3 0.0091 Flowers102 (Nilsback and Zisserman, 2008) 102 98.6±0.0 1.0 99.1±0.1 9.04 99.0±0.1 0.0093 98.6±0.0 0.0093 Pets (Parkhi et al., 2012) 37 88.4±0.2 1.0 91.0±0.3 3.30 90.5±0.0 0.0037 90.8±0.2 0.0037 Sun397 (Xiao et al., 2010) 397 48.0±0.1 1.0 51.1±0.7 34.29 51.6±0.5 0.0339 51.5±1.4 0.0339 SVHN (Netzer et al., 2011) 10 73.0±0.2 1.0 75.1±1.3 0.89 74.2±0.9 0.0014 50.1±2.2 0.0014 DTD (Cimpoi et al., 2014) 47 72.7±0.3 1.0 70.9±0.1 4.18 70.9±0.1 0.0046 68.2±1.1 0.0046 EuroSAT (Helber et al., 2019) 10 95.3±0.1 1.0 95.6±0.1 0.89 95.1±0.1 0.0014 93.8±0.1 0.0014 Resics45 (Cheng et al., 2017a) 45 85.9±0.0 1.0 82.6±0.1 4.01 82.5±0.2 0.0044 77.0±0.0 0.0044 Patch Camelyon (Veeling et al., 2018) 2 69.3±0.8 1.0 80.7±1.2 0.18 82.5±0.7 0.0007 79.9±0.2 0.0007 Retinopathy (Kaggle and EyePacs, 2015) 5 77.2±0.6 1.0 70.4±0.1 0.45 66.2±0.5 0.0009 57.9±0.3 0.0009 CLEVR-count (Johnson et al., 2017) 8 54.6±7.1 1.0 87.1±0.3 0.71 85.6±0.9 0.0012 88.7±0.3 0.0012 CLEVR-dist (Johnson et al., 2017) 6 47.9±0.8 1.0 58.1±0.8 0.54 56.1±0.8 0.0010 58.3±0.6 0.0010 dSprites-loc (Matthey et al., 2017) 16 91.6±1.1 1.0 77.1±2.0 1.43 74.8±1.4 0.0019 68.6±2.4 0.0019 dSprites-ori (Matthey et al., 2017) 16 65.9±0.3 1.0 56.7±0.3 1.43 51.3±0.7 0.0019 34.2±0.8 0.0019 SmallNORB-azi (LeCun et al., 2004) 18 18.7±0.2 1.0 18.9±0.6 1.61 16.2±0.1 0.0021 13.5±0.1 0.0021 SmallNORB-elev (LeCun et al., 2004) 9 25.8±0.9 1.0 40.4±0.2 0.80 37.0±0.6 0.0013 35.0±0.6 0.0013 DMLab (Beattie et al., 2016) 6 47.1±0.1 1.0 43.8±0.3 0.54 41.6±0.6 0.0010 39.3±0.3 0.0010 KITTI-dist (Geiger et al., 2013) 4 80.1±0.9 1.0 77.5±0.7 0.36 77.7±0.8 0.0008 75.3±0.2 0.0008 All 68.3 70.6 69.3 65.5 Natural 77.0 78.8 78.7 74.7 Specialized 81.9 82.3 81.5 77.1 Structured 54.0 57.5 55.0 51.6

Table 1: FiT outperforms BiT on VTAB-1k. Classification accuracy and Relative Model Update Size (RMUS) for all three variants of FiT and BiT on the VTAB-1k benchmark. The backbone is BiT-M-R50x1. Accuracy figures are percentages and the

sign indicates the 95% confidence interval over 3 runs. Bold type indicates the highest scores (within the confidence interval).

4.3 Personalization

In our experiments, we use ORBIT (Massiceti et al., 2021), a real-world few-shot video dataset recorded by people who are blind/low-vision. A blind or vision-impaired user collects a series of short videos on their smartphone of objects that they would like to recognize. The collected videos and associated labels are then uploaded to a central service to train a personalized classification model for that user. Once trained, the personalized model is downloaded to the user’s smartphone. The initial model download would include all the model parameters, both shared and updateable. However, models with a smaller number of updateable parameters are preferred in order to save model storage space on the central server and in transmitting any updated models to a user. The goal is to take a backbone pretrained on ImageNet (Deng et al., 2009) or other large-scale dataset and construct a personalized model for a user using only their individual data. We follow the object recognition benchmark task proposed by the authors, which tests a personalized model on two different video types: clean where only a single object is present and clutter where that object appears within a realistic, multi-object scene.

In Table 2, we compare FiT-LDA to several competitive transfer learning and meta-learning methods. We use the LDA variant of FiT, as it achieves higher accuracy in comparison to the ProtoNets variant, while using far fewer updateable parameters than QDA. For transfer learning, we include FineTuner (Yosinski et al., 2014), which freezes the weights in the backbone and fine-tunes only the linear classifier head on an individual’s data. For meta-learning approaches, we include ProtoNets (Snell et al., 2017) and Simple CNAPs (Bateni et al., 2020), which are meta-trained on Meta-Dataset (Dumoulin et al., 2021). Training and evaluation details are in Section A.5.2. For this comparison, we show frame and video accuracy, averaged over all the videos from all tasks across all test users ( test users, tasks in total). We also report the number of shared and individual updateable parameters required to be stored or transmitted. The key observations from our results are:

  • FiT-LDA outperforms competitive meta-learning methods, Simple CNAPs and ProtoNets.

  • FiT-LDA also outperforms FineTuner in terms of the video accuracy and performs within error bars of it in terms of the frame accuracy.

  • The number of individual parameters for FiT-LDA is far fewer than in Simple CNAPs and is of the same order of magnitude as FineTuner and ProtoNets.

  • Compared to a linear head, Naive Bayes reduces the size of the optimization space as there are fewer parameters to learn (only and ), making FiT-LDA more suitable for the few-shot setting.

max width= Clean Videos Clutter Videos Parameters model frame acc video acc frame acc video acc shared per user average FineTuner (Yosinski et al., 2014) 78.1 (2.0) 85.9 (2.3) 63.1 (1.8) 66.9 (2.4) 4.01M 0.01M Simple CNAPs (Bateni et al., 2020) 73.1 (2.1) 80.7 (2.6) 61.6 (1.8) 67.1 (2.4) 5.67M 7.63M ProtoNets (Snell et al., 2017) 71.6 (2.2) 78.9 (2.7) 63.0 (1.8) 67.3 (2.4) 4.01M 0.01M FiT-LDA 81.8 (1.8) 90.6 (1.9) 65.7 (1.8) 70.7 (2.3) 4.01M 0.03M

Table 2: FiT outperforms competitive methods on ORBIT. Average accuracy (95% confidence interval) over test tasks. Shared is the number of parameters shared among all users. Per User indicates the number of parameters stored for each user with classes. Average is the mean number of individual user parameters over the ORBIT dataset.

4.4 Few-shot Federated Learning

In this section, we show how FiT can be used in the few-shot federated learning setting, where training data are split between client nodes, e.g. mobile phones or personal laptops, and each client has only a handful of samples. Model training is performed via numerous communication rounds between a server and clients. In each round, the server selects a fraction of clients making updates and then sends the current model parameters to these clients. For data privacy reasons, clients update models locally using only their personal data and then send their parameter updates back to the server. Finally, the server aggregates information from all the clients, updates the shared model parameters, and proceeds to the next round until convergence. In this setting, models with a smaller number of updatable parameters are preferred in order to reduce the client-server communication cost which is typically bandwidth-limited.

Experiments

For our experiments, we use CIFAR100 (Krizhevsky et al., 2009), a relatively large-scale dataset compared to those commonly used to benchmark federated learning methods (Reddi et al., 2021; Shamsian et al., 2021). We employ the basic FedAvg (McMahan et al., 2016) algorithm. We train all models for communication rounds, with clients per round and update steps per client. Each client has classes, which are sampled randomly before the start of training. During an update, a client initializes their local model with recent FiLM parameters received from the server and then performs several training steps as described in Section 2, using only their data. The Naive Bayes head allows a client to construct a local classifier at each update step, eliminating the need to initiate a shared classifier and transmit the parameters of this classifier at each training round. In contrast, using a linear head in this setting would entail additional communication costs, as it would be passed at each client-server interaction. We use a ProtoNets head for simplicity, although QDA and LDA heads could also be used. After training, the global FiLM parameters are transmitted to all clients. More specific training and evaluation details are in Section A.5.3.

We evaluate FiT in two scenarios, global and personalized. In the global setting, the aim is to construct a global classifier and report accuracy on the CIFAR100 (Krizhevsky et al., 2009) test set. We assume that the server knows which classes belong to each client, and constructs a shared classifier by taking a mean over prototypes produced by clients for a particular class. In the personalized scenario, we test a personalized model on test classes present in the individual’s training set and then report the mean accuracy over all clients. As opposed to the personalization experiments on ORBIT, where a personalized model is trained using only the client’s local data, in this experiment we initialize a personalized model with the learned global FiLM parameters and then construct a ProtoNets classifier with individual’s data. Thus, the goal of the personalized setting is to estimate how advantageous distributed learning can be for training FiLM layers to build personalized few-shot models.

To evaluate the FiT model in a distributed learning setup, we define baselines which form an upper and lower bounds on the model performance. For the global scenario, we take a FiT model simultaneously trained on all available data as the upper bound baseline. To get the lower bound baseline, we train a FiT model for each client with their local data, then average FiLM parameters of these individual models and construct a global ProtoNets classifier using the resulting FiLM parameters. The upper bound is therefore standard batch training, the performance of which we hope federated learning can approach. The lower bound is a simplistic version of federated learning with a single communication round which federated averaging should improve over. For the personalized setting, the upper bound baseline is as in the global scenario from which we form a personalized classifier by taking a subset of classes belonging to a client from a global -way classifier. The lower bound baseline is set to a FiT model trained for each client individually. The upper bound is again standard batch training and the lower bound is derived from locally trained models which do not share information and therefore should be improved upon by federated learning.

Results

Fig. 2 shows global and personalized classification accuracy as a function of communication cost for different numbers of clients and shots per client. By communication cost we mean the number of parameters transmitted during training. The key observations from our results are:

  • In the global setting, the federated learning model is only slightly worse () than the upper bound baseline, while outperforming the lower bound model, often by a large margin. This shows that FiT can be efficiently used in distributed learning settings with different configurations.

  • In the personalized scenario, for a sufficient number of clients () the gap between the federated learning model and the upper bound model is significantly reduced with the increase in number of shots. Distributed training strongly outperforms the lower bound baseline, surprisingly even in the case of clients with disjoint classes. This provides empirical evidence that collaborative distributed training can be helpful for improving personalized models in the few-shot data regime.

  • The low communication cost per round and relatively fast empirical convergence of FiT results in a parameter efficient training protocol, requiring only around M parameters to be transferred during the whole training phase. In contrast, if we use BiT for federated learning, around M parameters will be sent at each communication round, yielding an enormous communication cost.

In Section A.3.4, we show that distributed training of a FiT model can be efficiently used to learn from more extreme, non-natural image datasets like Quickdraw (Jongejan et al., 2016). Although the number of communication rounds must be increased for efficient transfer to Quickdraw, FiT still has orders of magnitude lower communication cost than BiT.

Figure 2: Global and personalized classification accuracy as a function of communication cost over rounds for different numbers of clients and shots per client on CIFAR100. Classification accuracy is on the vertical axis and is the average of runs with different data sampling seeds. The color of the line indicates the number of shots per class. The solid line shows the federated learning model, while dashed and dotted lines indicate the upper and lower bounds baselines, respectively.

5 Discussion

In this work, we proposed FiT, a parameter and data efficient few-shot transfer learning system that allows image classification models to be updated with only a small subset of the total model parameters. We demonstrated that FiT can outperform BiT, a state-of-the-art transfer learning method at low shot while using one hundred times fewer updateable parameters. We also showed the efficiency benefits of employing FiT in model personalization and federated learning applications. There has been considerable work on compressing models (Cheng et al., 2017b) and designing more parameter efficient networks (Tan and Le, 2019, 2021; Sandler et al., 2018) to reduce the model parameter count. These lines of research are complementary to FiT. Model compression can be used in conjunction with our work by compressing the subset of updateable parameters. Similarly, parameter efficient networks can serve as the backbones of our classification systems. We leave the combination of these technologies and FiT for future work.

Limitations

The main limitation of this work is that it is computationally expensive and much slower to adapt to new tasks compared to meta-learning methods that can adapt with a single forward pass through a network (Requeima et al., 2019) or a small number of gradient steps (Finn et al., 2017). Thus FiT may be inappropriate for certain time critical applications and potentially use more energy than competitive approaches.

Societal Impact

Image classification methods, including the work presented here, have the potential to be beneficial to society as they are readily applicable to medical image analysis, remote sensing for environmental work, and as we demonstrated in Section 4.3, helping blind users find their personal items. Conversely, the same methods could be deployed in an adverse manner such as in military or police surveillance systems. Our system for improving the parameter efficiency of models has the potential benefits of lowering energy, bandwidth, and storage costs and our federating learning approach may provide benefits in protecting user privacy. Despite the improvements presented in this work, image classification methods remain imperfect and can potentially capture biases that violate fairness principles. As a result, our method should be applied judiciously.

6 Acknowledgements

We thank Aristeidis Panos and Siddharth Swaroop for providing helpful and insightful comments. This work has been performed using resources provided by the Cambridge Tier-2 system operated by the University of Cambridge Research Computing Service https://www.hpc.cam.ac.uk funded by EPSRC Tier-2 capital grant EP/P020259/1.

Funding Transparency Statement

Funding in direct support of this work: Aliaksandra Shysheya, John Bronskill, Massimiliano Patacchiola and Richard E. Turner are supported by an EPSRC Prosperity Partnership EP/T005386/1 between the EPSRC, Microsoft Research and the University of Cambridge.

References

  • P. Bateni, R. Goyal, V. Masrani, F. Wood, and L. Sigal (2020) Improved few-shot visual classification. In , pp. 14493–14502. Cited by: §A.5.2, §A.5.2, §3, §4.3, Table 2.
  • C. Beattie, J. Z. Leibo, D. Teplyashin, T. Ward, M. Wainwright, H. Küttler, A. Lefrancq, S. Green, V. Valdés, A. Sadik, et al. (2016) Deepmind lab. arXiv preprint arXiv:1612.03801. Cited by: Table A.4, Table 1.
  • J. F. Bronskill, D. Massiceti, M. Patacchiola, K. Hofmann, S. Nowozin, and R. E. Turner (2021) Memory efficient meta-learning with large images. In Thirty-Fifth Conference on Neural Information Processing Systems, External Links: Link Cited by: §A.5.2, §A.5.2, §3, §4.
  • T. Brown, B. Mann, N. Ryder, M. Subbiah, J. D. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. (2020) Language models are few-shot learners. Advances in neural information processing systems 33, pp. 1877–1901. Cited by: §1.
  • G. Cheng, J. Han, and X. Lu (2017a)

    Remote sensing image scene classification: benchmark and state of the art

    .
    Proceedings of the IEEE 105 (10), pp. 1865–1883. Cited by: Table A.4, Table 1.
  • Y. Cheng, D. Wang, P. Zhou, and T. Zhang (2017b)

    A survey of model compression and acceleration for deep neural networks. arxiv 2017

    .
    arXiv preprint arXiv:1710.09282. Cited by: §5.
  • M. Cimpoi, S. Maji, I. Kokkinos, S. Mohamed, and A. Vedaldi (2014) Describing textures in the wild. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 3606–3613. Cited by: Table A.4, Table 1.
  • J. Deng, W. Dong, R. Socher, L. Li, K. Li, and L. Fei-Fei (2009) ImageNet: a large-scale hierarchical image database. In , Cited by: §A.5.2, §4.3.
  • R. O. Duda, P. E. Hart, and D. G. Stork (2012) Pattern classification. John Wiley & Sons. Cited by: 1st item, 2nd item.
  • V. Dumoulin, N. Houlsby, U. Evci, X. Zhai, R. Goroshin, S. Gelly, and H. Larochelle (2021) Comparing transfer and meta learning approaches on a unified few-shot classification benchmark. arXiv preprint arXiv:2104.02638. Cited by: §A.5.2, §1, §4.3, §4.
  • B. Efron (1975) The efficiency of logistic regression compared to normal discriminant analysis. Journal of the American Statistical Association 70 (352), pp. 892–898. Cited by: §2.
  • L. Fei-Fei, R. Fergus, and P. Perona (2006) One-shot learning of object categories. IEEE transactions on pattern analysis and machine intelligence 28 (4), pp. 594–611. Cited by: Table A.4, Table 1.
  • C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In 34th, pp. 1126–1135. Cited by: §5.
  • R. A. Fisher (1936) The use of multiple measurements in taxonomic problems. Annals of eugenics 7 (2), pp. 179–188. Cited by: 1st item, 2nd item.
  • A. Geiger, P. Lenz, C. Stiller, and R. Urtasun (2013) Vision meets robotics: the kitti dataset. The International Journal of Robotics Research 32 (11), pp. 1231–1237. Cited by: Table A.4, Table 1.
  • T. Hastie, R. Tibshirani, and J. Friedman (2009) The elements of statistical learning: data mining, inference, and prediction. Springer Science & Business Media. Cited by: §2.
  • M. Havasi, R. Jenatton, S. Fort, J. Z. Liu, J. Snoek, B. Lakshminarayanan, A. M. Dai, and D. Tran (2020) Training independent subnetworks for robust prediction. arXiv preprint arXiv:2010.06610. Cited by: §1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016a) Deep residual learning for image recognition. In , pp. 770–778. Cited by: Figure A.1, §1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016b) Identity mappings in deep residual networks. In European conference on computer vision, pp. 630–645. Cited by: §2.
  • P. Helber, B. Bischke, A. Dengel, and D. Borth (2019) Eurosat: a novel dataset and deep learning benchmark for land use and land cover classification. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing 12 (7), pp. 2217–2226. Cited by: Table A.4, Table 1.
  • T. Hospedales, A. Antoniou, P. Micaelli, and A. Storkey (2020) Meta-learning in neural networks: a survey. arXiv preprint arXiv:2004.05439. Cited by: §1.
  • J. Johnson, B. Hariharan, L. Van Der Maaten, L. Fei-Fei, C. Lawrence Zitnick, and R. Girshick (2017) Clevr: a diagnostic dataset for compositional language and elementary visual reasoning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2901–2910. Cited by: Table A.4, Table 1.
  • J. Jongejan, H. Rowley, T. Kawashima, J. Kim, and N. Fox-Gieg (2016) The quick, draw! - a.i. experiment. Note: https://quickdraw.withgoogle.com/data Cited by: §4.4.
  • Kaggle and EyePacs (2015) Kaggle diabetic retinopathy detection. Note: https://www.kaggle.com/c/diabetic-retinopathy-detection/data Cited by: Table A.4, Table 1.
  • D. Kingma and J. Ba (2015) Adam: a method for stochastic optimization. In 3rd, Cited by: §A.5.1.
  • A. Kolesnikov, L. Beyer, X. Zhai, J. Puigcerver, J. Yung, S. Gelly, and N. Houlsby (2019) Big transfer (bit): general visual representation learning. arXiv preprint arXiv:1912.11370 6 (2), pp. 8. Cited by: §A.5.1, §A.5.3, §1, §1, §2, §4.1, §4.
  • A. Kolesnikov, L. Beyer, X. Zhai, J. Puigcerver, J. Yung, S. Gelly, and N. Houlsby (2020) Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper. GitHub. Note: https://github.com/google-research/big_transfer Cited by: §A.5.1, §A.5.1.
  • A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Cited by: Table A.2, Table A.3, Table A.4, §4.1, §4.4, §4.4, Table 1.
  • Y. LeCun, F. J. Huang, and L. Bottou (2004) Learning methods for generic object recognition with invariance to pose and lighting. In Proceedings of the 2004 IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2004. CVPR 2004., Vol. 2, pp. II–104. Cited by: Table A.4, Table 1.
  • D. Massiceti, L. Zintgraf, J. Bronskill, L. Theodorou, M. T. Harris, E. Cutrell, C. Morrison, K. Hofmann, and S. Stumpf (2021) ORBIT: A Real-World Few-Shot Dataset for Teachable Object Recognition. In , Cited by: §A.5.2, §A.5.2, §1, §1, §4.3, §4.
  • L. Matthey, I. Higgins, D. Hassabis, and A. Lerchner (2017) Dsprites: disentanglement testing sprites dataset. Cited by: Table A.4, Table 1.
  • B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pp. 1273–1282. Cited by: §1, §1.
  • H. B. McMahan, E. Moore, D. Ramage, and B. A. y Arcas (2016) Federated learning of deep networks using model averaging. CoRR abs/1602.05629. External Links: Link, 1602.05629 Cited by: §4.4.
  • P. K. Mudrakarta, M. Sandler, A. Zhmoginov, and A. Howard (2019) K for the price of 1: parameter efficient multi-task and transfer learning. In International Conference on Learning Representations, External Links: Link Cited by: §3.
  • Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng (2011) Reading digits in natural images with unsupervised feature learning. Cited by: Table A.4, Table 1.
  • M. Nilsback and A. Zisserman (2008) Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, pp. 722–729. Cited by: Table A.2, Table A.3, Table A.4, §4.1, Table 1.
  • O. M. Parkhi, A. Vedaldi, A. Zisserman, and C. Jawahar (2012) Cats and dogs. In 2012 IEEE conference on computer vision and pattern recognition, pp. 3498–3505. Cited by: Table A.2, Table A.3, Table A.4, §4.1, Table 1.
  • E. Perez, F. Strub, H. De Vries, V. Dumoulin, and A. Courville (2018) FiLM: visual reasoning with a general conditioning layer. In 32nd, Cited by: §1, §2.
  • M. Pohar, M. Blas, and S. Turk (2004) Comparison of logistic regression and linear discriminant analysis: a simulation study. Metodoloski zvezki 1 (1), pp. 143. Cited by: §2.
  • S. Rebuffi, H. Bilen, and A. Vedaldi (2017) Learning multiple visual domains with residual adapters. In Advances in Neural Information Processing Systems, pp. 506–516. Cited by: §3.
  • S. Rebuffi, H. Bilen, and A. Vedaldi (2018) Efficient parametrization of multi-domain deep neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8119–8127. Cited by: §3.
  • S. J. Reddi, Z. Charles, M. Zaheer, Z. Garrett, K. Rush, J. Konečný, S. Kumar, and H. B. McMahan (2021) Adaptive federated optimization. In International Conference on Learning Representations, External Links: Link Cited by: §4.4.
  • J. Requeima, J. Gordon, J. Bronskill, S. Nowozin, and R. E. Turner (2019) Fast and flexible multi-task classification using conditional neural adaptive processes. In 33rd, pp. 7957–7968. Cited by: §1, §3, §5.
  • O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, A. C. Berg, and L. Fei-Fei (2015) ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV) 115 (3), pp. 211–252. External Links: Document Cited by: §A.5.3, §2, §4.1.
  • M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, and L. Chen (2018) Mobilenetv2: inverted residuals and linear bottlenecks. In , pp. 4510–4520. Cited by: §5.
  • F. Schroff, D. Kalenichenko, and J. Philbin (2015)

    Facenet: a unified embedding for face recognition and clustering

    .
    In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 815–823. Cited by: §1.
  • A. Shamsian, A. Navon, E. Fetaya, and G. Chechik (2021) Personalized federated learning using hypernetworks. In ICML, Cited by: §4.4.
  • J. Snell, K. Swersky, and R. Zemel (2017) Prototypical networks for few-shot learning. In 31st, pp. 4077–4087. Cited by: §A.5.2, §A.5.2, 3rd item, §4.3, Table 2.
  • M. Tan and Q. Le (2019)

    EfficientNet: Rethinking model scaling for convolutional neural networks

    .
    In 36th, pp. 6105–6114. Cited by: §5.
  • M. Tan and Q. Le (2021) Efficientnetv2: smaller models and faster training. In

    International Conference on Machine Learning

    ,
    pp. 10096–10106. Cited by: §5.
  • Y. Tian, Y. Wang, D. Krishnan, J. B. Tenenbaum, and P. Isola (2020) Rethinking few-shot image classification: a good embedding is all you need?. arXiv preprint arXiv:2003.11539. Cited by: §1.
  • E. Triantafillou, H. Larochelle, R. Zemel, and V. Dumoulin (2021) Learning a universal template for few-shot dataset generalization. In International Conference on Machine Learning, pp. 10424–10433. Cited by: §3.
  • B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, and M. Welling (2018) Rotation equivariant cnns for digital pathology. In International Conference on Medical image computing and computer-assisted intervention, pp. 210–218. Cited by: Table A.4, Table 1.
  • O. Vinyals, C. Blundell, T. Lillicrap, K. Kavukcuoglu, and D. Wierstra (2016) Matching networks for one shot learning. In 30th, pp. 3630–3638. Cited by: §2.
  • Y. Wu, M. Schuster, Z. Chen, Q. V. Le, M. Norouzi, W. Macherey, M. Krikun, Y. Cao, Q. Gao, K. Macherey, et al. (2016)

    Google’s neural machine translation system: bridging the gap between human and machine translation

    .
    arXiv preprint arXiv:1609.08144. Cited by: §1.
  • J. Xiao, J. Hays, K. A. Ehinger, A. Oliva, and A. Torralba (2010) Sun database: large-scale scene recognition from abbey to zoo. In 2010 IEEE computer society conference on computer vision and pattern recognition, pp. 3485–3492. Cited by: Table A.4, Table 1.
  • W. Xiong, L. Wu, F. Alleva, J. Droppo, X. Huang, and A. Stolcke (2018) The microsoft 2017 conversational speech recognition system. In 2018 IEEE international conference on acoustics, speech and signal processing (ICASSP), pp. 5934–5938. Cited by: §1.
  • J. Yosinski, J. Clune, Y. Bengio, and H. Lipson (2014) How transferable are features in deep neural networks?. In 28th, pp. 3320–3328. Cited by: §A.5.2, §A.5.2, §1, §4.3, Table 2.
  • X. Zhai, J. Puigcerver, A. Kolesnikov, P. Ruyssen, C. Riquelme, M. Lucic, J. Djolonga, A. S. Pinto, M. Neumann, A. Dosovitskiy, et al. (2019) A large-scale study of representation learning with the visual task adaptation benchmark. arXiv preprint arXiv:1910.04867. Cited by: §4.2, §4.

Appendix A Appendix

a.1 FiLM Layer Placement

Fig. A.1 illustrates a FiLM layer operating on a convolutional layer, and Fig. A.1 illustrates how a FiLM layer can be added to a ResNetV2 network block. FiLM layers can be similarly added to EfficientNet based backbones, amongst others.

A FiLM layer.
A ResNet basic block with FiLM layers.
Figure A.1: (Left) A FiLM layer operating on convolutional feature maps in layer and channel . (Right) How a FiLM layer is placed within a basic Residual network block [He et al., 2016a]

. GN is a Group Normalization layer, ReLU is a Rectified Linear Unit, and 1

1, and are 2D convolutional layers with the stated kernel size.

a.2 Model Parameters

Table A.1 depicts the shared and updateable parameters for BiT and the 3 variants of FiT as a function of , , , and . The rightmost column provides the example updateable parameter count for all models in the case of a BiT-M-R50x1 backbone with , , , and .

For FiT-QDA and FiT-LDA, the means and covariances contribute to the updateable parameter count. We use a mean for every class which contributes updateable parameters. A covariance matrix has values, however it can be represented in Cholesky factorized form which results in a lower triangular matrix and thus can be represented with values, with the rest being zeros.

In the case of FiT-LDA, where the covariance matrix is shared across all classes, a more compact representation is possible, resulting in considerable savings in updateable parameters:

(A.1)

From Eq. A.1, it follows that to compute the probability of classifying a test point , we need to store which has dimensionality and which has dimensionality 1 for each class , resulting in only parameters required for the FiT-LDA head. Since the covariance matrix is not shared in the case of FiT-QDA, no additional savings are possible in that case.

max width= Method Shared Updateable Example BiT 0 + 23,520,832 FiT - QDA 21,013,891 FiT - LDA 32,140 FiT - ProtoNets 32,128

Table A.1: Shared and updateable parameters for the transfer learning methods considered. The Example column contains the updateable parameters for all methods using a BiT-M-R50x1 backbone with , , , and .

a.3 Additional Results

This section contains additional results that would not fit into the main paper, including tabular versions of figures.

a.3.1 Additional Few-shot Results

Table A.2 shows the tabular version of Fig. 1. In addition, Table A.2 includes results for an additional variant of BiT (BiT-FiLM), and two additional variants of FiT (FiT-QDA and FiT-ProtoNets). Refer to Section 4.1 for analysis.

max width= BiT FiLM Transfer (FiT) Standard FiLM QDA LDA ProtoNets Dataset Shot Accuracy RMUS Accuracy RMUS Accuracy RMUS Accuracy RMUS Accuracy RMUS 10 1 52.7±3.9 1.0 47.0±14.2 0.0014 54.0±5.7 0.893 54.0±5.7 0.0014 52.9±5.0 0.0014 10 2 69.9±2.6 1.0 63.4±3.6 0.0014 73.0±8.8 0.893 74.2±8.8 0.0014 68.9±9.4 0.0014 CIFAR10 10 5 83.5±3.0 1.0 82.8±1.9 0.0014 85.4±3.9 0.893 86.4±4.2 0.0014 81.8±5.1 0.0014 [Krizhevsky et al., 2009] 10 10 88.1±2.0 1.0 46.1±39.7 0.0014 89.5±1.3 0.893 90.6±1.0 0.0014 87.3±2.3 0.0014 10 All 95.9±0.2 1.0 96.0±0.2 0.0014 96.4±0.0 0.893 96.3±0.1 0.0014 96.0±0.1 0.0014 100 1 36.4±1.3 1.0 30.2±2.7 0.0091 33.8±0.8 8.860 33.8±0.8 0.0091 30.7±0.7 0.0091 100 2 48.8±0.2 1.0 43.0±2.9 0.0091 55.1±1.1 8.860 54.5±1.0 0.0091 50.6±1.8 0.0091 CIFAR100 100 5 65.3±1.5 1.0 39.6±0.6 0.0091 69.0±0.7 8.860 69.5±1.3 0.0091 67.9±0.7 0.0091 [Krizhevsky et al., 2009] 100 10 72.3±0.9 1.0 50.1±0.2 0.0091 75.6±0.6 8.860 75.3±0.6 0.0091 74.7±0.3 0.0091 100 All 86.1±0.1 1.0 82.6±0.2 0.0091 82.1±0.1 8.860 82.6±0.2 0.0091 80.3±0.4 0.0091 102 1 79.8±0.7 1.0 79.3±2.5 0.0093 89.6±0.9 9.036 89.6±0.9 0.0093 85.1±0.9 0.0093 102 2 89.6±1.4 1.0 91.7±0.8 0.0093 95.6±0.5 9.036 95.6±0.7 0.0093 93.0±0.9 0.0093 Flowers102 102 5 96.8±0.7 1.0 96.3±0.2 0.0093 98.3±0.3 9.036 98.4±0.3 0.0093 98.0±0.3 0.0093 [Nilsback and Zisserman, 2008] 102 10 97.3±0.7 1.0 96.9±0.7 0.0093 99.0±0.1 9.036 98.8±0.1 0.0093 98.6±0.1 0.0093 102 All 96.6±0.4 1.0 96.9±0.4 0.0093 99.0±0.1 9.036 98.9±0.1 0.0093 98.6±0.2 0.0093 37 1 38.3±2.3 1.0 36.8±5.5 0.0037 50.2±2.8 3.297 50.1±3.1 0.0037 46.6±2.0 0.0037 37 2 61.2±2.6 1.0 60.1±3.0 0.0037 74.8±1.4 3.297 76.8±1.8 0.0037 72.8±2.7 0.0037 Pets 37 5 76.7±2.0 1.0 76.8±1.5 0.0037 80.0±0.2 3.297 82.5±0.4 0.0037 78.4±0.7 0.0037 [Parkhi et al., 2012] 37 10 78.6±5.3 1.0 79.2±4.4 0.0037 86.2±1.0 3.297 86.9±1.5 0.0037 85.8±0.4 0.0037 37 All 90.1±0.6 1.0 90.1±0.2 0.0037 91.7±0.3 3.297 91.9±0.3 0.0037 91.2±0.1 0.0037

Table A.2: Classification accuracy and Relative Model Update Size (RMUS) for all three variants of FiT and two variants of BiT on standard downstream datasets as a function of shots per class. The backbone is BiT-M-R50x1 with , , and . Accuracy figures are percentages and the sign indicates the 95% confidence interval over 3 runs. Bold type indicates the highest scores (within the confidence interval).

a.3.2 Few-shot Task Ablations

Table A.3 shows the few-shot results for all three variants of FiT with different ablations on how the downstream dataset is allocated during training. No Split indicates that is not split into two disjoint partitions and . However, and are sampled to form episodic training tasks as detailed in Algorithm A.2. Split indicates that is is split into two disjoint partitions as detailed in Algorithm A.1 and then sampled into tasks as described in Algorithm A.2. Use All indicates that (i.e. is not split) and that and are not sampled and that for all tasks .

Table A.3 shows that Use All is consistently the worst option. In general, in the few-shot case, Split either outperforms No Split (CIFAR10, Pets) or achieves the same level of performance (CIFAR100, Flowers102). As a result, we use the Split option when reporting the few-shot results.

max width= QDA LDA ProtoNets Dataset Shot No Split Split Use All No Split Split Use All No Split Split Use All 1 42.5±2.2 54.0±5.7 37.2±3.9 36.7±4.3 54.0±5.7 28.5±2.3 53.0±5.1 52.9±5.0 53.0±5.1 2 62.8±4.8 73.0±8.8 68.2±5.4 60.4±1.6 74.2±8.8 40.8±5.2 65.2±4.5 68.9±9.4 65.2±4.5 CIFAR10 5 79.7±0.9 85.4±3.9 79.6±1.0 86.6±3.5 86.4±4.2 69.4±3.5 76.1±4.2 81.8±5.1 75.6±9.3 [Krizhevsky et al., 2009] 10 84.2±0.1 89.5±1.3 84.0±0.3 92.3±0.3 90.6±1.0 87.3±0.9 84.5±3.8 87.3±2.3 81.9±4.5 All 96.6±0.1 96.4±0.0 96.2±0.0 96.6±0.0 96.3±0.1 96.6±0.0 96.1±0.1 96.0±0.1 95.4±0.0 1 33.0±3.9 33.8±0.8 33.0±0.8 32.8±2.7 33.8±0.8 33.8±0.8 30.7±0.7 30.7±0.7 30.7±0.7 2 54.5±2.0 55.1±1.1 45.6±2.0 55.6±1.2 54.5±1.0 45.1±3.1 52.3±2.0 50.6±1.8 40.2±6.1 CIFAR100 5 69.7±1.0 69.0±0.7 57.9±0.4 69.8±1.0 69.5±1.3 57.6±1.8 68.7±1.5 67.9±0.7 56.6±4.2 [Krizhevsky et al., 2009] 10 75.5±0.3 75.6±0.6 67.6±7.9 75.6±0.2 75.3±0.6 67.0±3.2 74.8±0.2 74.7±0.3 65.4±2.8 All 82.4±0.1 82.1±0.1 77.2±0.1 82.6±0.2 82.6±0.2 81.1±0.1 80.6±0.2 80.3±0.4 78.2±0.1 1 86.1±0.5 89.6±0.9 89.1±0.8 82.1±0.8 89.6±0.9 89.6±0.9 85.1±0.9 85.1±0.9 85.1±0.9 2 95.2±0.6 95.6±0.5 94.4±0.6 95.6±0.5 95.6±0.7 94.9±0.5 93.9±1.0 93.0±0.9 91.9±1.2 Flowers102 5 98.4±0.2 98.3±0.3 98.2±0.4 98.5±0.2 98.4±0.3 97.4±0.4 98.1±0.4 98.0±0.3 96.6±0.6 [Nilsback and Zisserman, 2008] 10 99.0±0.0 99.0±0.1 98.9±0.1 98.9±0.1 98.8±0.1 98.5±0.0 98.6±0.0 98.6±0.1 96.7±0.0 All 99.1±0.1 99.0±0.1 98.8±0.0 99.0±0.1 98.9±0.1 98.4±0.0 98.8±0.1 98.6±0.2 96.7±0.0 1 29.4±3.1 50.2±2.8 50.2±2.8 17.2±6.3 50.1±3.1 50.0±3.2 46.6±2.0 46.6±2.0 46.6±2.0 2 53.0±2.1 74.8±1.4 64.2±0.7 49.4±5.7 76.8±1.8 53.4±2.0 60.1±0.4 72.8±2.7 60.1±0.4 Pets 5 81.2±2.3 80.0±0.2 73.6±1.1 82.1±2.2 82.5±0.4 71.0±3.1 83.0±2.4 78.4±0.7 67.4±6.2 [Parkhi et al., 2012] 10 87.3±1.0 86.2±1.0 80.0±6.2 87.1±0.8 86.9±1.5 81.6±1.6 87.1±1.2 85.8±0.4 79.2±2.0 All 92.1±0.2 91.7±0.3 77.8±0.0 91.8±0.2 91.9±0.3 88.3±0.1 91.8±0.2 91.2±0.1 82.4±0.4

Table A.3: Classification accuracy for all three variants of FiT as a function of shots per class and how the downstream dataset is utilized during training on standard datasets. The backbone is BiT-M-R50x1 with , , and . Accuracy figures are percentages and the sign indicates the 95% confidence interval over 3 runs.

a.3.3 VTAB-1k Task Ablations

Table A.4 shows the VTAB-1k results for all three variants of FiT with different ablations on how the downstream dataset is allocated during training. Refer to Section A.3.2 for the meanings of No Split, Split, and Use All. With some minor exceptions, the Use All case performs the worst. The performance of the No Split and Split options is very close, with No Split being slightly better when averaged over all of the datasets. As a result, we use the No Split option when reporting the VTAB-1k results.

max width= QDA LDA ProtoNets Dataset No Split Split Use All No Split Split Use All No Split Split Use All Caltech101 [Fei-Fei et al., 2006] 90.3±0.8 90.0±0.7 85.5±0.0 90.4±0.8 90.7±0.7 87.3±0.0 89.6±0.2 89.7±0.3 82.3±0.0 CIFAR100 [Krizhevsky et al., 2009] 74.1±0.1 74.8±0.3 63.4±0.0 74.2±0.5 74.1±0.4 69.0±0.0 73.9±0.3 73.7±0.6 65.5±0.2 Flowers102 [Nilsback and Zisserman, 2008] 99.1±0.1 99.0±0.1 98.8±0.0 99.0±0.1 98.9±0.1 98.5±0.0 98.6±0.0 98.6±0.1 96.7±0.0 Pets [Parkhi et al., 2012] 91.0±0.3 90.4±0.4 75.6±0.0 90.5±0.0 90.5±0.5 87.2±0.1 90.8±0.2 90.4±0.2 85.6±0.0 Sun397 [Xiao et al., 2010] 51.1±0.7 52.1±0.1 49.3±0.7 51.6±0.5 50.8±0.7 42.7±3.4 51.5±1.4 50.4±0.8 42.4±4.2 SVHN [Netzer et al., 2011] 75.1±1.3 73.3±1.4 26.2±0.1 74.2±0.9 71.5±0.2 67.4±0.0 50.1±2.2 47.4±1.4 35.1±0.1 DTD [Cimpoi et al., 2014] 70.9±0.1 70.2±0.3 72.8±0.0 70.9±0.1 70.8±0.3 66.5±0.0 68.2±1.1 68.4±1.0 61.3±0.2 EuroSAT [Helber et al., 2019] 95.6±0.1 94.7±0.6 93.5±0.0 95.1±0.1 94.3±0.6 94.3±0.0 93.8±0.1 92.7±0.2 89.4±0.1 Resics45 [Cheng et al., 2017a] 82.6±0.1 82.0±0.3 77.5±0.0 82.5±0.2 80.8±0.2 78.3±0.1 77.0±0.0 76.4±0.8 71.9±0.1 Patch Camelyon [Veeling et al., 2018] 80.7±1.2 81.5±1.0 65.7±0.0 82.5±0.7 80.5±0.8 82.4±0.5 79.9±0.2 78.5±2.1 69.0±0.1 Retinopathy [Kaggle and EyePacs, 2015] 70.4±0.1 67.5±1.0 25.5±0.0 66.2±0.5 63.4±0.3 25.0±0.1 57.9±0.3 58.4±0.6 17.0±0.0 CLEVR-count [Johnson et al., 2017] 87.1±0.3 84.9±1.1 40.6±0.1 85.6±0.9 84.3±0.5 82.0±0.1 88.7±0.3 85.4±0.7 87.4±0.1 CLEVR-dist [Johnson et al., 2017] 58.1±0.8 58.4±0.6 39.1±0.1 56.1±0.8 55.7±1.7 57.7±0.4 58.3±0.6 55.0±1.2 33.7±0.2 dSprites-loc [Matthey et al., 2017] 77.1±2.0 75.1±1.2 13.5±0.3 74.8±1.4 71.1±1.1 62.4±1.1 68.6±2.4 66.8±0.8 74.8±1.4 dSprites-ori [Matthey et al., 2017] 56.7±0.3 55.6±0.7 39.6±1.8 51.3±0.7 48.0±0.1 53.8±0.0 34.2±0.8 32.4±1.2 36.7±0.3 SmallNORB-azi [LeCun et al., 2004] 18.9±0.6 19.8±0.6 14.0±0.1 16.2±0.1 17.1±1.3 15.8±0.3 13.5±0.1 13.0±0.6 13.1±0.0 SmallNORB-elev [LeCun et al., 2004] 40.4±0.2 40.3±1.4 28.2±0.1 37.0±0.6 38.5±1.6 36.1±0.3 35.0±0.6 34.0±0.9 26.5±0.1 DMLab [Beattie et al., 2016] 43.8±0.3 41.2±1.3 33.7±0.1 41.6±0.6 39.5±1.3 38.9±0.4 39.3±0.3 38.6±0.2 28.2±0.1 KITTI-dist [Geiger et al., 2013] 77.5±0.7 77.2±2.2 73.6±0.1 77.7±0.8 77.5±1.3 73.1±0.1 75.3±0.2 74.3±2.9 69.0±0.9 All 70.6 69.9 53.5 69.3 68.3 64.1 65.5 64.4 57.1 Natural 78.8 78.5 67.4 78.7 78.2 74.1 74.7 74.1 67.0 Specialized 82.3 81.4 65.6 81.5 79.7 70.0 77.1 76.5 61.8 Structured 57.5 56.6 35.3 55.0 54.0 52.5 51.6 49.9 46.2

Table A.4: Classification accuracy for all three variants of FiT on the VTAB-1k benchmark as a function of how the downstream dataset is utilized during training. The backbone is BiT-M-R50x1. Accuracy figures are percentages and the sign indicates the 95% confidence interval over 3 runs. Bold type indicates the highest scores.

a.3.4 Additional Few-shot Federated Learning Results

max width= Global Personalized Clients Shot Lower Bound FL Upper Bound Lower Bound FL Upper Bound 2 42.6±1.9 49.3±1.3 52.5±0.9 73.2±1.4 80.1±0.8 81.7±1.2 10 5 55.6±1.6 64.1±1.6 68.7±0.7 82.6±0.4 88.8±0.2 90.6±0.4 10 60.7±1.2 71.5±0.8 74.7±0.3 86.6±0.2 91.8±0.5 92.6±0.6 2 59.9±0.6 68.5±0.3 73.2±0.2 75.4±0.9 84.6±0.5 93.4±0.2 50 5 63.0±0.8 73.8±0.7 78.0±0.1 83.5±0.5 91.0±0.4 94.9±0.2 10 65.8±1.1 77.1±0.3 79.2±0.3 87.4±0.6 93.4±0.3 95.2±0.1 2 65.6±0.6 74.4±0.3 77.5±0.2 75.7±0.9 85.3±0.2 94.6±0.2 100 5 66.0±0.3 76.5±0.2 79.5±0.2 83.1±0.9 91.3±0.2 95.3±0.1 10 66.8±0.3 78.4±0.1 80.2±0.1 87.4±0.5 93.6±0.2 95.5±0.1 2 68.8±0.1 77.9±0.4 80.1±0.3 75.5±0.2 85.8±0.2 95.6±0.1 500 5 67.6±0.1 77.6±0.1 80.6±0.2 83.2±0.3 91.0±0.2 95.7±0.1 10 67.7±0.1 79.2±0.3 80.8±0.2 87.7±0.1 94.3±0.1 96.2±0.1

Table A.5: Few-shot Federated Learning Results on CIFAR100 for different numbers of clients and shots per client. Accuracy figures are percentages and the sign indicates the confidence interval over runs. Global stands for the global setting, while Personalized stands for the personalized scenario. FL indicates Federated Learning training.

Table A.5 shows the tabular version of Fig. 2. For the Federated Learning results, it includes only the resulting accuracy after training for communication rounds. Refer to Section 4.4 for analysis.

To test distributed training of a FiT model on a more extreme, non-natural image dataset, we also include the results for federated training of FiT on the Quickdraw dataset. As there is no pre-defined train/test split for the Quickdraw dataset, we randomly choose samples from each of the classes and use them for testing. We train all federated training models for communication rounds, with clients per round, and update steps per client. Since Quickdraw is a more difficult dataset than CIFAR100, it requires more communication rounds for training. Each client has classes, which are sampled randomly at the start of training. In our experiments, we omit the -clients case, as the overall amount of data in the system is not enough to even train a robust global upper bound baseline model.

Fig. A.2 shows global and personalized classification accuracy as a function of communication cost for different numbers of clients and shots per client for Quickdraw, while Table A.6 shows the tabular version of this figure.

Figure A.2: Global and personalized classification accuracy as a function of communication cost over rounds for different numbers of clients and shots per client on Quickdraw. Classification accuracy is on the vertical axis and is the average of runs with different data sampling seeds. The color of the line indicates the number of shots per class. The solid line shows the federated learning model, while dashed and dotted lines indicate the upper and lower bounds baselines, respectively.

max width= Global Personalized Clients Shot Lower Bound FL Upper Bound Lower Bound FL Upper Bound 2 26.3±0.4 29.0±0.5 32.9±0.6 36.9±0.2 43.4±0.6 64.2±0.8 50 5 30.6±0.3 40.7±0.6 43.9±0.5 48.8±0.4 63.6±0.3 74.3±0.4 10 34.3±0.7 44.9±0.1 47.2±0.2 56.5±0.8 71.1±0.1 76.6±0.2 2 28.5±0.1 32.4±0.8 43.0±1.0 35.9±0.3 43.9±1.0 73.1±0.4 100 5 32.8±0.4 44.1±0.4 47.4±0.1 48.6±0.9 64.9±0.8 76.4±0.3 10 35.3±0.1 46.8±0.1 48.7±0.2 55.5±0.4 71.5±0.5 77.2±0.3 2 32.0±0.3 40.3±2.4 49.1±0.3 36.2±0.1 49.2±2.4 77.6±0.3 500 5 34.0±0.1 46.1±0.1 48.6±2.2 48.7±0.2 65.8±0.3 77.0±1.7 10 36.1±0.1 48.1±0.4 48.2±1.2 55.7±0.1 72.2±0.2 76.7±0.9

Table A.6: Few-shot Federated Learning Results on Quickdraw for different numbers of clients and shots per client. Accuracy figures are percentages and the sign indicates the confidence interval over runs. Global stands for the global setting, while Personalized stands for the personalized scenario. FL indicates Federated Learning training.

a.4 FiT Training Algorithms

Algorithm A.1 and Algorithm A.2 detail how episodic tasks are split and sampled, respectively, for use in the FiT training protocol.

1:: downstream dataset
2:unique() function that returns a list of unique classes and list of counts of each class
3:select_by_class() function that extracts samples of a specified class from a dataset
4:procedure split()
5:      Create an empty list to hold
6:      Create an empty list to hold
7:     classes, class_counts
8:     for all  do
9:         assert Require a minimum of 2 shots per class.
10:         train_count
11:          Select examples of class from
12:          Add train_count examples to
13:          Add remaining examples to
14:     end for
15:     return ,
16:end procedure
Algorithm A.1 Splitting the downstream dataset
1:: train portion of downstream dataset
2:: test portion of downstream dataset
3:support_set_size: size of the support set
4:unique() function that returns a list of unique classes and list of counts of each class
5:randint() function that returns a random integer between and
6:choice() function that returns a random list of integers from
7:procedure sample_task(, support_set_size)
8:      Create an empty list to hold
9:      Create an empty list to hold
10:     train_classes, train_class_counts
11:     test_classes, test_class_counts
12:     min_way
13:     max_way
14:     way Classification way to use for this task
15:     selected_classes List of classes to use in this task
16:     balanced_shots = max(round(support_set_size / len(selected_classes)), 1)
17:     max_test_shots
18:     for all  do
19:         class_shots train_class_counts()
20:         shots_to_use min(class_shots, balanced_shots)
21:         selected_shots choice(class_shots, shots_to_use) Support shot list
22:          Add examples to
23:         class_shots test_class_counts()
24:         shots_to_use min(class_shots, max_test_shots)
25:         selected_shots choice(class_shots, shots_to_use) Query shot list
26:          Add examples to
27:     end for
28:     return ,
29:end procedure
Algorithm A.2 Sampling a task

a.5 Training and Evaluation Details

In this section, we provide implementation details for all of the experiments in Section 4.

a.5.1 Few-shot and VTAB-1k Transfer Learning Experiments

FiT

All of the FiT few-shot and VTAB-1k transfer learning experiments were carried out on a single NVIDIA A100 GPU with 80GB of memory. The Adam optimizer [Kingma and Ba, 2015] with a constant learning rate of 0.0035, for 400 iterations, and =100 was used throughout. No data augmentation was used and images were scaled to 384384 pixels unless the image size was 3232 pixels or less, in which case the images were scaled to 224224 pixels. These hyper-parameters were empirically derived from a small number of runs.

FiT-QDA, FiT-LDA, and FiT-ProtoNets take approximately 12, 10, and 9 hours, respectively, to fine-tune on all 19 VTAB datasets and 5, 3, and 3 hours, respectively, to fine tune all shots on the 4 low-shot datasets.

BiT

For the BiT few-shot experiments, we used the code supplied by the authors [Kolesnikov et al., 2020] with minor augmentations to read additional datasets. The BiT few-shot experiments were run on a single NVIDIA V100 GPU with 16GB.

For the BiT VTAB-1k experiments, we used the three fine-tuned models for each of the datasets that were provided by the authors [Kolesnikov et al., 2020]. We evaluated all of the models on the respective test splits for each dataset and averaged the results of the three models. The BiT-HyperRule [Kolesnikov et al., 2019] was respected in all runs. These experiments were executed on a single NVIDIA GeForce RTX 3090 with 24GB of memory.

a.5.2 Personalization on ORBIT Experiments

The personalization experiments were carried out on a single NVIDIA GeForce RTX 3090 with 24GB of memory. It takes approximately hours to train FiT-LDA personalization models for all the ORBIT [Massiceti et al., 2021] test tasks. We derived all hyperparameters empirically from a small number of runs. We used the ORBIT codebase111https://github.com/microsoft/ORBIT-Dataset in our experiments, only adding the code for splitting test user tasks and slightly modifying the main training loop to make it suitable for FiT training.

For the personalization experiments, all methods use an EfficientNet-B0 () as the feature extractor, as it has previously shown superior performance on the ORBIT dataset [Bronskill et al., 2021], and an image size of . FiT-LDA, FineTuner [Yosinski et al., 2014] and Simple CNAPs [Bateni et al., 2020] use a backbone pretrained on ImageNet [Deng et al., 2009], while ProtoNets [Snell et al., 2017] meta-trained the weights of the feature extractor on Meta-Dataset [Dumoulin et al., 2021].

The FineTuner [Yosinski et al., 2014] results are from [Bronskill et al., 2021]. Meta-trained weights for Simple CNAPs [Bateni et al., 2020] and ProtoNets [Snell et al., 2017] are also taken from [Bronskill et al., 2021]. Using these weights, we test these models on the ORBIT test set and report the results.

FiLM layers in FiT-LDA are added to the feature extractor as described in Section 2, resulting in .

We follow the task sampling protocols described in [Massiceti et al., 2021], and train the FiT model for optimization steps using the Adam optimizer with a learning rate of . The ORBIT test tasks have a slightly different structure in comparison to standard few-shot classification tasks, so in Algorithm A.3 we provide a modified version of data splitting for the classifier head construction. In particular, each test user has a number of objects (classes) they want to recognize, with several videos recorded per object. Each video is split into clips, consecutive -frame parts of the video. A user test task is comprised of these clips, randomly sampled from different videos of the user’s objects, and associated labels. Since clips sampled from the same video can be semantically similar, we split the test task so that clips from the same video can only be in either the support or query set, except for the cases when there is only one video of an object available.

1:: downstream dataset; , where is the number of classes in test task, is data of class ; , where is the number of videos in class , is the set of clips from th video of class ; , where is the number of clips in th video of class , is the th clip from video
2:batch_size: size of context split
3:choose() function that randomly samples different integers from a set
4:select_by_index(, ) function that extracts samples of indices from a dataset
5:diff() function that computes set difference between sets and
6:range() function that returns a set of values
7:procedure split_orbit_task()
8:     
9:     
10:     
11:     for  to  do
12:         
13:         
14:         
15:         if  then
16:              
17:              
18:              
19:              
20:         else
21:              for  do
22:                  
23:                  
24:              end for
25:              for  do
26:                  
27:              end for
28:         end if
29:     end for
30:     return ,
31:end procedure
Algorithm A.3 Splitting a test task for ORBIT personalization experiments

a.5.3 Federated Learning Experiments

For each local update a new Adam optimizer is initialized. In each communication round, clients are randomly chosen for making model updates. All of the federated learning experiments were carried out on a single NVIDIA A100 GPU with 80GB of memory. In all experiments we use FiT with the BiT-M-R50x1 [Kolesnikov et al., 2019] backbone pretrained on the ImageNet-21K [Russakovsky et al., 2015] dataset and ProtoNets head. We derive all hyperparameters empirically from a small number of runs.

Cifar100

We train all federated learning models with different number of clients and shots per client for communication rounds. We use a learning rate of at the start of the training, decaying it by every communication rounds. Upper and lower bound baselines for both the global and personalized scenarios were trained for epochs using the Adam optimizer with a constant learning rate of . It takes around minutes to train federated learning models, with slightly more training time required for the models with a larger number of shots.

Quickdraw

We train all federated learning models with a different number of clients and shots per client for communication rounds. We use a constant learning rate of for training all federated learning models, except for the model with clients and shots, where we decay the learning rate by every communication rounds. Upper bound baseline models, which require training a global model using all available data, were trained for steps using the Adam optimizer with a constant learning rate of . Lower baseline models, requiring training a personalized model for each individual, were trained for steps using the Adam optimizer with a learning rate of . As there are only few samples per class per client, personalized models are trained in a few-shot regime, resulting in overfitting if trained for longer. It takes around hours to train federated learning models, with slightly more training time required for the models with a larger number of shots.