Log In Sign Up

GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution

by   Matt J. Kusner, et al.

Generative Adversarial Networks (GAN) have limitations when the goal is to generate sequences of discrete elements. The reason for this is that samples from a distribution on discrete objects such as the multinomial are not differentiable with respect to the distribution parameters. This problem can be avoided by using the Gumbel-softmax distribution, which is a continuous approximation to a multinomial distribution parameterized in terms of the softmax function. In this work, we evaluate the performance of GANs based on recurrent neural networks with Gumbel-softmax output distributions in the task of generating sequences of discrete elements.


page 1

page 2

page 3

page 4


Generating Multi-Categorical Samples with Generative Adversarial Networks

We propose a method to train generative adversarial networks on mutivari...

Invertible Gaussian Reparameterization: Revisiting the Gumbel-Softmax

The Gumbel-Softmax is a continuous distribution over the simplex that is...

Language Generation with Recurrent Generative Adversarial Networks without Pre-training

Generative Adversarial Networks (GANs) have shown great promise recently...

Conditional Hybrid GAN for Sequence Generation

Conditional sequence generation aims to instruct the generation procedur...

Unsupervised Cipher Cracking Using Discrete GANs

This work details CipherGAN, an architecture inspired by CycleGAN used f...

Haydi: Rapid Prototyping and Combinatorial Objects

