Following the influential work by (Kingma & Welling, 2013; Rezende et al., 2014), deep generative models with latent variables have been widely used to model data such as natural images (Rezende & Mohamed, 2015; Kingma et al., 2016; Chen et al., 2016; Gulrajani et al., 2016), speech and music time-series (Chung et al., 2015; Fraccaro et al., 2016; Krishnan et al., 2015), and video (Babaeizadeh et al., 2017; Ha & Schmidhuber, 2018; Denton & Fergus, 2018). The power of these models lies in combining learned nonlinear function approximators with a principled probabilistic approach, resulting in expressive models that can capture complex distributions. Unfortunately, the nonlinearities that empower these model also make marginalizing the latent variables intractable, rendering direct maximum likelihood training inapplicable. Instead of directly maximizing the marginal likelihood, a common approach is to maximize a tractable lower bound on the likelihood such as the variational evidence lower bound (ELBO) (Jordan et al., 1999; Blei et al., 2017). The tightness of the bound is determined by the expressiveness of the variational family. For tractability, a factorized variational family is commonly used, which can cause the learned model to be overly simplistic.
Burda et al. (2015) introduced a multi-sample bound, IWAE, that is at least as tight as the ELBO and becomes increasingly tight as the number of samples increases. Counterintuitively, although the bound is tighter, Rainforth et al. (2018) theoretically and empirically showed that the standard inference network gradient estimator for the IWAE bound performs poorly as the number of samples increases due to a diminishing signal-to-noise ratio (SNR). This motivates the search for novel gradient estimators.
Roeder et al. (2017)
proposed a lower-variance estimator of the gradient of the IWAE bound. They speculated that their estimator was unbiased, however, were unable to prove the claim. We show that it is in fact biased, but that it is possible to construct an unbiased estimator with a second application of the reparameterization trick which we call the IWAEdoubly reparameterized gradient (DReG) estimator. Our estimator is an unbiased, computationally efficient drop-in replacement, and does not suffer as the number of samples increases, resolving the counterintuitive behavior from previous work (Rainforth et al., 2018). Furthermore, our insight is applicable to alternative multi-sample training techniques for latent variable models: reweighted wake-sleep (RWS) (Bornschein & Bengio, 2014) and jackknife variational inference (JVI) (Nowozin, 2018).
In this work, we derive DReG estimators for IWAE, RWS, and JVI and demonstrate improved scaling with the number of samples on a simple example. Then, we evaluate DReG estimators on MNIST generative modeling, Omniglot generative modeling, and MNIST structured prediction tasks. In all cases, we demonstrate substantial unbiased variance reduction, which translates to improved performance over the original estimators.
Our goal is to learn a latent variable generative model where are observed data and are continuous latent variables. The marginal likelihood of the observed data, , is generally intractable. Instead, we maximize a variational lower bound on such as the ELBO
where is a variational distribution. Following the influential work by (Kingma & Welling, 2013; Rezende et al., 2014), we consider the amortized inference setting where , referred to as the inference network, is a learnable function parameterized by that maps from to a distribution over . The tightness of the bound is coupled to the expressiveness of the variational family (i.e., ). As a result, limited expressivity of , can negatively affect the learned model.
Burda et al. (2015)
introduced the importance weighted autoencoder (IWAE) bound which alleviates this coupling
with . The IWAE bound reduces to the ELBO when , is non-decreasing as increases, and converges to as under mild conditions (Burda et al., 2015). When is reparameterizable111Meaning that we can express as , where is a deterministic, differentiable function and does not depend on or . This allows gradients to be estimated using the reparameterization trick (Kingma & Welling, 2013; Rezende et al., 2014)., the standard gradient estimator of the IWAE bound is
where . A single sample estimator of this expectation is typically used as the gradient estimator.
As increases, the bound becomes increasingly tight, however, Rainforth et al. (2018) show that the signal-to-noise ratio (SNR) of the inference network gradient estimator goes to . This does not happen for the model parameters (). Following up on this work, Le et al. (2018) demonstrate that this deteriorates the performance of learned models on practical problems.
Because the IWAE bound converges to (as ) regardless of , ’s affect on the bound must diminish as increases. It may be tempting to conclude that the SNR of the inference network gradient estimator must also decrease as . However, low SNR is a limitation of the gradient estimator, not necessarily of the bound. Although the magnitude of the gradient converges to , if the variance of the gradient estimator decreases more quickly, then the SNR of the gradient estimator need not degrade. This motivates the search for lower variance inference network gradient estimators.
To derive improved gradient estimators for , it is informative to expand the total derivative222 depends on in two ways: and . The total derivative accounts for both sources of dependence and the partial derivative considers as a constant. of the IWAE bound with respect to
Previously, Roeder et al. (2017) found that the first term within the parentheses of Eq. 3 can contribute significant variance to the gradient estimator. When , this term analytically vanishes in expectation, so when they suggested dropping it. Below, we abbreviate this estimator as STL. As we show in Section 6.1, the STL estimator introduces bias when .
3 Doubly Reparameterized Gradient Estimators (DReGs)
Our insight is that we can estimate the first term within the parentheses of Eq. 3 efficiently with a second application of the reparameterization trick. To see this, first note that
so it suffices to focus on one of the terms. Because the derivative is a partial derivative , it treats
as a constant, so we can freely change the random variable that the expectation is over to. Now,
where is the set of without . The inner expectation resembles a REINFORCE gradient term (Williams, 1992), where we interpret as the “reward”. Now, we can use the following well-known equivalence between the REINFORCE gradient and the reparameterization trick gradient (See Appendix 8.1 for a derivation)
This holds even when depends on . Typically, the reparameterization gradient estimator has lower variance than the REINFORCE gradient estimator, which suggests using it will improve performance. Applying the identity from Eq. 5 to the right hand side of Eq. 4 gives
This last expression can be efficiently estimated with a single Monte Carlo sample. When is not reparameterizable (e.g., the models in (Mnih & Rezende, 2016)), we can use a control variate (e.g., ). In both cases, when , this term vanishes exactly and we recover the estimator proposed in (Roeder et al., 2017) for the ELBO.
We call the algorithm that uses the single sample Monte Carlo estimator of this expression for the inference network gradient the IWAE doubly reparameterized gradient estimator (IWAE-DReG). This estimator has the property that when is optimal (i.e.,
), the estimator vanishes exactly and has zero variance, whereas this does not hold for the standard IWAE gradient estimator. We provide an asymptotic analysis of the IWAE-DReG estimator in Appendix8.2. The conclusion of that analysis is that, in contrast to the standard IWAE gradient estimator, the SNR of the IWAE-DReG estimator exhibits the same scaling behaviour of for both the generation and inference network gradients (i.e., improving in ).
4 Alternative training algorithms
Now, we review alternative training algorithms for deep generative models and derive their doubly reparameterized versions.
4.1 Reweighted Wake Sleep (RWS)
Bornschein & Bengio (2014) introduced RWS, an alternative multi-sample update for latent variable models that uses importance sampling. Computing the gradient of the log marginal likelihood
requires samples from , which is generally intractable. We can approximate the gradient with a self-normalized importance sampling estimator
where . Interestingly, this is precisely the same as the IWAE gradient of , so the RWS update for can be interpreted as maximizing the IWAE lower bound in terms of . Instead of optimizing a joint objective for and , RWS optimizes a separate objective for the inference network. (Bornschein & Bengio, 2014) propose a “wake” update and a “sleep” update for the inference network. Le et al. (2018) provide empirical support for solely using the wake update for the inference network, so we focus on that update.
The wake update approximately minimizes the KL divergence from to . The gradient of the KL term is
The wake update of the inference network approximates the intractable expectation by self-normalized importance sampling
with . Le et al. (2018) note that this update does not suffer from diminishing SNR as increases. However, a downside is that the updates for and are not gradients of a unified objective, so could potentially lead to instability or divergence.
Doubly Reparameterized Reweighted Wake update
The wake update gradient for the inference network (Eq. 8) can be reparameterized
We call the algorithm that uses the single sample Monte Carlo estimator of this expression as the wake update for the inference network RWS-DReG.
Interestingly, the inference network gradient estimator from (Roeder et al., 2017) can be seen as the sum of the IWAE gradient estimator and the wake update of the inference network (as the wake update minimizes, we add the negative of Eq. 9). Their positive results motivate further exploration of convex combinations of IWAE-DReG and RWS-DReG
We refer to the algorithm that uses the single sample Monte Carlo estimator of this expression as DReG(). When , this reduces to RWS-DReG, when , this reduces to IWAE-DReG and when , this reduces STL.
4.2 Jackknife Variational Inference (JVI)
Alternatively, Nowozin (2018) reinterprets the IWAE lower bound as a biased estimator for the log marginal likelihood. He analyzes the bias and introduces a novel family of estimators, Jackknife Variational Inference (JVI), which trade off reduction in bias for increased variance. This additional flexibility comes at the cost of no longer being a stochastic lower bound on the log marginal likelihood. The first-order JVI has significantly reduced bias compared to IWAE, which empirically results in a better estimate of the log marginal likelihood with fewer samples (Nowozin, 2018). For simplicity, we focus on the first-order JVI estimator
It is straightforward to apply our approach to higher order JVI estimators.
Doubly Reparameterized Jackknife Variational Inference (JVI)
The JVI estimator is a linear combination of and sample IWAE estimators, so we can use the doubly reparameterized gradient estimator (Eq. 7) for each term.
5 Related Work
Mnih & Rezende (2016) introduced a generalized framework of Monte Carlo objectives (MCO). The log of an unbiased marginal likelihood estimator is a lower bound on the log marginal likelihood by Jensen’s inequality. In this view, the ELBO can be seen as the MCO corresponding to a single importance sample estimator of the marginal likelihood with as the proposal distribution. Similarly, IWAE corresponds to the -sample estimator. Maddison et al. (2017) show that the tightness of an MCO is directly related to the variance of the underlying estimator of the marginal likelihood.
However, Rainforth et al. (2018) point out issues with gradient estimators of multi-sample lower bounds. In particular, they show that although the IWAE bound is tighter, the standard IWAE gradient estimator’s SNR scales poorly with large numbers of samples, leading to degraded performance. Le et al. (2018) experimentally investigate this phenomenon and provide empirical evidence of this degradation across multiple tasks. They find that RWS (Bornschein & Bengio, 2014) does not suffer from this issue and find that it can outperform models trained with the IWAE bound. We conclude that it is not sufficient to just tighten the bound; it is important to understand the gradient estimators of the tighter bound as well.
Wake-sleep is an alternative approach to fitting deep generative models, first introduced in (Hinton et al., 1995) as a method for training Hemholtz machines. It was extended to the multi-sample setting by (Bornschein & Bengio, 2014) and the sequential setting in (Gu et al., 2015). It has been applied to generative modeling of images (Ba et al., 2015).
To evaluate DReG estimators, we first measure variance and signal-to-noise ratio (SNR) of gradient estimators on a toy example which we can carefully control. Then, we evaluate gradient variance and model learning on MNIST generative modeling, Omniglot generative modeling, and MNIST structured prediction tasks.
6.1 Toy Gaussian
, we sample a set of parameters for the model and inference network close to the optimal parameters (perturbed by zero-mean Gaussian noise with standard deviation), then estimate the gradient of the inference network parameters for increasing number of samples ().
In addition to signal-to-noise ratio (SNR), we plot the squared bias and variance of the gradient estimators333All dimensions of behaved qualitatively similarly, so for clarity, we show curves for a single randomly chosen dimension of . in Fig. 1. The bias is computed relative to the expected value of the IWAE gradient estimator. Importantly, SNR does not penalize estimators that are biased, so trivial constant estimators can have infinite SNR. Thus, it is important to consider additional evaluation measures as well. As increases, the SNR of the IWAE-DReG estimator increases, whereas the SNR of the standard gradient estimator of IWAE goes to , as previously reported. Furthermore, we can see the bias present in the STL estimator. As a check of our implementation, we verified that the observed “bias” for IWAE-DReG was statistically indistinguishable from
6.2 Generative modeling
Training generative models of the binarized MNIST digits dataset is a standard benchmark task for latent variable models. For this evaluation, we used the single latent layer architecture from(Burda et al., 2015). The generative model used 50 Gaussian latent variables with an isotropic prior and passed through two deterministic layers of 200 tanh units to parameterize factorized Bernoulli outputs. The inference network passed
through two deterministic layers of 200 tanh units to parameterize a factorized Gaussian distribution over. Because our interest was in improved gradient estimators and optimization performance, we used the dynamically binarized MNIST dataset, which minimally suffers from overfitting. We used the standard split of MNIST into train, validation, and test sets.
We trained models with the IWAE gradient, the RWS wake update, and with the JVI estimator. In all three cases, the doubly reparameterized gradient estimator reduced variance and as a result substantially improved performance (Fig. 2).
We found similar behavior with different numbers of samples (Fig. 3). Interestingly, the biased gradient estimators STL and RWS-DReG perform best on this task with RWS-DReG slightly outperforming STL. As observed in (Le et al., 2018), RWS increasingly outperforms IWAE as increases. Finally, we experimented with convex combinations of IWAE-DReG and RWS-DReG (right Fig. 3). On this dataset, convex combinations that heavily weighted RWS-DReG had the best performance. However, as we show below, this is task dependent.
6.3 Structured prediction on MNIST
Structured prediction is another common benchmark task for latent variable models. In this task, our goal is to model a complex observation given a context (i.e., model the conditional distribution ). We can use a conditional latent variable model , however, as before, computing the marginal likelihood is generally intractable. It is straightforward to adapt the bounds and techniques from the previous section to this problem.
To evaluate our method in this context, we use the standard task of modeling the bottom half of a binarized MNIST digit from the top half. We use a similar architecture, but now learn a conditional prior distribution where is the top half of the MNIST digit. The conditional prior feeds to two deterministic layers of 200 tanh units to parameterize a factorized Gaussian distribution over . To model the conditional distribution , we concatenate with and feed it to two deterministic layers of 200 tanh units to parameterize factorized Bernoulli outputs.
As in the previous tasks, the doubly reparameterized gradient estimator improves across all three updates (IWAE, RWS, and JVI; Appendix Fig. 7). However, on this task, the biased estimators (STL and RWS) underperform unbiased IWAE gradient estimators (Fig. 4). In particular, RWS becomes unstable later in training. We suspect that this is because RWS does not directly optimize a consistent objective.
In this work, we introduce doubly reparameterized estimators for the updates in IWAE, RWS, and JVI. We demonstrate that across tasks they provide unbiased variance reduction, which leads to improved performance. As a result, we recommend that DReG estimators be used instead of the typical gradient estimators.
Variational Sequential Monte Carlo (Maddison et al., 2017; Naesseth et al., 2018; Le et al., 2018) and Neural Adapative Sequential Monte Carlo (Gu et al., 2015) extend IWAE and RWS to sequential latent variable models, respectively. It would be interesting to develop DReG estimators for these approaches as well.
We found that a convex combination of IWAE-DReG and RWS-DReG performed best, however, the weighting was task dependent. In future work, we intend to apply ideas from (Baydin et al., 2017) to automatically adapt the weighting based on the data.
Finally, the form of the IWAE-DReG estimator (Eq. 7) is surprisingly simple and suggests that there may be a more direct derivation that is applicable to general MCOs.
Ba et al. (2015)
Jimmy Ba, Ruslan R Salakhutdinov, Roger B Grosse, and Brendan J Frey.
Learning Wake-Sleep Recurrent Attention Models.In C Cortes, N D Lawrence, D D Lee, M Sugiyama, and R Garnett (eds.), Advances in Neural Information Processing Systems 28, pp. 2593–2601. 2015.
- Babaeizadeh et al. (2017) Mohammad Babaeizadeh, Chelsea Finn, Dumitru Erhan, Roy H Campbell, and Sergey Levine. Stochastic variational video prediction. International Conference on Learning Representations, 2017.
- Baydin et al. (2017) Atilim Gunes Baydin, Robert Cornish, David Martinez Rubio, Mark Schmidt, and Frank Wood. Online learning rate adaptation with hypergradient descent. arXiv preprint arXiv:1703.04782, 2017.
- Blei et al. (2017) David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A review for statisticians. Journal of the American Statistical Association, 2017.
- Bornschein & Bengio (2014) Jörg Bornschein and Yoshua Bengio. Reweighted wake-sleep. International Conference on Learning Representations, 2014.
- Burda et al. (2015) Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. nternational Conference on Learning Representations, 2015.
- (7) George Casella and Roger L Berger. Statistical inference, volume 2. Duxbury Pacific Grove, CA.
- Chen et al. (2016) Xi Chen, Diederik P Kingma, Tim Salimans, Yan Duan, Prafulla Dhariwal, John Schulman, Ilya Sutskever, and Pieter Abbeel. Variational lossy autoencoder. International Conference on Learning Representations, 2016.
- Chung et al. (2015) Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. In Advances in neural information processing systems, pp. 2980–2988, 2015.
Denton & Fergus (2018)
Emily Denton and Rob Fergus.
Stochastic video generation with a learned prior.
International Conference on Machine Learning, 2018.
- Fraccaro et al. (2016) Marco Fraccaro, Søren Kaae Sønderby, Ulrich Paquet, and Ole Winther. Sequential neural models with stochastic layers. In Advances in neural information processing systems, pp. 2199–2207, 2016.
- Gu et al. (2015) Shixiang Gu, Zoubin Ghahramani, and Richard E Turner. Neural adaptive sequential monte carlo. In Advances in Neural Information Processing Systems, pp. 2629–2637, 2015.
- Gulrajani et al. (2016) Ishaan Gulrajani, Kundan Kumar, Faruk Ahmed, Adrien Ali Taiga, Francesco Visin, David Vazquez, and Aaron Courville. Pixelvae: A latent variable model for natural images. International Conference on Learning Representations, 2016.
- Ha & Schmidhuber (2018) David Ha and Jürgen Schmidhuber. World models. Advances in neural information processing systems, 2018.
Hinton et al. (1995)
Geoffrey E Hinton, Peter Dayan, Brendan J Frey, and Radford M Neal.
The” wake-sleep” algorithm for unsupervised neural networks.Science, 268(5214):1158–1161, 1995.
- Jordan et al. (1999) Michael I Jordan, Zoubin Ghahramani, Tommi S Jaakkola, and Lawrence K Saul. An introduction to variational methods for graphical models. Machine learning, 37(2):183–233, 1999.
- Kingma & Welling (2013) Diederik P Kingma and Max Welling. Auto-encoding variational bayes. nternational Conference on Learning Representations, 2013.
- Kingma et al. (2016) Diederik P Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, and Max Welling. Improved variational inference with inverse autoregressive flow. In Advances in Neural Information Processing Systems, pp. 4743–4751, 2016.
- Krishnan et al. (2015) Rahul G Krishnan, Uri Shalit, and David Sontag. Deep kalman filters. arXiv preprint arXiv:1511.05121, 2015.
- Le et al. (2018) Tuan Anh Le, Adam R Kosiorek, N Siddharth, Yee Whye Teh, and Frank Wood. Revisiting reweighted wake-sleep. arXiv preprint arXiv:1805.10469, 2018.
- Maddison et al. (2017) Chris J Maddison, John Lawson, George Tucker, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Teh. Filtering variational objectives. In Advances in Neural Information Processing Systems, pp. 6573–6583, 2017.
- Mnih & Rezende (2016) Andriy Mnih and Danilo J Rezende. Variational inference for monte carlo objectives. International Conference on Machine Learning, 2016.
Naesseth et al. (2018)
Christian Naesseth, Scott Linderman, Rajesh Ranganath, and David Blei.
Variational sequential monte carlo.
International Conference on Artificial Intelligence and Statistics, pp. 968–977, 2018.
- Nowozin (2018) Sebastian Nowozin. Debiasing evidence approximations: On importance-weighted autoencoders and jackknife variational inference. International Conference on Learning Representations, 2018.
- Rainforth et al. (2018) Tom Rainforth, Adam R Kosiorek, Tuan Anh Le, Chris J Maddison, Maximilian Igl, Frank Wood, and Yee Whye Teh. Tighter variational bounds are not necessarily better. International Conference on Machine Learning, 2018.
- Rezende & Mohamed (2015) Danilo Rezende and Shakir Mohamed. Variational inference with normalizing flows. In International Conference on Machine Learning, pp. 1530–1538, 2015.
Rezende et al. (2014)
Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra.
Stochastic backpropagation and approximate inference in deep generative models.In International Conference on Machine Learning, pp. 1278–1286, 2014.
- Roeder et al. (2017) Geoffrey Roeder, Yuhuai Wu, and David Duvenaud. Sticking the landing: An asymptotically zero-variance gradient estimator for variational inference. Advances in Neural Information Processing Systems, 2017.
Ronald J Williams.
Simple statistical gradient-following algorithms for connectionist reinforcement learning.Machine learning, 8(3-4):229–256, 1992.
8.1 Equivalence between REINFORCE gradient and reparameterization trick gradient
Given a function , we have
for a reparameterizable distribution . To see this, consider the following function of and
On one hand,
via the REINFORCE gradient. On the other hand,
via the reparameterization trick. Thus, we conclude that
for any choice of , in particular for .
8.2 Informal asymptotic analysis
At a high level, Rainforth et al. (2018) show that the expected value of the IWAE gradient of the inference network collapses to zero with rate , while its standard deviation is only shrinking at a rate of . This is the essence of the problem that results in the SNR (expectation divided by standard deviation) of the inference network gradients going to zero at a rate , worsening with . In contrast, Rainforth et al. (2018) show that the generation network gradients scales like , improving with .
Because the IWAE-DReG estimator is unbiased, we cannot hope to change the scaling of the expected value in , but we can hope to change the scaling of the variance. In particular, in this subsection, we provide an informal argument, via the delta method, that the standard deviation of IWAE-DReG scales like , which results in an overall scaling of for the inference network gradient’s SNR (i.e., increasing with ). Thus, the SNR of the IWAE-DReG estimator improves similarly in for both inference and generation networks.
We will appeal to the delta method on a two-variable function . Define the following notation for the partials of evaluated at the mean of random variables ,
The delta method approximation of is given by (Section 5.5 of Casella & Berger ),
Now, assume without loss of generality that is a single real-valued parameter. Let , , and . Let , then is the IWAE-DReG estimator whose variance we seek to understand. Letting and we get in this case after cancellations,
Because are all mutually independent, we get . Similarly for and . Because the and are identically distributed and independent for , we have . All together we can see that scales like . Thus, the standard deviation scales like .
8.3 Unified Surrogate Objectives for Estimators
In the main text, we assumed that and were disjoint, however, it can be helpful to share parameters between and (e.g., (Fraccaro et al., 2016)). With the IWAE bound, we differentiate a single objective with respect to both the and parameters. Thus it is straightforward to adapt IWAE and IWAE-DReG to the shared parameter setting. In this section, we discuss how to deal with shared parameters in RWS.
Suppose that both and are parameterized by . If we denote the unshared parameters of by , then we can restrict the RWS wake update to only . Alternatively, with a modified RWS wake update, we can derive a single surrogate objective for each scenario such that taking the gradient with respect to results in the proper update. For clarity, we introduce the following modifier notation for , , and which are functions of and . We use to mean with stopped gradients with respect to , to mean with stopped gradients with respect to (but not is not stopped in ), and to mean with stopped gradients for all variables. Then, we can use the following surrogate objectives:
The only subtle difference is that DReG() does not correspond exactly to STL due to the scaling between terms: