rnn.wgan
Code for training and evaluation of the model from "Language Generation with Recurrent Generative Adversarial Networks without Pretraining"
view repo
Generative Adversarial Networks (GANs) have shown great promise recently in image generation. Training GANs for language generation has proven to be more difficult, because of the nondifferentiable nature of generating text with recurrent neural networks. Consequently, past work has either resorted to pretraining with maximumlikelihood or used convolutional networks for generation. In this work, we show that recurrent neural networks can be trained to generate text with GANs from scratch using curriculum learning, by slowly teaching the model to generate sequences of increasing and variable length. We empirically show that our approach vastly improves the quality of generated sequences compared to a convolutional baseline.
READ FULL TEXT VIEW PDFCode for training and evaluation of the model from "Language Generation with Recurrent Generative Adversarial Networks without Pretraining"
Generative adversarial networks (Goodfellow et al., 2014) have achieved stateoftheart results in image generation (Goodfellow et al., 2014; Radford et al., 2015; Arjovsky et al., 2017; Gulrajani et al., 2017)
. For text generation, training GANs with recurrent neural networks (RNNs) has been more challenging, mostly due to the nondifferentiable nature of generating discrete symbols. Consequently, past work on using GANs for text generation has been based on pretraining
(Yu et al., 2016; Li et al., 2017; Yang et al., 2017; Wu et al., 2017; Liang et al., 2017; Zhang et al., 2016; Shetty et al., 2017) or joint training (Lamb et al., 2016; Che et al., 2017) of the generator and discriminator with a supervised maximumlikelihood loss.Recently, two initial attempts to generate text using purely generative adversarial training were conducted by Gulrajani et al. (2017) and Hjelm et al. (2017)
. In these works, a convolutional neural network (CNN) was trained to produce sequences of 32 characters. This CNN architecture is fully differentiable, and the authors demonstrated that it generates text at a reasonable level. However, the generated text was still filled with spelling errors and had little coherence. RNNs are a more natural architecture for language generation, since they condition each generated character on the entire history, and are not constrained to generating a fixed number of characters.
In this paper, we extend the setup of Gulrajani et al. (2017) and present a method for generating text with GANs. Our main contribution is a model that employs an RNN for both the generator and discriminator, similar to current stateoftheart approaches for language generation (Sutskever et al., 2011; Mikolov, 2012; Jozefowicz et al., 2016). We succeed in training the model by using curriculum learning (Elman, 1993; Bengio et al., 2009; Ranzato et al., 2015)
: At each stage we increase the maximal length of generated sequences, and train over sequences of variable length that are shorter than that maximal length. In addition, we aid the model by feeding it with ground truth characters before generation. We show that these methods vastly improve the quality of generated sequences. Sequences contain substantially more ngrams from a development set compared to those generated by a CNN, and generation generalizes to sequences that are longer than the sequences the model was trained on.
While models trained with a maximumlikelihood objective (ML) have shown success in language generation (Sutskever et al., 2011; Mikolov, 2012; Jozefowicz et al., 2016), there are drawbacks to using ML, that suggest training with GANs. First, using ML suffers from “exposure bias”, that is, at training time the model is exposed to gold data only, but at test time it observes its own predictions, and thus wrong predictions quickly accumulate, resulting in bad text generation.
Secondly, the ML loss function is very stringent. When training with ML, the model aims to allocate all probability mass to the
th character of the training set given the previous characters, and considers any deviation from the gold sequence as incorrect, although there are many possible sequences given a certain prefix. GANs suffer less from this problem, because the objective is to fool the discriminator, and thus the objective evolves dynamically as training unfolds. While at the beginning the generator might only generate sequences of random letters with spaces, as the discriminator learns to better discriminate, the generator will evolve to generate words and after that it may advance to longer, more coherent sequences of text. This interplay between the discriminator and generator helps incremental learning of text generation.Gulrajani et al. (2017) and Hjelm et al. (2017) trained a purely generative adversarial model (without pretraining) for characterlevel sentence generation. We briefly review the setup of Gulrajani et al. (2017), who use the Improved Wasserstein GAN objective (Arjovsky et al., 2017; Gulrajani et al., 2017), which we employ as well. Hjelm et al. (2017) have a similar setup, but employ the BoundarySeeking GAN objective.
The generator in Gulrajani et al. (2017)
is a CNN that transforms a noise vector
into a matrix , where is the size of the character vocabulary, and 32 is the length of the generated text. In this matrix theth row is a probability distribution over characters that represents a prediction for the
th output in the character sequence. To decode a sequence, they choose the highest probability character in each row. The discriminator is another CNN that receives a matrix as input and needs to determine if this matrix is the output of the generator or sampled from the real data (where each row in the matrix now is a onehot vector). The loss of the Improved WGAN generator is:and the loss of the discriminator is:
Where is the data distribution and is the generator distribution implicitly defined by . The last term of the objective controls the complexity of the discriminator function and penalizes functions that have high gradient norm, that is, change too rapidly. is defined by sampling uniformly along a straight line between a point sampled from the data distribution and a point sampled from the generator distribution.
A disadvantage of the generators in Gulrajani et al. (2017) and Hjelm et al. (2017) is that they use CNNs for generation, and thus the th generated character is not directly conditioned on the entire history of generated characters. This might be a factor in the frequent spelling mistakes and lack of coherence in the output of these models. We now present a model for language generation with GANs that utilizes RNNs, which are stateoftheart in language generation.
Samples  %INTEST  
1  2  3  4  

