GAN Q-learning

05/13/2018 ∙ by Thang Doan, et al. ∙ McGill University 0

Distributional reinforcement learning (distributional RL) has seen empirical success in complex Markov Decision Processes (MDPs) in the setting of nonlinear function approximation. However, there are many different ways in which one can leverage the distributional approach to reinforcement learning. In this paper, we propose GAN Q-learning, a novel distributional RL method based on generative adversarial networks (GANs) and analyze its performance in simple tabular environments, as well as OpenAI Gym. We empirically show that our algorithm leverages the flexibility and blackbox approach of deep learning models while providing a viable alternative to other state-of-the-art methods.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Reinforcement learning (RL) has recently had great success in solving complex tasks with continuous control trpo ; ppo

. However, as these methods often have high variance results while dealing with unstable environments, distributional perspectives on the state-value function in RL have begun to gain popularity 

bellemare2017distributional . Note that the distributional perspective is distinct from Bayesian approaches to the RL problem as the former models the inherent variability of the returns from a state and not the agent’s confidence in its prediction of the average return.

Up to now, deep learning methods in RL used multiple function approximators (typically a network with shared hidden layers) to fit a state value or state-action value distribution. For instance, bootstrappedDQN used -heads on the state-action value function for every available action and used it to model a distribution. In bayesianpol , a Bayesian framework was applied to the actor-critic architecture by fitting a Gaussian Process (GP) instead of the critic, hence allowing for a closed-form derivation of update rules. More recently, bellemare2017distributional

introduced a distributional algorithm C51 which aimed to solve the RL problem by learning a categorical probability vector over returns

. Unlike GANRL which uses a generative network to learn the underlying transition model of the environment, we utilize a generative network to model the distribution approximation of the Bellman updates.
In this work, we build on top of the aforementioned distributional RL methods and introduce a novel way to learn the state-value distribution. Inspired by the analogy between the actor-critic architecture and generative adversarial networks (GANs) connection_gan_actor_critic , we leverage the later in order to implicitly represent the distribution of the Bellman target update through a generator/discriminator architecture. We show that, although sometimes volatile, our proposed algorithm is a viable alternative to now considered classical deep Q networks (DQN). We aim to provide a unifying view on distributional RL through the minimization of the Earth-Mover distance and without explicitly using the distributional Bellman projection operator on the support of the values.

2 Related Work

2.1 Background

Multiple tasks in machine learning require finding an optimal behaviour in a given setting, i.e. solving the reinforcement learning problem. We proceed to formulate the task as follows.

Let (,,,,,) be a 5-tuple representing a Markov decision process (MDP) where is the set of states, the set of allowed actions, is the (deterministic or stochastic) reward function, are the environment transition probabilities and is the set of initial states. At a given time step an agent acts according to a policy . The environment is characterized by its set of initial states in which the agent starts, as well as the transition model which encodes the mechanics of moving from one state to another. In order to compare states, we introduce the state value function which gives the expected sum of discounted rewards in that state. That is, .
The reinforcement learning problem is two-fold: (1) given a fixed policy we would like to obtain the correct state value function and (2) we wish to find the optimal policy which yields the highest for all states of the MDP or, equivalently, . The first task is known in the reinforcement learning literature as prediction and the second as control.
In order to find the value function for each state, we need to solve the Bellman equations bellman1954theory :


for and . If we define a state-action value function as , then we can rewrite Eq.1 as:


for .
While both Eq.1 and Eq.2 have the following direct solution obtained by matrix inversion: . However it should be noted that this is only well-defined for finite-state MDPs, and further as this computation has complexity coppersmith1987matrix , it is only computationally feasible for small MDPs. Therefore, sample-based algorithms such as Temporal Difference (TD) sutton1988learning for prediction and SARSA or Q-learning rummery1994line ; watkins1992q for control are preferred for more general classes of environments.

2.2 Distributional Reinforcement Learning

In the setting of distributional reinforcement learning bellemare2017distributional , we seek to learn the distribution of returns from a state, rather than the mean of the returns. We translate the Bellman operator on points to an operator on distributions. The vector of mean rewards therefore becomes a function of reward distributions . We can thus represent as

, a random variable whose law is the returns following a state. Both expected and distributional distributional quantities are linked through the following relation:

. We finally arrive to the distributional Bellman equations:


where denotes an analogue to the well-known Bellman operator now defined over distributions.
Eq.3 is the distributional counterpart of Eq.2, where equality holds for sequences of random variables.

In traditional reinforcement learning algorithms, we use experience in the MDP to improve an estimated state value function in order to minimize the expected distance between the value function’s output for a state and the actual returns from that state. In distributional reinforcement learning, we still aim to minimize the distance between our output and the true distribution, but now we have more freedom in how we choose our definition of "distance", as there are many metrics on probability distributions which have subtly different properties.

The Wasserstein metric between two real-valued random variables and

with cumulative distribution functions

and is given by


For we recover the widely known Earth-Mover distance. More generally, for any , holds givens1984class and is useful in contraction arguments. The maximal Wasserstein metric, is defined over state-action tuples for any two value distributions .

It has been shown that while is a contraction under the maximal Wasserstein metric ruschendorf1985wasserstein , the Bellman optimality operator is not necessarily a contraction. This result implies that the control setting requires a treatment different from prediction, which is done in bellemare2017distributional through the C51 algorithm in order to guarantee proper convergence.

2.3 Generative Models

Generative models such as hidden Markov models (HMMs), restricted Boltzmann machines (RBMs)

salakhutdinov2007restricted , variational auto-encoders (VAEs) kingma2013auto and generative adversarial networks (GANs) goodfellow2014generative learn the distribution of data for all classes. Moreover, they provide a mechanism which allows to sample new observations from the learned distribution.

In this work, we make use of generative adversarial networks which consist of two neural networks playing a zero-sum game against each other. The generator network

is a mapping from a high-dimensional noise space onto the input space on which a target distribution is defined. The generator’s task consists in fitting the underlying distribution of observed data as closely as possible. The discriminator network scores each input as the probability of coming from the real data distribution or from the generator . Both networks are gradually improved through alternating or simultaneous gradient descent updates.

The classical GAN algorithm minimizes the Jensen-Shannon divergence (JS) between the real and generated distributions. Recently, arjovsky2017wasserstein suggested to replace the JS metric by the Wasserstein-1 or Earth-Mover divergence. We make use of an improved version of this algorithm, Wasserstein GAN with Gradient Penalty gulrajani2017improved . It’s objective is given below:


where , , and . Setting recovers the original WGAN objective.

3 GAN Q-learning

3.1 Motivation

We borrow the two-player game analogy from the original GAN paper goodfellow2014generative : the generator network’s purpose is to produce realistic samples of the optimal state-action value distribution (estimate of ). On the other hand, the discriminator network aims to distinguish real samples of from the samples outputted by . The generator network improves its performance through the signal from the discriminator, which is reminiscent of the actor-critic architecture connection_gan_actor_critic .

3.2 Algorithm

At each timestep, receives stochastic noise and a state as input and returns a sample for every action from the current estimate of . We then select the action . The agent then applies the chosen action , receives a reward and transitions to state . The tuple is saved into a replay buffer as done in dqn . Each update step consists in sampling a tuple uniformly from the buffer and proceed to update the generator and discriminator according to Eq 5.

Values obtained from the Bellman backup operator are considered as coming from the real distribution. The discriminator’s goal is to differentiate between the Bellman target and the output produced by . We obtain the following updates for and , respectively:


where are weights of the generator and discriminator networks with respect to which the gradient updates are taken.

Note that to further stabilize the training process, one can use a target network updated every epochs as in dqn . Due to the nature of GANs, we advise training the model in a batch setting using experience replay and a second (target) network for a more stable training process.

  Input: MDP , discriminator and generator networks , learning rate , the number of updates of the discriminator, the number of updates of the generator, gradient penalty coefficient , batch size .
  Initialize replay buffer to capacity N, and with random weights, , .
  for  do
     for  do
        sample a batch
        Store transition (,,,) in
        {Updating Discriminator}
        for  do
           Sample minibatch from
           sample batch
           sample a batch
            RMSProp( )
        end for
        {Updating Generator}
        for  do
           sample a batch of
            RMSProp( )
        end for
     end for
  end for
Algorithm 1 GAN Q-learning

Here, the objective function is identical to Eq.5, where is taken to be .

Using a generative model to represent the state-action value distribution allows for an alternative to explicit exploration strategies. Indeed, at the beginning of the training process, taking an action is analogous to using a decaying exploration strategy (since has not clearly separated for every pair). Fig. 1 demonstrates how gradually separating from acts as implicit exploration by sampling suboptimal actions.

Figure 1: Evolution of state value distribution learned by GAN Q-learning on the Two State environment. Here, denotes the greedy policy.

4 Convergence

It is well-known that Q-learning can exhibit divergent behaviour in the case of nonlinear function approximation tsitsiklis1997analysis . Because nonlinear value function approximation can be viewed as a special case of the GAN framework where the latent random variable is sampled from a degenerate distribution, we can see that the class of problems for which GAN Q-learning can fail to converge contains those for which vanilla Q-learning does not converge.

Further, as explored in mescheder2018convergence , we observe that in many popular GAN architectures, convergence to the target distribution is not guaranteed, and oscillatory behaviour can be observed. This is a double blow to the reinforcement learning setting, as we must guarantee both that a stationary distribution exists and that the GAN architecture can converge to this stationary distribution.

We also note that although in an idealized setting for the Wasserstein GAN the generator should be able to represent the target distribution and the discriminator should be able to learn any 1-Lipschitz function in order to produce the true Wasserstein distance, this is unlikely to be the case in practice. It is thus possible for the optimal generator-discriminator equilibrium to correspond to a generated distribution that has a different expected value from the target distribution. Consider, for example, a generator which produces a Dirac distribution, and a discriminator which can compute quadratic functions of the form . Then the discriminator attempts to approximate the Wasserstein distance by computing

Suppose we are in a 2-armed bandit setting, where arm A always returns a reward of for some small and arm B gives rewards distributed as a . Then the optimal generator (constrained to the class of Dirac distributions) will predict a Dirac distribution with support for arm A, and a Dirac distribution with support for arm B. Consequently, an agent which has reached an equilibrium will incorrectly estimate arm B as being the optimal arm.

Empirical results reported in the next section demonstrate the ability of our algorithm to solve complex reinforcement learning tasks. However, providing convergence results for nonlinear and can be hindered by complex environment dynamics and the unstable nature of GANs. For instance, proving convergence of the generator-discriminator tuple to a saddle point requires an argument similar to mescheder2018convergence .

5 Experiments

In order to compare the performance of the distributional approaches to traditional algorithms such as Q-learning, we conducted a series of experiments on tabular and continuous state space environments.

5.1 Environments

We considered the following environments:

  1. 10-state chain with two goal states (2G Chain) ( is in the middle). Two deterministic actions (left, right) are allowed. A reward of is given when we stay in the goal state for one step, otherwise. The discount factor and the maximum episode length is 50;

  2. Deterministic gridworld (Gridworld) with start and goal states in opposing corners and walls along the perimeter. A reward of 0 is given in the goal state, otherwise. The agent must reach the goal tile in the least number of steps while avoiding being stuck against walls. The discount factor and the maximum episode length is 100;

  3. The simple two state MDP (2 States) presented in Fig.2. The discount factor and the maximum episode length is 25.

    Figure 2: Simple two-state MDP.
  4. OpenAI Gym brockman2016openai environments Cartpole-v0 and Acrobot-v1.

All experiments were conducted with a similar, one hidden layer architecture for GAN Q-learning and DQN. A total of 3 dense layers of 64 units for tabular and 128 units for OpenAI environments each, as well as

nonlinearities were used. Note that a Convolutional Neural Network (CNN)

krizhevsky2012imagenet can be used to learn the rewards similarly to dqn .

5.2 Results

In this section, we present empirical results obtained by our proposed algorithm. For shorthand notation we abbreviate the tabular version of distributional Q-learning to dQ-learning; GAN Q-learning will be contracted to GAN-DQN. The dQ-learning performs a mixture update between the predicted and target distribution for each state-action pair it visits, analogous to TD updates.
All tabular experiments were ran on 10 initial seeds and the reported scores are averaged across 300 episodes. Gym environments were tested on 5 initial seeds over 1000 episodes each with no restrictions on maximum number of steps.

Figure 3: Average cumulative rewards obtained by distributional and expected agents in tabular environments.

In general, our GAN Q-learning results go on par with ones obtained by tabular baseline algorithms such as Q-learning (see Fig. 3 and Table 1). The high variance in the first few episodes corresponds to the time needed for the generator network to separate real and generated distributions for each actions. Note that although the Gridworld environment yields sparse rewards similar to Acrobot, GAN Q-learning eventually finds the optimal path to the end goal. Fig. 4 shows that our method can efficiently use the greedy policy in order to learn the state value function; lower state values are associated with the start state while the higher state values are attributed to both goal states.

Just like DQN, the proposed method demonstrates the ability to learn an optimal policy in both OpenAI environments. For instance, increasing the number of updates for and in the CartPole-v0 environment stabilized the algorithm and ensured the proper convergence of to the Bellman target. Fig. 5 demonstrates that, although sometimes unstable, control with GAN Q-learning has the capacity to learn an optimal policy and outperform DQN. Due to sparse rewards in Acrobot-v1, we had to rely on a target network as in dqn to stabilize the training process. Using GAN Q-learning without a second network in such environments would lead to increased variance in the agent’s predictions and is hence discouraged.

Algorithms 2 State 2G Chain Gridworld
Q-learning 426.613 0.953 -8.059
dQ-learning 427.328 0.950 -8.190
GAN-DQN 398.918 0.978 -11.720
Table 1: Performance of Q-learning algorithms in tabular environments (in rewards/episode).

In addition to the common variance reduction practices mentioned above, we used a learning rate scheduler as a safeguard to reduce instability during the training process. We found that for timestep , initial learning rate and varying between environments (for CartPole-v0, ) yielded the best results. Even though our model has an implicit exploration strategy induced by the generator, we used an greedy exploration policy like most traditional algorithms. We noticed that introducing greedy the generator in separating the state value distributions for all actions. Note that for environments with less complex dynamics (e.g. tabular MDPs), our method does not require an explicit exploration strategy and has shown viable performance without it.

Unlike in the original WGAN-GP paper where a strong Lipschitz constraint is enforced via a gradient penalty coefficient (=10), we observed empirically that relaxing this property with =0.1 gave better results.

Figure 4: Expected state value function found by GAN Q-learning after 50 episodes in the Two Goal Chain environment for the greedy policy .
Figure 5: Average cumulative rewards obtained by distributional and expected agents in OpenAI Gym environments.

6 Discussion

We introduced a novel framework based on techniques borrowed from the deep learning literature in order to successfully learn the state-action value distribution via an actor-critic like architecture. Our experiments indicate that GAN Q-learning can be a viable alternative to classical algorithms such as DQN while having the appealing characteristics of a typical deep learning blackbox model. The parametrization of the returns distribution by a neural network within the scope of our approach is countered by its volatility in environments with particularly sparse rewards. We believe that a thorough understanding of the nonlinear dynamics of generative nets and convergence properties of MDPs is mandatory for a successful improvement of the algorithm. Recent work in the field hints that a saddle-point analysis of the objective function is a valid way to approach such problems mescheder2018convergence .

Future work should address with high priority the stability of the training iteration. Moreover, using a CNN on top of screen frames in order to encode the state can provide a meaningful approximation to the reward distribution. Our proposed algorithm opens possibilities to integrate the GAN architecture into more complex algorithms such as DDPG ddpg and TRPO trpo , which can be a potential topic of investigation.

7 Acknowledgments

We would like to thank Marc G. Bellemare from Google Brain for helpful advice throughout this paper.