Haydi ( is a framework for generating discre...

1 Introduction

Generative adversarial networks (GANs) are methods for generating synthetic data with similar statistical properties as the real one goodfellow2014generative

. In the GAN methodology a discriminative neural network D is trained to distinguish whether a given data instance is synthetic or real, while a generative network G is jointly trained to confuse D by generating high quality data. This approach has been very sucessful in computer vision tasks for generating samples of natural images

denton2015deep ; dosovitskiy2016generating ; radford2016 .

GANs work by propagating gradients back from the discriminator D through the generated samples to the generator G. This is perfectly feasible when the generated data is continuous such as in the examples with images mentioned above. However, a lot of data exists in the form of squences of discrete items. For example, text sentences Bowman2016 , molecules encoded in the SMILE language gomez2016automatic

, etc. In these cases, the discrete data is not differentiable and the backpropagated gradients are always zero.

Discrete data, encoded using a one-hot representation, can be sampled from a multinomial distribution with probabilities given by the output of a softmax function. The resulting sampling process is not differentiable. However, we can obtain a differentiable approximation by sampling from the Gumbel-softmax distribution


. This distribution has been previously used to train variatoinal autoencoders with discrete latent variables

jang2016categorical . Here, we propose to use it to train GANs on sequences of discrete tokens and we evaluate its performance in this setting.

An alternative approach to train GANs on discrete sequences is described in yu2016seqgan

. This method models the generation of the discrete sequence as a stochastic policy in reinforcement learning and bypasses the generator differentiation problem by directly performing gradient policy update.

2 Gumbel-softmax distribution

The softmax function can be used to parameterized a multinomial distribution on a one-hot-encoding

-dimensional vector

in terms of a continuous -dimensional vector . Let be a -dimensional vector of probabilities specifying the multinomial distribution on with , . Then


where returns here a -dimensional vector with the output of the softmax function:


It can be shown that sampling according to the previous multinomial distribution with probability vector given by (1) is the same as sampling according to


where the are independent and follow a Gumbel distribution with zero location and unit scale.

The sample generated in (3) has gradient zero with respect to because the operator is not differentiable. We propose to approximate this operator with a differentiable function based on the soft-max transformation jang2016categorical . In particular, we approximate with


where is an inverse temperature parameter. When , the samples generated by (4) have the same distribution as those generated by (3) and when , the samples are always the uniform probability vector. For positive and finite values of the samples generated by (4) are smooth and differentiable with respect to .

The probability distribution for (

4), which is parameterized by and , is called the Gumbel-softmax distribution jang2016categorical . A GAN on discrete data can then be trained by using (4), starting with some relatively large and then anealing it to zero during training.

3 A recurrent neural network for discrete sequences

In this section we describe how to construct a generative adversarial network (GAN) that is able to generate text from random noise samples. We also give a simple algorithm to train our model, inspired by recent work in adversarial modeling.

Figure 1: Models to generate simple one-variable arithmetic sequences. (Top): The classic LSTM model during the prediction phase. Each LSTM unit (shown as a blue box) makes a prediction based on the input it as seen in the past. This prediction is then used as input to the next unit, which makes its own prediction, and so on. (Bottom): Our generative model for discrete sequences. At the beginning we draw a pair of samples which are fed into the network in place of the initial cell state and hidden state . Our trained network takes these samples and uses them to generate an initial character, this generated character is fed to the next cell in the LSTM as input, and so on.

An example

Consider the problem of learning to generate simple one-variable arithmetic sequences that can be described by the following context-grammar:

where divides possible productions of the grammar. The above grammar generates sequences of characters such as and .

Our generative model is based on a Long Short Term Memory (LSTM) recurrent neural network

hochreiter1997long , shown in the top of Figure 1. The LSTM is trained to predict a hidden-state vector at every time-step (i.e., for every character). The softmax operator is then applied to as in equations (2) and (1), whic gives a distribution over all possible generated characters (i.e., ). After training, the network generates data by sampling from the softmax distribution at each time-step.

One way to train the LSTM model to predict future characters is by matching the softmax distribution to a one-hot encoding of the input data via maximum likelihood estimation (MLE). In this work, we are interested in constructing a generative model for discrete sequences, which we will accomplish by sampling through the LSTM, as shown in the bottom of Figure 

1. Our generative model takes as input a sample-pair which effectively replace the initial cell and hidden states. From this sample our generator constructs a sequence by successively feeding its predictions as input to the following LSTM unit. Our primary contribution is designing a method to train this generator to generate real-looking discrete sequences.

Generative adversarial modeling

Given a set of data points independently and identically drawn from a -dimensional distribution (in our case each is a one-hot encoding of a character), the goal of generative modeling is to learn a distribution that accuratley approximates . The framework of generative adversarial modeling has been shown to yield models that generate amazingly realistic data points. The adversarial training idea is straight-forward. First, we are going to learn a so-called generator

that transforms samples from a simple, known distribution (e.g., a uniform or Gaussian distribution) into samples that approximate those drawn from

. Specifically, we define , where (let be the

-dimensional uniform distribution on the interval

). Second, to learn

we will introduce a classifier we call the

discriminator . The discriminator takes as input any real -dimensional vector (this could be a generated input or a real one ) and predicts the probability that the input is actually drawn from the real distribution . It will be trained to take samples and real inputs and accurately distinguish them. At the same time, the generator is trained so that it can fool the discriminator into thinking that a fake point it generated is real with high probability. Initially, the discriminator will be able to easily tell the fake points from the real ones and the generator is poor. However, as training progresses the generator uses this signal from the discriminator to determine how to generate more realistic samples. Eventually, the generator will generate samples so real that the discriminator will have a random chance of guessing if a generated point is real.

Using the Gumbel-softmax distribution

In our case and are both LSTMs with parameters and , respectfully. Our aim is to learn and by sampling inputs and generated points

, and minimizing differentiable loss functions for

and to update and . Unfortunately, sampling generated points from the softmax distribution given by the LSTM, eq. (1), is not differentiable with respect to the hidden states (and thus ). However, the Gumbel-softmax distribution, eq. (4) is. Equipped with this trick we can take any differentiable loss function and optimize and using gradient-based techniques. We describe our adversarial training procedure in Algorithm 1, inspired by recent work on GANs sonderby2016amortised . This algorithm can be shown in expectation to minimize the KL-divergence between and .

1:  data: ,
2:  Generative LSTM network
3:  Discriminative LSTM network
4:  while  loop until convergence  do
5:     Sample mini-batch of inputs
6:     Sample noise
7:     Update discriminator
8:     Update generator
9:  end while
Algorithm 1 Generative Adversarial Network sonderby2016amortised

Figure 2 shows a schematic of the adversarial training procedure for discrete sequences.

Figure 2: The adversarial training procedure. Our generative model first generates a full-length sequence. This sequence is fed to the discriminator (also a LSTM), which predicts the probability of it being a real sequence. Additionally (not shown), the discriminator is fed real discrete sequence data, which again it predicts the probability of it being real. The weights the networks are modified to make the discriminator better at recognizing real from fake data, and to make the generator better at fooling the discriminator.

4 Experiments

We now show the power of our adversarial modeling framework for generating discrete sequences. To illustrate this we consider modeling the context-free grammar introduced in Section 3. We generate samples with a maximum length of

characters from the context-free grammar (CFG) for our training set. We pad all sequences with less than

characters with spaces.

Optimization details

We train both the discriminator and generator using ADAM kingma2014adam with a fixed learning rate of and a mini-batch size of . Inspired by the work of sonderby2016amortised who use input noise to stabilize GAN training, for every input we form a vector such that its softmax (instead of being one-hot) places a probability of approximately on the correct character and a probability of on the remaining characters. We then apply the Gumbel-softmax trick to generate a vector as in equation (4). We use this vector instead of throughout training. We train the generator and discriminator for mini-batch iterations. During the training we linearly anneal the temperature of the Gumbel-softmax distribution, from (i.e., a very flat distribution) to (a more peaked distribution) for iterations to and then kept at until training ends.

Figure 3: The generative and discriminative losses throughout training. Ideally the loss of the discriminator should increase while the generator should decrease as the generator becomes better at mimicking the real data. (a) The default network with Gumbel-softmax temperature annealing. (b) The same setting as (a) but increasing the size of the generated samples to . (c) Only varying the input vector temperature. (d) Only introducing random noise into the hidden state and not the cell state.

Learning a CFG

Figure 3 (a) shows the generator and discriminator losses throughout training for this setting. We experimented with increasing the size of the generated samples to , as this has been reported to improve GAN modeling huszar2015not , shown in Figure 3 (b). We also experimented with just varying the temperature for the input vectors and fixing the generator temperature to (in Figure 3 (c)). Finally, we also tried just introducing random noise into the hidden state and allowing the network to learn an initial cell state (Figure 3 (d)).

Figure 4:

The generated text for MLE and GAN models. The plots

(a)-(d) correspond to the models of Figure 3.

Figure 4 shows the text generated by MLE and GAN models. Each row is a sample from either model, each consisting of characters (we have included the blank space character as some training inputs are padded with spaces if less than characters). While the MLE LSTM is not strictly a generative model in the sense of drawing a discrete sequence from a distribution, we include it for reference. We can see that our GAN models are learning to generate alternating sequences of ’s, similar to the MLE result. Specifically, the 4th, 10th, and 17th rows of plot (a), show samples that are very close to the training data, and many such examples exist for the remaining plots as well.

We believe that these results, as a proof of concept, show strong promise for training GANs to generate discrete sequence data. Further, we believe that incorporating recent advances in GANs such as training GANs using variational divergence minimization nowozin2016f or via density ratio estimation uehara2016generative could yield further improvements. We aim to experiment with these in future work.