Theory and Experiments on Vector Quantized Autoencoders

05/28/2018 ∙ by Aurko Roy, et al. ∙ Google 0

Deep neural networks with discrete latent variables offer the promise of better symbolic reasoning, and learning abstractions that are more useful to new tasks. There has been a surge in interest in discrete latent variable models, however, despite several recent improvements, the training of discrete latent variable models has remained challenging and their performance has mostly failed to match their continuous counterparts. Recent work on vector quantized autoencoders (VQ-VAE) has made substantial progress in this direction, with its perplexity almost matching that of a VAE on datasets such as CIFAR-10. In this work, we investigate an alternate training technique for VQ-VAE, inspired by its connection to the Expectation Maximization (EM) algorithm. Training the discrete bottleneck with EM helps us achieve better image generation results on CIFAR-10, and together with knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches a strong greedy autoregressive baseline Transformer, while being 3.3 times faster at inference.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Unsupervised learning of meaningful representations is a fundamental problem in machine learning since obtaining labeled data can often be very expensive. Continuous representations have largely been the workhorse of unsupervised deep learning models of images goodfellow2014generative ; van2016conditional ; kingma2016improved ; salimans2017pixelcnn++ ; imagetrans , audio van2016wavenet ; reed2017parallel , and video kalchbrenner2016video . However, it is often the case that datasets are more naturally modeled as a sequence of discrete symbols rather than continuous ones. For example, language and speech are inherently discrete in nature and images are often concisely described by language, see e.g., vinyals2015show . Improved discrete latent variable models could also prove useful for learning novel data compression algorithms theis2017lossy , while having far more interpretable representations of the data.

We build on Vector Quantized Variational Autoencoder (VQ-VAE) vqvae , a recently proposed training technique for learning discrete latent variables. The method uses a learned code-book combined with nearest neighbor search to train the discrete latent variable model. The nearest neighbor search is performed between the encoder output and the embedding of the latent code using the

distance metric. The generative process begins by sampling a sequence of discrete latent codes from an autoregressive model fitted on the encoder latents, acting as a learned prior. The discrete latent sequence is then consumed by the decoder to generate data. The resulting discrete autoencoder obtains impressive results on uncoditional image, speech, and video generation. In particular, on image generation the performance is almost on par with continuous VAEs on datasets such as CIFAR-10

vqvae . An extension of this method to conditional supervised generation, out-performs continuous autoencoders on WMT English-German translation task kaiser2018fast .

kaiser2018fast

introduced the Latent Transformer, which achieved impressive results using discrete autoencoders for fast neural machine translation. However, additional training heuristics, namely, exponential moving averages (EMA) of cluster assignment counts, and product quantization

norouzi2013cartesian were essential to achieve competitive results with VQ-VAE. In this work, we show that tuning for the code-book size can significantly outperform the results presented in kaiser2018fast . We also exploit VQ-VAE’s connection with the expectation maximization (EM) algorithm dempster1977maximum , yielding additional improvements. With both improvements, we achieve a BLEU score of on English to German translation, outperforming kaiser2018fast by BLEU. Knowledge distillation hinton2015distilling ; seq-d provides significant gains with our best models and EM, achieving BLEU, which almost matches the autoregressive transformer model with no beam search at BLEU, while being faster.

Our contributions can be summarized as follows:

  1. We show that VQ-VAE from vqvae can outperform previous state-of-the-art without product quantization.

  2. Inspired by the EM algorithm, we introduce a new training algorithm for training discrete variational autoencoders, that outperforms the previous best result with discrete latent autoencoders for neural machine translation.

  3. Using EM training, we achieve better image generation results on CIFAR-10, and with the additional use of knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches a strong greedy autoregressive baseline Transformer, while being times faster at inference.

2 VQ-VAE and the Hard EM Algorithm

Figure 1: VQ-VAE model as described in vqvae . We use the notation to denote the input image, with the output of the encoder being used to perform nearest neighbor search to select the (sequence of) discrete latent variable. The selected discrete latent is used to train the latent predictor model, while the embedding of the selected discrete latent is passed as input to the decoder.

The connection between -means, and hard EM, or the Viterbi EM algorithm is well known bottou1995convergence

