ttt_imagenet_release
TTT Code Release
view repo
We introduce a general approach, called testtime training, for improving the performance of predictive models when test and training data come from different distributions. Testtime training turns a single unlabeled test instance into a selfsupervised learning problem, on which we update the model parameters before making a prediction on the test sample. We show that this simple idea leads to surprising improvements on diverse image classification benchmarks aimed at evaluating robustness to distribution shifts. Theoretical investigations on a convex model reveal helpful intuitions for when we can expect our approach to help.
READ FULL TEXT VIEW PDFTTT Code Release
TTT Code Release
AttributeGuided Adversarial Training for Robustness to Natural Perturbations. Code for AAAI 2021 paper https://arxiv.org/pdf/2012.01806.pdf
None
Supervised learning remains notoriously weak at outofdistribution generalization. Unless training data and test data are drawn from the same distribution, predictive models tend to fail easily, and even seemingly minor natural variations in distribution turn out to defeat stateoftheart models. Adversarial robustness, transfer learning and domain adaptation are but a few existing paradigms that anticipate differences of sorts between training and test distributions. In this work, we explore a new take on the problem of outofdistribution generalization, without any mathematical structure or data available at trainingtime about the distributional differences.
We start from a simple observation. When presented with an unlabeled test instance , this instance itself gives us a hint about the distribution from which it was drawn. Our approach, called testtime training, allows the model parameters to depend on the test case , but not its unknown label . While such a samplespecific model is powerful in principle, seizing this informationtheoretic advantage in practice raises new challenges that we begin to address in this work.
The route we take is to create a selfsupervised learning problem based only on this single test case , which we use to update at testtime before we then make a prediciton on . Selfsupervised learning uses an auxiliary task that automatically creates labels from unlabeled data. For the visual data we work with, the task rotates an image by a multiple of 90 degrees, and assigns the angle as the label (gidaris2018unsupervised).
Our approach can also be easily modified to work outside the standard supervised learning setting. If several test instances arrive in a batch, we can use the entire batch for testtime training. If test instances arrive online sequentially, we obtain further improvements by keeping the state of the parameters. After all, prediction is rarely a single event. The online setting can be the natural mode of deployment in practice, and shows the strongest improvements.
We experiment with testtime training for outofdistribution generalization in the context of object recognition on several benchmarks. These include images with diverse types of corruption at various levels (hendrycks2019benchmarking), video frames of moving objects (shankar2019systematic), and a new test set of unknown distribution shifts collected by recht2018cifar. Our algorithm does not hurt on the original distribution, but makes substantial improvements under distribution shifts.
In all our experiments, we compare with a strong baseline (labeled joint training) that uses both supervised and selfsupervised learning at trainingtime, but keeps the model fixed at testtime. Very recent work shows that additional trainingtime selfsupervision improves outofdistribution generalization hendrycks2019using. The joint training baseline we use corresponds to an improved implementation of their work. A comprehensive review of related work follows in section 5.
We complement the empirical results with theoretical investigations in Section 4
, of when testtime training is expected to help on a convex model, and establish an intuitive sufficient condition, which, roughly speaking, is to have correlated gradients between the loss functions of the main task and selfsupervised task.
Next we describe the algorithmic details of our method. To setup notations, consider a standard
layer neural network with parameters
for layer. The stacked parameter vector
specifies the entire model for a classification task with loss function on the test instance . We call this the main task, as indicated by the subscript of the loss function.We assume to have training data drawn i.i.d. from a distribution . Standard empirical risk minimization corresponds to solving the optimization problem:
(1) 
Our method requires a selfsupervised auxiliary task with loss function . In this paper, we choose the rotation prediction task (gidaris2018unsupervised), which has been demonstrated to be simple and effective at feature learning for neural networks. The task simply rotates on the image plane by one of 0, 90, 180 and 270 degrees and have the model predict the angle of rotation as a fourway classification problem. Other selfsupervised tasks in section 5 might also be used for our method.
The auxiliary task shares some of the model parameters up to a certain We think of those layers as a shared feature extractor. The auxiliary task uses its own taskspecific parameters . We call the unshared parameters the selfsupervised task branch, and the main task branch. Pictorially, the joint architecture is a structure with a shared bottom and two branches. For our experiments, the selfsupervised task branch has the exact same architecture as the main branch, except for the output dimensionality of the last layer due to the different number of classes in the two tasks.
Training is done in the fashion of multitask learning (caruana1997multitask); the model is trained on both tasks on the same data drawn from . Losses for both tasks are added together, and gradients are taken for the collection of all parameters. The joint training problem is therefore
(2) 
Now we describe the standard version of testtime training on a single test input . Simply put, testtime training finetunes the shared feature extractor by minimizing the auxiliary task loss on . This can be formulated as
(3) 
Denote the (approximate) minimizer of Equation 3. The model then makes a prediction using the updated parameters . Empirically, the difference is negligible between minimizing Equation 3 over versus over both and . Theoretically, there exists a difference only when optimization is done with more than one step of gradient descent.
In the standard version of our method, the optimization problem in Equation 3 is always initialized with parameters obtained by minimizing Equation 2 on data from . After making a prediction on , is discarded. Outside of the standard supervised learning setting, when the test instances arrive online sequentially, the online version of testtime training solves the same optimization problem as in Equation 3 to update the shared feature extractor . However, on test input , is instead initialized with updated on the previous test input . This allows to take advantage of the distributional information available in as well as .
Testtime training naturally benefits from standard data augmentation techniques. On each test input , we perform the exact same set of random transformations as used for data augmentation during training to form a batch for testtime training.
We experiment with both versions of our method (standard and online) on three kinds of outofdistribution benchmarks, presented here in the order of visually low to highlevel, which is roughly also the order of easy to hard. Our code is available at /https://github.com/yueatsprograms/ttt_cifar_release and https://github.com/yueatsprograms/ttt_imagenet_release.
Our architecture and hyperparameters are consistent across all experiments. We use Residual Networks (ResNets) (he2016identity), which are constructed differently for CIFAR10 ^{1}^{1}1CIFAR10 (krizhevsky2009learning) is a standard object recognition dataset with 10 classes of objects in natural scenes. The standard train / test split has 50,000 / 10,000 images, each of size 32by32 pixels.
(26layer) and ImageNet
^{2}^{2}2The ImageNet 2012 classification dataset (ILSVRC15) for object recognition has images from 1, 000 classes, 1.2 million for training and 50,000 for validation. Following standard practice (he2016deep; he2016identity; huang2016deep), the validation set is used as the test set.(18layer). ResNets on ImageNet have four groups, each containing convolutional layers with the same number of channels and size of feature maps; our splitting point is the end of the third group. ResNets on CIFAR10 have three groups; our splitting point is the end of the second group. In addition, Batch Normalization (BN), a popular module in deep networks, is ineffective when training on small batches, for which the estimated batch statistics are less accurate
(ioffe2015batch). This technicality hurts testtime training since each batch only contains (augmented) copies of a single image. Therefore our networks instead use Group Normalization (GN) (wu2018group), which achieves similar performance as BN on large batches without hurting on small ones. Results with BN are shown in Appendix E for completeness.


