1 Introduction
Despite the huge computational costs associated with training large neural models, the set of optimization algorithms used to train them has largely been restricted to simple update functions mapping from gradients to parameter updates (e.g. stochastic gradient descent
(robbins1951stochastic), Adam (kingma2014adam), or RMSProp
(tieleman2012lecture)). These algorithms typically depend on a small number of handdesigned features and parameters. However, the last decade in machine learning research has repeatedly seen small, handdesigned models outperformed by parameterized models (such as neural networks) trained to purpose on large amounts of data
(lecun2015deep). Thus, a promising direction to improve training performance and reduce costs is to replace handdesigned optimizers with more expressive learned optimizers, trained on problems similar to those encountered in practice.Learned optimizers specify parameter update rules using a flexible parametric form and learn the parameters of this function from a “dataset” of optimization tasks—a procedure typically referred to as metatraining or metalearning (andrychowicz2016learning; finn2017model; hochreiter2001learning). Learned optimizers represent a path towards improved optimizer performance, and possess the ability to target different objectives (e.g. test loss (metz2019understanding), or robustness (metz2019using)), as well as the ability to leverage new features useful for optimization. Despite being an active area of research (andrychowicz2016learning; wichrowska2017learned; chen2020training; metz2020using; metz2021training; almeida2021generalizable; zheng2022symbolic), they are not yet commonly used in practice. Several challenges have limited the widespread application of learned optimizers: they are typically difficult to metatrain on a task family of interest, they can require significant memory and compute overhead when applied, and they often generalize less well to novel tasks than handdesigned optimizers.
In this work, we aim to precisely study the limitations of learned optimizers, and address these limitations via a novel learned optimizer architecture. In particular, we explore and quantify the tradeoffs in terms of memory, compute, and generalization across a variety of optimizers, including handdesigned optimizers (bottou2010large; tieleman2012lecture; kingma2014adam), learned hyperparameter controllers (daniel2016learning; hansen2016using; xu2017reinforcement; xu2019learning; almeida2021generalizable), and neural network based learned optimizers (andrychowicz2016learning; wichrowska2017learned; metz2020tasks), with the goal of understanding how choices in optimizer design affect performance and usability. Our core contributions are:

We present a thorough empirical characterization of the tradeoffs inherent in different learned optimizer architectures and features, and a comparison of these learned optimizer architectures against their welltuned handdesigned counterparts.

We develop a new perparameter learned optimizer architecture, on the Pareto frontier with regards to performance, computational cost, and memory usage.

We provide an open source implementation written in JAX
(jax2018github) to enable future research and reliable benchmarking^{1}^{1}1http://github.com/google/learned_optimization.
2 Optimizers
In this section we review and formalize the class of optimizers that are commonly used in training neural networks. We then define metalearned optimizers, and highlight differences with standard optimization approaches. We describe several examples of both common, standard neural network optimizers as well as classes of learned optimizers, all of which are investigated in this paper.
2.1 Gradient Based Optimizers
Most firstorder optimizers^{2}^{2}2Optimizers using only gradient information and not higher order derivatives. used to train neural networks can be viewed as functions mapping from a history of gradients to parameter updates. We will assume the optimizer acts on an underlying model with parameters , while maintaining an internal optimizer state . The parameters may be, for example, neural network weights, whereas the optimizer state includes quantities such as the accumulated momentum values in momentumaccelerated optimizers (polyak1964some; nesterov1983method). The optimizer acts by ingesting gradients
(which arise from a specified loss function and a dataset) and outputting updated parameters
.More precisely, we define an optimizer as a pair of functions. The first, which we call the Update function, computes new parameter values and state from stochastic gradients, the current parameter value, and the current optimizer state. The second, which we refer to as the Init function, initializes the optimizer state. Both functions have hyperparameters , such as the learning rate or the initial value of accumulators. Thus, we write the optimizer as:
(1)  
(2)  
(3) 
Optimizers can benefit from problem information beyond stochastic gradients, parameter values, and losses. For instance, methods that utilize line searches (le2011optimization), validation loss (xu2019learning; metz2020tasks), or the structure of the underlying computation graph (martens2015optimizing) all rely on additional information. However, the present work is restricted to optimizers which minimize training loss by minibatch stochastic gradient descent.
Firstorder handdesigned optimizers: Handdesigned optimizers typically have a simple form, and a small number of hyperparameters (), which are tuned by random search (bergstra2012random), Bayesian optimization (snoek2012practical), or other lowdimensional blackbox optimization techniques (bergstra2011algorithms; golovin2017google; optuna_2019). They mostly have low overhead in terms of compute and memory usage. For instance, Adam (kingma2014adam) has two accumulators, and SGD has none^{3}^{3}3Some handdesigned methods, such as Shampoo (gupta2018shampoo; anil2020second), involve considerable compute overhead, but can make more progress per update step..
In this work, we experiment with four kinds of handdesigned optimizers: SGD (robbins1951stochastic; bottou2010large), SGDM (SGD with momentum) (polyak1964some), Adam (kingma2014adam), and Nesterov accelerated Adam (dozat2016incorporating) with AdamW (loshchilov2017decoupled) style weight decay (NAdamW). For SGD, SGDM, and Adam, we search over learning rates every half order of magnitude between and . For NAdamW we use random search with many more hyperparameter configurations per task (1000) and a much larger search space over hyperparameters controlling: first and second momentum time scales, weight decays, and learning rate schedules. Past work has shown this to be a powerful search space (metz2020using) and, in our work, this dramatically outperforms learning rate search. See Appendix F for more details. Many other handdesigned optimizer architectures have been proposed (ruder2016overview; zeiler2012adadelta; reddi2018adaptive; you2019large; liu2019variance), but their practical benefits are small in most situations (schmidt2020descending).
Factorized optimizers: In some settings, having even one additional copy of parameters to use for accumulators is too costly. Recent methods such as AdaFactor (shazeer2018adafactor) and SM3 (anil2019memory) factorize the weights and accumulate statistics using a sublinear amount of memory with respect to parameters. This style of accumulator has not been explored in the context of learned optimizers, but we will show this provides an effective way to improve performance without meaningfully increasing memory overhead (§4.2).
2.2 MetaLearned Optimizers
The metalearning problem for optimizers consists of tuning the hyperparameters of a class of parameterized optimizers with respect to some loss function^{4}^{4}4Common choices of loss function for the metaoptimization problem include the average training loss across inner optimizer iterations, the average validation loss, as well as the terminal train/validation loss.. How is this different from the hyperparameter tuning discussed in the last subsection? While there is no formal difference between the hyperparameter selection problem and training learned optimizers, the learned optimizers we consider in this subsection universally include a blackbox component with a (comparatively) large number of parameters (in our case, always parameterized by a neural network). This large number of parameters limits the effectiveness of traditional hyperparameter tuning methods such as random search, and so we focus on local optimization methods (including firstorder gradientbased methods as well as zerothorder methods) which are able to perform better in high dimensional optimization. Below, we outline several types of learned optimizer.
Hyperparameter controllers: Optimizing the hyperparameters of a handdesigned optimizer over a broad set of tasks may limit the performance within each specific task. These handdesigned optimizers can be augmented with a metalearned controller, often parameterized as a neural network, that modulates the hyperparameters of the optimizer over the course of training to yield better performance in each particular problem (daniel2016learning; hansen2016using; xu2017reinforcement; xu2019learning; almeida2021generalizable)
. This controller takes in summary statistics (e.g. gradient norms, loss values), and can either globally assign identical hyperparameters to all layers, or operate perlayer. One benefit of hyperparameter controllers is that their perparameter compute overhead is small, as the majority of the computation only needs to be performed once per tensor, or per network rather than scaling with the number of parameters.
We introduce a novel hyperparameter controller architecture which we refer to as nn_adam. This architecture consists of an LSTMbased (hochreiter1997long) hyperparameter controller, operating on features derived from each tensor independently, and outputting Adam hyperparameters consisting of a pertensor learning rate, , , and
. For features, this model uses normalized values derived from the first moment of gradients, the second moment, and the tensor shape. We refer to this optimizer as nn_adam. See Appendix
C for details.Perparameter learned optimizers: Perparameter learned optimizers (andrychowicz2016learning) learn a function, often parameterized by a neural network, which is applied to each parameter independently, though sometimes with normalization performed across parameters in a tensor (metz2019understanding).
Multilevel approaches: In an effort to add additional capacity to a learned optimizer while retaining good computational complexity with respect to number of parameters, hierarchical models have been proposed (wichrowska2017learned; metz2020tasks). These models leverage up to three levels of hierarchy: a global controller, which sends and receives activations from a perlayer (or pertensor) controller, which finally sends and receives activations to a perparameter optimizer.
New perparameter learned optimizer: Finally, we introduce a new learned optimizer architecture (which we call small_fc_lopt) that combines architectural features of perparameter and factorized optimizers, and outperforms both. This architecture will be directly motivated by the tradeoffs among compute, memory, performance, and generalization shown in §4. Our learned optimizer incorporates an extremely tiny, perparameter, MLPbased learned optimizer similar to that used in metz2019understanding. This 197 parameter MLP takes as input 39 input features with 4 perparameter accumulators (3 momenta at different metalearned timescales, and 1 gradient second moment accumulator), and 3 AdaFactor accumulators also at 3 different metalearned timescales. These features are passed into a 1 hidden layer, 4 hidden unit MLP. See Appendix A for additional details.
3 Training and MetaTraining
In the previous section we specified possible architectures for standard optimizers (with a small number of hyperparameters) as well as learned optimizers. Both learned and hand designed optimizers are iteratively applied to some parameterized model, paired with a loss function and (possibly) a dataset. We refer to this collection as a task. We use the loss obtained by an optimizer on these tasks to select hyperparameters (in the case of hand designed optimizers), and to optimize the learned optimizer weights. In this section, we discuss the tasks used, the measurement of performance by which we can compare optimizers (metaloss), and discuss how the weights of the learned optimizers are computed (which we refer to as metaoptimization).
3.1 Tasks
Throughout this paper, the tasks of interest are neural network training problems. Each task is specified via three quantities. The first is the underlying model architecture and the initial parameter values (or a procedure for initializing the model parameters). The second is a function to generate batches of data, and the third is a loss function. While a more abstract definition of a task could cover more general optimization problems, we aim to address neural network training as a setting and believe generalizations are (in most cases) straightforward. In this work we consider solely supervised learning. We also consider only a single function to generate a batch of data, though this could easily be extended to multiple functions corresponding to, for example, train and validation loss. As discussed in the next subsection, we focus solely on training loss for simplicity.
We primarily consider two tasks in this paper: A 2 hidden layer MLP with 128 hidden units and ReLU activations on Fashion MNIST
(xiao2017/online), and a 3 layer convolutional network on CIFAR10
(krizhevsky2009cifar). See Appendix B for more details and implementations. In Section 4.5, to assess generalization, we additionally evaluate optimizers metatrained on these two tasks on three additional problems.The tradeoffs inherent in optimizer design are task dependent (see §4.3), and the perparameter compute and memory requirements of the optimizer must be balanced against the perparameter compute and memory requirements of the task. These latter requirements are a function of parameter sharing, sparsity in parameter use, model architecture, and minibatch size (compute overhead per parameter can be made arbitrarily small by increasing the minibatch size). The two problems we consider have different compute and memory requirements and were therefore chosen as reasonable baseline tasks providing insight into optimizer performance at different points in the space of possible tasks. Moreover, these (relatively small) tasks were chosen to enable the largescale evaluation and comparisons done in this paper.
3.2 MetaLoss and MetaOptimization
To evaluate an optimizer, we apply our optimizer for 2,000 iterations and evaluate the average loss obtained over the course of training. In this work, we exclusively focus on training loss performance as opposed to validation loss. This is to decouple optimization performance tradeoffs from the implicit regularization effects of learned optimizers shown in metz2019understanding; metz2020tasks.
We train optimizers targeting the two tasks described above by randomly sampling from a fixed search space for handdesigned optimizers, and Persistent Evolution Strategies (PES) (pmlrv139vicol21a) to train the learned optimizers. See Appendix D.2 for details. To minimize confounds, we focus on the scenario where metatrain matches metatest (i.e. the tasks presented during training and testing are the same), and examine the overhead and performance tradeoffs inherent in a learned optimizer metatrained to optimize a single task.
4 Exploring tradeoffs across optimizer families
In this section we experimentally explore tradeoffs when designing learned optimizers. In §4.1, we show memory and time tradeoffs for various handdesigned and learned optimizers. In §4.2, we focus on perparameter learned optimizers and explore the impact of both feature choice and size of the learned optimizer. In §4.3, we discuss computational costs of running learned optimizers as a function of task features (such as the number of network weights). In §4.4, we tie all these evaluations together and show wallclock time performance for the different tasks. In §4.5, we explore metageneralization—applying optimizers to a task different from those in which they were metatrained.
4.1 Compute, memory, performance tradeoffs for learned and handdesigned optimizers
We characterize the tradeoffs between performance, memory overhead, and compute overhead for both handdesigned and learned optimizers. The optimizers examined here consist of the handdesigned optimizers (SGD, Adam, and NAdamW), the MLP optimizer from metz2019understanding (fc_lopt), the hierarchical optimizer from metz2020tasks, (rnn_fc_lopt), a hyperparameter controller described in §2.2 (nn_adam), and the perparameter optimizer proposed in §2.2 (small_fc_lopt).
Metatraining curves are shown in Figure 1ac. We additionally show final performance of the fully trained learned optimizer as a function of compute time per step (Figure 1bd) and with respect to memory usage (Figure 1e). We find learned optimizers can achieve lower metaloss than baselines, but at the cost of more compute time and memory usage. For the MLP task, the cost of the learned optimizer far outstrips the cost of a handdesigned optimizer (taking > 5x more time in the case of rnn_fc_lopt). For the CIFAR10 ConvNet however, the compute overhead is small relative to overall compute, due to the much larger perparameter cost for computing gradients for a ConvNet.
4.2 Design choices for the MLP learned optimizer
To guide learned optimizer design, we explore the memory, time, and performance tradeoffs associated with different choices of input features for an MLP learned optimizer. The dominant source of memory overhead is the inclusion of additional perparameter accumulators. We explore two kinds of perparameter accumulators that optimizers use—the exponential moving average of the gradient’s first moment, and second moment as used in momentum and RMSProp (tieleman2012lecture) respectively. Unlike existing optimizers, the learned optimizers we explore accumulate these statistics over multiple timescales^{5}^{5}5This is similar to what is done in the AggMo (lucas2018aggregated) handdesigned optimizer.. In addition to these, we also explore preconditioning features based on AdaFactor (shazeer2018adafactor) which use sublinear memory in parameter count.
We plot performance vs. compute cost, and performance vs. memory, in Figure 2. We plot baselines with hyperparameters found via random grid search (in black) (adam, sgd, sgdm, nadamw), and baseline learned optimizers (in gray) (fc_lopt (metz2019understanding), rnn_fc_lopt (metz2020tasks), nn_adam (§2.2). All other conditions consist of differently parameterized learned optimizers with different input feature. Each configuration is trained with PES (pmlrv139vicol21a) for 100k metatraining iterations. We test optimizers using only a single momentum accumulator with different decays (in green), multiple momentum accumulators (in yellow), a single second moment accumulator (in brown), multiple second moment accumulators (in pink), two accumulators with the same decay for first and second moments (in blue), multiple decay first and second moments (in purple), using AdaFactor features with and without additional momentum accumulators (in red), using only gradient features (in orange), and finally the union of all features (in gray). See Appendix D.3 for more experimental details.
We find the general trend that providing more features to a learned optimizer leads to better performance. However, including more accumulators increases the computational and memory overhead of using these optimizers. AdaFactor features by themselves (adafact) use very little memory, but do not perform well. Combining a small number of momentum features with AdaFactor features (adafact_m_mult) recovers the performance of using second moment accumulators, without the need for second moment accumulators.
Finally, we explore varying the hidden size of the MLP (Figure 3). Using the same features as in small_fc_lpot (§2.2), we sweep the hidden size of the MLP from 2 to 256 units. For each width we perform a small hyperparameter search over metalearning rate selecting between ,, and take the best performing learning rate for each width. Surprisingly, an extremely narrow MLP is sufficient to outperform the best handdesigned baseline (NAdamW). Increasing width boosts performance, but performance improvements diminish.
The relationship between learned optimizer width and compute overhead depends heavily on implementation details. TPU and other accelerator hardware often have specialized matrix multiplication units that operate on fixed dimensional matrices (e.g. TPUv2 has 128x128 systolic arrays (norrie2020google)). For a naive implementation of the learned optimizer using matrix multiplication kernels on TPUv2, there are no significant speedups from shrinking the optimizer width below approximately 64 units. However, if matrix multiplication is expanded explicitly in terms of elementwise operations, then continued speedups can be achieved even for optimizer hidden state vectors of two units, achieving a nearly 2x speedup over the use of matrix multiplication primitives as different primitives and thus different pieces of hardware are used. Profiling suggests that even greater speedups should be possible using custom kernels (as are frequently written for handdesigned optimizers). In a sense, these tiny learned optimizers, with matrix multiplications expanded, blur the line between hand designed and learned optimizers as both implementations are a handful of element wise floating point operations.
4.3 Overhead of learned optimizers on different tasks
To explore the dependence on task identity, we measure the relative overhead of training with a learned optimizer compared to SGD for different widths and batch sizes of the Fashion MNIST MLP and CIFAR10 ConvNet. Results are shown in Figure 4. We find that, in all cases, increasing batch size lowers the overhead, as the cost to compute gradients increases but the cost of applying the optimizer remains constant. In the case of the ConvNets, increased the channel count of the target problem lowers the overhead for small_fc_lopt and nn_adam, but increases for fc_lopt and rnn_fc_lopt. For MLP target problems, increasing the target MLP’s hidden size increases overhead for all optimizers including Adam. The asymptotic scaling of this behavior is due to both the computational complexity, and memory bandwidth, of the underlying hardware.
Next, we explore overheads for some common, large scale models. Table 1 show results for ResNets and Transformers, trained on a single TPUv2 chip. Distributed training would allow us to split the optimizer computation across devices and thus achieve even lower optimizer overhead. See Appendix D.4 for further details.
Task  Params  SGD step time (ms)  LOpt step time (ms)  LOpt multipler 

ResNet18(BS=128)  11.7M  159  180  1.13 
ResNet50(BS=32)  25.6M  99.5  137.9  1.39 
Transformer(L=256,BS=16)  43.1M  91.2  132.4  1.45 
Transformer(L=512,BS=2)  43.1M  29.7  70.9  2.39 
with different word sequence lengths (L) and batch sizes (BS). All numbers are medians over 10 timings. Standard error is under the reported number of digits.
4.4 Performance with respect to wall clock time
In the previous sections, we measured the performance achieved by optimizers, and the computational overhead required to achieve that performance. In practice, one often cares most about the total wall clock time required to reach a given performance. Further, the optimal metaparameters change depending on the length of innertraining. To quantify the achievable performance as a function of wall clock time, we compare training trajectories for both our learned and handdesigned optimizers. We apply a learned optimizer metatrained for length 2k unrolls, and optimize hyperparameters of handdesigned optimizers to perform well on 1k, 2k, 3k, 4k, and 5k length unrolls. In Figure 5 we show the resulting performance for both the Fashion MNIST MLP task and CIFAR10 ConvNet tasks. Our learned optimizer is always faster with respect to step count. With respect to wallclock time, on the CIFAR10 ConvNet task we also see faster training, while for the Fashion MNIST MLP task, where the relative overhead of the learned optimizer is large, NAdamW performs best.
4.5 Metageneralization: optimizer performance on holdout tasks
One final tradeoff is the interaction between optimizer design choice and generalization performance. Generalization performance, in this context, refers to the ability of the optimizer to perform well at training a novel task, different from the task distribution used to metatrain the optimizer. To quantify this, we measure the performance of diverse optimizers metatrained on one task, but then used as the optimizer for a novel task.
We metatrain each of the different learned optimizers on either the Fashion MNIST MLP or the CIFAR10 ConvNet tasks. Over the course of metatraining we evaluate performance on 5 tasks:

The Fashion Mnist MLP described in 3.1.

The CIFAR10 ConvNet described in 3.1 using 16x16 images for computational reasons.

A three hidden layer MLP trained on 16x16 imagenet.

A CIFAR10 autoencoder trained with mean squared error loss.

An LSTM (hochreiter1997long)
language model trained on byte level LM1B
(DBLP:journals/corr/ChelbaMSGBK13).
See Appendix E for more details and implementations.
First, we train a number of learned optimizers of different types and select the checkpoint which performs best on the metatraining task, and evaluate their performance on different held out tasks. We show transfer performance when metatraining and metatesting on Fashion MNIST MLP and CIFAR10 ConvNet in Figure 6, and the remainder of the comparisons in Appendix E. While there is some correlation across learned optimizer architecture, in general we find poor metageneralization. Additionally, metageneralization seems to depend more strongly on details of the metatraining process than it does on learned optimizer architectural choices. This variation poses challenges when reporting results due to the cost of metatraining.
To explore variance in the metatraining process, we plot performance over the course of metatraining on both the targettask, and the held out tasks. We show these dynamics for two different learned optimizer seeds in Figure 6cd. In both cases, metatraining trajectories exhibit high variability. However, they also both show an initial phase of correlated performance improvement, culminating in better performance than the baselines for both the target and held out task, before the optimizer finally overfits to the targettask.
This type of metaoverfitting is not unique to learned optimizers and happens even when trying to transfer hand designed optimizers from task to task. To show this, we simulate metatraining on the Fashion MNIST MLP by randomly sampling subsets of parameters from the original NAdamw search space for different budgets. We then find the best performance on the metatraining task and apply the best hyperparameters to the metatest task. We show metatrain vs metatest performance for different budgets in Figure 7
. We see signs of metaoverfitting for some tasks, such as the ImageNet MLP and the MLP Autoencoder. For the others, CIFAR10 ConvNet and LSTM, we continue to see correlation between metatrain and metatest.
5 Related Work
Originally proposed in bengio1992optimization; runarsson2000evolution, interest in learned optimizers has undergone a recent revival. Proposed learned optimizer architectures have included perparameter RNNs (andrychowicz2016learning), hierarchical models enabling sharing of information across parameters (wichrowska2017learned; metz2020tasks), and a simplified architecture consisting just of an MLP (metz2019understanding). Optimizer metatraining techniques have included gradient descent (maclaurin2015gradient; andrychowicz2016learning; wichrowska2017learned)
(li2016learning; li2017learning), and more advanced training procedures (lv2017learning; maheswaranathan2019guided; metz2019understanding; chen2020training; pmlrv139vicol21a; metz2021training) leveraging both Evolution Strategies (ES), and gradients. Learned optimizers have been targeted at applications including model robustness (metz2019using), chemistry (learn2hop), minmax optimization (shen2021learning), adversarial training (xiong2020improved), fewshot learning (ravi2016optimization), swarm optimization (cao2019learning)(metz2018learning), black box optimization (chen2016learning), and MCMC sampling (levy2017generalizing; wang2017meta; gong2018meta). Other work has analyzed learned optimizer behavior (maheswaranathan2020reverse). Bello17 takes a different approach, and metalearns symbolic rather than neuralnetwork driven parameter update rules.In an effort to understand computational costs, we look to Pareto frontiers of computation and memory vs performance. The concept of Pareto optimally was originally proposed in economics to understand how individuals can prosper with finite resources (newman1998new)
and has since become a useful tool in computer science. Studying tradeoffs in this way is common in computer vision and natural language processing, where performance as a function of model size is often explored
(simonyan2014very; he2016deep; vaswani2017attention). Building efficient frontiers of models has been a target for metalearning as well (tan2019efficientnet). In the scope of learned optimizers, metz2019understanding explored training wallclock efficiency, but on limited hardware (CPU) and with respect to a single target problem instance. wichrowska2017learned showed that the relative overhead of computing updates with a learned optimizer shrinks as batch size is increases. zheng2022symbolicpropose a symbolic distillation metatraining step which converts neural network parameterized optimizers to a symbolic form resulting in both lower memory and compute costs. There has also been significant research exploring the tradeoffs between different optimization techniques outside of deep learning, especially between stochastic, full batch, and between different second order methods. For example, Newtons method has led to a number of approximations – e.g. diagonal approximations
(duchi2011adaptive), block diagonal (martens2015optimizing; gupta2018shampoo), as well as a large family of quasi newton methods (dennis1977quasi) (e.g. BFGS (broyden1970convergence) and its low memory counterpart, LBFGS (liu1989limited)).6 Conclusion
In this work, we characterized practical tradeoffs involved in designing learned optimizers, including those between performance optimizing a target task, compute and memory overhead associated with the learned optimizer, training time, choice of target task, and generalization to new tasks. Using the lessons learned from our careful exploration, we introduce an architecture that strikes a better balance between memory usage, compute, and performance. We then show that this learned optimizer architecture can be used to accelerate training on accelerator hardware.
The goal of this work was to provide a thorough investigation of the fundamental tradeoffs associated with learned optimizers. We view this paper as a first step toward the empirical characterization necessary for principled comparisons, but our experiments were limited in several ways. First, to control the number of covariates and perform the required experiments within a reasonable computation budget, we have limited ourselves to (primarily) two tasks, which themselves are simple compared with stateoftheart neural network models. Moreover, we have provided only a limited investigation of metageneralization and validation performance. In general, while this paper serves as a first step toward rigorous empirical comparison within this novel class of learned optimizers, further work is required to extend our results.
In order to make it easier for future research to build on our work, and to include better grounded empirical comparisons, all optimizers, tasks, and training code are open sourced in learned_optimization^{6}^{6}6https://github.com/google/learned_optimization, an open source library written in JAX for designing, training, and testing learned optimizers.
Acknowledgements
We would like to thank Chip Huyen, Ben Poole, and Amil Merchant, Wenqing Zheng, for their support and comments on this work as well as the entire Brain team for providing a wonderful research environment. We would also like to thank the authors of the python scientific computing stack including Numpy (van2011numpy), and Matplotlib (matplotlib).
References
Appendix A small_fc_lopt architectural details
We describe the full details of our proposed learned optimizer. The source code for this optimizer can be found in https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/learned_optimizers/adafac_mlp_lopt.py.
Our learned optimizer consists of features, concatenated then fed into an MLP. These features contain:

the parameter values

the 3 momentum values ()

the second moment value ()

3 values consisting of momenta normalized by rms gradient norm –

the value

3 AdaFactor normalized gradient values

the tiled, AdaFactor row features (3 features)

the tiled, AdaFactor column features (3 features)

of these previous 6 features

3 features consisting AdaFactor normalized momentum values

11 features formed by taking the current timestep, , and computing where
All but the time features are normalized to have a second moment of 1 across the tensor and fed into a
hidden unit, 2 hidden layer MLP with a ReLU activation function then projected to 2 hidden dimensions representing a magnitude,
, and a scalar direction to be combined to form a predicted step: where and are constants set to 0.001 to keep initial stepsizes small.In order to reduce computational overhead, this optimizer is dramatically smaller than past learned optimizers, containing only 197 metaparameters.
Appendix B Tasks we metatrain on
b.1 Fashion MNIST MLP
The MLP we use consists of 2 hidden layers with 128 hidden units and ReLU activations. It is trained on batches of Fashion MNIST rescaled to lie between [0, 1] and on batch sizes of 128. We train with cross entropy loss. Our network was built in Haiku [haiku2020github] and the implementation can be found at https://github.com/google/learned_optimization/blob/32c4f21ec238a12756afe70e3d699017ea938f5d/learned_optimization/tasks/fixed/image_mlp.py#L35.
b.2 CIFAR10 ConvNet
The ConvNet consists of 3 hidden layers with ReLU activations. All layers have kernel sizes of 3x3. The first layer has 32 units and stride 2. The following 2 layers have 64 hidden units and stride 1. All convolutions have same padding. We average over the spatial dimensions then linearly project to 10. We train on batches of 128 CIFAR10 rescaled between [0, 1] and use cross entropy loss. The implementation can be found at
https://github.com/google/learned_optimization/blob/ba2b56565fb507368652d2e4a12ab305a6d99ded/learned_optimization/tasks/fixed/conv.py#L96.Appendix C nn_adam architecture
The nn_adam learned optimizer is a hyper parameter controller based learned optimizer. In addition to the description that follows, we provide an implementation at https://github.com/google/learned_optimization/blob/ba2b56565fb507368652d2e4a12ab305a6d99ded/learned_optimization/learned_optimizers/nn_adam.py#L161. For each tensor of the target problem, we compute some set of features (see C.1), feed them into a 32 unit LSTM, and output 4 values – log learning rate, beta 1, beta 2 (both parameterized as log(1  beta)) and log epsilon. These hyperparameters are then fed into Adam where the update to each weight and accumulator follows the Adam update equations.
c.1 per tensor features
We use a same set of pertensor features as used by the hierarchical learned optimizer in metz2020tasks. For many features, we employ a simple transformation to obtain magnitudes of features. This transformation involves computing the log of the absolute value, clipping between 5 and 5, and rescaling by 0.5.
For each tensor, we use the following as a feature set:

Transformed mean momentum value

Sign of mean momentum

Transformed variance squared of momentum

Transformed mean of the accumulator of second moment

Sign of the mean of the accumulator of second moment

Transformed mean parameter value

Sign of mean parameter value

Transformed variance squared of parameter value.

Transformed mean gradient value

Sign of mean gradient value

Transformed variance squared gradient

Transformed mean absolute value of gradient.
Appendix D Experiment details
d.1 Common
For all experiments in this paper we metatrain with Persistent Evolutionary Strategies [pmlrv139vicol21a] with a standard deviation of 0.01 and length 20 truncations with length 2k inner training steps. The metaobjective we target is mean training loss (clipped at the initialization value –
in the classification problems). We use 4 distributed workers (each using a single TPU accelerator chip) in an async batched fashion. We use a metabatch size of 4 with each meta gradient being an average from each worker which is itself an average over 8 tasks. For all models, we have an additional learner job (also a TPU chip) which averages metagradients and performs PES updates. We use Adam as the metaoptimizer in all experiments with gradient clipping of 3.0 done on each value of the gradient independently. Each training job has an additional 3 machines (15 for metageneralization experiments), each with a single TPU chip to perform evaluations by training a task with the current metaparameters. The metatraining curves we report are from these machines. All metatraining in this work took 24 days per experiment.
d.2 §4.1 details: Optimizer overhead vs optimizer type
For each task, and learned optimizer pair we train 3 random seeds for 5 learning rates: [1e5, 3e5, 1e4, 3e4, 1e3]. We show the best performing optimizer. We do this as opposed to mean as the low dimensional hidden size of small_fc_lopt (4 hidden units) can result in unstable training. We found using a simple learning rate schedule improves metatraining stability, but for a fair comparison to other optimizers we do not use any schedule here.
d.3 §4.2 details: Input features experiment details
For these experiments we use a fixed learning rate, set to as this was found to be the best performing model for fc_lopt from §D.2. Given the amount of variations tried, we could not afford to search over learning rate for each configuration. For each configuration we compute 3 random seeds.
grads_time_p: Just using parameter, and gradient, and time step features.
m_0.1, m_0.5, m_0.9, m_0.99, m_0.999: Using parameter value, gradient, momentum with the listed decay value, and time step features.
m_all: Same as before but using multiple momentum values. In this case five values: 0.1, 0.5, 0.9, 0.99, 0.999.
m_mid2: Same as before but with 2 momentum values 0.5 and 0.9.
m_mid3: Same as before but with 3 momentum values 0.5, 0.9 and 0.99.
rms_0.1, rms_0.5, rms_0.9, rms_0.99, rms_0.999: Using parameter value, gradient, the second moment accumulator with the listed decay value, 1 over the sqrt of this feature, and time step features.
rms_all: Same as before but using multiple second moment accumulator values. In this case six values: 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999.
rms_mid2: Same as before but with 2 second moment values 0.9 and 0.99.
rms_mid4: Same as before but with 3 second moment values 0.5, 0.9, 0.99, and 0.999.
m_rms_0.1, m_rms_0.5, m_rms_0.9, m_rms_0.99, adams_0.999: Using parameter value, gradient, the second moment accumulator with the listed decay value, 1 over the sqrt of this feature, momentum with the listed decay, as well as the product of momentum and 1 over the square root of the second moment (similar the Adam update) and time step features.
m_rms_all: Same as before but using multiple second moment and momentum accumulator values. In this case six values: 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999.
m_rms_mid2: Same as before but with 2 accumulator timescales: 0.9 and 0.99.
m_rms_mid4: Same as before but with 3 accumulator timescales: 0.5, 0.9, 0.99 and 0.999.
adafact: Using parameter values, 6 adafactor accumulator decays (0.1, 0.5, 0.9, 0.99, 0.999, 0.9999) which are fed to the learned optimizer in the form 3 multiplications: 1 over sqrt, 1 over sqrt multiplied by the gradient, and by tiling both of the low rank accumulators.
adafact_m_mul: Same as before, but with 3 adafactor accumulators and 3 momentum accumulators of decays (0.5, 0.9, 0.99). In addition to the previous features, we also include the multiplication of momentum value by the preconditioner from adafactor.
union: The union of all features. This includes parameter value, gradient value, time features, all momentum and second moment accumulators (0.1, 0.5, 0.9, 0.99, 0.999, 0.9999), all features from adafactor computed with these same timescales, as well as multiplications of adafactor and momentum features.
d.4 §4.3 details: Large scale overhead timings
We use the ResNet18, and ResNet50 implementations from Haiku [haiku2020github] with the V2 flag set to true.
For transformers, we use vocab size of 256 to emulate byte level training, a hidden size of 768, 6 layers, and 12 self attention heads per layer. When applying dense layers we use a 4x widening factor.
Appendix E Extended metageneralization experiments
e.1 Experimental details
Over the course of training a learned optimizer on a particular task, we monitor performance on a variety of held out tasks described here.
Fashion Mnist MLP: This is a 2 hidden layer, 128 unit MLP trained on fashion mnist. Source code can be found https://github.com/google/learned_optimization/blob/32c4f21ec238a12756afe70e3d699017ea938f5d/learned_optimization/tasks/fixed/image_mlp.py#L35.
CIFAR10 Convnet: This is convnet with 3 hidden layers trained on 16x16 CIFAR10. It contains 3 hidden layers starting with a 32 channels stride 2, and followed by two 64 channel, stride 1 convolutions. Average pooling is then performed before linearly mapping to the number of output channels. An implementation can be found at https://github.com/google/learned_optimization/blob/78f25e8f1e9c6236a1f559b7b0b36859c59542d2/learned_optimization/tasks/fixed/conv.py#L86
Imagenet MLP: This is an MLP operating on 16x16 resized imagenet images. The network has 3 hidden layers, of size 256. An implementation can be found at https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/image_mlp.py#L94.
Auto Encoder: This is an auto encoder trained on CIFAR10 with mean squared error. The network consists 3 hidden layers with sizes 128, 32, 128. A full implementation can be found in https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/image_mlp_ae.py#L101.
LSTM language modeling: This is a language model trained on [DBLP:journals/corr/ChelbaMSGBK13]. The language is tokenized as bytes, and sliced into length 32 sequences. The model consists of embedding the tokens with a 64 dimensional lookup table, followed by a size 128 LSTM tasked to predict the next token. See https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/rnn_lm.py#L142 for the impementation.
e.2 Additional figures
In this section we provide additional experimental results for metageneralization similar to §4.5. First, in Figure 8 we show additional performance measurements on more held out tasks. As in §4.5, we see poor metageneralization and high variability.
We plot evaluations over the course of metatraining for each different learned optimizer type and multiple random seeds in Figure 9 when metatraining on the fashion Mnist MLP, and Figure 10 for the CIFAR10 conv net. When metatraining and evaluating on the same distribution, we find stable evaluation loss. When evaluating on other kinds of tasks, we see wide variability in performance across both architecture, and even among different initializations of the learned optimizer weights holding all else fixed. In some cases, such the learned optimizers switches between performing optimization on the target task, and diverging as shown by the rapid spikes in the metaloss.
Finally, we show an alternative plot of the same data discussed in the previous paragraph. This time, we plot metaevaluation performance against metatrain performance. For each figure we show each learning rate, and each seed in a separate pane. We show the small_fc_lopt optimizer in Figure 11, the rnn_fc_lopt in Figure 12, the fc_lopt in Figure 13, and nn_adam in Figure 14. Once again we find high variability across architecture, learning rate, and random seed. In these figures, metaoverfitting is highlighted by a "c" shaped curve – metatraining performance continues to improve, but metaevaluation performance gets worse after some point. Jagged lines / instability suggest a high sensitivity in performance on the evaluation task.
Appendix F NAdamW update equations and search space
For our NAdamW baseline, we use the same implementation, and search space described in metz2020using. We repeat the functional form here for convenience.
f.1 Update equations
This optimizer architecture has 10 hyperparameters. The base learning rate, , first and second moment momentum, , , the numerical stability term, , regularization strength, AdamW style weight decay, and a boolean to switch between NAdam and Adam, . The learning rate schedule is based off of a single cycle cosine decay with a warmup. It is controlled by 3 additional parameters – , , and .
The learning rate is defined by:
(4)  
(5)  
(6)  
(7)  
(8)  
(9) 
The update equations of NAdamW follow.
problem specified random initialization  (10)  
(11)  
(12)  
(13)  
(14)  
(15)  
(16)  
(17)  
(18)  
(19)  
(20)  
(21) 
f.2 hyperparameter search space
The initial learning rate, is sampled from log space between and . is sampled logrithmically between , and . is sampled between , and . is sampled logarithmically between and . We sample using nesterov () 50% of the time. We sample and logrithmically between and
. Equal probabilities of a third we either use both terms, zero out
, or zero out . With 50% probability we use a nonzero min learning rate multiplier sampled logrithmically between and . With 50% probability we sample the warm up fraction, between 1e5 and 1e1, otherwise it is set to zero. Finally, we uniformly sample the amount of time the learning rate is held constant () between 0 and 1.