The Usual Suspects? Reassessing Blame for VAE Posterior Collapse

12/23/2019 ∙ by Bin Dai, et al. ∙ Tsinghua University 7

In narrow asymptotic settings Gaussian VAE models of continuous data have been shown to possess global optima aligned with ground-truth distributions. Even so, it is well known that poor solutions whereby the latent posterior collapses to an uninformative prior are sometimes obtained in practice. However, contrary to conventional wisdom that largely assigns blame for this phenomena on the undue influence of KL-divergence regularization, we will argue that posterior collapse is, at least in part, a direct consequence of bad local minima inherent to the loss surface of deep autoencoder networks. In particular, we prove that even small nonlinear perturbations of affine VAE decoder models can produce such minima, and in deeper models, analogous minima can force the VAE to behave like an aggressive truncation operator, provably discarding information along all latent dimensions in certain circumstances. Regardless, the underlying message here is not meant to undercut valuable existing explanations of posterior collapse, but rather, to refine the discussion and elucidate alternative risk factors that may have been previously underappreciated.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The variational autoencoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) represents a powerful generative model of data points that are assumed to possess some complex yet unknown latent structure. This assumption is instantiated via the marginalized distribution

(1)

which forms the basis of prevailing VAE models. Here is a collection of unobservable latent factors of variation that, when drawn from the prior , are colloquially said to generate an observed data point through the conditional distribution . The latter is controlled by parameters that can, at least conceptually speaking, be optimized by maximum likelihood over given available training examples.

In particular, assuming training points

, maximum likelihood estimation is tantamount to minimizing the negative log-likelihood expression

. Proceeding further, because the marginalization over in (1) is often intractable, the VAE instead minimizes a convenient variational upper bound given by  

(2)

with equality iff for all . The additional parameters govern the shape of the variational distribution that is designed to approximate the true but often intractable latent posterior .

The VAE energy from (2

) is composed of two terms, a data-fitting loss that borrows the basic structure of an autoencoder (AE), and a KL-divergence-based regularization factor. The former incentivizes assigning high probability to latent codes

that facilitate accurate reconstructions of each . In fact, if is a Dirac delta function, this term is exactly equivalent to a deterministic AE with data reconstruction loss defined by . Overall, it is because of this association that is generally referred to as the encoder distribution, while denotes the decoder distribution. Additionally, the KL regularizer pushes the encoder distribution towards the prior without violating the variational bound.

For continuous data, which will be our primary focus herein, it is typical to assume that

(3)

where

is a scalar variance parameter, while the Gaussian moments

, , and

are computed via feedforward neural network layers. The encoder network parameterized by

takes as an input and outputs and . Similarly the decoder network parameterized by converts a latent code into . Given these assumptions, the generic VAE objective from (2) can be refined to

excluding an inconsequential factor of . This expression can be optimized over using SGD and a simple reparameterization strategy (Kingma & Welling, 2014; Rezende et al., 2014) to produce parameter estimates . Among other things, new samples approximating the training data can then be generated via the ancestral process and .

Although it has been argued that global minima of (1) may correspond with the optimal recovery of ground truth distributions in certain asymptotic settings (Dai & Wipf, 2019), it is well known that in practice, VAE models are at risk of converging to degenerate solutions where, for example, it may be that . This phenomena, commonly referred to as VAE posterior collapse (He et al., 2019; Razavi et al., 2019), has been acknowledged and analyzed from a variety of different perspectives as we detail in Section 2. That being said, we would argue that there remains lingering ambiguity regarding the different types and respective causes of posterior collapse. Consequently, Section 3 provides a useful taxonomy that will serve to contextualize our main technical contributions. These include the following:

  • [leftmargin=*]

  • Building upon existing analysis of affine VAE decoder models, in Section 4 we prove that even arbitrarily small nonlinear activations can introduce suboptimal local minima exhibiting posterior collapse.

  • We demonstrate in Section 5 that if the encoder/decoder networks are incapable of sufficiently reducing the VAE reconstruction errors, even in a deterministic setting with no KL-divergence regularizer, there will exist an implicit lower bound on the optimal value of . Moreover, we prove that if this is sufficiently large, the VAE will behave like an aggressive thresholding operator, enforcing exact posterior collapse, i.e., .

  • Based on these observations, we present experiments in Section 6 establishing that as network depth/capacity is increased, even for deterministic AE models with no regularization, reconstruction errors become worse. This bounds the effective VAE trade-off parameter such that posterior collapse is essentially inevitable. Collectively then, we provide convincing evidence that posterior collapse is, at least in certain settings, the fault of deep AE local minima, and need not be exclusively a consequence of usual suspects such as the KL-divergence term.

We conclude in Section 7 with practical take-home messages, and motivate the search for improved AE architectures and training regimes that might be leveraged by analogous VAE models.

2 Recent Work and the Usual Suspects for Instigating Collapse

Posterior collapse under various guises is one of the most frequently addressed topics related to VAE performance. Depending on the context, arguably the most common and seemingly transparent suspect for causing collapse is the KL regularization factor that is obviously minimized by

. This perception has inspired various countermeasures, including heuristic annealing of the KL penalty or KL warm-start

(Bowman et al., 2015; Huang et al., 2018; Sønderby et al., 2016), tighter bounds on the log-likelihood (Burda et al., 2015; Rezende & Mohamed, 2015), more complex priors (Bauer & Mnih, 2018; Tomczak & Welling, 2018), modified decoder architectures (Cai et al., 2017; Dieng et al., 2018; Yeung et al., 2017), or efforts to explicitly disallow the prior from ever equaling the variational distribution (Razavi et al., 2019). Thus far though, most published results do not indicate success generating high-resolution images, and in the majority of cases, evaluations are limited to small images and/or relatively shallow networks. This suggests that there may be more nuance involved in pinpointing the causes and potential remedies of posterior collapse. One notable exception though is the BIVA model from (Maaløe et al., 2019), which employs a bidirectional hierarchy of latent variables, in part to combat posterior collapse. While improvements in NLL scores have been demonstrated with BIVA using relatively deep encoder/decoders, this model is significantly more complex and difficult to analyze.

On the analysis side, there have been various efforts to explicitly characterize posterior collapse in restricted settings. For example, Lucas et al. (2019) demonstrate that if is fixed to a sufficiently large value, then a VAE energy function with an affine decoder mean will have minima that overprune latent dimensions. A related linearized approximation to the VAE objective is analyzed in (Rolinek et al., 2019); however, collapsed latent dimensions are excluded and it remains somewhat unclear how the surrogate objective relates to the original. Posterior collapse has also been associated with data-dependent decoder covariance networks (Mattei & Frellsen, 2018), which allows for degenerate solutions assigning infinite density to a single data point and a diffuse, collapsed density everywhere else. Finally, from the perspective of training dynamics, (He et al., 2019) argue that a lagging inference network can also lead to posterior collapse.

3 Taxonomy of Posterior Collapse

Although there is now a vast literature on the various potential causes of posterior collapse, there remains ambiguity as to exactly what this phenomena is referring to. In this regard, we believe that it is critical to differentiate five subtle yet quite distinct scenarios that could reasonably fall under the generic rubric of posterior collapse:

  1. [label=(),leftmargin=*]

  2. Latent dimensions of that are not needed for providing good reconstructions of the training data are set to the prior, meaning at any superfluous dimension . Along other dimensions will be near zero and will provide a usable predictive signal leading to accurate reconstructions of the training data. This case can actually be viewed as a desirable form of selective posterior collapse that, as argued in (Dai & Wipf, 2019), is a necessary (albeit not sufficient) condition for generating good samples.

  3. The decoder variance is not learned but fixed to a large value111Or equivalently, a KL scaling parameter such as used by the -VAE (Higgins et al., 2017) is set too large. such that the KL term from (2) is overly dominant, forcing most or all dimensions of to follow the prior . In this scenario, the actual global optimum of the VAE energy (conditioned on being fixed) will lead to deleterious posterior collapse and the model reconstructions of the training data will be poor. In fact, even the original marginal log-likelihood can potentially default to a trivial/useless solution if is fixed too large, assigning a small marginal likelihood to the training data, provably so in the affine case (Lucas et al., 2019).

  4. As mentioned previously, if the Gaussian decoder covariance is learned as a separate network structure (instead of simply ), there can exist degenerate solutions that assign infinite density to a single data point and a diffuse, isotropic Gaussian elsewhere (Mattei & Frellsen, 2018). This implies that (1) can be unbounded from below at what amounts to a posterior collapsed solution and bad reconstructions almost everywhere.

  5. When powerful non-Gaussian decoders are used, and in particular those that can parameterize complex distributions regardless of the value of (e.g., PixelCNN-based (Van den Oord et al., 2016)), it is possible for the VAE to assign high-probability to the training data even if (Alemi et al., 2017; Bowman et al., 2015; Chen et al., 2016). This category of posterior collapse is quite distinct from categories (ii) and (iii) above in that, although the reconstructions are similarly poor, the associated NLL scores can still be good.

  6. The previous four categories of posterior collapse can all be directly associated with emergent properties of the VAE global minimum under various modeling conditions. In contrast, a fifth type of collapse exists that is the explicit progeny of bad VAE local minima. More specifically, as we will argue shortly, when deeper encoder/decoder networks are used, the risk of converging to bad, overregularized solutions increases.

The remainder of this paper will primarily focus on category (v), with brief mention of the other types for comparison purposes where appropriate. Our rationale for this selection bias is that, unlike the others, category (i) collapse is actually advantageous and hence need not be mitigated. In contrast, while category (ii) is undesirable, it be can be avoided by learning

. As for category (iii), this represents an unavoidable consequence of models with flexible decoder covariances capable of detecting outliers

(Dai et al., 2019). In fact, even simpler inlier/outlier decomposition models such as robust PCA are inevitably at risk for this phenomena (Candès et al., 2011). Regardless, when this problem goes away. And finally, we do not address category (iv) in depth simply because it is unrelated to the canonical Gaussian VAE models of continuous data that we have chosen to examine herein. Regardless, it is still worthwhile to explicitly differentiate these five types and bare them in mind when considering attempts to both explain and improve VAE models.

4 Insights from Simplified Cases

Because different categories of posterior collapse can be impacted by different global/local minima structures, a useful starting point is a restricted setting whereby we can comprehensively characterize all such minima. For this purpose, we first consider a VAE model with the decoder network set to an affine function. As is often assumed in practice, we choose , where is a scalar parameter within the parameter set . In contrast, for the mean function we choose for some weight matrix

and bias vector

. The encoder can be arbitrarily complex (although the optimal structure can be shown to be affine as well).

Given these simplifications, and assuming the training data has

nonzero singular values, it has been demonstrated that at any global optima, the columns of

will correspond with the first principal components of provided that we simultaneously learn or set it to the optimal value (which is available in closed form) (Dai et al., 2019; Lucas et al., 2019; Tipping & Bishop, 1999). Additionally, it has also be shown that no spurious, suboptimal local minima will exist. Note also that if the same basic conclusions still apply; however, will only have nonzero columns, each corresponding with a different principal component of the data. The unused latent dimensions will satisfy , which represents the canonical form of the benign category (i) posterior collapse. Collectively, these results imply that if we converge to any local minima of the VAE energy, we will obtain the best possible linear approximation to the data using a minimal number of latent dimensions, and malignant posterior collapse is not an issue, i.e., categories (ii)-(v) will not arise.

Even so, if instead of learning , we choose a fixed value that is larger than any of the significant singular values of , then category (ii) posterior collapse can be inadvertently introduced. More specifically, let denote the number of such singular values that are smaller than some fixed value. Then along latent dimensions , and the corresponding columns of will be set to zero at the global optima (conditioned on this fixed ), regardless of whether or not these dimensions are necessary for accurately reconstructing the data. And it has been argued that the risk of this type of posterior collapse at a conditionally-optimal global minimum will likely be inherited by deeper models as well (Lucas et al., 2019), although learning can ameliorate this problem.

Of course when we move to more complex architectures, the risk of bad local minima or other suboptimal stationary points becomes a new potential concern, and it is not clear that the affine case described above contributes to reliable, predictive intuitions. To illustrate this point, we will now demonstrate that the introduction of an arbitrarily small nonlinearity can nonetheless produce a pernicious local minimum that exhibits category (v) posterior collapse. For this purpose, we assume the decoder mean function

(5)

The function is nothing more than a soft-threshold operator as is commonly used in neural network architectures designed to reflect unfolded iterative algorithms for representation learning (Gregor & LeCun, 2010; Sprechmann et al., 2015). In the present context though, we choose this nonlinearity largely because it allows (5) to reflect arbitrarily small perturbations away from a strictly affine model, and indeed if the exact affine model is recovered. Collectively, these specifications lead to the parameterization and and energy (excluding irrelevant scale factors and constants) given by

where and denote arbitrary encoder moments for data point (this is consistent with the assumption of an arbitrarily complex encoder as used in previous analysis of affine decoder models). Now define , with . We then have the following result (all proofs are deferred to the appendices): For any , there will always exist data sets such that (4) has a global minimum that perfectly reconstructs the training data, but also a bad local minimum characterized by

(7)

Hence the moment we allow for nonlinear (or more precisely, non-affine) decoders there can exist a poor local minimum, across all parameters including a learnable , that exhibits category (v) posterior collapse.222This result mirrors related efforts examining linear DNNs, where it has been previously demonstrated that under certain conditions, all local minima are globally optimal (Kawaguchi, 2016), while small nonlinearities can induce bad local optima (Yun et al., 2019). However, the loss surface of these models is completely different from a VAE, and hence we view Proposition 4 as a complementary result. In other words, no predictive information about passes through the latent space, and a useless/non-informative distribution emerges that is incapable of assigning high probability to the data (except obviously in the trivial degenerate case where all the data points are equal to the empirical mean ). We will next investigate the degree to which such concerns can influence behavior in arbitrarily deep architectures.

5 Extrapolating to Practical Deep Architectures

Previously we have demonstrated the possibility of local minima aligned with category (v) posterior collapse the moment we allow for decoders that deviate ever so slightly from an affine model. But nuanced counterexamples designed for proving technical results notwithstanding, it is reasonable to examine what realistic factors are largely responsible for leading optimization trajectories towards such potential bad local solutions. For example, is it merely the strength of the KL regularization term, and if so, why can we not just use KL warm-start to navigate around such points? In this section we will elucidate a deceptively simple, alternative risk factor that will be corroborated empirically in Section 6.

From the outset, we should mention that with deep encoder/decoder architectures commonly used in practice, a stationary point can more-or-less always exist at solutions exhibiting posterior collapse. As a representative and ubiquitous example, please see Appendix D. But of course without further details, this type of stationary point could conceivably manifest as a saddle point (stable or unstable), a local maximum, or a local minimum. For the strictly affine decoder model mentioned in Section 4

, there will only be a harmless unstable saddle point at any collapsed solution (the Hessian has negative eigenvalues). In contrast, for the special nonlinear case elucidated via Proposition

4 we can instead have a bad local minima. We will now argue that as the depth of common feedforward architectures increases, the risk of converging to category (v)-like solutions with most or all latent dimensions stuck at bad stationary points can also increase.

Somewhat orthogonal to existing explanations of posterior collapse, our basis for this argument is not directly related to the VAE KL-divergence term. Instead, we consider a deceptively simple yet potentially influential alternative: Unregularized, deterministic AE models can have bad local solutions with high reconstruction errors when sufficiently deep. This in turn can directly translate to category (v) posterior collapse when training a corresponding VAE model with a matching deep architecture. Moreover, to the extent that this is true, KL warm-start or related countermeasures will likely be ineffective in avoiding such suboptimal minima. We will next examine these claims in greater depth followed by a discussion of practical implications.

5.1 From Deeper Architectures to Inevitable Posterior Collapse

Consider the deterministic AE model formed by composing the encoder mean and decoder mean networks from a VAE model, i.e., reconstructions are computed via . We then train this AE to minimize the squared-error loss , producing parameters . Analogously, the corresponding VAE trained to minimize (1) arrives at a parameter set denoted . In this scenario, it will typically follow that

(8)

meaning that the deterministic AE reconstruction error will generally be smaller than the stochastic VAE version. Note that if , the VAE defaults to the same deterministic encoder as the AE and hence will have identical representational capacity; however, the KL regularization prevents this from happening, and any can only make the reconstructions worse.333Except potentially in certain contrived adversarial conditions that do not represent practical regimes. Likewise, the KL penalty factor can further restrict the effective capacity and increase the reconstruction error of the training data. Beyond these intuitive arguments, we have never empirically found a case where (8) does not hold (see Section 6 for examples).

We next define the set

(9)

for any . Now suppose that the chosen encoder/decoder architecture is such that with high probability, achievable optimization trajectories (e.g., via SGD or related) lead to parameters , i.e., . It then follows that the optimal VAE noise variance denoted , when conditioned on practically-achievable values for other network parameters, will satisfy

(10)

The equality in (10) can be confirmed by simply differentiating the VAE cost w.r.t. and equating to zero, while the inequality comes from (8) and the fact that .

From inspection of the VAE energy from (1), it is readily apparent that larger values of will discount the data-fitting term and therefore place greater emphasis on the KL divergence. Since the latter is minimized when the latent posterior equals the prior, we might expect that whenever and therefore is increased per (10), we are at a greater risk of nearing collapsed solutions. But the nature of this approach is not at all transparent, and yet this subtlety has important implications for understanding the VAE loss surface in regions at risk of posterior collapse.

For example, one plausible hypothesis is that only as do we risk full category (v) collapse. If this were the case, we might have less cause for alarm since the reconstruction error and by association will typically be bounded from above at any local minimizer. However, we will now demonstrate that even finite values can exactly collapse the posterior. In formally showing this, it is helpful to introduce a slightly narrower but nonetheless representative class of VAE models.

Specifically, let , i.e., the VAE data term evaluated at a single data point without the scale factor. We then define a well-behaved VAE as a model with energy function (1) designed such that and are Lipschitz continuous gradients for all . Furthermore, we specify a non-degenerate decoder as any with set to a value such that for some constant that can be arbitrarily small. This ensures that is an increasing function of , a quite natural stipulation given that increasing the encoder variance will generally only serve to corrupt the reconstruction, unless of course the decoder is completely blocking the signal from the encoder. In the latter degenerate situation, it would follow that , which is more-or-less tantamount to category (v) posterior collapse.

Based on these definitions, we can now present the following: For any well-behaved VAE with arbitrary, non-degenerate decoder , there will always exist a such that the trivial solution and will have lower cost.

Around any evaluation point, the sufficient condition we applied to demonstrate posterior collapse (see proof details) can also be achieved with some if we allow for partial collapse, i.e., along some but not all latent dimensions . Overall, the analysis loosely suggests that the number of dimensions vulnerable to exact collapse will increase monotonically with .

Proposition 5.1 also provides evidence that the VAE behaves like a strict thresholding operator, completely shutting off latent dimensions using a finite value for . This is analogous to the distinction between using the versus norm for solving regularized regression problems of the standard form  , where is a design matrix and is a penalty function. When is the norm, some or all elements of can be pruned to exactly zero with a sufficiently large but finite Zhao & Yu (2006). In contrast, when the norm is applied, the coefficients will be shrunk to smaller values but never pushed all the way to zero unless .

5.2 Practical Implications

In aggregate then, if the AE base model displays unavoidably high reconstruction errors, this implicitly constrains the corresponding VAE model to have a large optimal value, which can potentially lead to undesirable posterior collapse per Proposition 5.1. In Section 6 we will demonstrate empirically that training unregularized AE models can become increasingly difficult and prone to bad local minima (or at least bad stable stationary points) as the depth increases; and this difficulty can persist even with counter-measures such as skip connections. Therefore, from this vantage point we would argue that it is the AE base architecture that is effectively the guilty party when it comes to category (v) posterior collapse.

The perspective described above also helps to explain why heuristics like KL warm-start are not always useful for improving VAE performance. With the standard Gaussian model (1) considered herein, KL warm-start amounts to adopting a pre-defined schedule for incrementally increasing starting from a small initial value, the motivation being that a small will steer optimization trajectories away from overregularized solutions and posterior collapse.

However, regardless of how arbitrarily small may be fixed at any point during this process, the VAE reconstructions are not likely to be better than the analogous deterministic AE (which is roughly equivalent to forcing within the present context). This implies that there can exist an implicit as computed by (10) that can be significantly larger such that, even if KL warm-start is used, the optimization trajectory may well lead to a collapsed posterior stationary point that has this as the optimal value in terms of minimizing the VAE cost with other parameters fixed. Note that if full posterior collapse does occur, the gradient from the KL term will equal zero and hence, to be at a stationary point it must be that the data term gradient is also zero. In such situations, varying manually will not impact the gradient balance anyway.

6 Empirical Assessments

In this section we empirically demonstrate the existence of bad AE local minima with high reconstruction errors at increasing depth, as well as the association between these bad minima and imminent VAE posterior collapse. For this purpose, we first train fully connected AE and VAE models with , , , , and hidden layers on the Fashion-MNIST dataset (Xiao et al., 2017). Each hidden layer is

-dimensional and followed by ReLU activations (see Appendix

A for further details). The reconstruction error is shown in Figure 1(left). As the depth of the network increases, the reconstruction error of the AE model first decreases because of the increased capacity. However, when the network becomes too deep, the error starts to increase, indicating convergence to a bad local minima (or at least stable stationary point/plateau) that is unrelated to KL-divergence regularization. The reconstruction error of a VAE model is always worse than that of the corresponding AE model as expected. Moreover, while KL warm-start/annealing can help to improve the VAE reconstructions to some extent, performance is still worse than the AE as expected.

We next train AE and VAE models using a more complex convolutional network on Cifar100 data (Krizhevsky & Hinton, 2009). At each spatial scale, we use to convolution layers followed by ReLU activations. We also apply max pooling to downsample the feature maps to a smaller spatial scale in the encoder and use a transposed convolution layer to upscale the feature map in the decoder. The reconstruction errors are shown in Figure 1(middle). Again, the trend is similar to the fully-connected network results. See Appendix A

for an additional ImageNet example.

It has been argued in the past that skip connections can increase the mutual information between observations and the inferred latent variables (Dieng et al., 2018), reducing the risk of posterior collapse. And it is well-known that ResNet architectures based on skip connections can improve performance on numerous recognition tasks (He et al., 2016). To this end, we train a number of AE models using ResNet-inspired encoder/decoder architectures on multiple datasets including Cifar10, Cifar100, SVHN and CelebA. Similar to the convolution network structure from above, we use , , and residual blocks within each spatial scale. Inside each block, we apply to convolution layers. For aggregate comparison purposes, we normalize the reconstruction error obtained on each dataset by dividing it with the corresponding error produced by the most shallow network structure ( residual block with convolution layers). We then average the normalized reconstruction errors over all four datasets. The average normalized errors are shown in Figure 1(right), where we observe that adding more convolution layers inside each residual block can increase the reconstruction error when the network is too deep. Moreover, adding more residual blocks can also lead to higher reconstruction errors. And empirical results obtained using different datasets and networks architectures, beyond the conditions of Figure 1, also show a general trend of increased reconstruction error once the effective depth is sufficiently deep.

Figure 1: Reconstruction errors for various encoder/decoder models of varying complexity. Left: Fully connected networks with different depths trained on Fashion-MNIST. Middle: Convolution networks with increasing depth/# of spatial scales trained on Cifar100. Right: Averaged AE results from residual networks with varying number of residual blocks and block depth trained on SVHN, Cifar10, Cifar100 and CelebA. In all plots, once the encoder/decoder complexity is sufficiently high, the reconstruction errors begin to increase.

We emphasize that in all these models, as the network complexity/depth increases, the simpler models are always contained within the capacity of the larger ones. Therefore, because the reconstruction error on the training data is becoming worse, it must be the case that the AE is becoming stuck at bad local minima or plateaus. Again since the AE reconstruction error serves as a probable lower bound for that of the VAE model, a deeper VAE model will likely suffer the same problem, only exacerbated by the KL-divergence term in the form of posterior collapse. This implies that there will be more values moving closer to as the VAE model becomes deeper; similarly values will push towards 0. The corresponding dimensions will encode no information and become completely useless.

To help corroborate this association between bad AE local minima and VAE posterior collapse, we plot histograms of VAE values as network depth is varied in Figure 2. The models are trained on CelebA and the number of convolution layers in each spatial scale is , and from left to right. As the depth increases, the reconstruction error becomes larger and there are more near .

