A Variational Auto-encoder [kingma2013auto, rezende2014stochastic] (VAE) is a type of generative model that has been widely used in a variety of applications, such as image generation [yan2016attribute2image, liu2017unsupervised], dialog generation [zhang2020dive, wang2019topic, hu2017toward], and disentangled representation learning [higgins2017beta, kim2018disentangling]. A VAE is composed of an encoder that maps input data to a distribution of latent variables, and a decoder that maps a sampled latent code back to the data space. The encoder and decoder are trained jointly by minimizing the reconstruction loss between input data and output of the decoder, plus the KL-divergence between the latent distribution and a pre-defined prior, such as the unit Gaussian.
The two terms in the objective have contrasting effects: the reconstruction term improves the reconstruction quality while neglecting the structure of the latent space; the KL-divergence term regularizes the latent space, possibly at the cost of some overlapping between latent variables, hence resulting in a more noisy encoding. In practice, depending on the tasks, we often want to find a desirable position in such a trade-off. For example, in text or image generation, the goal is to generate diverse and new text or images, as opposed to reproducing one of the training samples. If KL-divergence is too low, the output samples have very limited diversity (known as the KL-vanishing or posterior collapse problem [bowman2015generating]). To increase output diversity, it becomes advantageous to artificially increase the KL-divergence. Conversely, in disentangled representation learning [denton2017unsupervised], we want to ensure independence among latent variables. In this context, artificially decreasing KL-divergence is desirable (e.g., by increasing its weight in a VAE’s objective function, which is known as the -VAE), as it imposes a stricter information bottleneck that forces the learned latent factors to be more independent (i.e., non-redundant), leading to a better disentangling.
The above examples suggest that a useful extension of VAEs is one that allows users to explicitly control the KL-divergence term in the objective. Accordingly, much theoretical analysis was done to justify putting more emphasis on one of the two terms in the VAE’s objective function, rather than treating them equally [rezende2018taming, klushyn2019learning, alemi2017fixing, zhao2019infovae, burgess2018understanding]. In this paper, we develop a systematic method method to control the KL-divergence. Previous solutions mainly assign a fixed or learnable weight for the KL term [higgins2017beta, burgess2018understanding, alemi2017fixing, dai2019diagnosing, asperti2020balancing] to manipulate the value of KL-divergence. However, they cannot accurately control the value of the KL-divergence or achieve a good trade-off with reconstruction error. To address this issue, we propose a novel controllable variational autoencoder, ControlVAE, that leverages automatic control to control the trade-off between reconstruction accuracy and KL-divergence as shown in Fig. 1
. Specifically, a non-linear PI controller is designed that stabilizes the value of KL-divergence via dynamically tuning the weight of the KL term during training. The reason why we adopt PI control algorithm instead of machine learning methods, such as Bayesian optimization, is that our method is an on-line dynamic tuning approach that only needs one-round training, while Bayesian optimization is an off-line profiling method that requires to train model multiple rounds with different weights.
This paper is an extension of work originally presented in ICML2020 [shao2020controlvae]. This work is different from the prior paper in several respects. First, in order to determine its set point of KL-divergence that improves ELBO over regular VAE, we offer an analytic proof of this result. Then we further verify these analytical results empirically via conducting a new set of experiments on the task of image generation. In addition, we present the connection between ControlVAE and the existing VAE models as well as the constrained optimization of the KL term using Lagrange multiplier. In order to further improve its disentanglement ability, we explore a new variation of ControlVAE, called Control-FactorVAE, to improve the learning of disentangled representations. Since the prior metric, mutual information gap (MIG), is merely used to measure the overall score for disentanglement, we adopt another new robust MIG (RMIG) [do2020theory] to measure the disentanglment score of each generating factor. Finally, we conduct ablation studies to explore the effect of hyper-parameters on the performance of ControlVAE.
We apply our proposed methods to three different tasks: language modeling, disentangling, and image generation. Evaluation results on benchmark datasets demonstrate that ControlVAE is able to achieve an adjustable trade-off between reconstruction error and KL-divergence. It can significantly reduce the reconstruction error while achieving comparable disentanglement. We also show that ControlVAE is able to improve the ELBO and reconstruction quality on the task of image generation. For language modeling, it completely avoids the posterior collapse (KL vanishing) and improves the diversity of generated data.
The objective function of VAEs consists of two terms: log-likelihood and KL-divergence. The first term tries to reconstruct the input data, while KL-divergence has the desirable effect of keeping the representation of input data sufficiently diverse. In particular, KL-divergence can affect both the reconstruction quality and diversity of generated data. If the KL-divergence is too high, it would affect the accuracy of generated samples. If it is too low, output diversity is reduced, which may be a problem in some applications such as language modeling [bowman2015generating] (where it is known as the KL-vanishing problem).
To mitigate KL vanishing, one promising way is to add an extra hyperparameterin the VAE objective function to control the KL-divergence via increasing from until to
with sigmoid function or cyclic function[liu2019cyclical]. These methods, however, blindly change without sampling the actual KL-divergence during model training. Using a similar methodology, researchers recently developed a new -VAE () [higgins2017beta, burgess2018understanding] to learn the disentangled representations by controlling the value of KL-divergence. However, -VAE suffers from high reconstruction errors [kim2018disentangling], because it adds a very large in the VAE objective so that the model tends to focus disproportionately on optimizing the KL term. In addition, its hyperparameter is fixed during model training, missing the chance of balancing the reconstruction error and KL-divergence.
The core technical challenge responsible for the above application problems lies in the difficulty to tune the weight of the KL-divergence term during model training. Inspired by control systems, we fix this problem using feedback control. Our controllable variational autoencoder is illustrated in Fig. 1. It samples the output KL-divergence at each training step , and feeds it into an algorithm that tunes the hyperparameter, , accordingly, aiming to stabilize KL-divergence at a desired value, called set point. We further design a non-linear PI controller, a variant of the PID control algorithm [aastrom2006advanced], to tune the hyperparameter . The advantage of PID control algorithm over Bayesian optimization is that it is a on-line dynamic tuning method that only needs one-round training, which has very low computational complexity. Next, we introduce the background of PID algorithm in detail.
PID control is the basic and most prevalent form of feedback control in a large variety of industrial [aastrom2006advanced] and software performance control [hellerstein2004feedback] applications. The general model of PID controller is defined by
where is the output of the controller; is the error between the actual value and the desired value at time ; and denote the coefficients for the P term, I term and D term, respectively.
The basic idea of the PID algorithm is to calculate an error, , between a set point (in this case, the desired KL-divergence) and the current value of the controlled variable (in this case, the actual KL-divergence), then apply a correction in a direction that reduces that error. The correction is applied to some intermediate directly accessible variable (in our case, ) that influences the value of the variable we ultimately want to control (KL-divergence). In general, the correction computed by the controller is the weighted sum of three terms; one changes with error (P), one changes with the integral of error (I), and one changes with the derivative of error (D). In a nonlinear controller, the changes can be described by nonlinear functions. Note that, since derivatives essentially compute the slope of a signal, when the signal is noisy, the slope often responds more to variations induced by noise. Hence, following established best practices in control of noisy systems, we do not use the derivative (D) term in our specific controller. Next, we introduce VAEs and our objective in more detail.
2.1 The Variational Autoencoder (VAE)
The VAE [kingma2013auto, rezende2014stochastic] has been one of the most popular types of generative models. It assumes a latent variable with a prior , and a conditional distribution , to model the observed variable . The generative model, denoted by , can be expressed as . However, direct computation of this integral is often intractable, hence variational inference is used to derive a lower bound on . This leads to the evidence lower bound (ELBO):
where is a probabilistic decoder
parameterized by a neural network to generate datagiven the latent variable , and the posterior distribution of latent variable given data is approximated by the variational posterior, , which is parameterized by an encoder network. The VAE is trained by maximizing , which consists of a reconstruction term and a KL term, over the training data.
However, plain VAEs cannot balance reconstruction error and KL-divergence, making them unsuitable to be applied to some specific tasks. For example, VAEs often suffer from KL vanishing in language modeling [bowman2015generating, liu2019cyclical], meaning that the KL-divergence becomes nearly zero during optimization.
The -VAE [higgins2017beta, chen2018isolating] is an extension to the basic VAE framework, often used as an unsupervised method for learning a disentangled representation of the data generative factors. A disentangled representation, according to the literature [bengio2013representation], is defined as one where single latent units are sensitive to changes in single generative factors, while being relatively invariant to changes in other factors. Compared to the original VAE, -VAE adds an extra hyperparameter as a weight of KL-divergence in the original VAE objective (2). It can be expressed by
In order to discover more disentangled factors, researchers further put a constraint on total information capacity, , to control the capacity of the information bottleneck (KL-divergence) [burgess2018understanding]. Then Lagrangian method is adopted to solve the following optimization problem.
where is a large hyperparameter (e.g., 100).
However, one drawback of -VAE is that it obtains good disentangling at the cost of reconstruction quality. When the weight is large, the optimization algorithm tends to optimize the second term in (4), leading to a high reconstruction error.
FactorVAE augments the VAE objective with an extra term that directly encourages independence in the latent code distribution. The objective of FactorVAE is given by
where is the distribution of representation for the entire input data, and is defined to be the product of the marginal distribution of each latent variable . The third term is called Total Correlation (TC) [watanabe1960information]
, which measures the independence among different random variables. Note thatif and only if each is independent under . In addition, is a hyperparameter that represents the weight on the penalty of total correlation term. Since FactorVAE directly penalizes the total correlation term without constraining the mutual information between and , it has better reconstruction quality than -VAE. However, as the weight is fixed, we still cannot explicitly control the value of TC term to ensure it is small enough after training.
The above background suggests that a common challenge in applying VAEs (and their extensions) lies in appropriate weight allocation among the reconstruction accuracy and KL-divergence in the VAEs objective function. As mentioned earlier, we solve this using a nonlinear PI controller that manipulates the value of the non-negative hyperparameter, . This algorithm is described next.
3 The ControlVAE Algorithm
During model training, we sample the output KL-divergence, which we denote by , at training step . The sampled KL-divergence is then compared to the set point, , and the difference, , then used as the feedback to a controller to calculate the hyperparameter . ControlVAE can be expressed by the following variational lower bound:
When KL-divergence drops below the set point, the controller counteracts this change by reducing the hyperparameter (to reduce penalty for KL-divergence in the objective function (6)). The reduced weight, , allows KL-divergence to grow, thus approaching the set point again. Conversely, when KL-divergence grows above the set point, the controller increases (up to a certain value), thereby increasing the penalty for KL-divergence and forcing it to decrease. This effect is achieved by computing using Equation (7), below, which is an instance of nonlinear PI control:
where and are the constants. The first term (on the right hand side) ranges between and thanks to the exponential function . Note that when error is large and positive (KL-diverge is below set point), the first term approaches 0, leading to a lower that encourages KL-divergence to grow. Conversely, when error is large and negative (KL-divergence above set point), the first term approaches its maximum (which is ), leading to a higher that encourages KL-divergence to shrink.
The second term of the controller sums (integrates) past errors with a sampling period (one training step in this paper). This creates a progressively stronger correction (until the sign of the error changes). The negative sign ensures that while errors remain positive (i.e., when KL-divergence is below set point), this term continues to decrease, whereas while errors remain negative (i.e., when KL-divergence is above set point), this term continues to increase. In both cases, the change forces in a direction that helps KL-divergence approach the set point. In particular, note that when the error becomes zero, the second term (and thus the entire right hand side) stops changing, allowing controller output, , to stay at the same value that hopefully caused the zero error in the first place. This allows the controller to “lock in” the value of that meets the KL-divergence set point. Finally, is an application-specific constant. It effectively shifts the range within which is allowed to vary. This PI controller is illustrated in Fig. 2.
3.1 The Controllable FactorVAE Algorithm
For disentangled representation learning, the total correlation (TC) term of FactorVAE may collapse to 0, leading to bad disentanglement. To deal with this issue, we propose a novel controllable FactorVAE, Control-FactorVAE, to stabilize the TC to a small value based on the actual value of TC during model training. Similar to ControlVAE, the objective function of Control-FactorVAE is expressed as
where is the output of the above designed PI controller using the output TC as feedback.
3.2 PI Parameter Tuning for ControlVAE
One challenge of applying the PI control algorithm lies how to tune its parameters, and effectively. While optimal tuning of nonlinear controllers is non-trivial, in this paper we follow a very simple rule: tune these constants to ensure that reactions to errors are sufficiently smooth to allow gradual convergence. Let us first consider the coefficient . Observe that the maximum (positive) error occurs when actual KL-divergence is close to zero. In this case, if is the set point on KL-divergence, then the error, , is approximated by . When KL-divergence is too small, the VAE does not learn useful information from input data [liu2019cyclical]. We need to assign a very small non-negative value, so that KL-divergence is encouraged to grow (when the resulting objective function is optimized). In other words, temporarily ignoring other terms in Equation (7), the contribution of the first term alone should be sufficiently small:
where is a small constant (e.g., in our implementation). The above (9) can also be rewritten as . Empirically, we find that leads to good performance and satisfies the above constraint.
Conversely, when the actual KL-divergence is much larger than the desired value , the error becomes a large negative value. As a result, the first term in (7) becomes close to a constant, . If the resulting larger value of is not enough to cause KL-divergence to shrink, one needs to gradually continue to increase . This is the job of second term. The negative sign in front of that term ensures that when negative errors continue to accumulate, the positive output continues to increase. Since it takes lots of steps to train deep VAE models, the increase per step should be very small, favoring smaller values of . Empirically we found that a value between and stabilizes the training. Note that, should not be too small either, because it would then unnecessarily slow down the convergence.
3.3 Set Point Guidelines for ControlVAE
The choice of desired value of KL-divergence (set point) is largely application specific. In general, when , the upper bound of expected KL-divergence is the value of KL-divergence as ControlVAE converges when , denoted by . Similarly, its lower bound, , can be defined as the KL-divergence produced by ControlVAE when . For feedback control to be most effective (i.e., not run against the above limits), the KL-divergence set point should vary in the range of . Since ControlVAE is an end-to-end learning model, users can customize the desired value of KL-divergence (using KL-divergence of the original VAE as a reference) to meet their demand with respect to different applications. For instance, if some users prefer to improve the diversity of text generation and image generation, they can slightly increase the KL-divergence produced by the original VAE. Otherwise they can reduce the KL-divergence if they want to improve the generation accuracy.
In this paper, we provide a simplified theoretical analysis about how to choose the set point if our goal is to improve the ELBO or the reconstruction accuracy over the basic VAE. Let () denote the optimal range of KL-divergence between ControlVAE and the basic VAE, and denote the value of KL-divergence as the basic VAE converges. Then the KL-divergence of ControlVAE can be denoted by
Therefore, the ELBO of the ControlVAE can be written as
In order to achieve a higher ELBO for ControlVAE, we want
where and denote the latent variable of ControlVAE and the original VAE, respectively. Since the KL term mainly affects the encoder parameters, to simplify analysis, we may assume that .
In order to obtain the range , our next step is to bound the difference of reconstruction accuracy between ControlVAE and the basic VAE, . Assuming the decoder to be Lipschitz continuous [virmaux2018lipschitz], we have the following theorem.
Given the KL-divergence, , of the original VAE as it converges, we have
Proof: please see it in Appendix A.
3.4 Summary of the PI Control Algorithm
We summarize the proposed PI control algorithm in Algorithm 1. Our PI algorithm updates the hyperparameter, , with the feedback from sampled KL-divergence at training step . Line computes the error between the desired KL-divergence, , and the sampled . Line to calculate the P term and I term for the PI algorithm, respectively. Note that, Line 10 and 11 is a popular constraint in PID/PI design, called anti-windup [azar2015design, peng1996anti]. It effectively disables the integral term of the controller when controller output gets out of range, not to exacerbate the out-of-range deviation. Line is the calculated hyperparameter by PI algorithm in (7). Finally, Line to aim to limit to a certain range, .
3.5 Connection to Other Models
We are going to illustrate the connection between ControlVAE and some other existing VAE models.
3.5.1 Connection to Lagrange Multiplier
The goal of ControlVAE is to dynamically tune the weight on KL term to stabilize the KL-divergence to a desired value, . Inspired by literature [rezende2018taming], we can also formulate it as a constrained optimization problem
The Lagrangian of this optimization problem is
where is the Lagrange multiplier.
The above problem could be solved by classic gradient descent-ascent. In particular, the gradient with respect to the Lagrange multiplier at the -th iteration is given by
where is the error between the desired KL-divergence and the actual one, as defined in Section 3. At training step , we can update as follows:
where is the learning rate.
After training the above model with iterations, the hyperparameter, , can be expressed by
When is initialized with 0, the above Eq. (20) becomes
It can be observed from the above formula that is the same as I term of the designed PI algorithm in Eq.(7) when . Thus, the optimization problem using Lagrange multiplier can be seen as a special case of ControlVAE. We also conduct experiments to compare the performance of ControlVAE with Lagrange multiplier method in Section 4.
3.5.2 Connection to VAE and -Vae
For the basic VAE, we have on the KL term in the VAE objective. After model training, the KL-divergence of the basic VAE converges to a value, . When we set the target KL-divergence, , to , ControlVAE becomes the basic VAE as the weight converges to at the end of model training. For the -VAE, it assigns a large and fixed weight to the KL term. As long as we fix the output of PI controller as a large value, ControlVAE then becomes -VAE.
3.6 Applications of ControlVAE
As a preliminary demonstration of the general applicability of the above approach and as an illustration of its customizability, we apply ControlVAE to three different applications stated below.
Disentangling: We then apply the ControlVAE model to achieve a better trade-off between reconstruction quality and disentangling. As mentioned in Section 2.2, -VAE () assigns a large hyperparameter to the objective function to control the KL-divergence (information bottleneck), which, however, leads to a large reconstruction error. To mitigate this issue, we adopt ControlVAE to automatically adjust the hyperparameter based on the output KL-divergence during model training. Using the similar methodology in [burgess2018understanding], we train a single model by gradually increasing KL-divergence from to a desired value with a step function for every training steps. Since , we set to for the PI algorithm in (7). Following the PI tuning method above, the coefficients and are set to and , respectively.
Image generation: In this paper, we try to leverage ControlVAE to manipulate (slightly increase) the value of KL-divergence to improve the ELBO and reconstruction quality over the basic VAE for image generation. Different from the original VAE (), we extend the range of the hyperparameter, , from to in our controlVAE model. Given a desired KL-divergence, controlVAE can automatically tune within that range. For this task, we use the same PI control algorithm and hyperparameters as the above language modeling.
Language modeling: We first apply ControlVAE to solve the KL vanishing problem meanwhile improve the diversity of generated data. As mentioned in Section 2.1, the VAE models often suffer from KL vanishing in language modeling. The existing methods cannot completely solve the KL vanishing problem or explicitly manipulate the value of KL-divergence. In this paper, we adopt ControlVAE to control KL-divergence to a specified value to avoid KL vanishing using the output KL-divergence. Following PI tuning strategy in Section 3.2, we set , of the PI algorithm in (7) to and , respectively. In addition, is set to and the maximum value of is limited to .
random seeds. ControlVAE (KL=16, 18) and Control-FactorVAE have lower reconstruction errors and variance compared to the other methods. (c) shows an example about the disentangled factors in the latent variable as the total KL-divergence increases fromto for ControlVAE (KL=18). Each curve with positive KL-divergence (except black one) represents one disentangled factor by ControlVAE.
We evaluate the performance of ControlVAE on benchmark datasets in the three different applications mentioned above.
The datasets used for our experiments are introduced below.
Disentangling: 1) 2D Shapes [matthey2017dsprites]: it has binary images of 2D shapes with five ground truth factors (number of values): shape(3), scale(6), orientation(40), x-position(32), y-position(32) [kim2018disentangling].
Image generation: 1) CelebA(cropped version) [liu2015deep]: It has RGB images of celebrity faces. The data is split into and images for training and testing.
Language modeling: 1) Penn Tree Bank (PTB) [marcus1993building]: it consists of training sentences, validation sentences and testing sentences. 2) Switchboard(SW) [godfrey1997switchboard]: it has two-sided telephone conversations with manually transcribed speech and alignment. The data is randomly split into , and dialog for training, validation and testing.
4.2 Model Configurations
The detailed model configurations and hyperparameter settings for each model is presented in Appendix B.
4.3 Evaluation on Disentangled Representations
First of all, we evaluate the performance of ControlVAE and Control-FactorVAE on the learning of disentangled representations using 2D Shapes data. We compare it with two baselines: FactorVAE [kim2018disentangling] and -VAE [burgess2018understanding].
Fig. 3 (a) and (b) shows the comparison of reconstruction error and the hyperparameter (using random seeds) for different models. We can observe from Fig. 3 (a) that ControlVAE (KL=16,18) has lower reconstruction error and variance than the other methods. This is because our ControlVAE automatically adjusts the hyperparameter, , to stabilize the KL-divergence, while -VAE and FactorVAE keep the hyperparameter unchanged during model training. In addition, the newly proposed Control-FactorVAE has slightly lower reconstruction error than the two baselines. Specifically, for ControlVAE (KL=18), the hyperparameter is large in the beginning in order to obtain good disentangling, and then it gradually drops to around to improve reconstruction quality as the training converges, as shown in Fig. 3(b). In contrast, -VAE () and FactorVAE have a large and fixed weight on the KL-divergence in the objective so that its optimization algorithm tends to optimize the KL-divergence term, leading to a large reconstruction error. What is more, Fig. 3(c) illustrates an example of KL-divergence per factor in the latent code as training progresses and the total information capacity (KL-divergence) increases from until to . We can see that ControlVAE disentangles all the five generative factors, starting from positional latents ( and ) to scale, followed by orientation and then shape.
Next, we use two disentanglement metrics, mutual information gap (MIG) [chen2018isolating] and robust MIG (RMIG) [do2020theory], to evaluate the disentanglement of different models. Table I illustrates the comparison of MIG score for different methods. It can be observed that ControlVAE (KL=16) has a comparable MIG but lower variance than FactorVAE and Control-FactorVAE. Here it is worth noting that FactorVAE and Control-FactorVAE add a Total Correlation (TC) term in the objective while ControlVAE does not. Besides, Control-FactorVAE (TC=0.3) and FactorVAE have comparable MIG scores, because they have the approximately equal weights after the models converge during training. We use MIG to measure average disentanglement score of four disentangled factors (scale, rotation, position and position), and we adopt RMIG, which is a robust version of MIG, to measure the score of each disentangled factor, as shown in Table II. We can observe that Control-FactorVAE has the highest average RMIG score among them, because it can better disentangle the shape and scale factors. In addition, Control-FactorVAE, ControlVAE and FactorVAE have comparable RMIG score on average. For the performance of each single factor, ControlVAE (KL=16) has a better disentanglement than the other methods in terms of positional and while FactorVAE disentangles the orientation well.
|Metric||ControlVAE (KL=16)||ControlVAE (KL=18)||Control-FactorVAE||-VAE ()||FactorVAE ()|
|MIG||0.5628 0.0222||0.5432 0.0281||0.5620 0.0348||0.5138 0.0371||0.5625 0.0443|
Since there does not exist an exactly accurate metric to measure disentanglement, we also show the qualitative results of different models in Fig. 4. We can observe that ControlVAE and Control-FactorVAE can discover all the five generative factors: positional latent ( and ), scale, orientation and shape. However, -VAE () disentangles four generative factors except for entangling the scale and shape together (in the third row), while FactorVAE () does not disentangle position very well in Fig. 4. Based on the above experimental results, we can conclude that ControlVAE and Control-FactorVAE achieve a better reconstruction quality than the baselines for the comparable disentanglement.
4.4 Evaluation on ELBO and Reconstruction
Performance comparison for different methods on the CIFAR-10 averaged overrandom seeds. Fig.(a)(b) shows that ControlVAE has a higher ELBO and lower reconstruction loss than the other methods given the desired KL-divergence . Fig.(c) illustrates that ControlVAE is able to stabilize the KL-divergence to the target value, 145, while Lagrange multiplier (LM) method has a bias so that it cannot stabilize the KL-divergence.
We also demonstrate that ControlVAE can change the optimization trajectory to improve the ELBO and the reconstruction quality over the basic VAE for image generation task. Here we follow the set point guidelines in Section 3.3 to choose the set point of KL-divergence, KL=145, to conduct experiments. We compare our method with the following baselines.
Lagrange multiplier (LM): it formulates the KL-divergence in the VAE objective as a constrained optimization problem using Lagrangian multiplier as mentioned in Section 3.5.
-VAE [burgess2018understanding]: it assigns a large weight on the KL term to force the value of KL-divergence to be stabilized to a specified value.
VAE: It is a basic VAE model which consists of reconstruction term and KL-divergence term without any weight or constraint.
Fig. 5 (a) and (b) show the comparison of reconstruction error and ELBO under different set points of KL-divergence on CIFAR-10 dataset using different random seeds during model training. We can observe from it that ControlVAE-KL (KL=145) has the highest ELBO and lowest reconstruction error among them. This is because ControlVAE can achieve a good trade-off between reconstruction quality and KL-divergence to improve the optimization trajectory. The ELBO of Lagrange multiplier (LM) is slightly lower than the ControlVAE, because it suffers from local minima caused by the non-linear term. Moreover, ControlVAE outperforms -VAE (=30) in [burgess2018understanding], because it uses dynamic learning to stabilize the KL-divergence while the latter assigns a large and fixed hyper-parameter to the KL term. We also compare the stability performance of different method as illustrated in Fig. 5 (c). It can be observed that ControlVAE and -VAE can stabilize the KL-divergence to a target value, while LM has a bias to its target value. In other words, LM is unable to precisely control the KL divergence to a specified value.
We further use FID [lucic2018gans] and ELBO to evaluate the performance of ControlVAE using testing datasets, CIFAR-10 and CelebA, as illustrated in Table III and IV. It can be observed from them that ControlVAE outperforms the other methods in terms of FID and ELBO in the testing. Therefore, ControlVAE can improve the optimization trajectory and the reconstruction quality for image generation via choosing the desired value of KL-divergence from our derived set points.
|VAE (KL=118)||-372.10 0.64||135.25 0.31|
|ControlVAE (KL=145)||-365.18 0.37||122.36 0.33|
|LM (KL=145)||-366.87 0.90||124.76 0.58|
|-VAE (KL=145)||-388.82 0.76||134.05 0.54|
|VAE (KL=127)||-472.21 0.52||89.33 0.51|
|ControlVAE (KL=155)||-468.93 0.71||86.91 0.49|
|LM (KL=155)||-469.09 0.85||87.01 0.37|
|-VAE (KL=155)||-494.06 0.80||90.27 0.40|
4.5 Evaluation on Language Modeling
Finally, we compare the performance of ControlVAE with the following baselines for mitigating KL vanishing in text generation [bowman2015generating].
Cost annealing [bowman2015generating]: This method gradually increases the hyperparameter on KL-divergence from until to after training steps using sigmoid function.
Cyclical annealing [liu2019cyclical]: This method splits the training process into cycles and each increases the hyperparameter from until to using a linear function.
|ControlVAE-KL-35||6.27K 41||95.86K 1.02K||0.663 0.012||0.447 0.013||8.81 0.05|
|ControlVAE-KL-25||6.10K 60||83.15K 4.00K||0.698 0.006||0.495 0.014||12.47 0.07|
|Cost anneal-KL-17||5.71K 87||69.60K 1.53K||0.721 0.010||0.536 0.008||16.82 0.11|
|Cyclical (KL = 21.5)||5.79K 81||71.63K 2.04K||0.710 0.007||0.524 0.008||17.81 0.33|
Fig. 6 illustrates the comparison results of KL divergence, reconstruction loss and hyperparamter, , for different methods on the PTB dataset. Note that, here ControlVAE-KL- means we set the KL-divergence to a desired value (e.g., 3) for our PI controller following the set point guidelines in Section 3.3. Cost-annealing- means we gradually increase the hyperparameter, , from until to after steps using sigmoid function. We observe from Fig. 6(a) that ControlVAE (KL=1.5, 3) and Cyclical annealing ( cycles) can avert the KL vanishing. However, our ControlVAE is able to stabilize the KL-divergence while cyclical annealing could not. Moreover, our method has a lower reconstruction loss than the cyclical annealing in Fig. 6 (b). Cost annealing method still suffers from KL vanishing, because we use the Transformer [vaswani2017attention] as the decoder, which can predict the current data based on previous ground-truth data. In addition, we can observe from Fig. 6 (c) that ControlVAE improves the ELBO over the baselines, which means it can change the optimization trajectory. Fig. 6 (d) illustrates the tuning result of by ControlVAE compared with other methods. We can discover that our gradually converges to around a certain value. Note that, here of ControlVAE does not converge to because we slightly increase the value of KL-divergence (produced by the original VAE) in order to improve the diversity of generated data.
In order to further demonstrate ControlVAE can improve the diversity of generated text, we apply it to dialog-response generation using the Switchboard(SW) dataset. Following [zhao2017learning], we adopt a conditional VAE [zhao2017learning] that generates dialog conditioned on the previous response. We use metric - [xu2018dp] and self-BLEU [zhu2018texygen] (with 1000 sampled results) to measure the diversity of generated data, and perplexity (PPL) [jelinek1977perplexity]
to measure how well the probability distribution predicts a sample. TableV illustrates the comparison results for different approaches. We can observe that ControlVAE has more distinct grams and lower self-BLEU than the baselines when the desired KL-divergence is set to and . In addition, it has lower PPL than the other methods. Thus, we can conclude that ControlVAE can improve the diversity of generated data and generation performance. We also illustrate some examples of generated dialog by ControlVAE in Appendix C.
5 Ablation Studies
In this section, we conduct ablation studies to study the impact of hyper-parameters on the performance of the proposed ControlVAE.
5.1 Effect of Set Points of KL-divergence on Disentanglement
We first study the influence of different target values, , of KL-divergence on the disentanglement representation learning. We change the target KL-divergence from 16 to 19 with step 1 while keeping the other parameters unchanged. Table VI illustrates the RMIG score of five disentangled factors and the overall score. We can observe from it that ControlVAE has the highest RMIG score on average when , since it disentangles three factors: position , and shape, better than the other KL values. In addition, when the target KL-divergence is too large (e.g., ), it may hurt the performance of disentanglement. This is because when the KL-divergence is increased, multiple latent factors may transmit through the information channels together.
5.2 Effect of Step Value on Disentanglement
For disentanglement representation learning, ControlVAE uses an annealing method to increase the target KL-divergence with step every iterations. Hence, we try to learn how the step value, , impacts the disentanglement. Table VII shows the comparison results of RMIG under different step values. It can be seen that when the step value , ControlVAE has the best disentanglement performance. Besides, it can better disentangle the position factor and orientation when and , respectively. We also find that the RMIG score of ControlVAE may decrease if the step value is too large, due to the fact that multiple factors would be entangled together with a large KL-divergence (information bottleneck).
5.3 Effect of Batch Size on ELBO
Next, we study how the batch size of model training influences the ELBO on the image generation task. In our experiment, we change batch size from to to evaluate the performance of ControlVAE on the CIFAR10 data set. Table VIII illustrates the ELBO of ControlVAE under different batch sizes. It can be observed from this table that the ELBO of ControlVAE is higher when trained with a large bath size than that with small batch size, . The main reason is that the output KL divergence is not very stable when the bath size is small during model training. Hence, we need to use a large batch size, such as 100, to train our model.
|Metric||batch =50||batch =100||batch =150|
|ELBO||-375.99 1.82||-368.18 1.00||-370.23 0.82|
5.4 Effect of Embedding Size on ELBO
We also study the influence of embedding size of latent space, , on the ELBO for image generation task. In this paper, the embedding size, 100, 200, and 500 are used to conduct experiments on the CIFAR10 data set. In Table IX, we can see that when the embedding size of latent variable is set to , ControlVAE has the highest ELBO.
|Metric||= 100||= 200||= 500|
|ELBO||-365.18 0.37||-361.80 0.86||-364.17 1.25|
6 Related Work
There are many work involving a trade-off between reconstruction and KL-divergence for VAEs applications. For disentangled representation learning, researchers proposed -VAE () [higgins2017beta, burgess2018understanding] that assigns a large and fixed hyperparameter, , to put more emphasis on the KL divergence to encourage disentangled latent representations. It, however, sacrifice the reconstruction quality in order to obtain better disentangling. Then some follow-up work [chen2018isolating, kim2018disentangling] further factorize the KL-divergence term to improve the reconstruction quality. However, these methods still assign a fixed and large hyperparameter to the decomposed terms in the objective, resulting in high reconstruction error. In contrast, ControlVAE dynamically tunes during optimization to achieve better disentangling and reconstruction quality.
In order to improve the sample generation quality of VAEs [dai2019diagnosing, xiao2019generative, ghosh2019variational, alemi2017fixing, zhao2019infovae], some researchers tried to reduce the weight of KL-divergence to make the decoder produce sharper outputs. Though they can obtain impressive sample quality, they suffer severely from the trade-off in the way that the latent distribution is far away from the prior. Recent studies adopted a constrained optimization for reconstruction error [rezende2018taming, klushyn2019learning] to achieve the trade-off between reconstruction error and KL-divergence. They may suffer from posterior collapse if the inference network fails to cover the latent space while our can totally avert posterior collapse. Moreover, different from their work, we try to optimize KL-divergence (information bottleneck) as a constraint. Our method and theirs complement each other for different applications.
In language modeling, VAE often suffers from KL vanishing, due to a powerful decoder, such as Transformer [vaswani2017attention] and LSTM. To remedy this issue, one popular way is to add a hyperparameter on the KL term [bowman2015generating, liu2019cyclical], and then gradually increases it from until . However, the existing methods [yang2017improved, bowman2015generating, liu2019cyclical], such as KL cost annealing and cyclical annealing, cannot totally solve KL vanishing or explicitly control the value of KL-divergence since they blindly change without observing the actual KL-divergence during model training. Conversely, our approach can avert KL vanishing and stabilize the KL-divergence to a desired value.
In this paper, we proposed a general controllable VAE framework, ControlVAE, that combines automatic control with the basic VAE framework to improve the performance of the VAE models. A novel non-linear PI controller was designed to control the value of KL divergence during model training. We also developed a new variant of ControlVAE, Control-FactorVAE, to improve the disentanglement. In addition, simplified theoretical analysis was provided to help choose the set points of KL-divergence for our method. Then we presented the connections between ControlVAE and other models. The evaluation results showed that ControlVAE is able to improve the ELBO over the basic VAE by changing its optimization trajectory on the task of image generation. For disentangled representation learning, it significantly improves the reconstruction quality while achieving a comparable disentanglement with the best baselines. We also demonstrated that ControlVAE can totally avert the KL-vanishing (posterior collapse) problem and control the diversity of generated data for language modeling.
Research reported in this paper was sponsored in part by DARPA award W911NF-17-C-0099, DTRA award HDTRA1-18-1-0026, and the Army Research Laboratory under Cooperative Agreements W911NF-09-2-0053 and W911NF-17-2-0196.
Appendix A Proof of Theorem 1
We are going to prove the Theorem 1 below.
Assuming the decoder to be -Lipschitz continuous, we have
where is the Lipschitz constant, and we can assume for simplicity, since this is achievable if spectral normalization [miyato2018spectral] is used.
Our next step is to compute the bound of . Both and
follow Gaussian distributions with diagonal covariance,and , respectively. Namely, they can be denoted by and , where . The prior distribution is always unit Gaussian: . According to [kingma2013auto], the KL-divergence term of the basic VAE can be expressed as
It is easy to prove that :
Let be , then we need to prove . Let . When , can get the minimal value, . Thus, we can obtain . ∎
Therefore, Eq. (23) can be expressed as
Similarly, for the KL-divergence of ControlVAE, we have
Appendix B Model Configurations and hyperparameter settings
We summarize the detailed model configurations and hyperparameter settings for ControlVAE in the following three applications: language modeling, disentanglement representation learning and image generation.
b.1 Experimental Details for Disentangling
Following the same model architecture of -VAE [higgins2017beta], we adopt a convolutional layer and deconvolutional layer for our experiments. We use Adam optimizer with , and a learning rate tuned from . We set and for PI algorithm to and , respectively. For the step function, we set the step, , to per training steps as the information capacity (desired KL- divergence) increases from until for 2D Shape data. ControlVAE uses the same encoder and decoder architecture as -VAE except for plugging in PI control algorithm, illustrated in Table X.
|Input binary image||Input|
|conv. ReLU. stride 2||FC. 256 ReLU.|
|conv. ReLU. stride 2||upconv. ReLU.|
|conv. ReLU. stride 2||upconv. ReLU. stride 2.|
|conv. ReLU. stride 2||upconv. ReLU. stride 2|
|conv. ReLU.||upconv. ReLU. stride 2|
|FC . FC.||upconv. ReLU. stride 2|
b.2 Experimental Details for Image Generation
Similar to the architecture of
-VAE, we use a convolutional layer with batch normalization as the encoder and a deconvolutional layer with batch normalization for our experiments. We use Adam optimizer with, and a learning rate for CelebA data. The size of latent variable is set to , because we find it has a better reconstruction quality than and . In addition, we set the desired value of KL-divergence to (same as the original VAE), , and . For PI control algorithm, we set and to and , respectively. We also use the same encoder and decoder architecture as -VAE above except that we add the batch normalization to improve the stability of model training, as shown in Table XI.
|Input RGB image||Input|
|conv. ReLU. stride 2||FC. 256 ReLU.|
|conv. ReLU. stride 2||upconv. ReLU. stride 2|
|conv. ReLU. stride 2||upconv. ReLU. stride 2.|
|conv. ReLU. stride 2||upconv. ReLU. stride 2|
|conv. ReLU. stride 2||upconv. ReLU. stride 2|
|FC . FC.||upconv. ReLU. stride 2|
b.3 Experimental Details for Language Modeling
For text generation on PTB data, we build the ControlVAE model on the basic VAE model, as in [bowman2015generating]
. We use one-layer LSTM as the encoder and a three-layer Transformer with eight heads as the decoder and a Multi-Layer Perceptron (MLP) to learn the latent variable. The maximum sequence length for LSTM and Transformer is set to , respectively. And the size of latent variable is set to . Then we set the dimension of word embedding to and the batch size to . In addition, the dropout is for LSTM and Transformer. Adam optimization with the learning rate is used during training. Following the tuning guidelines above, we set the coefficients and of P term and I term to and , respectively. Finally, We adopt the source code on Texar platform to implement experiments [hu2019texar].
For dialog-response generation, we follow the model architecture and hyperparameters of the basic conditional VAE in [zhao2017learning]. We use one-layer Bi-directional GRU as the encoder and one-layer GRU as the decoder and two fully-connected layers to learn the latent variable. In the experiment, the size of both latent variable and word embeddings is set to . The maximum length of input/output sequence for GRU is set to with batch size . In addition, Adam with initial learning rate is used. In addition, we set the same and of PI algorithm as text generation above. The model architectures of ControlVAE for these two NLP tasks are illustrated in Table XII, XIII.
|Input words||Input ,|
|FC||3-layer Transformer 8 heads|
Appendix C Examples of Generated Dialog by ControlVAE
In this section, we show an example to compare the diversity and relevance of generated dialog by different methods, as illustrated in Table XIV. Alice begins with the open-ended conversation on choosing a college. Our model tries to predict the response from Bob. The ground truth response is “um - hum”. We can observe from Table XIV that ControlVAE (KL=25, 35) can generate diverse and relevant response compared with the ground truth. In addition, while cyclical annealing can generate diverse text, some of them are not very relevant to the ground-truth response.
|Context: (Alice) and a lot of the students in that home town sometimes unk the idea of staying and going to school across the street so to speak|
|Topic: Choosing a college Target: (Bob) um - hum|
|yeah||uh - huh|
|um - hum||yeah|
|oh that’s right um - hum||oh yeah oh absolutely|
|right||um - hum|
|Cost annealing (KL=17)||Cyclical anneal (KL=21.5)|
|oh yeah||yeah that’s true do you do you do it|
|uh - huh||yeah|
|right||um - hum|
|uh - huh and i think we have to be together||yeah that’s a good idea|
|oh well that’s neat yeah well||yeah i see it too,it’s a neat place|