1 Introduction
Machine learning models are often used at test-time subject to constraints and trade-offs not present at training-time. A computer vision model operating on an embedded device may need to perform real-time inference; a translation model operating on a cell phone may wish to bound its average compute time in order to be power-efficient. In these cases, there is often a tension between satisfying the constraint and achieving acceptable model performance. These constraints need not be restricted to speed and accuracy, but can reflect preferences for model simplicity or other desiderata.
One way to deal with constraints is to build them into models explicitly at training time. This has two major disadvantages: First, it requires manually designing and retraining a new model for each use case. Second, it doesn’t permit adjusting constraints at test-time in an input-dependent way.
In this work, we describe a method to change model behavior at test-time on a per-input basis. This method involves two components: The first is a model that we call a Composer. A Composer adaptively constructs computation graphs from sub-modules on a per-input basis using a controller trained with reinforcement learning to examine intermediate activations. The second is the notion of Policy Preferences, which allow test-time modifications of the controller policy.
This technique has several benefits: First, it allows for dynamically adjusting for constraints at inference time with a single trained model. Second, the Composer model can ‘smartly’ adjust for constraints in the way that is best for model performance (e.g. it can decide to use less computation on simpler inputs). Finally, it provides some interpretability, since we can see which examples used which resources at test-time.
2 The Composer Model
The Composer Model (depicted in Figure 1
) consists of a set of modules and a controller network. The modules are neural networks and are organized into ‘metalayers’. At each metalayer in the network, the controller selects which module will be applied to the activations from the previous metalayer (see also
Eigen et al. (2013); Bengio et al. (2015); Andreas et al. (2016); Shazeer et al. (2017); Fernando et al. (2017); Denoyer & Gallinari (2014)).Specifically, let the th metalayer of the network be denoted , with a special book-keeping layer, called the ‘stem’. For , the th metalayer is composed of functions, that represent the individual modules in the metalayer. Note that these modules can differ in terms of architecture, number of parameters, or other characteristics. There are metalayers, so the depth of the network is (including the stem). Once a selection of modules is made, it defines a neural network, which is trained with SGD.
The controller is composed of functions, each of which output a policy distribution over the modules in the corresponding metalayer. In equations:
(1) | ||||
(2) | ||||
(3) | ||||
(4) | ||||
(5) |
where is the input, is the ground truth labeling associated with , are the network activations after the th metalayer,
are the parameters of the probability distribution output by the controller at step
, is the module choice made by sampling from the controller’s distribution at step , denotes , and similarly denotes .3 Policy Preferences
We can augment the reward function of the controller to express preferences we have about its policy over computation graphs. Implemented naively, this does not allow for test-time modification of the preferences. Instead, we add a cost to the reward for the controller, where is a preference value that can be changed on either a per-input or per-mini-batch basis.
Crucially, is given as an input to the controller. At train-time, we sample , where is a distribution over preferences. Because of this, the controller learns at training time to modify its policy for different settings of . At test time, can be changed to correspond to changing preferences. The specific dependence of on the action and the preference can vary: need not be a single scalar value.
Below we describe two instances of Policy Preferences applied to the Composer model. Policy Preferences could also be applied in a more general reinforcement learning setting.
3.1 Glimpse Preferences
For a variety of reasons we might want to express test-time preferences about model resource consumption (see also Graves (2016); Figurnov et al. (2016)). One environment where this might be relevant is when the model is allowed to take glimpses of the input that vary both in costliness and usefulness. If taking larger glimpses requires using more parameters, we can achieve this by setting
(6) |
where
is the vector of parameter counts for each module in the
th metalayer and is the parameter count for the module chosen at the th metalayer. is applied per-element here.3.2 Entropy Preferences
Because the modules and the controller are trained jointly, there is some risk that the controller will give all of its probability mass to the module that happens to have been trained the most so far. To counteract this tendency we can use another cost
(7) |
where is the number of examples in a batch and is the vector of module probabilities produced by the controller at metalayer for batch element . Note that this reward is maximized when the controller utilizes all modules equally within a batch. is applied per-batch.
One could simply augment the controller reward with a similar bonus and anneal the bonus to zero during training, but our method has at least two advantages: First, if load balancing between specialized modules must be done at test time, using Policy Preferences allows it to be done in a ‘smart’ way. Second, this method removes the need to search for and follow an annealing schedule - one can stop training at any time and set the test-time batch-entropy-preference to 0.
4 Experimental Results
We test the claim that one can dynamically adjust the amount of computation at inference time with a single trained model using a Composer and Policy Preferences. We introduce a modified version of MNIST called Wide-MNIST to accomplish this: Images in Wide-MNIST have shape , with a digit appearing in either the left half or the right half of the image. Each image is labeled with one of 20 classes (10 for ‘left’ digits and 10 for ‘right’ digits). We train a Composer model with two modules - a large module that glimpses the whole input and a small module that glimpses only the left side. The Composer is trained using both entropy preferences and glimpse preferences. See Figure 2 for results, which support the above claim. See Appendix B for qualitative experiments.
5 Acknowledgements
We thank Kevin Swersky, Hugo Larochelle, and Sam Schoenholz for helpful discussions.
References
- Andreas et al. (2016) Jacob Andreas, Marcus Rohrbach, Trevor Darrell, and Dan Klein. Learning to compose neural networks for question answering. CoRR, abs/1601.01705, 2016. URL http://arxiv.org/abs/1601.01705.
- Bengio et al. (2015) Emmanuel Bengio, Pierre-Luc Bacon, Joelle Pineau, and Doina Precup. Conditional computation in neural networks for faster models. arXiv preprint arXiv:1511.06297, 2015.
- Denoyer & Gallinari (2014) Ludovic Denoyer and Patrick Gallinari. Deep sequential neural network. arXiv preprint arXiv:1410.0510, 2014.
- Eigen et al. (2013) David Eigen, Marc’Aurelio Ranzato, and Ilya Sutskever. Learning factored representations in a deep mixture of experts. arXiv preprint arXiv:1312.4314, 2013.
- Fernando et al. (2017) Chrisantha Fernando, Dylan Banarse, Charles Blundell, Yori Zwols, David Ha, Andrei A. Rusu, Alexander Pritzel, and Daan Wierstra. Pathnet: Evolution channels gradient descent in super neural networks. CoRR, abs/1701.08734, 2017. URL http://arxiv.org/abs/1701.08734.
- Figurnov et al. (2016) Michael Figurnov, Maxwell D. Collins, Yukun Zhu, Li Zhang, Jonathan Huang, Dmitry P. Vetrov, and Ruslan Salakhutdinov. Spatially adaptive computation time for residual networks. CoRR, abs/1612.02297, 2016. URL http://arxiv.org/abs/1612.02297.
- Graves (2016) Alex Graves. Adaptive computation time for recurrent neural networks. CoRR, abs/1603.08983, 2016. URL http://arxiv.org/abs/1603.08983.
- Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc V. Le, Geoffrey E. Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. CoRR, abs/1701.06538, 2017. URL http://arxiv.org/abs/1701.06538.
- Williams (1992) Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.
Appendix A Policy Gradient Details
For simplicity let and . Then, we use REINFORCE/policy gradients to optimize the objective
(8) |
Jensen’s inequality tells us that this is a lower bound on what we truly seek to optimize,
To maximize this expectation we perform gradient ascent. The gradient of this quantity with respect to the model parameters is
(9) |
To implement the policy preference cost functions we can modify our reward to include . In this case our objective becomes
(10) |
and the gradient becomes
(11) |
Appendix B Qualitative Experiments
This section contains various qualitative explorations of Composer models trained with and without Policy Preferences.
b.1 Module Specialization
On MNIST, modules often specialize to digit values of of 8, 5, and 3 or 4, 7, and 9, which are the most frequently confused groups of digits. See Figure 3 for an example of this behavior.
Tests on CIFAR-10 yielded similar results: See Figure 4.
b.2 Entropy Preferences
We also conducted an ablation experiment demonstrating that the batch entropy preference results in better module utilization. See Figure 5 for results. In addition, we show the result of modifying the test-time entropy preference of a model that has been trained with a standard entropy penalty (Figure 6). This could be useful in a variety of contexts (e.g. language modeling).
The effect of the batch entropy preference on module selection frequency for a Composer with 1 metalayer and 4 modules trained on MNIST. These are module selection heatmaps similar to those in the above figures. (Left) With the entropy preference, modules are utilized more equally. After 100k steps, the mutual information between module choices and class labels for this run is 0.9 nats. (Right) With a normal entropy penalty, some modules ‘get ahead’ and others never catch up. After 100k steps, the mutual information for this run is only 0.43 nats. This test was conducted with a constant, zero-variance value for the preference, so in this case the non-batch entropy preference is equivalent to the standard entropy penalty (which involves a sum of separate entropy penalties - one per example).