For Equation 2, optimization hyperparameters are set to the default ^{3}^{3}3Namely, we use stochastic gradient descent (SGD) with weight decay and momentum; learning rate starts at 0.1 and is dropped by a factor of ten at two scheduled milestones, to 0.01 and 0.001. in standard practice (huang2016deep; he2016deep). For Equation 3
, we use stochastic gradient descent (SGD) with the learning rate set to that of the last epoch during training, which is 0.001 in all our experiments. Following standard practice
(he2018rethinking; liu2018rethinking) known to improve performance when finetuning, we do not use weight decay or momentum. For the standard version, we take ten gradient steps, using batches independently generated by the same image. For online we take only one step. The computational aspects of our method are discussed in Appendix C. Following standard practice, the transformations used for data augmentation are random crop with padding and random horizontal flip for CIFAR10
(guo2017calibration; huang2016deep), and random resized crop and random horizontal flip for ImageNet (ioffe2015batch; he2016deep). Specifically, these transformations do not contain information about the distribution shifts.In all the tables and figures, baseline refers to the plain ResNet model (using GN, unless otherwise specified); joint training refers to the model jointly trained on both the main task and the selfsupervised task, fixed at testtime as in hendrycks2019using; testtime training refers to the standard version described section 2; and testtime training online refers to the online version that does not discard for arriving sequentially from the same distribution. Performance for testtime training online is calculated, just like the others, as the average over the entire test set; we always shuffle the test set before testtime training online to avoid ordering artifacts.
hendrycks2019benchmarking propose to benchmark robustness of neural networks on 15 types of corruptions from four broad categories: noise, blur, weather and digital. Each corruption type comes in five levels of severity, with level 5 the most severe (details and sample images in Appendix D). The corruptions are algorithmically simulated to mimic realworld corruptions as much as possible on copies of the test set for both CIFAR10 and ImageNet. According to the authors, training should be done on the original training set, and the diversity of corruption types should make it difficult for any method to work well across the board if it relies too much on corruption specific knowledge.
Our results on the level 5 corruptions (most severe) are shown in Figure 9. Due to space constraints, our results on levels 14 are shown in Appendix E. Across all five levels and 15 corruption types, both versions of testtime training always improve over the baseline by a large margin. The standard version of testtime training always improves over joint training, and the online version often improves very significantly (10%) over joint training and never hurts by more than 0.2%. Specifically, testtime training online contributes 24% on the three noise types and 38% on pixelation. For the seemly unstable setup of a learning problem that abuses a single image, this kind of consistency is rather surprising.
The baseline ResNet26 has error 8.9% on the original test set of CIFAR10. The joint training baseline actually improves performance on the original to 8.1%. Most surprisingly, unlike many other methods that tradeoff original performance with robustness, testtime training further improves on the original test set by 0.2% consistently over many independent trials. This indicates that our method does not choose between specificity and generality.
Separate from our method, it is interesting to note that joint training consistently improves over the baseline, as discovered by hendrycks2019using. hendrycks2019benchmarking
have also experimented with various other training methods on this benchmark, and point to Adversarial Logit Pairing (ALP)
(kannan2018adversarial) as the most effective. Results of this additional baseline on all levels of CIFAR10C are shown in Appendix E, along with its implementation details. While surprisingly robust under some of the most severe corruptions (especially the three noise types), ALP incurs a much larger error (by a factor of two) on the original distribution and some corruptions (e.g. all levels of contrast and fog), and hurts performance significantly when the corruptions are not as severe (especially on levels 13); this kind of tradeoff is to be expected for methods based on adversarial training, but not testtime training.Our results on the level 5 corruptions (most severe) are shown in Figure 2. We use accuracy instead of error for this dataset because the baseline severely underperforms with most corruptions. The general trend is roughly the same as on CIFAR10C. Testtime training (standard version) always improves over the baseline and joint training, while the online version only hurts on the original by 0.1% over the baseline, but dramatically improves (by a factor of more than three) on many of the corruption types.
In the lower panel of Figure 2, we visualize how the accuracy (averaged over a sliding window) of the online version changes as more images are tested on. Due to space constraints, we show this plot on the original test set, as well as every third corruption type, following the same order as in the original paper. On the original test set, there is no visible change in performance after updating on the 50,000 samples. With corruptions, accuracy has already risen significantly after 10,000 samples, but is still rising towards the end of the 50,000 samples, indicating room for additional improvements if more samples were available. Without looking at a single label, testtime training online behaves as if we were training on the test set from the appearance of these plots.
Airplane  Bird  Car  Dog  Cat  Horse  Ship  Average  

Baseline  67.9  35.8  42.6  14.7  52.0  42.0  66.7  41.4 
Joint training  70.2  36.7  42.6  15.5  52.0  44.0  66.7  42.4 
Testtime training  70.2  39.2  42.6  21.6  54.7  46.0  77.8  45.2 
TTT online  70.2  39.2  42.6  22.4  54.7  46.0  77.8  45.4 
Method  Accuracy (%) 

Baseline  62.7 
Joint training  63.5 
Testtime training  63.8 
TTT online  64.3 
The ImageNet Video Classification (VID) dataset was developed by shankar2019systematic from the Video Detection dataset of ILSVRC2015 (ILSVRC15)
, to demonstrate how deep learning models for object recognition trained on ImageNet (still images) fail to adapt well to video frames
^{4}^{4}4 The VID dataset contains 1109 sets of video frames; each set forms a short video clip where all the frames are similar to an anchor frame. Our results are reported on the anchor frames. To map the 1000 ImageNet classes to the 30 VID classes, we use the maxconversion function in shankar2019systematic. . Without any modification for videos, we apply our method to VID on top of the same ImageNet model as in the previous subsection. Our results are shown in Table 2. Again, we use accuracy instead of error because the baseline performance is poor.In addition, we take the seven classes in VID that overlap with CIFAR10, and rescale those video frames to the size of CIFAR10 images, as a new test set for the model trained on CIFAR10 in the previous subsection. Again, we apply our method to this dataset without any modification. Our results are shown in Table 1, with a breakdown for each class. Noticing that testtime training does not improve on the airplane class, we inspect some airplane samples, and observe that most of them have black margins on the sides, which provide a trivial hint for the rotation prediction task. In addition, for airplanes captued in the sky, it is often impossible even for humans to tell if an image is rotated. This shows that our method requires the selfsupervised task to be both well defined and nontrivial on the new domain.
CIFAR10.1 (recht2018cifar) is a new test set of size 2000 modeled after CIFAR10, with the exact same classes and image dimensionality, following the dataset creation process documented by the original CIFAR10 paper as closely as possible. The purpose is to investigate the distribution shifts present between the two test sets, and the effect on object recognition. All models tested by the authors suffer a large performance drop on CIFAR10.1 comparing to CIFAR10, even though there is no human noticable difference, and both have the same human accuracy. This demonstrates how insidious and ubiquitous distribution shifts are, even when researchers strive to minimize them.
Method  Error (%) 

Baseline  17.4 
Joint training  16.7 
Testtime training  15.9 
The distribution shifts from CIFAR10 to CIFAR10.1 pose an extremely difficult problem, and nobody has made a successful attempt to improve the performance of an existing model on this new test set, probably because 1) Researchers cannot even identify the distribution shifts, let alone describe them with mathematics. 2) The samples in CIFAR10.1 are only revealed at testtime; and even if revealed during training, the distribution shifts are too subtle, and the sample size is too small, for domain adaptation algorithms
(recht2018cifar).On the original CIFAR10 test set, our baseline has error 8.9%, and with joint training 8.1%; comparing to the first two rows of Table 3, both suffer the typical performance drop (by a factor of two). Testtime training yields an improvement of 0.8% (relative improvement of 4.8%) over joint training. We recognize that this improvement is still small comparing to the performance drop, but see it as an encouraging first step for this very difficult problem.
This section contains our preliminary study of when and why testtime training is expected to work. For convex models, we prove that positive gradient correlation between the loss functions leads to better performance on the main task after testtime training. Equipped with this insight, we then empirically demonstrate that gradient correlation governs the success of testtime training on the deep learning model discussed in Section 3.
Before stating our main theoretical result, we first illustrate the general intuition with a toy model. Consider a regression problem where denotes the input, denotes the label, and the objective is the square loss for a prediction . Consider a two layer linear network parametrized by and (where stands for the hidden dimension). The prediction according to this model is , and the main task loss is
(4) 
In addition, consider a selfsupervised regression task that also uses the square loss and automatically generates a label for . Let the selfsupervised head be parameterized by . Then the selfsupervised task loss is
(5) 
Now we apply testtime training to update the shared feature extractor by one step of gradient descent on , which we can compute with known. This gives us
(6) 
where is the updated matrix and is the learning rate. If we set where
(7) 
then with some simple algebra, it is easy to see that the main task loss . Concretely, testtime training drives the main task loss down to zero with a single gradient step for a carefully chosen learning rate. In practice, this learning rate is unknown since it depends on the unknown . However, since our model is convex, as long as is positive, it suffices to set to be a small positive constant (see Lemma 1). If , one sufficient condition for to be positive (when neither loss is zero) is to have
(8) 
For our toy model, both parts in Equation 8 have an intuition interpretation. The first part says that the mistakes should be correlated, in the sense that predictions from both tasks are mistaken in the same direction. The second part, , says that the decision boundaries on the feature space should be correlated. In fact, these two parts hold iff (see Lemma 2). To summarize, if the gradients have positive correlation, testtime training is guaranteed to reduce the main task loss. Our main theoretical result extends this to general smooth and convex loss functions.
Let denote the main task loss on test instance with parameters , and the selfsupervised task loss that only depends on ^{5}^{5}5Because the main task branch and the selfsupervised branch are kept fixed at testtime, we do not explicitly describe their parameters, and include them implicitly in the loss functions.. Assume that for all , is differentiable, convex and smooth in , and both for all . With a fixed learning rate , for every such that
(9) 
we have
(10) 
where i.e. testtime training with one step of gradient descent.
. The blue lines and bands are the best linear fits and the 99% confidence intervals. The linear correlation coefficients are
and respectively, indicating strong positive correlation between the two quantities, as suggested by Theorem 1.The proof is in Section B.2. Theorem 1 reveals gradient correlation as a determining factor of the success of testtime training in the smooth and convex case. In Figure 3, we empirically show that our insight also holds for nonconvex loss functions, on the deep learning model and across the diverse set of corruptions considered in Section 3; stronger gradient correlation clearly indicates higher performance improvements over the baseline.
Our work has been influenced by the successes and limitations of many related fields. Each of these fields contains maybe hundreds of interesting works, which we unfortunately do not have enough time and space to acknowledge in this draft. We apologize for the omissions and are happy to include additional citations upon request.
Our idea is inspired by glasner2009super; shocher2018zero
, which learns to perform superresolution only on a single image, by learning to recover the original image itself from its downsampled version. In addition,
jain2011onlineimproves violajones for face detection, by using instances of faces with high confidence in an image to bootstrap the more difficult ones in the same image. The online version of our algorithm is inspired by
mullapudi2018online, which makes video segmentation more efficient by using a student model that learns online from a teacher model. The idea of online updates has also been used in kalal2011tracking for tracking and detection. zhu2019neural, a concurrent work in echocardiography, improves the deep learning model that tracks myocardial motion and cardiac blood flow with sequential updates. Lastly, we share the philosophy of transductive learning (vapnik2013nature; gammerman1998learning), but have little in common with their classical algorithms e.g. nearest neighbors.studies the robust risk: where is some loss function, and is the set of perturbations; one popular choice of is the ball, for . Works in this area rely heavily on the mathematical structure of , which might not be realistic models of perturbations in the real world. The most popular algorithms use either robust optimization (sinha2017certifying; raghunathan2018certified; wong2017provable; croce2018provable), or input transformations (guo2017countering; song2017pixeldefend), both can be seen as smoothing the decision boundary. This establishes a theoretical tradeoff between accuracy and robustness (zhang2019theoretically); intuitively, the more diverse is, the less effective this oneboundaryfitsall approach can be for a particular subset of . kang2019transfer show that empirically, robustness for one might not transfer to another, and training on the ball actually hurts robustness on the ball.
studies the effect of corruptions, perturbations, outofdistribution examples, and realworld distribution shifts (hendrycks2019improving; hendrycks2019using; hendrycks2018using; hendrycks2016baseline). geirhos2018generalisation show that training on images corrupted by Gaussian noise makes deep learning models recover superhuman performance on this particular noise type, but cannot improve performance on another those corrupted by another noise type e.g. saltandpepper noise.
(a.k.a. transfer learning) studies the problem of distribution shift (from to ), when unlabled data from is available at trainingtime (tzeng2017adversarial; ganin2016domain; gong2012geodesic; long2016unsupervised; chen2018adversarial; chen2011co; hoffman2017cycada; csurka2017domain; long2015learning). We are inspired by this very active and successful community, especially (sun2019uda), and believe that progress in this community can motivate new algorithms in the framework of testtime learning. Our update rule can be viewed as performing onesample unsupervised domain adaptation on the fly ^{6}^{6}6Note that typical unsupervised domain adaptation algorithms,such as those based on distributional discrepancy, adversarial learning, cotraining and generative modeling, might not work in our framework because the concept of a target distribution, which has been so deeply rooted and heavily relied on, becomes illdefined when there is only one sample from the target domain.. On the other hand, testtime learning comes from realizing the limitations of the unsupervised domain adaptation setting, that outside of the specific target distribution where data is available for training, generalization is still elusive. Previous works make the source and target distributions broader with multiple and evolving sources and targets without fundamentally address this problem (hoffman2018algorithms; hoffman2012discovering; hoffman2014continuous).
studies how to create labels from the data, by designing ingenious tasks that contain semantic information without human annotations, such as context prediction (doersch2015unsupervised), solving jigsaw puzzles (noroozi2016unsupervised)
(larsson2017colorproxy; zhang2016colorful), noise prediction (bojanowski2017unsupervised), and feature clustering (caron2018deep). Selfsupervision has also been used on videos (Wang_UnsupICCV2015; CVPR2019_CycleTime). hendrycks2019using proposes that jointly training a main task and a selfsupervised task (our joint training baseline in section 3) can improve robustness of the main task. The same idea is used in fewshot learning (su2019boosting), and domain generalization (carlucci2019domain).studies when a meta distribution generates multiple environment distributions, some of which are available during training (source), while others are used for testing (target) (li2018deep; shankar2018generalizing; muandet2013domain; balaji2018metareg; ghifary2015domain; motiian2017unified; li2017deeper; gan2016learning). With only a few environments, information on the meta distribution is often too scarce to be helpful, and with many environments, we are back to the i.i.d. setting where each environment can be seen as a sample, and a strong baseline is to simply train on all the environments (li2019episodic). The setting of domain generalization is limited by the inherent tradeoff between specificity and generality of a fixed decision boundary, and the fact that generalization is again elusive outside of the meta distribution i.e. the actual .
Continual learning (a.k.a. learning without forgetting) studies when a model is made to learn a sequence of tasks, and not forget about the task at the beginning (li2017learning; lopez2017gradient; kirkpatrick2017overcoming; santoro2016meta). In comparison, testtime learning does not care at all about forgetting (and might even encourage it). Few (one)shot learning studies extremely small training sets (maybe for some categories) (snell2017prototypical; vinyals2016matching; fei2006one; ravi2016optimization; li2017meta; finn2017model; gidaris2018dynamic). Our update rule can be viewed as performing oneshot selfsupervised learning and can potentially be improved by progress in fewshot learning.
(a.k.a. online optimization) is a wellstudied area of learning theory (shalev2012online; hazan2016introduction). The basic setting repeats the following: receive , predict , receive from a worstcase oracle and learn. Final performance is evaluated using the regret, colloquially how much worse than the best fixed model in hindsight. It is easy to see how our setting differs, even for the online version. We learn before predicting , but never receive any that is evaluated on, thus do not need to consider the worstcase orcale or the regret.
This paper took a long time to develop, and benefitted from conversations with many of our colleagues. We would especially like to thank Ben Recht and his students Ludwig Schmidt, Vaishaal Shanker and Becca Roelofs; Deva Ramanan and his student Achal Dave; and Armin Askari, Allan Jabri, Ashish Kumar, Angjoo Kanazawa and Jitendra Malik.
In LABEL:intro, we claim that in traditional supervised learning gives a fixed decision boundary, while our gives a variable decision boundary. Here we explain this claim from a theoretical perspective.
Denote the input space and output space . A decision boundary is simply a mapping . Let be a model class e.g . Now consider a family of parametrized functions , where . In the context of deep learning, is the neural network architecture and contains the parameters. We say that is a fixed decision boundary w.r.t. and if there exists s.t. for every , and a variable decision boundary if for every , there exists s.t. . Note how selection of can depend on for a variable decision boundary, and cannot for a fixed one. It is then trivial to verify that our claim is true under those definitions.
A critical reader might say that with an arbitrarily large model class, can’t every decision boundary be fixed? Yes, but this is not the end of the story. Let , and consider the enormous model class which is capable of representing all possible mappings between and . Let simply be the mapping represented by . A variable decision boundary w.r.t. and then indeed must be a fixed decision boundary w.r.t. and , but we would like to note two things. First, without any prior knowledge, generalization in is impossible with any finite amount of training data; reasoning about and is most likely not productive from an algorithmic point of view, and the concept of a variable decision boundary is to avoid such reasoning. Second,selecting based on for a variable decision boundary can be thought of as “training” on all points ; however, “training” only happens when necessary, for the that it actually encounters.
Altogether, the concept of a variable decison boundary is different from what can be described by traditional learning theory. A more in depth discussion is beyond the scope of this paper.
Here we prove the theoretical results covered in Section 4.
The following setting applies to the two lemmas; this is simply the setting of our toy problem, reproduced here for ease of reference. Consider a two layer linear network parametrized by (shared) and (fixed) for the two heads, respectively. Denote the input and the labels for the two tasks, respectively. For the main task loss
(11) 
and the selfsupervised task loss
(12) 
testtime learning yields an updated matrix
(13) 
where is the learning rate.
Following the exposition of section 4, denote
(14) 
Assume for some . Then for any , we are guaranteed an improvement on the main loss i.e. .
From the exposition of section 4, we know that
which can also be derived from simple algebra. Then by convexity, we have
(15)  
(16)  
(17)  
(18) 
where the last inequality uses the assumption that , which holds because .
Define i.e. the Frobenious inner product, then
(19) 
By simple algebra,
(20)  
(21)  
(22) 
which has the same sign as .
For any , by smoothness and convexity,
(23)  
(24) 
Denote
Then Equation 23 becomes
(25) 
And by our assumptions on the gradient norm and gradient inner product,
(26) 
Because we cannot observe in practice, we instead use a fixed learning rate , as stated in Theorem 1. Now we argue that this fixed learning rate still improves performance on the main task.
By our assumptions, , so . Denote , then by convexity of ,
(27)  
(28)  
(29) 
Combining with Equation 26, we have
(30)  
(31) 
Since , we have shown that
(32) 
At testtime, our method is batch_sizenumber_of_iterations times slower than regular testing, which only performs a single forward pass for each sample. As the first work on testtime learning, this paper is not as concerned about computational efficiency as improving robustness, but here we provide two potential solutions that might be useful, but have not been thoroughly verified. The first is to use the thresholding trick on , introduced as a solution for the small batches problem in section 2. For the models considered in our experiments, roughly of the test instances fall below the threshold, so testtime learning can only be performed on the other without much effect on performance, because those contain most of the samples with wrong predictions. The second is to reduce the number_of_iterations of testtime updates. For the online version, the number_of_iterations is already 1, so there is nothing to do. For the standard version, we have done some preliminary experiments setting number_of_iterations to 1 (instead of 10) and learning rate to 0.01 (instead of 0.001), and observing results almost as good as the standard hyperparameter setting. A more in depth discussion on efficiency is left for future works, which might, during training, explicitly make the model amenable to fast updates.
For table aethetics, we use the following abbreviations: B for baseline, JT for joint training, TTT for testtime training standard version, and TTTO for testtime training online version.
We have also abbreviated the names of the corruptions. The full names are, in order: original test set, Gaussian noise, shot noise, impulse noise, defocus blur, glass blue, motion blur, zoom blur, snow, frost, fog, brightness, contrast, elastic transformation, pixelation, and JPEG compression.
As discussed in section 3, Batch Normalization (BN) is ineffective for small batches, which are the inputs for testtime training (both standard and online version) since there is only one sample available when forming each batch; therefore, our main results are based on a ResNet using Group Normalization (GN). Here we provide results of our method on CIFAR10C level 5, with a ResNet using Batch Normalization (BN). These results are meant to be merely a point of reference for the curious readers, instead of our technical contributions.
In the early stage of this project, we have experimented with two potential solutions to the small batches problem with BN. The naive solution is to fix the BN layers during testtime training. but this diminishes the performance gains since there are fewer shared parameters. The better solution, adopted for the results below, is hard example mining: instead of updating on all inputs, we only update on inputs that incur large selfsupervised task loss , where the large improvements might counter the negative effects of inaccurate statistics.
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  7.9  63.9  58.8  64.3  46.3  54.6  41.6  45.9  31.9  44.0  37.5  13.0  69.2  33.8  61.4  31.7 
JT  7.5  70.7  65.6  67.2  43.1  55.4  40.9  42.7  30.3  44.5  42.5  12.7  58.6  30.7  62.6  31.9 
TTT  7.9  47.9  45.2  54.8  27.6  50.4  31.5  30.9  28.7  34.3  26.9  12.6  35.2  30.6  51.2  31.3 
Testtime training (standard version) is still very effective with BN. In fact, some of the improvements are quite dramatic, such as on contrast (34%), defocus blue (18%) and Gaussian noise (22% comparing to jointtraining, and 16% comparing to the baseline). Performance on the original distribution is still almost the same, and the original error with BN is in fact slightly lower than with GN, and takes half as many epochs to converge.
We did not further experiment with BN because of two reasons: 1) The online version does not work with BN, because the problem with inaccurate batch statistics is exacerbated when training online for many (e.g. 10000) steps. 2) The baseline error for almost every corruption type is significantly higher with BN than with GN. Although unrelated to the main idea of our paper, we make the interesting note that GN significantly improves model robustness.
As discussed in subsection 3.1, hendrycks2019benchmarking point to Adversarial Logit Pairing (ALP) (kannan2018adversarial) as an effective method for improving model robustness to corruptions and perturbations, even though it was designed to defend against adversarial attacks. We take ALP as an additional baseline on all benchmarks based on CIFAR10 (using GN), following the training procedure in kannan2018adversarial and their recommended hyperparameters. The implementation of the adversarial attack comes from the codebase of ding2019advertorch. We did not run ALP on ImageNet because the two papers we reference for this method, kannan2018adversarial and hendrycks2019benchmarking, did not run on ImageNet or make any claim or recommendation.
The following two tables correspond to the bar plots in section 3.
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  8.9  50.5  47.2  56.1  23.7  51.7  24.3  26.3  25.6  34.4  28.1  13.5  25.0  27.4  55.8  29.8 
JT  8.1  49.4  45.3  53.4  24.2  48.5  24.8  26.4  25.0  32.5  27.5  12.6  25.3  24.0  51.6  28.7 
TTT  7.9  45.6  41.8  50.0  21.8  46.1  23.0  23.9  23.9  30.0  25.1  12.2  23.9  22.6  47.2  27.2 
TTTO  8.2  25.8  22.6  30.6  14.6  34.4  18.3  17.1  20.0  18.0  16.9  11.2  15.6  21.6  18.1  21.2 
ALP  16.5  22.7  22.9  28.3  25.0  25.6  27.4  23.1  25.2  27.2  64.8  21.7  73.6  23.0  20.2  18.9 
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  68.9  1.3  2.0  1.3  7.5  6.6  11.8  16.2  15.7  14.9  15.3  43.9  9.7  16.5  15.3  23.4 
JT  69.1  2.1  3.1  2.1  8.7  6.7  12.3  16.0  15.3  15.8  17.0  45.3  11.0  18.4  19.7  22.9 
TTT  69.0  3.1  4.5  3.5  10.1  6.8  13.5  18.5  17.1  17.9  20.0  47.0  14.4  20.9  22.8  25.3 
TTTO  68.8  26.3  28.6  26.9  23.7  6.6  28.7  33.4  35.6  18.7  47.6  58.3  35.3  44.3  47.8  44.3 
The following bar plots and tables are on levels 14 of CIFAR10C. The original distribution is the same for all levels, so are our results on the original distribution.
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  8.9  46.4  39.2  44.8  15.3  52.5  19.1  20.5  21.3  26.9  13.3  10.5  13.7  20.8  35.3  26.9 
JT  8.1  45.0  38.3  42.2  16.4  50.2  20.7  20.5  21.1  25.4  14.1  10.0  14.7  19.0  33.2  25.1 
TTT  7.9  41.5  35.4  39.8  15.0  47.8  19.1  18.4  20.1  24.0  13.5  10.0  14.1  17.7  29.4  24.5 
TTTO  8.2  22.9  20.0  23.9  11.2  35.1  15.6  13.8  18.6  15.9  12.3  9.7  11.9  16.7  13.6  19.8 
ALP  16.5  21.3  20.5  24.5  20.7  25.9  23.7  21.4  24.2  23.9  42.2  17.5  53.7  22.1  19.1  18.5 
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  8.9  42.2  35.1  30.7  12.2  41.7  18.6  17.5  19.0  25.3  10.8  9.7  11.6  15.3  21.7  24.6 
JT  8.1  40.2  34.4  29.9  12.2  37.9  20.8  17.3  18.4  25.0  11.4  9.2  12.0  15.2  20.8  22.8 
TTT  7.9  37.2  31.6  28.6  11.5  35.8  19.1  15.8  17.8  23.3  11.0  9.1  11.6  14.3  18.9  22.3 
TTTO  8.2  21.3  17.7  17.9  9.0  23.4  15.3  12.5  16.4  15.8  10.9  9.0  10.7  12.8  12.2  18.7 
ALP  16.5  20.0  19.3  20.5  19.2  21.2  24.0  20.5  20.9  24.2  30.1  16.6  39.6  20.9  17.8  18.0 
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  8.9  31.7  22.6  24.3  9.9  42.6  14.9  14.7  21.7  18.4  9.8  9.1  10.0  13.1  17.1  22.4 
JT  8.1  31.0  22.6  23.4  9.1  39.2  16.4  14.2  21.2  17.5  9.4  8.3  10.6  12.8  15.9  20.5 
TTT  7.9  28.8  20.7  23.0  9.0  36.6  15.4  13.1  20.2  16.9  9.2  8.3  10.2  12.5  14.8  19.7 
TTTO  8.2  16.8  13.8  15.5  8.5  23.4  13.3  11.5  16.8  12.7  9.4  8.4  9.7  12.4  11.5  17.0 
ALP  16.5  18.0  17.2  19.0  17.8  20.7  21.2  19.3  19.0  20.1  22.4  16.3  29.2  20.3  17.4  17.8 
orig  gauss  shot  impul  defoc  glass  motn  zoom  snow  frost  fog  brit  contr  elast  pixel  jpeg  
B  8.9  21.7  17.1  17.0  9.0  44.0  12.1  13.9  14.3  13.4  9.2  8.9  9.0  13.2  12.0  17.3 
JT  8.1  20.4  16.6  16.9  8.2  40.5  12.2  13.0  13.1  12.3  8.4  8.1  8.5  12.9  11.3  15.9 
TTT  7.9  19.1  15.8  16.5  8.0  37.9  11.7  12.2  12.8  11.9  8.2  8.0  8.3  12.6  11.1  15.5 
TTTO  8.2  13.8  11.9  12.2  8.5  24.4  10.5  11.5  12.4  10.7  8.5  8.3  8.6  12.4  10.7  14.4 
ALP  17.0  16.8  17.6  16.8  20.9  18.7  19.0  17.3  17.5  17.4  16.1  18.4  20.4  17.0  17.2 