64.4  25.9  5.1  0.4 
We employ a GRU (Cho et al., 2014) based RNN for our generator and discriminator. The generator is initialized by feeding it with a noise vector
as the hidden state, and an embedded startofsequence symbol as input. The generator then generates a sequence of distributions over characters, using a softmax layer over the hidden state at each time step.
Because we want to have a fullydifferentiable generator, the input to the RNN generator at each time step is not the most probable character from the previous time step. Instead we employ a continuous relaxation, and provide at time step the weighted average representation given by the output distribution of step . More formally, let be the probability of generating the character computed at time step , and let be the embedding of the character , then the input to the GRU at time step is . This is fully differentiable compared to
. We empirically observe that the RNN quickly learns to output very skewed distributions.
The discriminator is another GRU that receives a sequence of character distributions as input, either onehot vectors (for real data) or softer distributions (for generated data). Character embeddings are computed from the distributions and fed into the GRU. The discriminator then takes the final hidden state and feeds it into a fully connected layer which outputs a single number, representing the score that the discriminator assigns to the input. The models are trained with the aforementioned Improved WGAN objective (Section 3).
An advantage of a recurrent generator compared to the convolutional generator of Gulrajani et al. (2017) and Hjelm et al. (2017) is that can output sequences of varying lengths, as we empirically show in Section 5.
Our baseline model trains the generator and discriminator over sequences of length , similar to how CNNs were trained in Gulrajani et al. (2017). We found that training this baseline was difficult and resulted in nonsensical text. We now present three extensions that stabilize the training process.
CL  VL  TH  Samples  %INTEST  
1  2  3  4  
✗  ✗  ✗ 

28.8  3.7  0.0  0.0  
✗  ✓  ✗ 

80.6  8.6  0.0  0.0  
✓  ✗  ✗ 

27.0  7.9  2.0  0.0  
✓  ✓  ✗ 

68.1  24.5  4.4  0.5  
✗  ✓  ✓ 

79.4  44.6  11.5  0.7  
✓  ✓  ✓ 

87.7  54.1  19.2  3.8  


✓  ✓  ✓  Sequences of length 64. Examples in Table 3.  87.5  51.3  15.1  1.7 

In this extension, we start by training on short sequences and then slowly increase sequence length. In the first training stage, the generator generates sequences of length , and the discriminator receives real and generated sequences of length as input. Then, the generator generates sequences of length and the discriminator receives sequences of length . We increase sequence length in this manner until the maximum length of 32 characters.
Here, we define a maximum length , and generate during training sequences of every length in every batch. Without curriculum learning, this amounts to training and in every batch with sequences of length , . With curriculum learning, we generate at each step sequences of length , , and slowly increase throughout training.
Finally, we propose a procedure where we help the generator learn to generate long sequences by conditioning on shorter ground truth sequences. Recall that in our baseline, the generator generates an entire sequence of characters that are fed as input to the discriminator. Here, when generating sequences of length , we feed the generator a sequence of characters, sampled from the real data. Then, the generator generates a distribution over characters for the final character, which we concatenate to the real characters and feed as input to the discriminator. The discriminator observes a sequence of length composed of real characters and one character that is either real or generated. This could be viewed as a conditional GAN (Mirza and Osindero, 2014), where the first characters are the input and the final character is the output. Note that this extension may suffer from exposure bias, similar to the ML objective, and we plan to address this problem in future work.
To directly compare to Gulrajani et al. (2017), we follow their setup and train our models on the Billion Word dataset (Chelba et al., 2013). We evaluate by generating 640 sequences from each model and measuring %INTEST, that is, the proportion of word grams from generated sequences that also appear in a heldout test set. We evaluate these metrics for . Our goal is to measure the extent to which the generator is able to generate real words with local coherence.
In contrast to Arjovsky et al. (2017) and Gulrajani et al. (2017), where the generator is trained once for every training iterations of the discriminator, we found that training the generator for iterations every training iterations of the discriminator resulted in superior performance. In addition, instead of using noise vectors sampled from the distribution as in Gulrajani et al. (2017), we sample noise vectors from the
distribution, since we found this leads to a greater variance in the generated samples when using RNNs.
In all our experiments, we used single layer GRUs for both the discriminator and generator. The embedding dimension and hidden state dimension are both of size .
Following Gulrajani et al. (2017), we train all our models on sequences whose maximum length is characters. Table 1 shows results of the baseline model of Gulrajani et al. (2017), and Table 2 presents results of our models with various combinations of extensions (Curriculum Learning, Variable Length, and Teacher Helping). Our best model combines all of the extensions and outperforms the baseline by a wide margin on all metrics.
The samples show that models that used both the Variable Length and Teacher Helping extensions performed better than those that did not. This is also backed by the empirical evaluation, which shows that 3.8% of the word 4grams generated by the CL+VL+TH model also appear in the heldout test set. The weak performance of the curriculum learning model without the other extensions shows that curriculum learning by itself does not lead to better performance, and that training on variable lengths and with Teacher Helping is important. We note that curriculum learning did not perform well at generating sequences of length , but did perform well at generating sequences of shorter lengths earlier in the training process. For example, the model that used only curriculum learning had a %INTEST1 of when it was trained on sequences of length . This decreased to when the model reached sequences of length , and continued decreasing until training stopped. This also shows the importance of Variable Length and Teacher Helping.
Finally, to check the ability of our models to generalize to longer sequences, we generated sequences of length with our CL+VL+TH model, which was trained on sequences of up to characters (Table 3). We then evaluated the generated text, and this evaluation shows that there is a small degradation in performance (Table 2).
We show for the first time an RNN trained with a GAN objective that learns to generate natural language from scratch. Moreover, we demonstrate that our model generalizes to sequences longer than the ones seen during training. In future work, we plan to apply these models to tasks such as image captioning and translation, comparing them to models trained with maximum likelihood.
Proceedings of the 26th annual international conference on machine learning
. ACM, pages 41–48.