Figure 2: Histogram of values as VAE encoder/decoder network depth is varied. There are , and convolution layers in each spatial scale from left to right. As depth increases, the reconstruction error grows and more values are near , indicative of impending posterior collapse.

7 Discussion

In this work we have emphasized the previously-underappreciated role of bad local minima in trapping VAE models at posterior collapsed solutions. Unlike affine decoder models whereby all local minima are provably global, Proposition 4 stipulates that even infinitesimal nonlinear perturbations can introduce suboptimal local minima characterized by deleterious posterior collapse. Furthermore, we have demonstrated that the risk of converging to such a suboptimal minima increases with decoder depth. In particular, we outline the following practically-likely pathway to posterior collapse:

  1. [leftmargin=*]

  2. Deeper AE architectures are essential for modeling high-fidelity images or similar, and yet counter-intuitively, increasing AE depth can actually produce larger reconstruction errors on the training data because of bad local minima (with or without skip connections). An analogous VAE model with the same architecture will likely produce even worse reconstructions because of the additional KL regularization term, which is not designed to steer optimization trajectories away from poor reconstructions.

  3. At any such bad local minima, the value of will necessarily be large, i.e., if it is not large, we cannot be at a local minimum.

  4. But because of the thresholding behavior of the VAE as quantified by Proposition 5.1, as becomes larger there is an increased risk of exact posterior collapse along excessive latent dimensions. And complete collapse along all dimensions will occur for some finite sufficiently large. Furthermore, explicitly forcing to be small does not fix this problem, since in some sense the implicit is still large as discussed in Section 5.2.

While we believe that this message is interesting in and of itself, there are nonetheless several practically-relevant implications. For example, complex hierarchical VAEs like BIVA notwithstanding, skip connections and KL warm-start have modest ability to steer optimization trajectories towards good solutions; however, this underappreciated limitation will not generally manifest until networks are sufficiently deep as we have considered. Fortunately, any advances or insights gleaned from developing deeper unregularized AEs, e.g., better AE architectures, training procedures, or initializations (Li & Nguyen, 2019), could likely be adapted to reduce the risk of posterior collapse in corresponding VAE models.

In closing, we should also mention that, although this work has focused on Gaussian VAE models, many of the insights translate into broader non-Gaussian regimes. For example, a variety of recent VAE enhancements involve replacing the fixed Gaussian latent-space prior with a parameterized non-Gaussian alternative (Bauer & Mnih, 2019; Tomczak & Welling, 2018). This type of modification provides greater flexibility in modeling the aggregated posterior in the latent space, which is useful for generating better samples (Makhzani et al., 2016). However, it does not immunize VAEs against the bad local minima introduced by deep decoders, and good reconstructions are required by models using Gaussian or non-Gaussian priors alike. Therefore, our analysis herein still applies in much the same way.

Appendix A Network Structure, Experimental Settings, and Additional ImageNet Results

Three different kinds of network structures are used in the experiments: fully connected networks, convolution networks, and residual networks. For all these structures, we set the dimension of the latent variable to . We then describe the network details accordingly.

Fully Connected Netowrk: This experiment is only applied on the simple Fashion-MNIST dataset, which contains black-and-while images. These images are first flattened to a dimensional vector. Both the encoder and decoder have multiple number of -dimensional hidden layers, each followed by ReLU activations.

Convolution Netowrk: The original images are either (Cifar10, Cifar100 and SVHN) or (CelebA and ImageNet). In the encoder, we use a multiple number (denoted as ) of convolution layers for each spatial scale. Each convolution layer is followed by a ReLU activation. Then we use a max pooling to downsample the feature map to a smaller spatial scale. The number of channels is doubled when the spatial scale is halved. We use channels when the spatial scale is . When the spatial scale reaches (there should be channels in this feature map), we use an average pooling to transform the feature map to a vector, which is then transformed into the latent variable using a fully connected layer. In the decoder, the latent variable is first transformed to a -dimensional vector using a fully connected layer and then reshaped to . Again in each spatial scale, we use transpose convolution layer to upscale the feature map and halve the number of channels followed by convolution layers. Each convolution and transpose convolution layer is followed by a ReLU activation layer. When the spatial scale reaches that of the original image, we use a convolution layer to transofrm the feature map to channels.

