distilling-policy-distillation
Presentation and code notebook on policy distillation
view repo
The transfer of knowledge from one policy to another is an important tool in Deep Reinforcement Learning. This process, referred to as distillation, has been used to great success, for example, by enhancing the optimisation of agents, leading to stronger performance faster, on harder domains [26, 32, 5, 8]. Despite the widespread use and conceptual simplicity of distillation, many different formulations are used in practice, and the subtle variations between them can often drastically change the performance and the resulting objective that is being optimised. In this work, we rigorously explore the entire landscape of policy distillation, comparing the motivations and strengths of each variant through theoretical and empirical analysis. Our results point to three distillation techniques, that are preferred depending on specifics of the task. Specifically a newly proposed expected entropy regularised distillation allows for quicker learning in a wide range of situations, while still guaranteeing convergence.
READ FULL TEXT VIEW PDFPresentation and code notebook on policy distillation
The MDPs used in this study are grid worlds, meaning that the state space is .
is a special state, to which an agent is moved with probability 0.01 after each action, ensuring finite length of the experiments considered. There is one initial state placed in the centre of the grid,
. There are four possible actions , each of them has an associated desired effect, namely , , , . Some transitions are invalid, as they would lead to leaving the state space, thus we defineThe transition dynamics are defined as:
where is the transition noise.
Rewards are associated with some states, and are fully deterministic.
Some states are terminal, which cause the episode to end, and bring the agent back to the initial state.
We considered partial, and fully observable versions of these environments. In the fully observable environments, the agent is given the state index as an observation, while in the partially observable environments a concatenated sequence of objects, namely is represented as
where is wall if and a pair otherwise. For example, if the state provides reward 10 and is terminating, then it will be observed as (10, True).
In all partially observable experiments, we use observations which are concatenations of squares of vision, centered in an agent position. We experimented with visual extents ranging from to full observability and found that this does not effect the qualitative results of the paper, thus the choice of the particular visual extent is not crucial.
In all experiments where we sample multiple MDPs we use the following procedure:
We create as described in the previous section.
For each , starting in the upper left corner and traversing first horizontally and then vertically:
With probability we remove from , which we call putting a wall in; if we modified a state, we go back to step 2 and continue the loop.
With probability we put a reward of +10 in and make it terminal; if we modified a state, we go back to step 2 and continue the loop.
With probability we put a reward of +5 in and make it terminal; if we modified a state, we go back to step 2 and continue the loop.
With probability we put a reward of -1 in ; if we modified a state, we go back to step 2 and continue the loop.
With probability we put a reward of -5 in and make it terminal; if we modified a state, we go back to step 2 and continue the loop.
With probability we put a reward of -10 in and make it terminal; if we modified a state, we go back to step 2 and continue the loop.
We check if there exists a path between the initial state and the state, and if this is not true, we repeat the process.
Unless otherwise stated in the text, we use , .
We use a basic actor critic method, where we sample one full episode under the student policy,
, and then update the parameters according to either the single sample Monte Carlo estimated return:
or, in the TD(1) case, with bootstrapped estimates
In all experiments we used , but we obtained qualitatively similar results with other values too ( and ).
After each update we use the same Monte Carlo or TD value to fit the baseline function, using the L loss:
or
in the case of TD learning, where is treated as a constant. All Vs are initialised to 0s. The learning rate used is .
We use the standard Q-Learning update rule of
applied after each visited state. All Qs are initialised to 0s. The learning rate was set to . The policy was trained for 30k iterations.
When treating the Q-Learned policy as a teacher, depending on the temperature reported (by default 0) it was either a greedy policy (if the temperature is 0)
or a Boltzman policy computed as:
Policies are represented as logits of each action, for each unique observation. Consequently for each observation
, and for action space the policy for Actor Critic is parameterised as .Similarly, value functions are represented simply as one float per observation: , and Q-values .
We include extended versions of various figures. Fig. 2 is an extended version of Fig. 9 including experiments with an A2C teacher.
Distilling Policy Distillation
Wojciech Marian Czarnecki &Razvan Pascanu &Simon Osindero DeepMind &DeepMind &DeepMind Siddhant Jayakumar &Grzegorz Świrszcz &Max Jaderberg DeepMind &DeepMind &DeepMind
Reinforcement Learning (RL) and in particular Deep Reinforcement Learning (DRL) has shown great success in recent years, allowing us to train agents capable of playing Atari games from raw-pixel inputs [16, 7], beating professional players in the game of GO [30] or learning dexterous robotic manipulation [19]. However, obtaining these levels of performance can often require an almost prohibitively large amount of experience to be acquired by the agent in order to learn [25]. Consequently, there is great interest in techniques that can allow for knowledge transfer; that is enabling the training of agents based on already trained policies [24] or human examples [1]. One of the most successful techniques for knowledge transfer is that of distillation [10, 24, 23]
, where an agent is trained to match the state-dependent probability distribution over actions provided by a teacher. Some examples include allowing us to: train otherwise untrainable agent architectures
[5]; speed up learning [26]; build stronger policies [23]; drive multi-task learning [2, 24]. Analogous techniques have been widely used in supervised learning problems to achieve model compression
[10, 11, 22], reparametrisation for inference speedup [18], and joint co-training of multiple networks [34].Although the high-level formulation of distillation in RL is simple, one can find dozens of different mathematical formulations used in practice. For example: sometimes trajectories are sampled from a teacher [24], sometimes from the student [13] or a mixture [23]; some authors use a KL divergence between teacher and student distribution [9] while others look at KL between entire trajectory probabilities [32]. A primary goal of this paper is to provide a roadmap of these different ideas and approaches, and then to perform a step-by-step comparison of them, both mathematically and empirically. This allows us to construct a set of useful guidelines to follow when trying to decide which specific distillation approach might best fit a particular problem.
The main contributions of this paper can be summarised as follows: In Section 5.1
we provide a proof that commonly used distillation with trajectories sampled from the student policy does not form a gradient vector field, and while it has convergence guarantees in simple tabular cases, it can oscillate as soon as one introduces rewards to the system. We show simple methods of recovering the gradient vector field property. In Section
5.2 we perform empirical evaluation of different control policies, showing when and why it is beneficial to use student-driven distillation. In Section 6 we analyse the actor-critic setup, in which one has access to a teacher’s value function, , in addition to its policy, and discuss howmay also be used for distillation. We empirically evaluate all the above techniques in thousands of random MDPs. Finally, based on all the results combined from our mathematical analyses and empirical evaluation, we propose effective new distillation variants and provide a rule of thumb decision tree, Fig.
5.Throughout the paper we assume that we are working with Markov Decision Processes, and will now outline our notation. We have a finite set of states,
, a finite set of possible actions, , an agent policy which outputs distribution over actions in each state. Agents interact with an environment at time by sampling actions , and the environment transitions to a new state according to unknown transition dynamics and produces rewards . Each is the state encountered at time in the trajectory . We use to denote an expectation over the distribution of trajectories, , generated by agent when interacting with the environment using policy . Under this notation the typical goal of reinforcement learning is to find , where is discount factor. For simplicity we use in most of our theoretical results, though the proofs can trivially be extended to arbitrary . For the empirical results we generally use , see the Appendix A. We consider the general problem of extracting knowledge from a teacher policy, , and transferring it to a different student policy, , using trajectories, , sampled from interactions between a control policy, , and the unknown environment.All proofs are provided in the Appendices C and D.
name | is ? | Loss | |||
---|---|---|---|---|---|
Teacher distill | 0 | yes [1] | |||
On-policy distill | 0 | no | does not exist | ||
Entropy regularised | 0 | yes [4] | |||
N-distill | - | yes | |||
Exp. entropy regularised | yes | ||||
Teacher V reward | 0 | yes |
Through the rest of this paper we consider update rules for (parameters of the student policy ) which are proportional to:
(1) |
for and a choice of and that define a specific instance of a distillation technique (see Table 1 for a list of examples). In this equation, can be seen as a form of auxiliary loss [12] responsible for policy alignment at the current step, while can be viewed as a reward term that combines extrinsic and intrinsic components [21] and thus is responsible for long-term alignment. Note, we assume undiscounted objectives and episodic RL, but analogous analysis can be performed for the discounted case.
Focusing on update rules rather than simply losses may seem to add unnecessary complexity, however one of the crucial outcomes of our work is to show that the update rules involved in certain distillation methods do not have corresponding loss functions. Consequently we must make an explicit distinction between
update rules, and losses which may be used to derive update rules.Many RL distillation frameworks set up knowledge transfer as a supervised learning problem [23, 24, 9, 33], by following updates in the direction of: with Monte Carlo estimates for the expectation based on trajectories derived from the teacher policy, . However, since then, several publications [20, 5, 26] have reported better empirical results when trajectories are sampled from the student instead, i.e. by following updates in the direction of: Note that in this form of update the gradient operator is under an expectation wrt. the same set of variables that it operates upon. Consequently it is not clear if this process will even converge, and thus the benefits of using such updates are also unclear.
In this section we analyse and prove the following properties: (i) For tabular policies, provided guarantees a non-zero probability of sampling each state visited by the teacher, the dynamics will converge. In particular with a softmax policy satisfies this property; (ii) In general, updates like this do not form gradient vector fields; (iii) If one adds reward optimisation to the system, the dynamics can cycle and never converge; (iv) A reward-based correction term can be added to ensure convergence (and with such a correction, the updates do correspond to proper gradient vector field); (v) There is a trade-off between the speed of convergence and the fidelity of the behaviour replication, which can be controlled by .
We begin by proving a general theorem about on-policy non-gradients. In principle, it is very similar to the notion of compatibility of a value function and the policy [31], or can be seen as a generalisation of incompatibility towards other possible trajectory level losses (e.g. ).
Let us assume that is differentiable and there does not exist such that almost everywhere. Then is not a gradient vector field of any function. If gradient of some is differentiable then ’s Hessian exists and is a symmetric matrix:
Consequently, if some function is a gradient vector field, then its Jacobian has to be symmetric. We will show that for this is not true in general, by focusing on two arbitrary indices and . We use notation to denote the th output of the multivariate function . Using the log derivative trick we obtain that equals
thus equals
In general this term is zero iff almost everywhere, which can not be true due to assumptions. Consequently, is not a gradient vector field of any function.
The assumption about the existence of is equivalent to the compatibility criterion [31], and thus it shows that incompatible value functions do not create valid gradient vector fields. This is complementary to the result that compatible value functions provide convergence to the optimal policy.
If we choose we recover the on-policy distillation updates used in techniques such as kickstarting [26] and Mix&Match [5]. In this setting there is no corresponding that simply rescales policy logits, and thus as a consequence of Theorem 5.1 we see that naive distillation with student-generated trajectories does not form a gradient vector field. We also note that exactly the same proof shows that the entropy penalty [15] , commonly used in actor critic algorithms, also results in updates that do not correspond to a valid gradient vector field.
Having seen that these commonly used updates do not correspond to a valid gradient vector field, a natural question to ask is whether this is necessarily problematic. For example – the updates used in Q-learning are not gradient steps either, but Q-learning still provides a convergent iterative scheme. We address the question of what can be said about the dynamical system emerging from this sort of distillation, and for a simple tabular setup we can show that indeed this is not an issue: Using an update rule of the form for a strongly stochastic^{1}^{1}1Meaning that each for each action , parameters and state , . student policy, with episodic finite state-space MDPs and tabular policies, provides convergence to the teacher policy over all reachable states for the loss function , provided the optimiser used can minimise wrt. , for any in the domain of , and reaches minimum at .
Because of strong stochasticity of , the distribution of states visited under this policy covers entire state space reachable from the initial state. We use notation . Consequently the update
can be rewritten as
where is the probability of agent being in state when following policy and we use the independence of parametrisation of the policy in each state (which comes from the tabular assumption – is the parametrisation of policy in state ).
Let us denote by gradient of a an expected loss under teacher policy
where again is the probability of sampling state under .
It is easy to notice that these two update directions have a non-negative cosine:
Furthermore, because for all , , the cosine is zero if and only if for each state either (teacher and student policies match) or (state is not reachable by ). This means that for every state, reachable by , the corresponding update rule coming from is guaranteed to be stricly descending as long as it is not in the minimum.
Due to assumptions about having a unique minimum and optimiser being able to find it, we obtain that will converge to for each where .
Consequently we have shown, that the update direction is a strict descent direction wrt. expected loss under the teacher policy and thus student policy converges to the teacher one over all reachable states.
Using Monte Carlo estimates for the
estimation can be analysed analogously to how Stochastic Gradient Descent generalises Gradient Descent.
(columns) and returns when following teacher policy. Shaded region represents 0.95 confidence intervals, estimated as 1.96
standard error of the mean. Right: Relation between student-driven distillation speed-up (measured as ratio of areas under reward/KL curves) and teacher determinism.However, despite this positive result, we can also show that even in the episodic MDP case one can break convergence if we introduce rewards. The counterexample, visualised in Fig. 6 and described in detail in the Appendix C, relies on a teacher that can discriminate between some states that the student cannot. It leads to an oscillation – the student policy will never converge even after infinitely many steps.
Consider a game with seven states, . We start at and in the first step we decide whether to go to or . If we chose to go to , in step 2 we chose whether to go to or to . Similarly, if we are in after round 1, in step 2 we have a choice whether to go to or . The only rewards are , , and . In the game we use a policy depending on two parameters and as follows. In the first step we go to with probability and to with probability . In step 2 we have two branchings again, if we are in with probalility we go to , and with probability we go to . Similarly, if we are in we go with probalility to , and with probability we go to . We choose a penalty function , living in the state , when we are in in step , is zero. Equivalently one can think of it being a distillation cost with an information potential loss, where the teacher . We have an update rule
This system of equations has a first integral (with integrating factor ). Note, that , therefore the fixed point is a center. Therefore, with each policy update the values stay on the same closed curve and they keep changing in a cyclic manner, never converging.
There are multiple possible ways to construct on-student-policy distillation learning methods similar to the ones used in practice, but which do provide update rules that are gradient vector fields^{2}^{2}2It is worth noting that the typical trick of importance sampling is not viable here. First, it is unclear what sampling distribution to correct with respect to – one can choose any distribution that is independent of , thus it could be teacher policy, but also say – a uniform one. Second, mathematically this simply leads to degeneration to optimisation of the corresponding loss, such as teacher distill rather than on-policy method.. One such way is to start from the objective suggested by the update rule component, namely: Then we compute its gradient using the log-derivative trick, analogously to how the KL-regularised RL objective is derived [27] or how Stochastic Computation Graphs are obtained [28]; doing so gives the update direction:
As we can see, the gradient vector field is composed of two expectation terms. The first term corresponds to the 1-step on-policy distillation setup discussed so far. The second term corresponds to the standard RL objective if plays the role of the reward function. This simple derivation allows us to prove the following: In order to recover the gradient vector field property for 1-step on-policy distillation updates with any loss , one can add an extra reward term . Analogously if the loss is of the form then the correction is of form Consider the following loss and its gradient:
using the log-derivative trick and the above equation we get
Consequently, we obtain that the valid gradient of the loss considered is composed of two expectations, one being the equivalent of a RL target, but with being a negation of the reward, and one which is exactly the auxiliary cost of interest. Consequently if we add the reward at time equal to minus loss at time we will recover proper gradient vector field.
For the case of a loss of the form this proof is analogous – simply the correction is not on a state-action pair level, rather a pure state level.
Since this is a gradient vector field, it can be safely composed with reward based updates without losing any convergence properties. As one can see on the right of Fig. 6 applying this correction to On-policy distill+R (and thus creating the N-distill+R), leads to convergence and minimisation of the loss, as expected.
Given the potential convergence issues with the naive updates from Equation 1, particularly when also considering reward from the environment as highlighted by our counterexample, it begs the question: Why does following the student’s policy when performing distillation typically lead to better empirical results? Our main hypothesis is that, if convergent^{3}^{3}3In practise researchers often force convergence by learning rate annealing, early stopping etc. so the issues highlighted here may often be masked., it provides more robust policies wrt. trajectories sampled from the student.
This follows the general machine learning principle of training in the same regime as that which we expect to encounter during test time. In particular, after distillation, the goal is usually to either evaluate a student agent when it is generating its own actions, or to allow the student agent to continue training on its own. Therefore what matters is an expectation wrt.
, and not wrt. . Consequently performing distillation “on-policy” with respect to student trajectories leads to less of a distribution-shift between training and testing phases.Another motivating argument is that if the teacher is almost deterministic, then it visits a relatively small fraction of the state space, even though during training it might have built a policy to deal with other situations too. When using during distillation, the student will not have the opportunity to replicate the teacher’s behaviour in these states, since it will not visit them. In general, after distillation, and they diverge quickly in complex environments or over long trajectories. Again, on-student-policy distillation avoids these issues, especially initially – many states will be visited that are not normally encountered under the teacher policy. Consequently one can expect better replication of the teacher, when measured in the entire state space.
The main observations of this section are: (i) on-policy 1-step distillation updates do not form gradient fields, and when mixed with environment rewards can lead to non-convergent behaviour; (ii) distillation using student-generated trajectories replicates the teacher policy in more states that are relevant under the student’s behaviour distribution. The following empirical section highlights the effects that different choices for the control-policy have in practice.
We consider teacher driven [24], student driven [20] and fixed (uniform) control policies. We define a distribution over grid world tasks, where we randomly place walls, terminating and rewarding states in 2020 2D grid worlds (see Appendix A.2 for a detailed description of the MDP generating procedure), and agents are capable of moving in 4 directions. There is a fixed probability of terminating each episode, such that we end up with bounded (undiscounted) returns. We sample 1k MDPs like this, and distill for 30k optimisation steps using various control policies . Since we have proven that in the tabular case we can use distillation based on per-step cross-entropy, H, this is the loss we are minimising, using a gradient based update to the underlying logits.
We use teachers trained with Q-Learning and -greedy policies ( set to 0.1, full details provided in the Appendix A.4) and observe how different types of control policy affect the distillation loss. We measure for various choices of , which can be different from used for distillation.
As predicted by our theoretical analysis, student-driven distillation brings benefits wrt. the loss computed over student trajectories. In other words, if one cares about how closely the student behaviour matches the teacher behaviour when the student agent is allowed to generate experience on its own, student-driven distillation optimises this quantity well, with the gap disappearing as the teacher is more and more uniform (entropic), see Fig. 7. Similarly, matching of the teacher policy outside of typically visited states (measured by
being uniform) is also much better when following a student-driven policy. The best result, in terms of whole state space, is achieved when using uniform distribution for
, but this is a control setting which does not scale to larger state or action spaces. And when we compare this control setting to the student-based setting we see that it converges extremely slowly even in these scenarios.Similarly, in terms of the returns obtained by the agent, student driven distillation sees the fastest learning progress, and needs on average 3 less steps than teacher driven distillation (and around 10 less than uniform driven) to recover full teacher performance. This is of crucial importance, since distillation in RL is typically a first step in a larger training procedure in which the rewards obtained by the student policy are to be maximised with potentially further training. In order to be useful for these applications, one typically seeks to obtain highly rewarding policies as rapidly as possible. The only criterion under which teacher driven distillation works more effectively is, somewhat obviously, the expected KL under trajectories generated from teacher distribution. However this is an artificial scenario, which is rarely encountered in practice.
To summarise, student-driven distillation provides significant improvements in terms of empirical results over teacher-driven distillation. While one could heuristically drive the switch between the two
[23], the pure student-driven method seems to be strong enough to use solely, as long as a proper loss is being used, which we discuss in the next section. Therefore, in the remainder of this paper we focus on student driven distillation .An important choice is the selection of which method we use to update the student policy given the trajectory and actions suggested by the teacher policy. There are two popular approaches here: one is to try to maximise the probability of the trajectory generated by the student under the teacher policy [32, 27], and the other is to frame the learning as a per-timestep supervised learning problem, defining the loss at each timestep to be the cross entropy between the teacher’s and student’s distributions over actions [2, 26, 5]
There are two aspects worth discussing. Firstly, in the entropy regularised setup, where we use the student would be considered a prior, while the teacher a posterior, however in the distillation setup with the teacher is the prior and the student the posterior. Secondly, the cross entropy regularised approach (which tries to minimise the cross entropy between whole trajectories distributions ) can be absorbed in the reward channel, without any
being used. Using only the reward signal can suffer from very high variance in the gradient estimator, for example when the action space is large. As one can see from Fig.
8, as we increase the size of the action space from four actions up to 4k actions (by adding many actions that do not move the agent) the speed of entropy regularised distillation drastically collapses, while traditional distillation still works well.One can reduce the variance by splitting the entropy term into a 1-step update expressed by , and incorporate the remaining updates through . This technique, denoted as expected entropy regularised, indeed recovers performance of traditional distillation. While this is not a new objective as such, but rather a different estimator, it is, to our best knowledge a novel method, which strictly dominates popular alternatives.
The direction chosen for the cross entropy has a very simple intuitive explanation. If one uses then one tries to replicate , however if one uses then one tries to find a deterministic policy, which puts all probability mass on the most probable action of (see the Appendix C for a proof). If the cost is changed to be the KL divergence instead, this issue is eliminated for finite action spaces, however for continuous control it is a matter of mean vs median seeking techniques [14] (see the Appendix C for more detail). Depending on the MDP, both methods can be beneficial, and of course for almost deterministic teachers – they actually are equivalent.
To summarise, these experiments demonstrate that across the space of possible student-driven distillation approaches the most reliable method, both mathematically and empirically, is our proposed expected entropy regularised distillation. It has three key benefits: (i) it creates a valid gradient vector field; (ii) it does not suffer from high-variance of the estimate typically used in similar methods [32]; and (iii) it directly maximises the probability of the student produced trajectories under the teacher policy, as opposed to the n-distill method, which looks at maximising the probability of being in states where the student and teacher agree. Consequently, it combines the best elements of various similar techniques in a single method, and avoids their respective drawbacks.
Let us fix a distribution , and consider a minima of and wrt. . It is easy to see that the minimum of is given by , as by the definition of divergence, the minimum of is given by , and , but for a fixed , is a constant, thus it does not affect the minima. For we will show that the minimum is given by the dirac delta distribution in the most probable action in , denoted as . For simplicity, assuming that this is a unique action, meaning that , then for any
While both and have the same minimum in the space of all distributions, they differ once one constrains the space we are looking over. To be more precise we have that
but at the same time there exists where is the space of all distributions such that
The simplest example is the mixture of multiple Gaussians, which we try to fit with just a single Gaussian. The typical cost of will match the mean of the distribution (thus the name of mean seeking), while will cover one of the Gaussians from the mixture, while ignoring the others (thus mode seeking), see Fig. 11 and Fig. 12.
In practice, we are often in this regime, since the teacher and student policies can have different capacities, architectures and priors, thus making perfect replication impossible. Therefore, the choice of direction of KL will affect if the agent prefers to just match one, very probable mode (action/behaviour), or if we prefer the agent to look for an averaged action/behaviour.
In practice, while pure policy based methods remain useful [23] in RL, the Actor-Critic framework [15, 29, 7] has risen to prominence in recent years. Consequently we now shift our attention towards distillation strategies that make use of value critics, denoted by . The availability of this additional knowledge, in the form of , allows us to better leverage imperfect teachers, as we can begin to estimate how much to trust them.
For example, let’s assume that we have ground-truth access to and , and consider a loss of the form: where iff . This can be seen as an action-independent version of the Generalised Policy Improvement technique [4]. As a result we can easily prove that after converging, our agent will be at least as strong as its teacher, independent of the initial returns of . For being a distribution over initial states, if we have then . Lets assume that the inequality does not hold, meaning that the following teacher’s policy gives higher return. This means, that there exists a state , where but the policies differ, meaning that . However, if then , and due to the assumption for every state, leads to (as cross entropy is equal to entropy of the first argument only when the argument are the same), which is a contradiction.
With techniques like this we can use teachers which do make mistakes. One such experiment, where we randomly flip the most probable action of the teacher policy in a given percentage of states it visits, is illustrated in Fig. 9. As before we generate 1k random MDPs, and empirically estimate by sampling 100 trajectories after a given fraction of state-action pairs have been modified. As one can see, methods that fully replicate the teacher (in this case a Q-Learning based one, with 25% noise) end up saturating in sub-optimal solutions – around 2-4 points (Fig. 9 blue and cyan curves). If we add true rewards to the learning system, we improve, but still saturate around 6.5 points (Fig. 9 dark green curve). With the value function based gating described in Proposition 6 we recover the full performance of the teacher with approximately 7 points (Fig. 9 lime curve). As expected, the usefulness of this approach depends on the quality of the teacher and the accuracy of the value function estimators.
Yet another way of using the teacher’s knowledge is to use its value functions, , instead of the student’s, , for bootstrapping – so that the usual actor-critic TD(1) update direction: