Log In Sign Up

Learning Interpretable Disentangled Representations using Adversarial VAEs

by   Mhd Hasan Sarhan, et al.

Learning Interpretable representation in medical applications is becoming essential for adopting data-driven models into clinical practice. It has been recently shown that learning a disentangled feature representation is important for a more compact and explainable representation of the data. In this paper, we introduce a novel adversarial variational autoencoder with a total correlation constraint to enforce independence on the latent representation while preserving the reconstruction fidelity. Our proposed method is validated on a publicly available dataset showing that the learned disentangled representation is not only interpretable, but also superior to the state-of-the-art methods. We report a relative improvement of 81.50 of disentanglement, 11.60 with a few amounts of labeled data.


Disentangled Representation Learning with Wasserstein Total Correlation

Unsupervised learning of disentangled representations involves uncoverin...

Interpretable Disentangled Parametrization of Measured BRDF with β-VAE

Finding a low dimensional parametric representation of measured BRDF rem...

3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces

Learning a disentangled, interpretable, and structured latent representa...

Data Overlap: A Prerequisite For Disentanglement

Learning disentangled representations with variational autoencoders (VAE...

Learning disentangled representations with the Wasserstein Autoencoder

Disentangled representation learning has undoubtedly benefited from obje...

Revisiting Factorizing Aggregated Posterior in Learning Disentangled Representations

In the problem of learning disentangled representations, one of the prom...

1 Introduction

Data-driven models with the help of Deep Learning (DL) are affecting wide areas of scientific research and the medical domain is no exception in this matter. However, in healthcare, developing a machine learning algorithm with expert level performance is important but not enough for the adoption of the algorithm when the issues of trust and explainability are not taken into consideration

[12]. Explainability of a model is approached either by 1) explicitly learning it by model design or 2) after model design such as using gradient-based localization [13].

Approaching explainability by model design could be facilitated in a supervised manner as in decision trees and rule-based systems or in an unsupervised manner as in Variational Autoencoder (VAE)

[9] or bvae [6]. In the latter, a lower dimensional representation of the data is learned and utilized for analyzing the data. The rest of the paper discusses this type of explainability. Deep learning models extract features from data in order to represent it in a compressed high-level representation that suits the application. The quality of this representation is crucial for the model performance and it is argued that disentangled representations would be helpful for having better control and interpretability over the data [1, 6]. A disentangled representation can be defined as a representation where one latent unit represents one generative factor of variation in the data while being invariant to other generative factors [1]. For example, a model trained on a dataset of faces would learn disentangled latent units that represent independent ground truth generative factors such as hair color, pose, lighting or skin color. Disentangling as many explanatory factors as possible is important for a more compact, explainable, transferable, abstract representation of the data [1].

(a) h
(b) h
Figure 1: Comparison of our model to VAE on examples for traversal over the representation components. Traversal is done between [-3, 3] (a) Examples of traversal for three images form ISIC 2018. Each row shows reconstructions of latent traversals across one latent dimension; (b) Example of a smooth transition over the manifold by changing multiple latent dimensions to go from small lesion on pale skin (top left image) to bigger horizontal lesion on red skin (bottom right image). Each column represents one dimension of change. The colored squares represent the image of the previous column from which the traversal has started on the current dimension.

Most of the previous work regarding disentanglement relied on information about the number or nature of the ground truth generative factors [7, 10]. In medical applications, the data is complex and a priori knowledge about the generative factors is mostly unavailable. Recently, multiple models for unsupervised disentangled feature learning were proposed [3, 6, 8, 2]. -VAE [6] is proposed as a modification on vae [9] where the parameter is used to introduce more emphasis on the KL-Divergence part of the VAE objective. This enforces the posterior to match the factorized Gaussian prior which constraints the bottleneck representation to be factorized while still reconstructing the data. Higher values encourage more disentangled representations with a trade-off on the reconstruction error. In -Total Correlation VAE (-TCVAE) [2], the training is focused on the total correlation part of KL term which is responsible for the factorized representation. This lowers the trade-off on the reconstruction fidelity proposed by -VAE. btcvae is validated on examples from a controlled environment with clear factors of generation. This doesn’t represent the complexity of medical data and should be addressed.


In this work, we propose a framework for learning disentangled representations in medical imaging in an unsupervised manner. To our knowledge, this is the first work that analyzes the strength of unsupervised disentangled feature representations in medical imaging and proposes a framework that is well suited to medical applications. We propose a novel residual adversarial VAE with total correlation constraint This enhances the fidelity of the reconstruction and captures more details that describe better the underlying generative factors.

2 Methodology

We utilize deep generative disentangled representation learning to learn the distribution of a medical imaging dataset. We then use the learned representation to generate images while controlling some generative factors. We first show how disentanglement is approached with bvae as a motivation for incorporating btcvae. We then present our contributions to the disentanglement framework by utilizing adversarial loss with residual blocks to enhance the disentanglement and reduce the compromise on the reconstruction. We hypothesize that using adversarial loss with residual blocks in a disentanglement framework would result in higher quality representations with more disentanglement in the feature space.

Let be a set of images generated by combinations of ground-truth generative factors . Our aim is to build an unsupervised generative model that utilizes only the images in

to learn the joint distribution of the images and the set of latent generative factors

allowing us to have better control and interpretability of the latent space. It is worth mentioning the latent generative factors capture both disentangled and entangled factors. To realize our aim, we follow the concept of bvae in learning a posterior distribution that could be used to generate images from . The posterior representation is approximated by . The model is built such that the generative factors are represented by the posterior bottleneck in a disentangled fashion.

In bvae, implicit independence is enforced on the posterior to encourage a disentangled representation. This is done by constraining the posterior to match a prior . The prior is set to be an isotropic unit Gaussian (). Adding extra pressure on the posterior to match constraints the capacity of the bottleneck and pushes it to be factorized [6]. Thus, the objective function for bvae is as follows


where and are trainable weights of encoder and decoder respectively,

is the Kullback-Leibler divergence. When

, we get the original VAE loss [9]. For disentanglement, values of are typically chosen. Using this formula enhances the disentanglement at the cost of reconstruction fidelity. It is suggested by [2] that the total correlation term within is responsible for the factorized representation. Hence, focusing the training on the total correlation would result in better disentanglement while having less effect on the reconstruction. The objective function changes such as is decomposed and is now only multiplied by the total correlation term as follows


The second term is the mutual information between the data and the latent variable. Penalizing this term reduces the amount of information related to that are represented in . Which in turn could decrease the reconstruction performance. The third term is the total correlation (TC) which is a generalization of mutual information to more than two variables. Penalizing TC forces independence in the represented factors. The last term is referred to as dimension-wise KL and is applied on individual latent dimensions. We use btcvae for its good results on disentanglement on various datasets while having better reconstruction that other disentanglement models and for the parameter-less approximation of . For more details about the decomposition and the approximation of the reader is referred to [2].

To enhance the fidelity of the reconstructions and improve the generative factors captured by

, we add a discriminator network on top of btcvae model. The discriminator is trained to decide whether an input image is generated synthetically or sampled from the real data distribution. We employ adversarial loss scheme for the training. The discriminator in this scenario has to learn implicitly a rich similarity metric based on features extracted from the images rather than relying only on pixel-wise similarity. This does not only improve generated images visually, but also learns a richer representation in the code

 [11]. This is because the pixel-wise loss acts as a content loss while the discriminator loss acts as a style loss [4]. Moreover, we incorporate residual blocks rather than convolutional layers applied in [2]. This is because residual blocks have shown a better flow of the gradients. This limits the problems related to vanishing/exploding gradients [5] and is being used in state-of-the-art Generative Adversarial Nets (GANs) literature [15] for more stable training. We denote to the discriminator network described by trainable parameters , is a real image sampled from and is the reconstructed image from . The final objective is


The model is trained by alternating between and optimization. We use pixel-wise -distance between and as .

3 Experiments

Experimental validation evaluates the proposed framework in two main experiments: First, we compare our proposed method disentanglement performance to state-of-the-art methods in learning both entangled and disentangled representations. We also utilize the learned representations in two use-cases, namely, unsupervised clustering and supervised classification with a few amounts of labels. In the second experiment, we evaluate the results visually and analyze the interpretable learned representation.


We opt for the publicly available Skin Lesion dataset from ISIC 2018 Challenge [14] to perform our validations. To train our model, we utilize the dataset of Task 3 which consists of RGB images with 7 types of skin lesions capturing 7 pathological generative factors. To evaluate the model against ground-truth generative factors, i.e. eccentricity, orientation, and size, we utilize the dataset of Task 2 which consists of images with pixel-wise segmentation. Note that all images are down-sampled to .

Evaluation metrics:

To quantitatively evaluate the disentanglement quality, we report the Mutual Information Gap (MIG) metric as proposed and suggested in [2]. As opposed to the disentanglement metric in [6], MIG takes axis-alignment (one is captured by one ) into consideration, and it is unbiased to hyper-parameters opposite to [6, 8]. MIG measures the mutual information (MI) between and the known generative factor , then the difference between the two highest MIs of a generative factor is calculated, and normalized then by the entropy of . The average MIG is then computed as