Residual Network: The network structure of the residual network is similar to that of a convolution network described above. We simply replace the convolution layer with a residual block. Inside the residual block, we use different numbers of convolution numbers. (The typical number of convolution layers inside a residual block is or . In our experiments, we try , , and .)

Training Details:

All the experiments with different network structures and datasets are trained in the same procedure. We use the Adam optimization method and the default optimizer hyper parameters in Tensorflow. The batch size is

and we train the model for iterations. The initial learning rate is and it is halved every iterations.

Additional Results on ImageNet: We also show the reconstruction error for convolution networks with increasing depth trained on ImageNet in Figure 3. The trend is the same as that in Figure 1.

Figure 3: Reconstruction error for Convolution networks with increasing depth/# of spatial scales trained on ImageNet.

Appendix B Proof of Proposition 4

While the following analysis could in principle be extended to more complex datasets, for our purposes it is sufficient to consider the following simplified case for ease of exposition. Specifically, we assume that , set , and .

Additionally, we will use the following basic facts about the Gaussian tail. Note that (12)-(13) below follow from integration by parts; see Orjebin (2014). Let ;

be the pdf and cdf of the standard normal distribution, respectively. Then

(11)
(12)
(13)

b.1 Suboptimality of (7)

Under the specificed conditions, the energy from (7) has a value of . Thus to show that it is not the global minimum, it suffices to show that the following VAE, parameterized by , has energy as :

This follows because, given the stated parameters, we have that

(i) holds when ; to see this, denote . Then

In the RHS above ; using (11)-(13) we then have

when . Thus

and

or

and we can see (i) holds.

b.2 Local Optimality of (7)

We will now show that at (7), the Hessian of the energy has structure

where p.d. means the corresponding submatrix is positive definite and independent of other parameters. While the Hessian is in the subspace of , we can show that for VAEs that are only different from (7) by , the gradient always points back to (7). Thus (7) is a strict local minima.

First we compute the Hessian matrix block-wise. We will identify with the vector , and use the shorthand notations , , , where (recall that is a scalar in this proof).

  1. The second-order derivatives involving can be expressed as

    (14)

    and therefore all second-order derivatives involving will have the form

    (15)

    where are some arbitrary functions that are finite at (7). Since , the above always evaluates to at .

  2. For second-order derivatives involving , we have

    and

    and and will also have the form of (15), thus both equal at .

  3. Next consider second-order derivatives involving or . Since the KL part of the energy, , only depends on and , and have p.d. Hessian at (7) independent of other parameters, it suffices to calculate the derivatives of the reconstruction error part, denoted as . Since

    all second-order derivatives will have the form of (15), and equal at .

  4. For , we can calculate that at (7).

Now, consider VAE parameters that are only different from (7) in . Plugging into (14), we have

As always holds, we can see that the gradient points back to (7). This concludes our proof of (7) being a strict local minima.

Appendix C Proof of Proposition 5.1

We begin by assuming an arbitrarily complex encoder for convenience. This allows us to remove the encoder-sponsored amortized inference and instead optimize independent parameters and separately for each data point. Later we will show that this capacity assumption can be dropped and the main result still holds.

We next define

(16)

which are nothing more than the concatenation of all of the decoder means and variances from each data point into the respective column vectors. It is also useful to decompose the assumed non-degenerate decoder parameters via

(17)

where is a scalar such that . Note that we can always reparameterize an existing deep architecture to extract such a latent scaling factor which we can then hypothetically optimize separately while holding the remaining parameters fixed. Finally, with slight abuse of notation, we may then define the function

This is basically just the original function summed over all training points, with fixed at the corresponding values extracted from while serves as a free scaling parameter on the decoder.

Based on the assumption of Lipschitz continuous gradients, we can always create the upper bound

where is the Lipschitz constant of the gradients and we have adopted and to simplify notation. Equality occurs at the evaluation point . However, this bound does not account for the fact that we know (i.e., is increasing w.r.t. ) and that . Given these assumptions, we can produce the refined upper bound

(20)

where