Despite the huge computational costs associated with training large neural models, the set of optimization algorithms used to train them has largely been restricted to simple update functions mapping from gradients to parameter updates (e.g. stochastic gradient descent(robbins1951stochastic), Adam (kingma2014adam)
, or RMSProp(tieleman2012lecture)
). These algorithms typically depend on a small number of hand-designed features and parameters. However, the last decade in machine learning research has repeatedly seen small, hand-designed models outperformed by parameterized models (such as neural networks) trained to purpose on large amounts of data(lecun2015deep). Thus, a promising direction to improve training performance and reduce costs is to replace hand-designed optimizers with more expressive learned optimizers, trained on problems similar to those encountered in practice.
Learned optimizers specify parameter update rules using a flexible parametric form and learn the parameters of this function from a “dataset” of optimization tasks—a procedure typically referred to as meta-training or meta-learning (andrychowicz2016learning; finn2017model; hochreiter2001learning). Learned optimizers represent a path towards improved optimizer performance, and possess the ability to target different objectives (e.g. test loss (metz2019understanding), or robustness (metz2019using)), as well as the ability to leverage new features useful for optimization. Despite being an active area of research (andrychowicz2016learning; wichrowska2017learned; chen2020training; metz2020using; metz2021training; almeida2021generalizable; zheng2022symbolic), they are not yet commonly used in practice. Several challenges have limited the widespread application of learned optimizers: they are typically difficult to meta-train on a task family of interest, they can require significant memory and compute overhead when applied, and they often generalize less well to novel tasks than hand-designed optimizers.
In this work, we aim to precisely study the limitations of learned optimizers, and address these limitations via a novel learned optimizer architecture. In particular, we explore and quantify the tradeoffs in terms of memory, compute, and generalization across a variety of optimizers, including hand-designed optimizers (bottou2010large; tieleman2012lecture; kingma2014adam), learned hyper-parameter controllers (daniel2016learning; hansen2016using; xu2017reinforcement; xu2019learning; almeida2021generalizable), and neural network based learned optimizers (andrychowicz2016learning; wichrowska2017learned; metz2020tasks), with the goal of understanding how choices in optimizer design affect performance and usability. Our core contributions are:
We present a thorough empirical characterization of the trade-offs inherent in different learned optimizer architectures and features, and a comparison of these learned optimizer architectures against their well-tuned hand-designed counterparts.
We develop a new per-parameter learned optimizer architecture, on the Pareto frontier with regards to performance, computational cost, and memory usage.
In this section we review and formalize the class of optimizers that are commonly used in training neural networks. We then define meta-learned optimizers, and highlight differences with standard optimization approaches. We describe several examples of both common, standard neural network optimizers as well as classes of learned optimizers, all of which are investigated in this paper.
2.1 Gradient Based Optimizers
Most first-order optimizers222Optimizers using only gradient information and not higher order derivatives. used to train neural networks can be viewed as functions mapping from a history of gradients to parameter updates. We will assume the optimizer acts on an underlying model with parameters , while maintaining an internal optimizer state . The parameters may be, for example, neural network weights, whereas the optimizer state includes quantities such as the accumulated momentum values in momentum-accelerated optimizers (polyak1964some; nesterov1983method). The optimizer acts by ingesting gradients
(which arise from a specified loss function and a dataset) and outputting updated parameters.
More precisely, we define an optimizer as a pair of functions. The first, which we call the Update function, computes new parameter values and state from stochastic gradients, the current parameter value, and the current optimizer state. The second, which we refer to as the Init function, initializes the optimizer state. Both functions have hyperparameters , such as the learning rate or the initial value of accumulators. Thus, we write the optimizer as:
Optimizers can benefit from problem information beyond stochastic gradients, parameter values, and losses. For instance, methods that utilize line searches (le2011optimization), validation loss (xu2019learning; metz2020tasks), or the structure of the underlying computation graph (martens2015optimizing) all rely on additional information. However, the present work is restricted to optimizers which minimize training loss by mini-batch stochastic gradient descent.
First-order hand-designed optimizers: Hand-designed optimizers typically have a simple form, and a small number of hyperparameters (), which are tuned by random search (bergstra2012random), Bayesian optimization (snoek2012practical), or other low-dimensional black-box optimization techniques (bergstra2011algorithms; golovin2017google; optuna_2019). They mostly have low overhead in terms of compute and memory usage. For instance, Adam (kingma2014adam) has two accumulators, and SGD has none333Some hand-designed methods, such as Shampoo (gupta2018shampoo; anil2020second), involve considerable compute overhead, but can make more progress per update step..
In this work, we experiment with four kinds of hand-designed optimizers: SGD (robbins1951stochastic; bottou2010large), SGDM (SGD with momentum) (polyak1964some), Adam (kingma2014adam), and Nesterov accelerated Adam (dozat2016incorporating) with AdamW (loshchilov2017decoupled) style weight decay (NAdamW). For SGD, SGDM, and Adam, we search over learning rates every half order of magnitude between and . For NAdamW we use random search with many more hyperparameter configurations per task (1000) and a much larger search space over hyperparameters controlling: first and second momentum time scales, weight decays, and learning rate schedules. Past work has shown this to be a powerful search space (metz2020using) and, in our work, this dramatically outperforms learning rate search. See Appendix F for more details. Many other hand-designed optimizer architectures have been proposed (ruder2016overview; zeiler2012adadelta; reddi2018adaptive; you2019large; liu2019variance), but their practical benefits are small in most situations (schmidt2020descending).
Factorized optimizers: In some settings, having even one additional copy of parameters to use for accumulators is too costly. Recent methods such as AdaFactor (shazeer2018adafactor) and SM3 (anil2019memory) factorize the weights and accumulate statistics using a sub-linear amount of memory with respect to parameters. This style of accumulator has not been explored in the context of learned optimizers, but we will show this provides an effective way to improve performance without meaningfully increasing memory overhead (§4.2).
2.2 Meta-Learned Optimizers
The meta-learning problem for optimizers consists of tuning the hyperparameters of a class of parameterized optimizers with respect to some loss function444Common choices of loss function for the meta-optimization problem include the average training loss across inner optimizer iterations, the average validation loss, as well as the terminal train/validation loss.. How is this different from the hyperparameter tuning discussed in the last subsection? While there is no formal difference between the hyperparameter selection problem and training learned optimizers, the learned optimizers we consider in this subsection universally include a black-box component with a (comparatively) large number of parameters (in our case, always parameterized by a neural network). This large number of parameters limits the effectiveness of traditional hyperparameter tuning methods such as random search, and so we focus on local optimization methods (including first-order gradient-based methods as well as zeroth-order methods) which are able to perform better in high dimensional optimization. Below, we outline several types of learned optimizer.
Hyperparameter controllers: Optimizing the hyperparameters of a hand-designed optimizer over a broad set of tasks may limit the performance within each specific task. These hand-designed optimizers can be augmented with a meta-learned controller, often parameterized as a neural network, that modulates the hyperparameters of the optimizer over the course of training to yield better performance in each particular problem (daniel2016learning; hansen2016using; xu2017reinforcement; xu2019learning; almeida2021generalizable)
. This controller takes in summary statistics (e.g. gradient norms, loss values), and can either globally assign identical hyperparameters to all layers, or operate per-layer. One benefit of hyperparameter controllers is that their per-parameter compute overhead is small, as the majority of the computation only needs to be performed once per tensor, or per network rather than scaling with the number of parameters.
We introduce a novel hyperparameter controller architecture which we refer to as nn_adam. This architecture consists of an LSTM-based (hochreiter1997long) hyperparameter controller, operating on features derived from each tensor independently, and outputting Adam hyperparameters consisting of a per-tensor learning rate, , , and
. For features, this model uses normalized values derived from the first moment of gradients, the second moment, and the tensor shape. We refer to this optimizer as nn_adam. See AppendixC for details.
Per-parameter learned optimizers: Per-parameter learned optimizers (andrychowicz2016learning) learn a function, often parameterized by a neural network, which is applied to each parameter independently, though sometimes with normalization performed across parameters in a tensor (metz2019understanding).
Multi-level approaches: In an effort to add additional capacity to a learned optimizer while retaining good computational complexity with respect to number of parameters, hierarchical models have been proposed (wichrowska2017learned; metz2020tasks). These models leverage up to three levels of hierarchy: a global controller, which sends and receives activations from a per-layer (or per-tensor) controller, which finally sends and receives activations to a per-parameter optimizer.
New per-parameter learned optimizer: Finally, we introduce a new learned optimizer architecture (which we call small_fc_lopt) that combines architectural features of per-parameter and factorized optimizers, and outperforms both. This architecture will be directly motivated by the trade-offs among compute, memory, performance, and generalization shown in §4. Our learned optimizer incorporates an extremely tiny, per-parameter, MLP-based learned optimizer similar to that used in metz2019understanding. This 197 parameter MLP takes as input 39 input features with 4 per-parameter accumulators (3 momenta at different meta-learned timescales, and 1 gradient second moment accumulator), and 3 AdaFactor accumulators also at 3 different meta-learned timescales. These features are passed into a 1 hidden layer, 4 hidden unit MLP. See Appendix A for additional details.
3 Training and Meta-Training
In the previous section we specified possible architectures for standard optimizers (with a small number of hyperparameters) as well as learned optimizers. Both learned and hand designed optimizers are iteratively applied to some parameterized model, paired with a loss function and (possibly) a dataset. We refer to this collection as a task. We use the loss obtained by an optimizer on these tasks to select hyper-parameters (in the case of hand designed optimizers), and to optimize the learned optimizer weights. In this section, we discuss the tasks used, the measurement of performance by which we can compare optimizers (meta-loss), and discuss how the weights of the learned optimizers are computed (which we refer to as meta-optimization).
Throughout this paper, the tasks of interest are neural network training problems. Each task is specified via three quantities. The first is the underlying model architecture and the initial parameter values (or a procedure for initializing the model parameters). The second is a function to generate batches of data, and the third is a loss function. While a more abstract definition of a task could cover more general optimization problems, we aim to address neural network training as a setting and believe generalizations are (in most cases) straightforward. In this work we consider solely supervised learning. We also consider only a single function to generate a batch of data, though this could easily be extended to multiple functions corresponding to, for example, train and validation loss. As discussed in the next subsection, we focus solely on training loss for simplicity.
, and a 3 layer convolutional network on CIFAR-10(krizhevsky2009cifar). See Appendix B for more details and implementations. In Section 4.5, to assess generalization, we additionally evaluate optimizers meta-trained on these two tasks on three additional problems.
The tradeoffs inherent in optimizer design are task dependent (see §4.3), and the per-parameter compute and memory requirements of the optimizer must be balanced against the per-parameter compute and memory requirements of the task. These latter requirements are a function of parameter sharing, sparsity in parameter use, model architecture, and minibatch size (compute overhead per parameter can be made arbitrarily small by increasing the minibatch size). The two problems we consider have different compute and memory requirements and were therefore chosen as reasonable baseline tasks providing insight into optimizer performance at different points in the space of possible tasks. Moreover, these (relatively small) tasks were chosen to enable the large-scale evaluation and comparisons done in this paper.
3.2 Meta-Loss and Meta-Optimization
To evaluate an optimizer, we apply our optimizer for 2,000 iterations and evaluate the average loss obtained over the course of training. In this work, we exclusively focus on training loss performance as opposed to validation loss. This is to decouple optimization performance tradeoffs from the implicit regularization effects of learned optimizers shown in metz2019understanding; metz2020tasks.
We train optimizers targeting the two tasks described above by randomly sampling from a fixed search space for hand-designed optimizers, and Persistent Evolution Strategies (PES) (pmlr-v139-vicol21a) to train the learned optimizers. See Appendix D.2 for details. To minimize confounds, we focus on the scenario where meta-train matches meta-test (i.e. the tasks presented during training and testing are the same), and examine the overhead and performance tradeoffs inherent in a learned optimizer meta-trained to optimize a single task.
4 Exploring tradeoffs across optimizer families
In this section we experimentally explore tradeoffs when designing learned optimizers. In §4.1, we show memory and time trade-offs for various hand-designed and learned optimizers. In §4.2, we focus on per-parameter learned optimizers and explore the impact of both feature choice and size of the learned optimizer. In §4.3, we discuss computational costs of running learned optimizers as a function of task features (such as the number of network weights). In §4.4, we tie all these evaluations together and show wall-clock time performance for the different tasks. In §4.5, we explore meta-generalization—applying optimizers to a task different from those in which they were meta-trained.
4.1 Compute, memory, performance tradeoffs for learned and hand-designed optimizers
We characterize the trade-offs between performance, memory overhead, and compute overhead for both hand-designed and learned optimizers. The optimizers examined here consist of the hand-designed optimizers (SGD, Adam, and NAdamW), the MLP optimizer from metz2019understanding (fc_lopt), the hierarchical optimizer from metz2020tasks, (rnn_fc_lopt), a hyperparameter controller described in §2.2 (nn_adam), and the per-parameter optimizer proposed in §2.2 (small_fc_lopt).
Meta-training curves are shown in Figure 1ac. We additionally show final performance of the fully trained learned optimizer as a function of compute time per step (Figure 1bd) and with respect to memory usage (Figure 1e). We find learned optimizers can achieve lower meta-loss than baselines, but at the cost of more compute time and memory usage. For the MLP task, the cost of the learned optimizer far outstrips the cost of a hand-designed optimizer (taking > 5x more time in the case of rnn_fc_lopt). For the CIFAR-10 ConvNet however, the compute overhead is small relative to overall compute, due to the much larger per-parameter cost for computing gradients for a ConvNet.
4.2 Design choices for the MLP learned optimizer
To guide learned optimizer design, we explore the memory, time, and performance trade-offs associated with different choices of input features for an MLP learned optimizer. The dominant source of memory overhead is the inclusion of additional per-parameter accumulators. We explore two kinds of per-parameter accumulators that optimizers use—the exponential moving average of the gradient’s first moment, and second moment as used in momentum and RMSProp (tieleman2012lecture) respectively. Unlike existing optimizers, the learned optimizers we explore accumulate these statistics over multiple timescales555This is similar to what is done in the AggMo (lucas2018aggregated) hand-designed optimizer.. In addition to these, we also explore preconditioning features based on AdaFactor (shazeer2018adafactor) which use sub-linear memory in parameter count.
We plot performance vs. compute cost, and performance vs. memory, in Figure 2. We plot baselines with hyperparameters found via random grid search (in black) (adam, sgd, sgdm, nadamw), and baseline learned optimizers (in gray) (fc_lopt (metz2019understanding), rnn_fc_lopt (metz2020tasks), nn_adam (§2.2). All other conditions consist of differently parameterized learned optimizers with different input feature. Each configuration is trained with PES (pmlr-v139-vicol21a) for 100k meta-training iterations. We test optimizers using only a single momentum accumulator with different decays (in green), multiple momentum accumulators (in yellow), a single second moment accumulator (in brown), multiple second moment accumulators (in pink), two accumulators with the same decay for first and second moments (in blue), multiple decay first and second moments (in purple), using AdaFactor features with and without additional momentum accumulators (in red), using only gradient features (in orange), and finally the union of all features (in gray). See Appendix D.3 for more experimental details.
We find the general trend that providing more features to a learned optimizer leads to better performance. However, including more accumulators increases the computational and memory overhead of using these optimizers. AdaFactor features by themselves (adafact) use very little memory, but do not perform well. Combining a small number of momentum features with AdaFactor features (adafact_m_mult) recovers the performance of using second moment accumulators, without the need for second moment accumulators.
Finally, we explore varying the hidden size of the MLP (Figure 3). Using the same features as in small_fc_lpot (§2.2), we sweep the hidden size of the MLP from 2 to 256 units. For each width we perform a small hyperparameter search over meta-learning rate selecting between ,, and take the best performing learning rate for each width. Surprisingly, an extremely narrow MLP is sufficient to outperform the best hand-designed baseline (NAdamW). Increasing width boosts performance, but performance improvements diminish.
The relationship between learned optimizer width and compute overhead depends heavily on implementation details. TPU and other accelerator hardware often have specialized matrix multiplication units that operate on fixed dimensional matrices (e.g. TPUv2 has 128x128 systolic arrays (norrie2020google)). For a naive implementation of the learned optimizer using matrix multiplication kernels on TPUv2, there are no significant speedups from shrinking the optimizer width below approximately 64 units. However, if matrix multiplication is expanded explicitly in terms of element-wise operations, then continued speedups can be achieved even for optimizer hidden state vectors of two units, achieving a nearly 2x speedup over the use of matrix multiplication primitives as different primitives and thus different pieces of hardware are used. Profiling suggests that even greater speedups should be possible using custom kernels (as are frequently written for hand-designed optimizers). In a sense, these tiny learned optimizers, with matrix multiplications expanded, blur the line between hand designed and learned optimizers as both implementations are a handful of element wise floating point operations.
4.3 Overhead of learned optimizers on different tasks
To explore the dependence on task identity, we measure the relative overhead of training with a learned optimizer compared to SGD for different widths and batch sizes of the Fashion MNIST MLP and CIFAR-10 ConvNet. Results are shown in Figure 4. We find that, in all cases, increasing batch size lowers the overhead, as the cost to compute gradients increases but the cost of applying the optimizer remains constant. In the case of the ConvNets, increased the channel count of the target problem lowers the overhead for small_fc_lopt and nn_adam, but increases for fc_lopt and rnn_fc_lopt. For MLP target problems, increasing the target MLP’s hidden size increases overhead for all optimizers including Adam. The asymptotic scaling of this behavior is due to both the computational complexity, and memory bandwidth, of the underlying hardware.
Next, we explore overheads for some common, large scale models. Table 1 show results for ResNets and Transformers, trained on a single TPUv2 chip. Distributed training would allow us to split the optimizer computation across devices and thus achieve even lower optimizer overhead. See Appendix D.4 for further details.
|Task||Params||SGD step time (ms)||LOpt step time (ms)||LOpt multipler|
with different word sequence lengths (L) and batch sizes (BS). All numbers are medians over 10 timings. Standard error is under the reported number of digits.
4.4 Performance with respect to wall clock time
In the previous sections, we measured the performance achieved by optimizers, and the computational overhead required to achieve that performance. In practice, one often cares most about the total wall clock time required to reach a given performance. Further, the optimal meta-parameters change depending on the length of inner-training. To quantify the achievable performance as a function of wall clock time, we compare training trajectories for both our learned and hand-designed optimizers. We apply a learned optimizer meta-trained for length 2k unrolls, and optimize hyperparameters of hand-designed optimizers to perform well on 1k, 2k, 3k, 4k, and 5k length unrolls. In Figure 5 we show the resulting performance for both the Fashion MNIST MLP task and CIFAR-10 ConvNet tasks. Our learned optimizer is always faster with respect to step count. With respect to wall-clock time, on the CIFAR-10 ConvNet task we also see faster training, while for the Fashion MNIST MLP task, where the relative overhead of the learned optimizer is large, NAdamW performs best.
4.5 Meta-generalization: optimizer performance on holdout tasks
One final trade-off is the interaction between optimizer design choice and generalization performance. Generalization performance, in this context, refers to the ability of the optimizer to perform well at training a novel task, different from the task distribution used to meta-train the optimizer. To quantify this, we measure the performance of diverse optimizers meta-trained on one task, but then used as the optimizer for a novel task.
We meta-train each of the different learned optimizers on either the Fashion MNIST MLP or the CIFAR-10 ConvNet tasks. Over the course of meta-training we evaluate performance on 5 tasks:
The Fashion Mnist MLP described in 3.1.
The CIFAR-10 ConvNet described in 3.1 using 16x16 images for computational reasons.
A three hidden layer MLP trained on 16x16 imagenet.
A CIFAR-10 auto-encoder trained with mean squared error loss.
An LSTM (hochreiter1997long)
language model trained on byte level LM1B(DBLP:journals/corr/ChelbaMSGBK13).
See Appendix E for more details and implementations.
First, we train a number of learned optimizers of different types and select the checkpoint which performs best on the meta-training task, and evaluate their performance on different held out tasks. We show transfer performance when meta-training and meta-testing on Fashion MNIST MLP and CIFAR-10 ConvNet in Figure 6, and the remainder of the comparisons in Appendix E. While there is some correlation across learned optimizer architecture, in general we find poor meta-generalization. Additionally, meta-generalization seems to depend more strongly on details of the meta-training process than it does on learned optimizer architectural choices. This variation poses challenges when reporting results due to the cost of meta-training.
To explore variance in the meta-training process, we plot performance over the course of meta-training on both the target-task, and the held out tasks. We show these dynamics for two different learned optimizer seeds in Figure 6cd. In both cases, meta-training trajectories exhibit high variability. However, they also both show an initial phase of correlated performance improvement, culminating in better performance than the baselines for both the target and held out task, before the optimizer finally overfits to the target-task.
This type of meta-overfitting is not unique to learned optimizers and happens even when trying to transfer hand designed optimizers from task to task. To show this, we simulate meta-training on the Fashion MNIST MLP by randomly sampling subsets of parameters from the original NAdamw search space for different budgets. We then find the best performance on the meta-training task and apply the best hyperparameters to the meta-test task. We show meta-train vs meta-test performance for different budgets in Figure 7
. We see signs of meta-overfitting for some tasks, such as the ImageNet MLP and the MLP Autoencoder. For the others, CIFAR-10 ConvNet and LSTM, we continue to see correlation between meta-train and meta-test.
5 Related Work
Originally proposed in bengio1992optimization; runarsson2000evolution, interest in learned optimizers has undergone a recent revival. Proposed learned optimizer architectures have included per-parameter RNNs (andrychowicz2016learning), hierarchical models enabling sharing of information across parameters (wichrowska2017learned; metz2020tasks), and a simplified architecture consisting just of an MLP (metz2019understanding). Optimizer meta-training techniques have included gradient descent (maclaurin2015gradient; andrychowicz2016learning; wichrowska2017learned)li2016learning; li2017learning), and more advanced training procedures (lv2017learning; maheswaranathan2019guided; metz2019understanding; chen2020training; pmlr-v139-vicol21a; metz2021training) leveraging both Evolution Strategies (ES), and gradients. Learned optimizers have been targeted at applications including model robustness (metz2019using), chemistry (learn2hop), min-max optimization (shen2021learning), adversarial training (xiong2020improved), few-shot learning (ravi2016optimization), swarm optimization (cao2019learning)metz2018learning), black box optimization (chen2016learning), and MCMC sampling (levy2017generalizing; wang2017meta; gong2018meta). Other work has analyzed learned optimizer behavior (maheswaranathan2020reverse). Bello17 takes a different approach, and meta-learns symbolic rather than neural-network driven parameter update rules.
In an effort to understand computational costs, we look to Pareto frontiers of computation and memory vs performance. The concept of Pareto optimally was originally proposed in economics to understand how individuals can prosper with finite resources (newman1998new)
and has since become a useful tool in computer science. Studying trade-offs in this way is common in computer vision and natural language processing, where performance as a function of model size is often explored(simonyan2014very; he2016deep; vaswani2017attention). Building efficient frontiers of models has been a target for meta-learning as well (tan2019efficientnet). In the scope of learned optimizers, metz2019understanding explored training wall-clock efficiency, but on limited hardware (CPU) and with respect to a single target problem instance. wichrowska2017learned showed that the relative overhead of computing updates with a learned optimizer shrinks as batch size is increases. zheng2022symbolic
propose a symbolic distillation meta-training step which converts neural network parameterized optimizers to a symbolic form resulting in both lower memory and compute costs. There has also been significant research exploring the trade-offs between different optimization techniques outside of deep learning, especially between stochastic, full batch, and between different second order methods. For example, Newtons method has led to a number of approximations – e.g. diagonal approximations(duchi2011adaptive), block diagonal (martens2015optimizing; gupta2018shampoo), as well as a large family of quasi newton methods (dennis1977quasi) (e.g. BFGS (broyden1970convergence) and its low memory counterpart, L-BFGS (liu1989limited)).
In this work, we characterized practical trade-offs involved in designing learned optimizers, including those between performance optimizing a target task, compute and memory overhead associated with the learned optimizer, training time, choice of target task, and generalization to new tasks. Using the lessons learned from our careful exploration, we introduce an architecture that strikes a better balance between memory usage, compute, and performance. We then show that this learned optimizer architecture can be used to accelerate training on accelerator hardware.
The goal of this work was to provide a thorough investigation of the fundamental tradeoffs associated with learned optimizers. We view this paper as a first step toward the empirical characterization necessary for principled comparisons, but our experiments were limited in several ways. First, to control the number of covariates and perform the required experiments within a reasonable computation budget, we have limited ourselves to (primarily) two tasks, which themselves are simple compared with state-of-the-art neural network models. Moreover, we have provided only a limited investigation of meta-generalization and validation performance. In general, while this paper serves as a first step toward rigorous empirical comparison within this novel class of learned optimizers, further work is required to extend our results.
In order to make it easier for future research to build on our work, and to include better grounded empirical comparisons, all optimizers, tasks, and training code are open sourced in learned_optimization666https://github.com/google/learned_optimization, an open source library written in JAX for designing, training, and testing learned optimizers.
We would like to thank Chip Huyen, Ben Poole, and Amil Merchant, Wenqing Zheng, for their support and comments on this work as well as the entire Brain team for providing a wonderful research environment. We would also like to thank the authors of the python scientific computing stack including Numpy (van2011numpy), and Matplotlib (matplotlib).
Appendix A small_fc_lopt architectural details
We describe the full details of our proposed learned optimizer. The source code for this optimizer can be found in https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/learned_optimizers/adafac_mlp_lopt.py.
Our learned optimizer consists of features, concatenated then fed into an MLP. These features contain:
the parameter values
the 3 momentum values ()
the second moment value ()
3 values consisting of momenta normalized by rms gradient norm –
3 AdaFactor normalized gradient values
the tiled, AdaFactor row features (3 features)
the tiled, AdaFactor column features (3 features)
of these previous 6 features
3 features consisting AdaFactor normalized momentum values
11 features formed by taking the current timestep, , and computing where
All but the time features are normalized to have a second moment of 1 across the tensor and fed into a
hidden unit, 2 hidden layer MLP with a ReLU activation function then projected to 2 hidden dimensions representing a magnitude,, and a scalar direction to be combined to form a predicted step: where and are constants set to 0.001 to keep initial step-sizes small.
In order to reduce computational overhead, this optimizer is dramatically smaller than past learned optimizers, containing only 197 meta-parameters.
Appendix B Tasks we meta-train on
b.1 Fashion MNIST MLP
The MLP we use consists of 2 hidden layers with 128 hidden units and ReLU activations. It is trained on batches of Fashion MNIST re-scaled to lie between [0, 1] and on batch sizes of 128. We train with cross entropy loss. Our network was built in Haiku [haiku2020github] and the implementation can be found at https://github.com/google/learned_optimization/blob/32c4f21ec238a12756afe70e3d699017ea938f5d/learned_optimization/tasks/fixed/image_mlp.py#L35.
b.2 CIFAR-10 ConvNet
The ConvNet consists of 3 hidden layers with ReLU activations. All layers have kernel sizes of 3x3. The first layer has 32 units and stride 2. The following 2 layers have 64 hidden units and stride 1. All convolutions have same padding. We average over the spatial dimensions then linearly project to 10. We train on batches of 128 CIFAR-10 re-scaled between [0, 1] and use cross entropy loss. The implementation can be found athttps://github.com/google/learned_optimization/blob/ba2b56565fb507368652d2e4a12ab305a6d99ded/learned_optimization/tasks/fixed/conv.py#L96.
Appendix C nn_adam architecture
The nn_adam learned optimizer is a hyper parameter controller based learned optimizer. In addition to the description that follows, we provide an implementation at https://github.com/google/learned_optimization/blob/ba2b56565fb507368652d2e4a12ab305a6d99ded/learned_optimization/learned_optimizers/nn_adam.py#L161. For each tensor of the target problem, we compute some set of features (see C.1), feed them into a 32 unit LSTM, and output 4 values – log learning rate, beta 1, beta 2 (both parameterized as log(1 - beta)) and log epsilon. These hyperparameters are then fed into Adam where the update to each weight and accumulator follows the Adam update equations.
c.1 per tensor features
We use a same set of per-tensor features as used by the hierarchical learned optimizer in metz2020tasks. For many features, we employ a simple transformation to obtain magnitudes of features. This transformation involves computing the log of the absolute value, clipping between -5 and 5, and rescaling by 0.5.
For each tensor, we use the following as a feature set:
Transformed mean momentum value
Sign of mean momentum
Transformed variance squared of momentum
Transformed mean of the accumulator of second moment
Sign of the mean of the accumulator of second moment
Transformed mean parameter value
Sign of mean parameter value
Transformed variance squared of parameter value.
Transformed mean gradient value
Sign of mean gradient value
Transformed variance squared gradient
Transformed mean absolute value of gradient.
Appendix D Experiment details
For all experiments in this paper we meta-train with Persistent Evolutionary Strategies [pmlr-v139-vicol21a] with a standard deviation of 0.01 and length 20 truncations with length 2k inner training steps. The meta-objective we target is mean training loss (clipped at the initialization value –
in the classification problems). We use 4 distributed workers (each using a single TPU accelerator chip) in an async batched fashion. We use a meta-batch size of 4 with each meta gradient being an average from each worker which is itself an average over 8 tasks. For all models, we have an additional learner job (also a TPU chip) which averages meta-gradients and performs PES updates. We use Adam as the meta-optimizer in all experiments with gradient clipping of 3.0 done on each value of the gradient independently. Each training job has an additional 3 machines (15 for meta-generalization experiments), each with a single TPU chip to perform evaluations by training a task with the current meta-parameters. The meta-training curves we report are from these machines. All meta-training in this work took 2-4 days per experiment.
d.2 §4.1 details: Optimizer overhead vs optimizer type
For each task, and learned optimizer pair we train 3 random seeds for 5 learning rates: [1e-5, 3e-5, 1e-4, 3e-4, 1e-3]. We show the best performing optimizer. We do this as opposed to mean as the low dimensional hidden size of small_fc_lopt (4 hidden units) can result in unstable training. We found using a simple learning rate schedule improves meta-training stability, but for a fair comparison to other optimizers we do not use any schedule here.
d.3 §4.2 details: Input features experiment details
For these experiments we use a fixed learning rate, set to as this was found to be the best performing model for fc_lopt from §D.2. Given the amount of variations tried, we could not afford to search over learning rate for each configuration. For each configuration we compute 3 random seeds.
grads_time_p: Just using parameter, and gradient, and time step features.
m_0.1, m_0.5, m_0.9, m_0.99, m_0.999: Using parameter value, gradient, momentum with the listed decay value, and time step features.
m_all: Same as before but using multiple momentum values. In this case five values: 0.1, 0.5, 0.9, 0.99, 0.999.
m_mid2: Same as before but with 2 momentum values 0.5 and 0.9.
m_mid3: Same as before but with 3 momentum values 0.5, 0.9 and 0.99.
rms_0.1, rms_0.5, rms_0.9, rms_0.99, rms_0.999: Using parameter value, gradient, the second moment accumulator with the listed decay value, 1 over the sqrt of this feature, and time step features.
rms_all: Same as before but using multiple second moment accumulator values. In this case six values: 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999.
rms_mid2: Same as before but with 2 second moment values 0.9 and 0.99.
rms_mid4: Same as before but with 3 second moment values 0.5, 0.9, 0.99, and 0.999.
m_rms_0.1, m_rms_0.5, m_rms_0.9, m_rms_0.99, adams_0.999: Using parameter value, gradient, the second moment accumulator with the listed decay value, 1 over the sqrt of this feature, momentum with the listed decay, as well as the product of momentum and 1 over the square root of the second moment (similar the Adam update) and time step features.
m_rms_all: Same as before but using multiple second moment and momentum accumulator values. In this case six values: 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999.
m_rms_mid2: Same as before but with 2 accumulator timescales: 0.9 and 0.99.
m_rms_mid4: Same as before but with 3 accumulator timescales: 0.5, 0.9, 0.99 and 0.999.
adafact: Using parameter values, 6 adafactor accumulator decays (0.1, 0.5, 0.9, 0.99, 0.999, 0.9999) which are fed to the learned optimizer in the form 3 multiplications: 1 over sqrt, 1 over sqrt multiplied by the gradient, and by tiling both of the low rank accumulators.
adafact_m_mul: Same as before, but with 3 adafactor accumulators and 3 momentum accumulators of decays (0.5, 0.9, 0.99). In addition to the previous features, we also include the multiplication of momentum value by the preconditioner from adafactor.
union: The union of all features. This includes parameter value, gradient value, time features, all momentum and second moment accumulators (0.1, 0.5, 0.9, 0.99, 0.999, 0.9999), all features from adafactor computed with these same timescales, as well as multiplications of adafactor and momentum features.
d.4 §4.3 details: Large scale overhead timings
We use the ResNet18, and ResNet50 implementations from Haiku [haiku2020github] with the V2 flag set to true.
For transformers, we use vocab size of 256 to emulate byte level training, a hidden size of 768, 6 layers, and 12 self attention heads per layer. When applying dense layers we use a 4x widening factor.
Appendix E Extended meta-generalization experiments
e.1 Experimental details
Over the course of training a learned optimizer on a particular task, we monitor performance on a variety of held out tasks described here.
Fashion Mnist MLP: This is a 2 hidden layer, 128 unit MLP trained on fashion mnist. Source code can be found https://github.com/google/learned_optimization/blob/32c4f21ec238a12756afe70e3d699017ea938f5d/learned_optimization/tasks/fixed/image_mlp.py#L35.
CIFAR-10 Convnet: This is convnet with 3 hidden layers trained on 16x16 CIFAR-10. It contains 3 hidden layers starting with a 32 channels stride 2, and followed by two 64 channel, stride 1 convolutions. Average pooling is then performed before linearly mapping to the number of output channels. An implementation can be found at https://github.com/google/learned_optimization/blob/78f25e8f1e9c6236a1f559b7b0b36859c59542d2/learned_optimization/tasks/fixed/conv.py#L86
Imagenet MLP: This is an MLP operating on 16x16 resized imagenet images. The network has 3 hidden layers, of size 256. An implementation can be found at https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/image_mlp.py#L94.
Auto Encoder: This is an auto encoder trained on CIFAR-10 with mean squared error. The network consists 3 hidden layers with sizes 128, 32, 128. A full implementation can be found in https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/image_mlp_ae.py#L101.
LSTM language modeling: This is a language model trained on [DBLP:journals/corr/ChelbaMSGBK13]. The language is tokenized as bytes, and sliced into length 32 sequences. The model consists of embedding the tokens with a 64 dimensional lookup table, followed by a size 128 LSTM tasked to predict the next token. See https://github.com/google/learned_optimization/blob/aa15091066aa5b3f45e6b7f4bee1c41fb7d467a0/learned_optimization/tasks/fixed/rnn_lm.py#L142 for the impementation.
e.2 Additional figures
In this section we provide additional experimental results for meta-generalization similar to §4.5. First, in Figure 8 we show additional performance measurements on more held out tasks. As in §4.5, we see poor meta-generalization and high variability.
We plot evaluations over the course of meta-training for each different learned optimizer type and multiple random seeds in Figure 9 when meta-training on the fashion Mnist MLP, and Figure 10 for the CIFAR-10 conv net. When meta-training and evaluating on the same distribution, we find stable evaluation loss. When evaluating on other kinds of tasks, we see wide variability in performance across both architecture, and even among different initializations of the learned optimizer weights holding all else fixed. In some cases, such the learned optimizers switches between performing optimization on the target task, and diverging as shown by the rapid spikes in the meta-loss.
Finally, we show an alternative plot of the same data discussed in the previous paragraph. This time, we plot meta-evaluation performance against meta-train performance. For each figure we show each learning rate, and each seed in a separate pane. We show the small_fc_lopt optimizer in Figure 11, the rnn_fc_lopt in Figure 12, the fc_lopt in Figure 13, and nn_adam in Figure 14. Once again we find high variability across architecture, learning rate, and random seed. In these figures, meta-overfitting is highlighted by a "c" shaped curve – meta-training performance continues to improve, but meta-evaluation performance gets worse after some point. Jagged lines / instability suggest a high sensitivity in performance on the evaluation task.
Appendix F NAdamW update equations and search space
For our NAdamW baseline, we use the same implementation, and search space described in metz2020using. We repeat the functional form here for convenience.
f.1 Update equations
This optimizer architecture has 10 hyperparameters. The base learning rate, , first and second moment momentum, , , the numerical stability term, , regularization strength, AdamW style weight decay, and a boolean to switch between NAdam and Adam, . The learning rate schedule is based off of a single cycle cosine decay with a warmup. It is controlled by 3 additional parameters – , , and .
The learning rate is defined by:
The update equations of NAdamW follow.
|problem specified random initialization||(10)|
f.2 hyperparameter search space
The initial learning rate, is sampled from log space between and . is sampled logrithmically between , and . is sampled between , and . is sampled logarithmically between and . We sample using nesterov () 50% of the time. We sample and logrithmically between and
. Equal probabilities of a third we either use both terms, zero out, or zero out . With 50% probability we use a nonzero min learning rate multiplier sampled logrithmically between and . With 50% probability we sample the warm up fraction, between 1e-5 and 1e-1, otherwise it is set to zero. Finally, we uniformly sample the amount of time the learning rate is held constant () between 0 and 1.