Data-driven models with the help of Deep Learning (DL) are affecting wide areas of scientific research and the medical domain is no exception in this matter. However, in healthcare, developing a machine learning algorithm with expert level performance is important but not enough for the adoption of the algorithm when the issues of trust and explainability are not taken into consideration. Explainability of a model is approached either by 1) explicitly learning it by model design or 2) after model design such as using gradient-based localization .
Most of the previous work regarding disentanglement relied on information about the number or nature of the ground truth generative factors [7, 10]. In medical applications, the data is complex and a priori knowledge about the generative factors is mostly unavailable. Recently, multiple models for unsupervised disentangled feature learning were proposed [3, 6, 8, 2]. -VAE  is proposed as a modification on vae  where the parameter is used to introduce more emphasis on the KL-Divergence part of the VAE objective. This enforces the posterior to match the factorized Gaussian prior which constraints the bottleneck representation to be factorized while still reconstructing the data. Higher values encourage more disentangled representations with a trade-off on the reconstruction error. In -Total Correlation VAE (-TCVAE) , the training is focused on the total correlation part of KL term which is responsible for the factorized representation. This lowers the trade-off on the reconstruction fidelity proposed by -VAE. btcvae is validated on examples from a controlled environment with clear factors of generation. This doesn’t represent the complexity of medical data and should be addressed.
In this work, we propose a framework for learning disentangled representations in medical imaging in an unsupervised manner. To our knowledge, this is the first work that analyzes the strength of unsupervised disentangled feature representations in medical imaging and proposes a framework that is well suited to medical applications. We propose a novel residual adversarial VAE with total correlation constraint This enhances the fidelity of the reconstruction and captures more details that describe better the underlying generative factors.
We utilize deep generative disentangled representation learning to learn the distribution of a medical imaging dataset. We then use the learned representation to generate images while controlling some generative factors. We first show how disentanglement is approached with bvae as a motivation for incorporating btcvae. We then present our contributions to the disentanglement framework by utilizing adversarial loss with residual blocks to enhance the disentanglement and reduce the compromise on the reconstruction. We hypothesize that using adversarial loss with residual blocks in a disentanglement framework would result in higher quality representations with more disentanglement in the feature space.
Let be a set of images generated by combinations of ground-truth generative factors . Our aim is to build an unsupervised generative model that utilizes only the images in
to learn the joint distribution of the images and the set of latent generative factorsallowing us to have better control and interpretability of the latent space. It is worth mentioning the latent generative factors capture both disentangled and entangled factors. To realize our aim, we follow the concept of bvae in learning a posterior distribution that could be used to generate images from . The posterior representation is approximated by . The model is built such that the generative factors are represented by the posterior bottleneck in a disentangled fashion.
In bvae, implicit independence is enforced on the posterior to encourage a disentangled representation. This is done by constraining the posterior to match a prior . The prior is set to be an isotropic unit Gaussian (). Adding extra pressure on the posterior to match constraints the capacity of the bottleneck and pushes it to be factorized . Thus, the objective function for bvae is as follows
where and are trainable weights of encoder and decoder respectively,
is the Kullback-Leibler divergence. When, we get the original VAE loss . For disentanglement, values of are typically chosen. Using this formula enhances the disentanglement at the cost of reconstruction fidelity. It is suggested by  that the total correlation term within is responsible for the factorized representation. Hence, focusing the training on the total correlation would result in better disentanglement while having less effect on the reconstruction. The objective function changes such as is decomposed and is now only multiplied by the total correlation term as follows
The second term is the mutual information between the data and the latent variable. Penalizing this term reduces the amount of information related to that are represented in . Which in turn could decrease the reconstruction performance. The third term is the total correlation (TC) which is a generalization of mutual information to more than two variables. Penalizing TC forces independence in the represented factors. The last term is referred to as dimension-wise KL and is applied on individual latent dimensions. We use btcvae for its good results on disentanglement on various datasets while having better reconstruction that other disentanglement models and for the parameter-less approximation of . For more details about the decomposition and the approximation of the reader is referred to .
To enhance the fidelity of the reconstructions and improve the generative factors captured by
, we add a discriminator network on top of btcvae model. The discriminator is trained to decide whether an input image is generated synthetically or sampled from the real data distribution. We employ adversarial loss scheme for the training. The discriminator in this scenario has to learn implicitly a rich similarity metric based on features extracted from the images rather than relying only on pixel-wise similarity. This does not only improve generated images visually, but also learns a richer representation in the code. This is because the pixel-wise loss acts as a content loss while the discriminator loss acts as a style loss . Moreover, we incorporate residual blocks rather than convolutional layers applied in . This is because residual blocks have shown a better flow of the gradients. This limits the problems related to vanishing/exploding gradients  and is being used in state-of-the-art Generative Adversarial Nets (GANs) literature  for more stable training. We denote to the discriminator network described by trainable parameters , is a real image sampled from and is the reconstructed image from . The final objective is
The model is trained by alternating between and optimization. We use pixel-wise -distance between and as .
Experimental validation evaluates the proposed framework in two main experiments: First, we compare our proposed method disentanglement performance to state-of-the-art methods in learning both entangled and disentangled representations. We also utilize the learned representations in two use-cases, namely, unsupervised clustering and supervised classification with a few amounts of labels. In the second experiment, we evaluate the results visually and analyze the interpretable learned representation.
We opt for the publicly available Skin Lesion dataset from ISIC 2018 Challenge  to perform our validations. To train our model, we utilize the dataset of Task 3 which consists of RGB images with 7 types of skin lesions capturing 7 pathological generative factors. To evaluate the model against ground-truth generative factors, i.e. eccentricity, orientation, and size, we utilize the dataset of Task 2 which consists of images with pixel-wise segmentation. Note that all images are down-sampled to .
To quantitatively evaluate the disentanglement quality, we report the Mutual Information Gap (MIG) metric as proposed and suggested in . As opposed to the disentanglement metric in , MIG takes axis-alignment (one is captured by one ) into consideration, and it is unbiased to hyper-parameters opposite to [6, 8]. MIG measures the mutual information (MI) between and the known generative factor , then the difference between the two highest MIs of a generative factor is calculated, and normalized then by the entropy of . The average MIG is then computed as
where is the entropy, is the mutual information. For our experiments, we set the generative factors as follows:
MIG Pathologies (): The ground truth classes are used as generative factors in one vs. all fashion. For instance, for the Skin Lesion dataset. Each generative factor has two possible values in this scenario.
MIG Handcrafted Factors (): In addition, we handcrafted a few generative factors which are easily visible in the image space, e.g. geometric and morphological changes. To do so, the segmentation masks given in Task 2 are utilized. The handcrafted factors are eccentricity, orientation, and size (i.e ). Each generative factor has two possible values.
In addition, we report the Peak signal-to-noise ratio (PSNR), and Normalized Mutual Information (NMI), and Accuracy (ACC) to evaluate the reconstruction error, clustering, and classification, respectively.
We compare the proposed model to two representation learning models. The first is vae  model which does not take disentanglement into account explicitly. The second model is btcvae  which adds constraints on the representation to disentangle the components. Further, We employ two variations of our proposed method with bottleneck residual blocks ; 1) without the adversarial loss in Equation 3 denoted as Ours-resnet; and 2) with the adversarial loss denoted as Ours-adv.
We implement the same architecture appeared in the CelebA experiments in  for both vae and btcvae. For our proposed method, we replace the convolutional layers with bottleneck residual blocks for both Ours-resnet and Ours-adv, while the additional discriminator network in Ours-adv has the same architecture of the encoder except for the last layer which has a single output. All models are trained using Adam optimizer for iterations with a minibatch size of 256, and a learning rate of . and are set to 6 and 32, respectively. Note that we employ leakyReLU in our Ours-adv which has been successfully applied in the adversarial training literature.
Comparison with state-of-the-art:
We compare our method with the recent state-of-the-art methods by reporting the evaluation metrics (cf Table.1). We notice improvements over the btcvae in terms of disentanglement with a relative improvement of and on and
, respectively. For reconstruction error, it is expected that vae would be superior to other models because there is no extra focus on the prior constraining part of the loss function which allows reconstruction error to optimize better. However, we notice an improvement on PSNR compared to btcvae model which compromises reconstruction error for disentanglement. This experiment shows that adding the bottleneck residual blocks together with adversarial training not only improves the disentanglement, but also improves the reconstruction quality.
In order to show that the disentangled representation is rather capturing some meaningful generative factors, which might be relevant to the task at hand. We design two use-cases in both unsupervised and supervised paradigms. For the clustering use-case, we utilize the learned representations to fit a Gaussian Mixture Model (GMM) with 7 components and assign a label to each data point. NMI is then calculated between assigned labels and ground-truth labels. We report an average of 10 realizations. Regarding the classification use-case, we utilize the learned representations of a few amounts of labeled data to train a multi-layer perceptron (MLP) onof the data and evaluate it on the remaining of the data. 10-fold stratified cross-validation is performed.
The model gives a relative improvement of and on the NMI and ACC, respectively. This could be attributed to the quality of the learned representation where features responsible for the pathologies are captured by disentanglement models as generative factors.
We qualitatively examine the interpretability of the learned representations by manipulating the latent code. For instance, Fig. (a)a shows a comparison of the traversal between the proposed model and VAE. We notice that the dimension responsible for changing skin color has some entanglement with eccentricity and size in the case of VAE. In contrast, we can see in our proposed model that the size and eccentricity are barely changed when the skin color dimension is changed. For eccentricity, we notice in the case of VAE that fewer variations are captured such as the absence of the horizontal elliptic lesions that are captured with the proposed approach.
In Fig. (b)b, we show the possibility of generating images with specific features by smoothly moving over the manifold of the representations. We show the transition of a small lesion on pale skin to a big horizontal lesion on reddish skin by changing multiple latent dimensions responsible for each feature. Having this control over the representation does not only give the ability to generate images with specific known features, but also gives an interpretable representation of the data which can be utilized in many applications.
In this paper, we introduce a novel adversarial vae with a total correlation constraint to enforce disentanglement on the latent representation while preserving the reconstruction fidelity. The proposed framework is evaluated on skin lesions dataset and shows improvements over other state-of-the-art methods in terms of disentanglement. The disentangled representations learned by the proposed method has shown remarkable performance in both unsupervised clustering and supervised classification. We believe that our work would pave the way for other researchers to further investigate this interesting direction of research. One potential direction is utilizing the control over the generative factors for data augmentation.
-  Bengio, Y., Courville, A., Vincent, P.: Representation learning: A review and new perspectives. IEEE transactions on pattern analysis and machine intelligence 35(8), 1798–1828 (2013)
-  Chen, T.Q., Li, X., Grosse, R.B., Duvenaud, D.K.: Isolating sources of disentanglement in variational autoencoders. In: Advances in Neural Information Processing Systems. pp. 2615–2625 (2018)
-  Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I., Abbeel, P.: Infogan: Interpretable representation learning by information maximizing generative adversarial nets. In: Advances in neural information processing systems. pp. 2172–2180 (2016)
-  Gatys, L.A., Ecker, A.S., Bethge, M.: A neural algorithm of artistic style. arXiv preprint arXiv:1508.06576 (2015)
He, K., Zhang, X., Ren, S., Sun, J.: Identity mappings in deep residual networks. In: European conference on computer vision. pp. 630–645. Springer (2016)
-  Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., Lerchner, A.: beta-vae: Learning basic visual concepts with a constrained variational framework. In: International Conference on Learning Representations (2017)
Hinton, G.E., Krizhevsky, A., Wang, S.D.: Transforming auto-encoders. In: International Conference on Artificial Neural Networks. pp. 44–51. Springer (2011)
-  Kim, H., Mnih, A.: Disentangling by factorising. arXiv preprint arXiv:1802.05983 (2018)
-  Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013)
-  Kulkarni, T.D., Whitney, W.F., Kohli, P., Tenenbaum, J.: Deep convolutional inverse graphics network. In: Advances in neural information processing systems. pp. 2539–2547 (2015)
-  Larsen, A.B.L., Sønderby, S.K., Larochelle, H., Winther, O.: Autoencoding beyond pixels using a learned similarity metric. arXiv preprint arXiv:1512.09300 (2015)
-  Miotto, R., Wang, F., Wang, S., Jiang, X., Dudley, J.T.: Deep learning for healthcare: review, opportunities and challenges. Briefings in bioinformatics 19(6), 1236–1246 (2017)
-  Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Grad-cam: Visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision. pp. 618–626 (2017)
-  Tschandl, P., Rosendahl, C., Kittler, H.: The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific data 5, 180161 (2018)
-  Zhang, H., Goodfellow, I., Metaxas, D., Odena, A.: Self-attention generative adversarial networks. arXiv preprint arXiv:1805.08318 (2018)