, where the former can be seen a special case of hard-EM style algorithm with a mixture-of-Gaussians model with identity covariance and uniform prior over cluster probabilities. In the following sections we briefly explain the VQ-VAE discretization algorithm for completeness and it’s connection to classical EM.

2.1 VQ-VAE discretization algorithm

VQ-VAE models the joint distribution

where are the model parameters, is the data point and is the sequence of discrete latent variables or codes. Each position in the encoded sequence has its own set of latent codes. Given a data point, the discrete latent code in each position is selected independently using the encoder output. For simplicity, we describe the procedure for selecting the discrete latent code () in one position given the data point (). The encoder output is passed through a discretization bottleneck using a nearest-neighbor lookup on embedding vectors . Here is the number of latent codes (in a particular position of the discrete latent sequence) in the model. More specifically, the discrete latent variable assignment is given by,

(1)

The selected latent variable’s embedding is passed as input to the decoder,

The model is trained to minimize:

(2)

where is the reconstruction loss of the decoder given (e.g., the cross entropy loss), and, is the stop gradient operator defined as follows:

It was observed in kaiser2018fast that an exponentially moving average (EMA) update of the latent embeddings and code-book assignments results in more stable training than using gradient-based methods.

Specifically, they maintain EMA of the following two quantities: 1) the embeddings for every and, 2) the count measuring the number of encoder hidden states that have as it’s nearest neighbor. The counts are updated in a mini-batch of targets as:

(3)

with the embedding being subsequently updated as:

(4)

where is the indicator function and is a decay parameter which we set to in our experiments. This amounts to doing stochastic gradient in the space of both code-book embeddings and cluster assignments. These techniques have also been successfully used in minibatch -means sculley2010web and online EM liang2009online ; sato2000line .

The generative process begins by sampling a sequence of discrete latent codes from an autoregressive model, which we refer to as the Latent Predictor model. The decoder then consumes this sequence of discrete latent variables to generate the data. The autoregressive model which acts as a learned prior is fitted on the discrete latent variables produced by the encoder. The architecture of the encoder, the decoder, and the latent predictor model are described in further detail in the experiments section.

2.2 Hard EM and the -means algorithm

In this section we briefly recall the hard Expectation maximization (EM) algorithm dempster1977maximum . Given a set of data points , the hard EM algorithm approximately solves the following optimization problem:

(5)

Hard EM performs coordinate descent over the following two coordinates: the model parameters , and the hidden variables . In other words, hard EM consists of repeating the following two steps until convergence:

  1. E step: ,

  2. M step:

A special case of the hard EM algorithm is -means clustering macqueen1967some ; bottou1995convergence where the likelihood is modelled by a Gaussian with identity covariance matrix. Here, the means of the

Gaussians are the parameters to be estimated,

With a uniform prior over the hidden variables (), the marginal is given by . In this case, equation (5) is equivalent to:

(6)

Note that optimizing equation (6) is NP-hard, however one can find a local optima by applying coordinate descent until convergence:

  1. E step: Cluster assignment is given by,

    (7)
  2. M step: The means of the clusters are updated as,

    (8)

We can now easily see the connections between the training updates of VQ-VAE and -means clustering. The encoder output corresponds to the data point while the discrete latent variables corresponds to clusters. Given this, Equation 1 is equivalent to the E-step (Equation 7) and the EMA updates in Equation 3 and Equation 4 converge to the M-step (Equation  8) in the limit. The M-step in

-means overwrites the old values while the EMA updates interpolate between the old values and the M step update.

3 VQ-VAE training with EM

In this section, we investigate a new training strategy for VQ-VAE using the soft EM algorithm.

3.1 Soft EM

First, we briefly describe the soft EM algorithm. While the hard EM procedure selects one cluster or latent variable assignment for a data point, here the data point is assigned to a mixture of clusters. Now, the optimization objective is given by,

Coordinate descent algorithm is again used to approximately solve the above optimization algorithm. The E and M step are given by:

  1. E step:

    (9)
  2. M step:

    (10)

3.2 Vector Quantized Autoencoders trained with EM

Now, we describe vector quantized autoencoders training using the soft EM algorithm. As discussed in the previous section, the encoder output

corresponds to the data point while the discrete latent variables corresponds to clusters. The E step instead of hard assignment now produces a probability distribution over the set of discrete latent variables (Equation 

9). Following VQ-VAE, we continue to assume a uniform prior over clusters, since we observe that training the cluster priors seemed to cause the cluster assignments to collapse to only a few clusters. The probability distribution is modeled as a Gaussian with identity covariance matrix,

Since computing the expectation in the M step (Equation 10) is computationally infeasible in our case, we instead perform Monte-Carlo Expectation Maximization wei1990monte by drawing samples , where refers to the

-way multinomial distribution with logits

. Thus, the E step can be finally written as:

The model parameters are then updated to maximize this Monte-Carlo estimate in the M step given by

Instead of exactly following the above M step update, we use the EMA version of this update similar to the one described in Section 2.1.

When sending the embedding of the discrete latent to the decoder, instead of sending the posterior mode, , similar to hard EM and -means, we send the average of the embeddings of the sampled latents:

(11)

Since latent code embeddings are sent to the decoder in the forward pass, all of them are updated in the backward pass for a single training example. In hard EM training, only one of them is updated during training. Sending averaged embeddings also results in more stable training using the soft EM algorithm compared to hard EM as shown in Section 5.

To train the latent predictor model (Section 2.1) in this case, we use an approach similar to label smoothing pereyra2017regularizing : the latent predictor model is trained to minimize the cross entropy loss with the labels being the average of the one-hot labels of .

4 Other Related Work

Variational autoencoders were first introduced by kingma2016improved ; rezende2014stochastic for training continuous representations; unfortunately, training them for discrete latent variable models has proved challenging. One promising approach has been to use various gradient estimators for discrete latent variable models, starting with the REINFORCE estimator of williams1992simple

, an unbiased, high-variance gradient estimator. An alternate approach towards gradient estimators is to use continuous relaxations of categorical distributions, for e.g., the Gumbel-Softmax reparametrization trick

gs1 ; gs2 . These methods provide biased but low variance gradients for training.

Machine translation using deep neural networks have been shown to achieve impressive results sutskever14 ; bahdanau2014neural ; cho2014learning ; transformer . The state-of-the-art models in Neural Machine Translation are all auto-regressive, which means that during decoding, the model consumes all previously generated tokens to predict the next one. Very recently, there have been multiple efforts to speed-up machine translation decoding. nonautoregnmt attempts to address this issue by using the Transformer model transformer together with the REINFORCE algorithm williams1992simple , to model the fertilities of words. The main drawback of the approach of nonautoregnmt is the need for extensive fine-tuning to make policy gradients work, as well as the non-generic nature of the solution. lee2018deterministic propose a non-autoregressive model using iterative refinement. Here, instead of decoding the target sentence in one-shot, the output is successively refined to produce the final output. While the output is produced in parallel at each step, the refinement steps happen sequentially.

5 Experiments

We evaluate our proposed methods on unconditional image generation on the CIFAR-10 dataset and supervised conditional language generation on the WMT English-to-German translation task. Our models and generative process follow the architecture proposed in vqvae for unconditional image generation, and kaiser2018fast for neural machine translation. For all our experiments, we use the Adam kingma2014adam optimizer and decay the learning rate exponentially after initial warm-up steps. Unless otherwise stated, the dimension of the hidden states of the encoder and the decoder is , see Table 5 for a comparison of models with lower dimension. The code to reproduce our experiments will be released with the next version of the paper.

5.1 Machine Translation

Figure 2: VQ-VAE model adapted to conditional supervised translation as described in kaiser2018fast . We use and to denote the source and target sentence respectively. The encoder, the decoder and the latent predictor now additionally condition on the source sentence .

In Neural Machine Translation with latent variables, we model , where and are the target and source sentence respectively. Our model architecture, depicted in Figure 2, is similar to the one in kaiser2018fast

. The encoder function is a series of strided convolutional layers with residual convolutional layers in between and takes target sentence

as input. The source sentence is converted to a sequence of hidden states through multiple causal self-attention layers. In kaiser2018fast , the encoder of the autoencoder attends additionally to this sequence of continuous representation of the source sentence. We use VQ-VAE as the discretization algorithm. The decoders, applied after the bottleneck layer uses transposed convolution layers whose continuous output is fed to a transformer decoder with causal attention, which generates the output.

The results are summarized in Table 1. Our implementation of VQ-VAE achieves a significantly better BLEU score and faster decoding speed compared to kaiser2018fast . We found that tuning the code-book size (number of clusters) for using discrete latents achieves the best accuracy which is 16 times smaller as compared to the code-book size in kaiser2018fast . Additionally, we see a large improvement in the performance of the model by using sequence-level distillation seq-d , as has been observed previously in non-autoregressive models nonautoregnmt ; lee2018deterministic . Our teacher model is a base Transformer transformer that achieves a BLEU score of and on the WMT’14 test set using beam search decoding and greedy decoding respectively. For distillation purposes, we use the beam search decoded Transformer. Our VQ-VAE model trained with soft EM and distillation, achieves a BLEU score of , without noisy parallel decoding nonautoregnmt . This perforamce is bleu points lower than an autoregressive model decoded with a beam size of , while being faster. Importantly, we nearly match the same autoregressive model with beam size , with a speedup.

The length of the sequence of discrete latent variables is shorter than that of target sentence . Specifically, at each compression step of the encoder we reduce its length by half. We denote by , the compression factor for the latents, i.e. the number of steps for which we do this compression. In almost all our experiments, we use reducing the length by 8. We can decrease the decoding time further by increasing the number of compression steps. As shown in Table 1, by setting to 4, the decoding time drops to 58 milliseconds achieving 25.4 BLEU while a NAT model with similar decoding speed achieves only 18.7 BLEU. Note that, all NAT models also train with sequence level knowledge distillation from an autoregressive teacher.

5.1.1 Analysis

Attention to Source Sentence Encoder:

While the encoder of the discrete autoencoder in kaiser2018fast attends to the output of the encoder of the source sentence, we find that to be unnecessary, with both models achieving the same BLEU score with latents. Also, removing this attention step results in more stable training particularly for large code-book sizes, see e.g., Figure 3.

VQ-VAE vs Other Discretization Techniques:

We compare the Gumbel-Softmax of gs1 ; gs2 and the improved semantic hashing discretization technique proposed in kaiser2018fast to VQ-VAE. When trained with sequence level knowledge distillation, the model using Gumbel-Softmax reached BLEU, the model using improved semantic hashing reached BLEU, while the model using VQ-VAE reached BLEU on WMT’14 English-German.

Size of Discrete Latent Variable code-book:

Table 3 in Appendix shows the BLEU score for different code-book sizes for models trained using hard EM without distillation. While kaiser2018fast use as their code-book size, we find that gives the best performance.

Robustness of EM to Hyperparameters:

While the soft EM training gives a small performance improvement, we find that it also leads to more robust training (Figure 3).

Figure 3: Comparison of hard EM (green curve) vs soft EM with different number of samples (yellow and blue curves) on the WMT’14 English-German translation dataset with a code-book size of , with the encoder of the discrete autoencoder attending to the output of the encoder of the source sentence as in kaiser2018fast . The -axis denotes the teacher-forced BLEU score on the test set. Notice that the hard EM/-means run collapsed, while the soft EM runs exhibit more stability.
Model Size:

The effect of model size on BLEU score for models trained with soft EM and distillation is shown in Table 5 in Appendix.

Number of samples in Monte-Carlo EM update

While training with soft EM, we perform a Monte-Carlo update with a small number of samples (Section 3.2). Table 4 in Appendix shows the impact of number of samples on the final BLEU score.

Model   BLEU Latency Speedup
Autoregressive Model (beam size=4)   - - 28.1 ms
Autoregressive Baseline (no beam-search)   - - 27.0 265 ms
NAT + distillation   - - 17.7 39 ms *
NAT + distillation + NPD=10   - - 18.7 79 ms *
NAT + distillation + NPD=100   - - 19.2 257 ms *
LT + Semhash   - - 19.8 105 ms
Our Results
VQ-VAE   3 - 21.4 81 ms
VQ-VAE with EM   3 5 22.4 81 ms
VQ-VAE + distillation   3 - 26.4 81 ms
VQ-VAE with EM + distillation   3 10 26.7 81 ms
VQ-VAE with EM + distillation   4 10 25.4 58 ms
Table 1: BLEU score and decoding times for different models on the WMT’14 English-German translation dataset. The baseline is the autoregressive Transformer of transformer with no beam search, NAT denotes the Non-Autoregressive Transformer of nonautoregnmt , and LT + Semhash denotes the Latent Transformer from vqvae using the improved semantic hashing discretization technique ofisemhash . NPD is noisy parallel decoding as described in nonautoregnmt . We use the notation to denote the compression factor for the latents, and the notation to denote the number of samples used to perform the Monte-Carlo approximation of the EM algorithm. Distillation refers to sequence level knowledge distillation from seq-d . We used a code-book of size for EM and decoding is performed on a single CPU machine with an NVIDIA GeForce GTX 1080 with a batch size of
  • Speedup reported for these items are compared to the decode time of ms for an autoregressive Transformer from nonautoregnmt .

5.2 Image Generation

Figure 4: Samples of original and reconstructed images from CIFAR-10 using VQ-VAE trained using EM with a code-book of size .
Model   Log perplexity
ImageTransformer   -
VAE   -
VQ-VAE vqvae   -
VQ-VAE (Ours)   -
EM  
Table 2: Log perplexity on CIFAR-10 measured in bits/dim. We train our VQ-VAE models on a field of latents with a code-book of size , while VQ-VAE refers to the results from vqvae which was trained on a field of latents on a code-book of size . Note that, VQ-VAE vqvae takes a unigram prior for each latent in the sequence independently instead of log-perplexity from the Latent Predictor model.

We train the unconditional VQ-VAE model on the CIFAR-10 data set, modeling the joint probability , where is the image and are the discrete latent codes. We use a field of latents with a code-book of size each containing dimensions. We maintain the same encoder and decoder as used in Machine Translation. Our Latent Predictor, uses an Image Transformer imagetrans auto-regressive decoder with layers of local self-attention. For the encoder, we use convolutional layers, with kernel size and strides , followed by residual layers, and a single dense layer. For the decoder, we use a single dense layers, residual layers, and deconvolutional layers.

We calculate the lower bound on negative log-likelihood in terms of the Latent Predictor loss and the negative log-perplexity of the autoencoder. Let be the total number of positions in the image, and the number of latent codes. Then the lower-bound on the negative log-likelihood is computed in bits/dim as Note that for CIFAR-10, while . We report the results in Table 2 and show reconstructions from the autoencoder in Figure 4. As seen from the results, our VQ-VAE model with EM gets bits/dim better negative log-likelihood as compared to the baseline VQ-VAE.

6 Conclusion

We investigate an alternate training technique for VQ-VAE inspired by its connection to the EM algorithm. Training the discrete bottleneck with EM helps us achieve better image generation results on CIFAR-10, and together with knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches the greedy autoregressive baseline, while being 3.3 times faster at inference.

References

Appendix A Ablation Tables

Model   Code-book size BLEU
VQ-VAE   20.8
VQ-VAE   21.6
VQ-VAE   21.0
VQ-VAE   21.8
Table 3: Results showing the impact of code-book size on BLEU score.
Model   BLEU Latency Speedup
VQ-VAE + distillation   3 - 26.4 81 ms 4.08
VQ-VAE with EM + distillation   3 5 26.4 81 ms 4.08
VQ-VAE with EM + distillation   3 10 26.7 81 ms 4.08
VQ-VAE with EM + distillation   3 25 26.6 81 ms 4.08
VQ-VAE with EM + distillation   3 50 26.5 81 ms 4.08
VQ-VAE + distillation   4 - 22.4 58 ms 5.71
VQ-VAE with EM + distillation   4 5 22.3 58 ms 5.71
VQ-VAE with EM + distillation   4 10 25.4 58 ms 5.71
VQ-VAE with EM + distillation   4 25 25.1 58 ms 5.71
VQ-VAE with EM + distillation   4 50 23.6 58 ms 5.71
Table 4: Results showing the impact of number of samples used to perform the Monte-Carlo EM update on the BLEU score.
Model   Hidden Vector dimension BLEU Latency Speedup
VQ-VAE + distillation   256 - 24.5 76 ms
VQ-VAE with EM + distillation   256 10 21.9 76 ms
VQ-VAE with EM + distillation   256 25 25.8 76 ms
VQ-VAE + distillation   384 - 25.6 80 ms
VQ-VAE with EM + distillation   384 10 22.2 80 ms
VQ-VAE with EM + distillation   384 25 26.2 80 ms
Table 5: Results showing the impact of the dimension of the word embeddings and the hidden layers of the model.