where is the entropy, is the mutual information. For our experiments, we set the generative factors as follows:

  1. MIG Pathologies (): The ground truth classes are used as generative factors in one vs. all fashion. For instance, for the Skin Lesion dataset. Each generative factor has two possible values in this scenario.

  2. MIG Handcrafted Factors (): In addition, we handcrafted a few generative factors which are easily visible in the image space, e.g. geometric and morphological changes. To do so, the segmentation masks given in Task 2 are utilized. The handcrafted factors are eccentricity, orientation, and size (i.e ). Each generative factor has two possible values.

In addition, we report the Peak signal-to-noise ratio (PSNR), and Normalized Mutual Information (NMI), and Accuracy (ACC) to evaluate the reconstruction error, clustering, and classification, respectively.


We compare the proposed model to two representation learning models. The first is vae [9] model which does not take disentanglement into account explicitly. The second model is btcvae [2] which adds constraints on the representation to disentangle the components. Further, We employ two variations of our proposed method with bottleneck residual blocks [5]; 1) without the adversarial loss in Equation 3 denoted as Ours-resnet; and 2) with the adversarial loss denoted as Ours-adv.

Implementation details:

We implement the same architecture appeared in the CelebA experiments in [2] for both vae and btcvae. For our proposed method, we replace the convolutional layers with bottleneck residual blocks for both Ours-resnet and Ours-adv, while the additional discriminator network in Ours-adv has the same architecture of the encoder except for the last layer which has a single output. All models are trained using Adam optimizer for iterations with a minibatch size of 256, and a learning rate of . and are set to 6 and 32, respectively. Note that we employ leakyReLU in our Ours-adv which has been successfully applied in the adversarial training literature.

Comparison with state-of-the-art:

We compare our method with the recent state-of-the-art methods by reporting the evaluation metrics (

cf Table.1). We notice improvements over the btcvae in terms of disentanglement with a relative improvement of and on and

, respectively. For reconstruction error, it is expected that vae would be superior to other models because there is no extra focus on the prior constraining part of the loss function which allows reconstruction error to optimize better. However, we notice an improvement on PSNR compared to btcvae model which compromises reconstruction error for disentanglement. This experiment shows that adding the bottleneck residual blocks together with adversarial training not only improves the disentanglement, but also improves the reconstruction quality.


In order to show that the disentangled representation is rather capturing some meaningful generative factors, which might be relevant to the task at hand. We design two use-cases in both unsupervised and supervised paradigms. For the clustering use-case, we utilize the learned representations to fit a Gaussian Mixture Model (GMM) with 7 components and assign a label to each data point. NMI is then calculated between assigned labels and ground-truth labels. We report an average of 10 realizations. Regarding the classification use-case, we utilize the learned representations of a few amounts of labeled data to train a multi-layer perceptron (MLP) on

of the data and evaluate it on the remaining of the data. 10-fold stratified cross-validation is performed.

The model gives a relative improvement of and on the NMI and ACC, respectively. This could be attributed to the quality of the learned representation where features responsible for the pathologies are captured by disentanglement models as generative factors.

VAE 5.23 2.74 22.91 9.12 67.88
-TCVAE 6.92 3.53 20.79 10.66 68.61
Ours-resnet 11.61 5.89 19.42 9.89 69.19
Ours-Adv 12.57 9.24 21.18 11.86 70.02
Table 1: Comparison of various representation learning models.


We qualitatively examine the interpretability of the learned representations by manipulating the latent code. For instance, Fig. (a)a shows a comparison of the traversal between the proposed model and VAE. We notice that the dimension responsible for changing skin color has some entanglement with eccentricity and size in the case of VAE. In contrast, we can see in our proposed model that the size and eccentricity are barely changed when the skin color dimension is changed. For eccentricity, we notice in the case of VAE that fewer variations are captured such as the absence of the horizontal elliptic lesions that are captured with the proposed approach.

In Fig. (b)b, we show the possibility of generating images with specific features by smoothly moving over the manifold of the representations. We show the transition of a small lesion on pale skin to a big horizontal lesion on reddish skin by changing multiple latent dimensions responsible for each feature. Having this control over the representation does not only give the ability to generate images with specific known features, but also gives an interpretable representation of the data which can be utilized in many applications.

4 Discussion

In this paper, we introduce a novel adversarial vae with a total correlation constraint to enforce disentanglement on the latent representation while preserving the reconstruction fidelity. The proposed framework is evaluated on skin lesions dataset and shows improvements over other state-of-the-art methods in terms of disentanglement. The disentangled representations learned by the proposed method has shown remarkable performance in both unsupervised clustering and supervised classification. We believe that our work would pave the way for other researchers to further investigate this interesting direction of research. One potential direction is utilizing the control over the generative factors for data augmentation.