disentangled_vae
Replicating "Understanding disentangling in β-VAE"
view repo
We present new intuitions and theoretical assessments of the emergence of disentangled representation in variational autoencoders. Taking a rate-distortion theory perspective, we show the circumstances under which representations aligned with the underlying generative factors of variation of data emerge when optimising the modified ELBO bound in β-VAE, as training progresses. From these insights, we propose a modification to the training regime of β-VAE, that progressively increases the information capacity of the latent code during training. This modification facilitates the robust learning of disentangled representations in β-VAE, without the previous trade-off in reconstruction accuracy.
READ FULL TEXT VIEW PDFReplicating "Understanding disentangling in β-VAE"
Representation learning lies at the core of machine learning research. From the hand-crafted feature engineering prevalent in the past
(Domingos_2012, )to implicit representation learning of the modern deep learning approaches
(Krizhevsky_etal_2012, ; He_etal_2016, ; Szegedy_etal_2015, ), it is a common theme that the performance of algorithms is critically dependent on the nature of their input representations. Despite the recent successes of the deep learning approaches (He_etal_2016, ; Szegedy_etal_2015, ; Gregor_etal_2015, ; Oord_etal_2016, ; Oord_etal_2016b, ; Mnih_etal_2015, ; Mnih_etal_2016, ; Jaderberg_etal_2017, ; Silver_etal_2016, ), they are still far from the generality and robustness of biological intelligence (Lake_etal_2016, ). Hence, the implicit representations learnt by these approaches through supervised or reward-based signals appear to overfit to the training task and lack the properties necessary for knowledge transfer and generalisation outside of the training data distribution.Different ways to overcome these shortcomings have been proposed in the past, such as auxiliary tasks (Jaderberg_etal_2017, ) and data augmentation (Tobin_etal_2017, )
. Another less explored but potentially more promising approach might be to use task-agnostic unsupervised learning to learn features that capture properties necessary for good performance on a variety of tasks
(Bengio_etal_2013, ; LuCun_YouTube_2016, ). In particular, it has been argued that disentangled representations might be helpful (Bengio_etal_2013, ; Ridgeway2016-nj, ).A disentangled representation can be defined as one where single latent units are sensitive to changes in single generative factors, while being relatively invariant to changes in other factors (Bengio_etal_2013, ). For example, a model trained on a dataset of 3D objects might learn independent latent units sensitive to single independent data generative factors, such as object identity, position, scale, lighting or colour, similar to an inverse graphics model (Kulkarni_etal_2015, ). A disentangled representation is therefore factorised and often interpretable, whereby different independent latent units learn to encode different independent ground-truth generative factors of variation in the data.
Most initial attempts to learn disentangled representations required supervised knowledge of the data generative factors Hinton_etal_2011 ; Rippel_Adams_2013 ; Reed_etal_2014 ; Zhu14 ; Yang_etal_2015 ; Goroshin_etal_2015 ; Kulkarni_etal_2015 ; Cheung15 ; Whitney_etal_2016 ; Karaletsos_etal_2016 . This, however, is unrealistic in most real world scenarios. A number of purely unsupervised approaches to disentangled factor learning have been proposed Schmidhuber_1992 ; Desjardins_etal_2012 ; Tang13 ; Cohen_Welling_2014 ; Cohen_Welling_2015 ; Chen_etal_2016 ; Higgins_etal_2017 , including -VAE (Higgins_etal_2017, ), the focus of this text.
-VAE is a state of the art model for unsupervised visual disentangled representation learning. It is a modification of the Variational Autoencoder (VAE)
(Kingma_Welling_2014, ; Rezende_etal_2014, )objective, a generative approach that aims to learn the joint distribution of images
and their latent generative factors .-VAE adds an extra hyperparameter
to the VAE objective, which constricts the effective encoding capacity of the latent bottleneck and encourages the latent representation to be more factorised. The disentangled representations learnt by -VAE have been shown to be important for learning a hierarchy of abstract visual concepts conducive of imagination (Higgins_etal_2017c, )and for improving transfer performance of reinforcement learning policies, including simulation to reality transfer in robotics
(Higgins_etal_2017b, ). Given the promising results demonstrating the general usefulness of disentangled representations, it is desirable to get a better theoretical understanding of how -VAE works as it may help to scale disentangled factor learning to more complex datasets. In particular, it is currently unknown what causes the factorised representations learnt by -VAE to be axis aligned with the human intuition of the data generative factors compared to the standard VAE (Kingma_Welling_2014, ; Rezende_etal_2014, ). Furthermore, -VAE has other limitations, such as worse reconstruction fidelity compared to the standard VAE. This is caused by a trade-off introduced by the modified training objective that punishes reconstruction quality in order to encourage disentanglement within the latent representations. This paper attempts to shed light on the question of why -VAE disentangles, and to use the new insights to suggest practical improvements to the -VAE framework to overcome the reconstruction-disentanglement trade-off.We first discuss the VAE and -VAE frameworks in more detail, before introducing our insights into why reducing the capacity of the information bottleneck using the hyperparameter in the -VAE objective might be conducive to learning disentangled representations. We then propose an extension to -VAE motivated by these insights that involves relaxing the information bottleneck during training enabling it to achieve more robust disentangling and better reconstruction accuracy.
Suppose we have a dataset of samples from a distribution parametrised by ground truth generative factors . The variational autoencoder (VAE) (Kingma_Welling_2014, ; Rezende_etal_2014, ) aims to learn the marginal likelihood of the data in such a generative process:
(1) |
where , parametrise the distributions of the VAE encoder and the decoder respectively. This can be re-written as:
(2) |
where
stands for the non-negative Kullback–Leibler divergence between the true and the approximate posterior. Hence, maximising
is equivalent to maximising the lower bound to the true objective in Eq. 1:(3) |
In order to make the optimisation of the objective in Eq. 3 tractable in practice, assumptions are commonly made. The prior and posterior distributions are parametrised as Gaussians with a diagonal covariance matrix; the prior is typically set to the isotropic unit Gaussian
. Parametrising the distributions in this way allows for use of the “reparametrisation trick” to estimate gradients of the lower bound with respect to the parameters
, where each random variable
is parametrised as a differentiable transformation of a noise variable :(4) |
-VAE is a modification of the variational autoencoder (VAE) framework (Kingma_Welling_2014, ; Rezende_etal_2014, ) that introduces an adjustable hyperparameter to the original VAE objective:
(5) |
Well chosen values of (usually ) result in more disentangled latent representations . When , the -VAE becomes equivalent to the original VAE framework. It was suggested that the stronger pressure for the posterior to match the factorised unit Gaussian prior introduced by the -VAE objective puts extra constraints on the implicit capacity of the latent bottleneck and extra pressures for it to be factorised while still being sufficient to reconstruct the data (Higgins_etal_2017, ). Higher values of necessary to encourage disentangling often lead to a trade-off between the fidelity of -VAE reconstructions and the disentangled nature of its latent code (see Fig. 6 in Higgins_etal_2017 ). This due to the loss of information as it passes through the restricted capacity latent bottleneck .
The -VAE objective is closely related to the information bottleneck principle (Tishby_etal_2000, ; Chechik_etal_2005, ; Achille_Soatto_2016, ; Alemi_etal_2017, ):
(6) |
where stands for mutual information and is a Lagrange multiplier. The information bottleneck describes a constrained optimisation objective where the goal is to maximise the mutual information between the latent bottleneck and the task while discarding all the irrelevant information about that might be present in the input . In the information bottleneck literature, would typically stand for a classification task, however the formulation can be related to the auto-encoding objective too (Alemi_etal_2017, ).
We can gain insight into the pressures shaping the learning of the latent representation in -VAE by considering the posterior distribution as an information bottleneck for the reconstruction task (Alemi_etal_2017, ). The -VAE training objective (Eq. 5) encourages the latent distribution to efficiently transmit information about the data points by jointly minimising the -weighted KL term and maximising the data log likelihood.
In -VAE, the posterior is encouraged to match the unit Gaussian prior . Since the posterior and the prior are factorised (i.e. have diagonal covariance matrix) and posterior samples are obtained using the reparametrization (Eq. 4) of adding scaled independent Gaussian noise to a deterministic encoder mean for each latent unit , we can take an information theoretic perspective and think of as a set of independent additive white Gaussian noise channels , each noisily transmitting information about the data inputs . In this perspective, the KL divergence term of the -VAE objective (see Eq. 5) can be seen as an upper bound on the amount of information that can be transmitted through the latent channels per data sample (since it is taken in expectation across the data). The KL divergence is zero when , i.e is always zero, and always 1, meaning the latent channels
have zero capacity. The capacity of the latent channels can only be increased by dispersing the posterior means across the data points, or decreasing the posterior variances, which both increase the KL divergence term.
Reconstructing under this bottleneck encourages embedding the data points on a set of representational axes where nearby points on the axes are also close in data space. To see this, following the above, note that the KL can be minimised by reducing the spread of the posterior means, or broadening the posterior variances, i.e. by squeezing the posterior distributions into a shared coding space. Intuitively, we can think about this in terms of the degree of overlap between the posterior distributions across the dataset (Fig. 1
). The more they overlap, the broader the posterior distributions will be on average (relative to the coding space), and the smaller the KL divergence can be. However, a greater degree of overlap between posterior distributions will tend to result in a cost in terms of log likelihood due to their reduced average discriminability. A sample drawn from the posterior given one data point may have a higher probability under the posterior of a different data point, an increasingly frequent occurrence as overlap between the distributions is increased. For example, in Figure
1, the sample indicated by the red star might be drawn from the (green) posterior , even though it would occur more frequently under the overlapping (blue) posterior , and so (assuming and were equally probable), an optimal decoder would assign a higher log likelihood to for that sample. Nonetheless, under a constraint of maximising such overlap, the smallest cost in the log likelihood can be achieved by arranging nearby points in data space close together in the latent space. By doing so, when samples from a given posterior are more likely under another data point such as , the log likelihood cost will be smaller if is close to in data space.A representation learned under a weak bottleneck pressure (as in a standard VAE) can exhibit this locality property in an incomplete, fragmented way. To illustrate this, we trained a standard VAE (i.e. with ) and a -VAE on a simple dataset with two generative factors of variation: the and position of a Gaussian blob (Fig. 2). The standard VAE learns to represent these two factors across four latent dimensions, whereas -VAE represents them in two. We examine the nature of the learnt latent space by plotting its traversals in Fig. 2, whereby we first infer the posterior , before plotting the reconstructions resulting from modifying the value of each latent unit one at a time in the range while keeping all the other latents fixed to their inferred values. We can see that the -VAE represention exhibits the locality property described in Sec. 4.2 since small steps in each of the two learnt directions in the latent space result in small changes in the reconstructions. The VAE represention, however, exhibits fragmentation in this locality property. Across much of the latent space, small traversals produce reconstructions with small, consistent offsets in the position of the sprite, similar to -VAE. However, there are noticeable representational discontinuities, at which small latent perturbations produce reconstructions with large or inconsistent position offsets. Reconstructions near these boundaries are often of poor quality or have artefacts such as two sprites in the scene.
We have seen how a strong pressure for overlapping posteriors encourages -VAE to find a representation space preserving as much as possible the locality of points on the data manifold. However, why would it find representational axes that are aligned with the generative factors of variation in the data? Our key hypothesis is that -VAE finds latent components which make different contributions to the log-likelihood term of the cost function (Eq. 5). These latent components tend to correspond to features in the data that are intuitively qualitatively different, and therefore may align with the generative factors in the data.
For example, consider optimising the -VAE objective shown in Eq. 5 under an almost complete information bottleneck constraint (i.e. ). The optimal thing to do in this scenario is to only encode information about the data points which can yield the most significant improvement in data log-likelihood (i.e. ). For example, in the dSprites dataset (dsprites17, ) (consisting of white 2D sprites varying in position, rotation, scale and shape rendered onto a black background), the model might only encode the sprite position under such a constraint. Intuitively, when optimising a pixel-wise decoder log likelihood, information about position will result in the most gains compared to information about any of the other factors of variation in the data, since the likelihood will vanish if reconstructed position is off by just a few pixels. Continuing this intuitive picture, we can imagine that if the capacity of the information bottleneck were gradually increased, the model would continue to utilise those extra bits for an increasingly precise encoding of position, until some point of diminishing returns is reached for position information, where a larger improvement can be obtained by encoding and reconstructing another factor of variation in the dataset, such as sprite scale.
At this point we can ask what pressures could encourage this new factor of variation to be encoded into a distinct latent dimension. We hypothesise that two properties of -VAE encourage this. Firstly, embedding this new axis of variation of the data into a distinct latent dimension is a natural way to satisfy the data locality pressure described in Sec. 4.2. A smooth representation of the new factor will allow an optimal packing of the posteriors in the new latent dimension, without affecting the other latent dimensions. We note that this pressure alone would not discourage the representational axes from rotating relative to the factors. However, given the differing contributions each factor makes to the reconstruction log-likelihood, the model will try to allocate appropriately differing average capacities to the encoding axes of each factor (e.g. by optimising the posterior variances). But, the diagonal covariance of the posterior distribution restricts the model to doing this in different latent dimensions, giving us the second pressure, encouraging the latent dimensions to align with the factors.
We tested these intuitions by training a simplified model to generate dSprites conditioned on the ground-truth factors, , with a controllable information bottleneck (Fig. 3). In particular, we wanted to evaluate how much information the model would choose to retain about each factor in order to best reconstruct the corresponding images given a total capacity constraint. In this model, the factors are each independently scaled by a learnable parameter, and are subject to independently scaled additive noise (also learned), similar to the reparameterised latent distribution in -VAE. This enables us to form a KL divergence of this factor distribution with a unit Gaussian prior. We trained the model to reconstruct the images with samples from the factor distribution, but with a range of different target encoding capacities by pressuring the KL divergence to be at a controllable value, . The training objective combined maximising the log likelihood and minimising the absolute deviation from (with a hyperparameter controlling how heavily to penalise the deviation, see Sec. A.2):
(7) |
In practice, a single model was trained across of range of ’s by linearly increasing it from a low value (0.5 nats) to a high value (25.0 nats) over the course of training (see top left panel in Fig. 3). Consistent with the intuition outlined above, at very low capacities ( nats), the KLs for all the factors except the X and Y position factors are zero, with always shared equally among X and Y. As expected, the model reconstructions in this range are blurry, only capturing the position of the original input shapes (see the bottom row of the lower panel in Fig. 3). However, as is increased, the KLs of other factors start to increase from zero, at distinct points for each factor. For example, starting around nats, the KL for the scale factor begins to climb from zero, and the model reconstructions become scaled (see 7.3 nats row in lower panel of Fig. 3). This pattern continues until all factors have a non-zero KL and eventually the reconstructions begin to look almost identical to the samples.
The intuitive picture we have developed of gradually adding more latent encoding capacity, enabling progressively more factors of variation to be represented whilst retaining disentangling in previously learned factors, motivated us to extend -VAE with this algorithmic principle. We applied the capacity control objective from the ground-truth generator in the previous section (Eq. 7) to -VAE, allowing control of the encoding capacity (again, via a target KL, ) of the VAE’s latent bottleneck, to obtain the modified training objective:
(8) |
Similar to the generator model, is gradually increased from zero to a value large enough to produce good quality reconstructions (see Sec. A.2 for more details).
Results from training with controlled capacity increase on coloured dSprites can be seen in Figure 3(a), which demonstrate very robust disentangling of all the factors of variation in the dataset and high quality reconstructions.
Single traversals of each latent dimension show changes in the output samples isolated to single data generative factors (second row onwards, with the latent dimension traversed ordered by their average KL divergence with the prior, high KL to low). For example, we can see that traversal of the latent with the largest KL produces smooth changes in the Y position of the reconstructed shape without changes in other factors. The picture is similar with traversals of the subsequent latents, with changes isolated to X position, scale, shape, rotation, then a set of three colour axes (the last two latent dimensions have an effectively zero KL, and produce no effect on the outputs).
Furthermore, the quality of the traversal images are high, and by eye, the model reconstructions (second row) are quite difficult to distinguish from the corresponding data samples used to generate them (top row). This contrasts with the results previously obtained with the fixed -modulated KL objective in (Higgins_etal_2017, ).
We also trained the same model on the 3D Chairs dataset (Aubry_etal_2014, ), with latent traversals shown in Figure 3(b). We can see that reconstructions are of high quality, and traversals of the latent dimensions produce smooth changes in the output samples, with reasonable looking chairs in all cases. With this richer dataset it is unclear exactly what the disentangled axes should correspond to, however, each traversal appears to generate changes isolated in one or few qualitative features that we might identify intuitively, such as viewing angle, size, and chair leg and back styles.
We have developed new insights into why -VAE learns an axis-aligned disentangled representation of the generative factors of visual data compared to the standard VAE objective. In particular, we identified pressures which encourage -VAE to find a set of representational axes which best preserve the locality of the data points, and which are aligned with factors of variation that make distinct contributions to improving the data log likelihood. We have demonstrated that these insight produce an actionable modification to the -VAE training regime. We proposed controlling the increase of the encoding capacity of the latent posterior during training, by allowing the average KL divergence with the prior to gradually increase from zero, rather than the fixed -weighted KL term in the original -VAE objective. We show that this promotes robust learning of disentangled representation combined with better reconstruction fidelity, compared to the results achieved in the original formulation of (Higgins_etal_2017, ).
Draw: A recurrent neural network for image generation.
ICML, 37:1462–1471, 2015.Stochastic backpropagation and approximate inference in deep generative models.
ICML, 32(2):1278–1286, 2014.Multi-view perceptron: a deep model for learning face identity and view representations.
Advances in Neural Information Processing Systems 27, 2014.The neural network models used for experiments in this paper all utilised the same basic architecture. The encoder for the VAEs consisted of 4 convolutional layers, each with 32 channels, 4x4 kernels, and a stride of 2. This was followed by 2 fully connected layers, each of 256 units. The latent distribution consisted of one fully connected layer of 20 units parametrising the mean and log standard deviation of 10 Gaussian random variables (or 32 for the CelebA experiment). The decoder architecture was simply the transpose of the encoder, but with the output parametrising Bernoulli distributions over the pixels. ReLU activations were used throughout. The optimiser used was Adam
[21] with a learning rate of 5e-4.used was 1000, which was chosen to be large enough to ensure the actual KL was always close to the target KL, . For dSprites, was linearly increased from 0 to 25 nats over the course of 100,000 training iterations, for CelebA it was increased to 50 nats.