The learning rate is often regarded as the single most important hyper-parameter to tune and highly influences model training using gradient decent algorithms Bengio (2012); Goodfellow et al. (2016). Researchers have developed several learning rate schedules such as linear decay, cosine decay, exponential decay, inverse square root decay, etc., sometimes with warm up steps, for different optimization problems Schaul et al. (2013); Zeiler (2012). However, there is limited intuition about which learning rate schedule best suits a given problem. In practice, researchers adopt a trial-and-error approach for different learning rate schedules along with different hyper-parameters, which is very time consuming Bergstra and Bengio (2012). In this paper, we would like to automatically learn a controller that adapts learning rate schedule by incorporating information from past training dynamics. In addition, current learning rate schedules assume predefined parametric learning rate changes, which are fixed irregardless of actual training dynamics. The optimization landscape can be very complex Li et al. (2018) and these parametric schedules have limited flexibility and may not be optimized for the training dynamics of different high dimensional and non-convex optimization problems. In comparison, our framework offers an auto-learned or meta-learned adaptive learning rate schedule that adapts dynamically based on current training dynamics.
There are several related works proposing better update schedules for gradient descent algorithms. Andrychowicz et al. (2016)
propose to directly learn the gradient descent updates using a long short-term memory (LSTM) network. Our work only learns the learning rate and so is more efficient. Hypergradient takes the derivative of the learning rate and updates the learning rate based on its gradientBaydin et al. (2017). In addition to the current state, our approach also considers the entire training history and has a more comprehensive view. Daniel et al. (2016) propose to use reinforcement learning (RL) to adapt the learning rate. In comparison, we use validation loss as the reward signal and a learning rate scaling function as the action. They improve the generalization capability and stability. There are a family of widely used optimizers that dynamically adapt the learning rate on a per-parameter basis. For example, Adagrad adapts the learning rate per weight based on the sums of the squares of the gradients Duchi et al. (2011), while Adam uses an exponentially decayed average of past gradients Kingma and Ba (2015). However, these optimizers still require a global learning rate which is important to tune. Our work is complementary to these works.
This paper makes three main contributions: First, we propose a reinforcement learning based framework to automatically learn an adaptive learning rate schedule based on past training histories. This schedule can adjust the learning rate dynamically to adapt to current training dynamics. Second, we present an effective set of state observation features, reward functions, and actions for the learning rate decision problem. Specifically, different from the previous work, we use validation loss as the reward signal and a learning rate scaling function as the action. Third, we conduct experiments on Fashion MNIST and CIFAR10 datasets with convolutional neural networks (CNN)LeCun et al. (1998) and residual networks (ResNet) He et al. (2016a, b) to show the effectiveness and generalization capability of our framework. The auto-learned learning rate schedule can achieve better results and generalize to different datasets.
2 An Auto-learned Adaptive Learning Rate
2.1 Controller and Trainee Network
In our framework, we use RL to train a learning rate controller, which proposes learning rates using features from the training dynamics of the trainee network. The trainee network is trained for a certain number of steps using a proposed learning rate, reports the observations of training dynamics to the controller which then returns a new learning rate. The whole process keeps running until reaching a certain stopping criterion.
2.2 State Observation, Reward, Action
Observation In order to characterize the training dynamics, we design a set of state observation features , which is an extension of the features used in Daniel et al. (2016)
. The proposed features include the current train loss, validation loss, variance of network predictions, variance of network prediction changes, mean and variance of the weight matrix of the final dense layer, and the previous step learning rate. As a general principle, we want our features to be easy to compute and generalizable to different trainee model architectures. This is why we only use the moments of the weight matrix of the final dense layer instead of all layers.
Reward We use the per step validation loss as the reward . Even though the final validation loss is what we really care about, empirically, we find that providing a reward for each training step can achieve better results compared with just using the final reward. These intermediate rewards provide more direct feedback and make credit assignment much easier.
Action The most direct action of the controller could be proposing a new learning rate. However, the learning rate is very sensitive and could be in the scale or even smaller. It would be very unstable to directly use the network output as the learning rate. Another choice of action could be proposing the log of the learning rate. However, we want our action to be generalizable to different data sets since they may require different learning rate scales. Instead, we propose a learning rate scaling action . At the first step, we provide a default learning rate. In the following steps, we use the network output as the scaling factor for the previous step learning rate, which can scaling it up or down. In this case, the controller can provide both warm up and decay capabilities in a stable way. This action provides a better inductive bias keeping learning consistent across steps.
2.3 Auto-learned Adaptive Learning Rate Schedule with Proximal Policy Optimization
In this section, we present the RL-based learning rate controller, which is trained to propose a better learning rate schedule in order to reduce the validation loss in the model training process. Since the reward is validation loss and it is non-differentiable, we use policy gradients, and in particular the proximal policy optimization (PPO) algorithm Schulman et al. (2017, 2015) to learn the controller parameters as it has better sample complexity. PPO optimizes a clipped surrogate objective function:
to penalize large policy updates.
is the importance weighting probability ratio.is the advantage function at time . is a hyper-parameter.
3 Experiment Setting
3.1 Data Set and Model Architecture
In our experiments, we use two data sets: Fashion MNIST Xiao et al. (2017) and CIFAR-10 Krizhevsky and Hinton (2010). For Fashion MNIST, we use 50k, 10k, and 10k images as the training, validation and test sets. For CIFAR10, we use 40k, 10k, and 10k images as the training, validation and test sets. We test on two model architectures: CNN and ResNet. CNN follows the structure of LeNet LeCun et al. (1998). ResNet follows the structures from He et al. (2016a, b)
. We use the open source TensorFlow implementationTensorflow (2019).
3.2 Baseline Learning Rate Schedule
We use a popular step decay learning rate schedule as our baseline Ge et al. (2019), which is used in the open source ResNet implementation Tensorflow (2019). It contains three components: initial learning rate, discount step, and discount factor. The baseline schedule starts from the initial learning rate, then it decreases by the discount factor every discount steps. In the baseline experiments, we test all combinations from the initial learning rate in , the discount step in , and the discount factor in
. After choosing the best baseline schedule, we run it 10 times with the same set of hyper-parameters and report mean and standard deviation of test loss and accuracy.
|Dataset||Model||Test Loss||Test Accuracy||Test Loss||Test Accuracy|
|Fa. MNIST||CNN||0.2497 (0.0042)||0.9102 (0.0019)||0.2351 (0.0038)||0.9201 (0.0022)|
|Fa. MNIST||ResNet||0.2346 (0.0074)||0.9188 (0.0029)||0.2296 (0.0069)||0.9192 (0.0028)|
|CIFAR10||CNN||0.9539 (0.0140)||0.6759 (0.0048)||0.9361 (0.0104)||0.6787 (0.0041)|
|CIFAR10||ResNet||0.8317 (0.0155)||0.7395 (0.0206)||0.6288 (0.0196)||0.8181 (0.0069)|
Performance comparison between baseline and auto-learned learning rate schedules. We report the mean and standard deviation of the test set loss and accuracy derived from 10 runs with the same set of hyper-parameters. Parentheses denote standard deviations. We bold the best numbers for each model/dataset pair. Asterisk denotes the improvement is statistical significant under independent two-sample t-test with p-value threshold. Fa. stands for Fashion.
3.3 Training Details
We use validation loss as the reward signal and compare the hold-out test set loss and accuracy between baseline learning rate schedules and our auto-learned schedules. For trainee network training, the initial learning rate is in
, the train batch size is 1k, and we run 1k training steps for all model architectures on all data sets. Given 50k training set size, 50 train steps equal to 1 train epoch. In total, we train 20 epochs for Fashion MNIST and 25 epochs for CIFAR10. We will use the number of train steps for future plots. In experiments, the controller proposes a new learning rate every 10 training steps. We refer to the whole 1k train steps of the trainee network as one training episode. The controller network is trained after every training episode of the trainee network. The actor of the controller is a multilayer perceptron (MLP) that contains one hidden layer with size 32. We also test a LSTM actor with hidden size 32. The critic is a MLP that contains one hidden layer with size 32. The learning rate is 0.001 for actor and 0.005 for critic. Note that the main goal of our experiments is to show the effectiveness of the auto-learned learning rate schedule instead of outperforming the state-of-the-art classification accuracy on the target task. In our experiments, we restrict the training epoch to 25 and use ResNet with 18 total layers for computational reasons.
4 Experiment Results
4.1 Test Set Results And Training Trajectories
Table 1 shows the performance comparison between baseline and auto-learned learning rate schedules on test loss and accuracy. We choose the checkpoint based on the best validation loss and evaluate it on the test set once to get the single run test loss and accuracy. From this table, we can see the auto-learned learning rate schedule achieves better results on all tasks. We hypothesize because it does not follow the predefined parametric learning rate changes and has the higher flexibility to adapt the learning rate based on the training dynamics. Figure 1 shows the training dynamics comparison between baseline and auto-learned learning rate schedules in terms of the log of validation loss. For CNN on Fashion MNIST and ResNet on CIFAR10 in Figure 1(a) and Figure 1(d), the auto-learned schedule achieves lower validation loss faster.
4.2 Auto-learned Learning Rate Schedule
In this section, we present the auto-learned learning rate schedules as shown in Figure 2. For the CNN model on Fashion MNIST in Figure 2(a) and ResNet model on CIFAR10 in Figure 2(d), the controller learns to warm up first and then decay the learning rate. This fits human’s intuition when designing some existing learning rate schedules, but it is automatically learned and does not follow a predefined trajectory. For ResNet model on Fashion MNIST in Figure 2(b), the auto-learned learning rate schedule is similar to an exponential decay learning rate schedule. For CNN model on CIFAR10 in Figure 2(c), the auto-learned learning rate keeps flat at first, then warms up and decays later.
We also investigate the transferability of a trained controller network when applied to different datasets. In the experiments, we load the controller checkpoints of CNN and ResNet models on CIFAR10, and let them propose the learning rate schedule for CNN and ResNet models on Fashion MNIST accordingly without training the learning rate controller. In comparison, we also apply the best baseline learning rate schedules of models on CIFAR10 to Fashion MNIST. From Table 2 we can see the trained controller is transferable between two data sets. Our learning rate controller does not simply memorize, but is able to learn a transferable procedure to tune learning rate that generalizes.
|Transferred Baseline||Transferred Controller Network|
|Model||Test Loss||Test Accuracy||Test Loss||Test Accuracy|
|CNN||0.2730 (0.0031)||0.9021 (0.0013)||0.2598 (0.0071)||0.9074 (0.0030)|
|ResNet||0.2443 (0.0040)||0.9166 (0.0025)||0.2315 (0.0072)||0.9212 (0.0034)|
In this paper, we propose a reinforcement learning based framework which can auto-learn an adaptive learning rate schedule based on the information from past training histories. In order to achieve this goal, we introduce an effective set of features to characterize the dynamic training process, meaningful reward function, action space, and a sample-efficient RL algorithm to adapt the learning rate dynamically. Experimental results on Fashion MNIST and CIFAR10 data sets with CNN and ResNet models show our framework can learn a better learning rate schedule compared with step decay baseline schedules and can be transfered to new datasets never seen during meta-training.
- Learning to learn by gradient descent by gradient descent. In Advances in Neural Information Processing Systems, pp. 3981–3989. Cited by: §1.
- Online learning rate adaptation with hypergradient descent. arXiv preprint arXiv:1703.04782. Cited by: §1.
- Practical recommendations for gradient-based training of deep architectures. Neural networks: Tricks of the trade, pp. 437–478. Cited by: §1.
Random search for hyper-parameter optimization.
Journal of Machine Learning Research13 (Feb), pp. 281–305. Cited by: §1.
Learning step size controllers for robust neural network training.
Thirtieth AAAI Conference on Artificial Intelligence, Cited by: §1, §2.2.
- Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research 12 (Jul), pp. 2121–2159. Cited by: §1.
- The step decay schedule: a near optimal, geometrically decaying learning rate procedure. arXiv preprint arXiv:1904.12838. Cited by: §3.2.
- Deep learning. MIT press. Cited by: §1.
- Deep residual learning for image recognition. In , pp. 770–778. Cited by: §1, §3.1.
- Identity mappings in deep residual networks. In European conference on computer vision, pp. 630–645. Cited by: §1, §3.1.
- Adam: a method for stochastic optimization. In The International Conference on Learning Representations (ICLR), Cited by: §1.
Convolutional deep belief networks on cifar-10. Unpublished manuscript 40 (7). Cited by: §3.1.
- Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §1, §3.1.
- Visualizing the loss landscape of neural nets. In Advances in Neural Information Processing Systems, Cited by: §1.
- No more pesky learning rates. In International Conference on Machine Learning, pp. 343–351. Cited by: §1.
- Trust region policy optimization. In International Conference on Machine Learning, pp. 1889–1897. Cited by: §2.3.
- Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347. Cited by: §2.3.
- ResNet in tensorflow. Note: From https://github.com/tensorflow/models/tree/master/official/resnet Cited by: §3.1, §3.2.
- Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747. Cited by: §3.1.
- ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701. Cited by: §1.