Representation Disentanglement for Multi-task Learning with application to Fetal Ultrasound

08/21/2019 ∙ by Qingjie Meng, et al. ∙ Imperial College London 4

One of the biggest challenges for deep learning algorithms in medical image analysis is the indiscriminate mixing of image properties, e.g. artifacts and anatomy. These entangled image properties lead to a semantically redundant feature encoding for the relevant task and thus lead to poor generalization of deep learning algorithms. In this paper we propose a novel representation disentanglement method to extract semantically meaningful and generalizable features for different tasks within a multi-task learning framework. Deep neural networks are utilized to ensure that the encoded features are maximally informative with respect to relevant tasks, while an adversarial regularization encourages these features to be disentangled and minimally informative about irrelevant tasks. We aim to use the disentangled representations to generalize the applicability of deep neural networks. We demonstrate the advantages of the proposed method on synthetic data as well as fetal ultrasound images. Our experiments illustrate that our method is capable of learning disentangled internal representations. It outperforms baseline methods in multiple tasks, especially on images with new properties, e.g. previously unseen artifacts in fetal ultrasound.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Image interpretation using convolutional neural networks (CNNs) has been widely and successfully applied to medical image analysis during recent years. However, in contrast to human observers, CNNs exhibit weaknesses of being generalized to tackle previously unseen entangled image properties (

e.g. shape and texture) [1]. In Ultrasound (US), the image property entanglement can be observed when acquisition-related artifacts (e.g. shadows) obfuscate the underlying anatomy (see Fig. 1). A CNN simultaneously learns anatomical features and artifacts features for either anatomy classification or artifacts detection [2]. As a result, the model trained by images with certain entangled properties (e.g. images without acoustic shadows) can hardly handle images with new entangled properties which are unseen during training (e.g. images with shadows).

Figure 1: Examples of fetal US data. Green framed images are shadow-free and red framed images contain acoustic shadows.

Approaches for representation disentanglement have been proposed in order to learn semantically disjoint internal representations for improving image interpretation [3]. These methods pave a way for improving the generalization of CNNs in a wide range of medical image analysis problems. Specifically for a practical application in this work, we want to disentangle anatomical features from shadow features so that to generalize anatomical standard plane analysis for a better detection of abnormality in early pregnancy.

Contribution: In this paper, we propose a novel, end-to-end trainable representation disentanglement model that can learn distinct and generalizable features through a multi-task architecture with adversarial training. The obtained disjoint features are able to improve the performance of multi-task networks, especially on data with previously unseen properties. We evaluate the proposed model on specific multi-task problems, including shape/background-color classification tasks on synthetic data and standard-plane/shadow-artifacts classification tasks on fetal US data. Our experiments show that our model is able to disentangle latent representations and, in a practical application, improves the performance for anatomy analysis in US imaging.

Related work:

Representation disentanglement has been widely studied in the machine learning literature, ranging from traditional models such as Independent Component Analysis (ICA) 

[4] and bilinear models [5] to recent deep learning-based models such as InfoGAN [6] and -VAE [7, 8]. Disentangled representations can be utilized to interpret complex interactions of underlying factors within data [9, 10] and enable deep learning models to manipulate relevant information for specific tasks [11, 12, 13]. Particularly related to our work is the work by Mathieu et al. [14], which proposed a conditional generative model with adversarial networks to disentangle specific and unspecific factors of variation in deep representations without strong supervision. Compared to [14], Hadad et al. [13] proposed a simpler two-step method with the same aim. Their network directly utilizes the encoded latent space without assuming the underlying distribution, which can be more efficient for learning various unspecified features. Different from their aim – disentangling one specific representation from unspecific factors – our work focuses on disentangling several specific factors. Further related to our research question is to learn only unspecific invariant features, for example, for domain adaptation [15]. However, unlike learning invariant features, which ignores task-irrelevant information [9], our method aims to preserve information for multiple tasks while enhancing feature generalizability.

In the medical image analysis community, few approaches have focused on disentangling internal factors of representations in discriminative tasks. Ben-Cohen et al. [16] proposed a method to disentangle lesion type from image appearance and use disentangled features to generate more training samples for data augmentation. Their work improves liver lesions classification. In contrast, our work aims to utilize disentangled features for generalization of deep neural networks in medical image analysis.

2 Method

Our goal is to disentangle latent representations of the data into distinct feature sets () that separately contain relevant information for corresponding different tasks (). The main motivation of the proposed method is to learn feature sets that are maximally informative about their corresponding task (e.g. ) but minimally representative for irrelevant tasks (e.g. ). While our approach scales to any number of classification tasks, in this work we focus on two tasks as a proof of concept. The proposed method consists of two classification tasks () with an adversarial regularization. The classification aims to map the encoded features to their relevant class identities, and is trained to maximize and . The adversarial regularization penalizes the mutual information between the encoded features and their irrelevant class identities, in other words, minimizes and . The training architecture of our method is shown in Fig. 2.

Figure 2: Training framework for the proposed method. Res-Blk refers to residual-blocks. Example 1/2 are two data set examples used in Sect. 3. The classifications enables the encoded features to be maximally informative about related tasks while the adversarial regularization encourages these features to be less informative about irrelevant tasks.

Classification

is used to learn the encoded features that enable high prediction performance for the class identity of the relevant task. Each of the two classification networks is composed of an encoder and a classifier for a defined task. Given data

, the matching labels are for and for . is the number of images and are the number of class identities in each task. Two independent encoders map to and with parameters and respectively, yielding and . Two classifiers are used to predict class identity for the corresponding task, where and . and are the parameters of the corresponding classifiers. We define the the cost functions and as the softmax cross-entropy between and and between and respectively. The classification loss is minimized to train the two encoders and the two classifiers () for obtaining and that are maximally related to their relevant task.

Adversarial regularization is used to force the encoded features to be minimally informative about irrelevant tasks, which results in disentanglement of internal representations. The adversarial regularization is implemented by using an adversarial network for each task as shown in Fig. 2. These adversarial networks are utilized to map the encoded features to class identity of the irrelevant task, yielding and . Here, and are the parameters of the corresponding adversarial networks. By referring to and as the softmax cross-entropy between and and between and , the adversarial loss is defined as . During training, the adversarial networks are trained to minimize while two encoders and two classifiers are trained to maximize (). This competition between the encoders/classifiers and the adversarial networks encourages the encoded features to be invalid for irrelevant tasks.

By combining the two classifications with the adversarial regularization, the whole model is optimized iteratively during training. The training objective for optimizing the two encoders and the two classifiers can be written as

(1)

Here, is the trade-off parameter of the adversarial regularization. The training objective for the optimization of the adversarial networks thus follows as

(2)

Network architectures: and both consist of six residual-blocks implemented as proposed in [17] to reduce the training error and to support easier network optimization. and both contain two dense layers with hidden units. The adversarial networks and have the same architecture as and respectively.

Training: Our model is optimized for epochs and

is chosen heuristically and independently for each data set using validation data. For more stable optimization 

[13], in each iteration, we train the encoders and classifiers once, followed by five training steps of the adversarial networks. Similar to [13], we use the Adam optimizer (, ) to train the encoders and classifiers based on Eq. 1

, and use Stochastic Gradient Descent (SGD) with momentum optimizer (

, ) to update the parameters of the adversarial networks in Eq. 2. We apply L2 regularization () to all weights during training to prevent over-fitting. The batch size is 50 and the images in each batch have been randomly flipped as data augmentation. Our model is trained on a Nvidia Titan X GPU with 12 GB of memory.

3 Evaluation and Results

Evaluation on synthetic data: We use synthetic data as a proof of concept example to verify our model. This data set contains a randomly located gray circle or rectangle on a black or white background. We split the data into images for train/validation/test and these images consist of circles on white background, rectangles on black background and rectangles on white background. To keep the balance between image properties in the training split, we use circle:rectangle=1:1 and black:white=7:5. In this case, is a background color classification task and is the a shape classification task. We implement our model as outlined in Sec.2 and choose . We evaluate our model on the test data. The experimentation illustrates that the encoded features successfully identify the class identities of the relevant task (e.g. , ) but fail to handle irrelevant task (e.g. , ). Here, is the overall accuracy. To show the utility of the proposed method on images with previously unseen entangled properties, we additionally compare the shape classification performance of our model and a baseline (our model without the adversarial regularization) on images with a previously unseen entangled properties (circles on black background). The proposed model achieves and outperforms the baseline which achieves . We use PCA to examine the learned embedding space at the penultimate dense layer of the classifiers. The top row of Fig. 11 illustrates that the extracted features is able to identify class identities for relevant tasks (see (a,c)) but unable to predict correct class identities for irrelevant tasks (see (b,d).

Evaluation on fetal US data: We verify the applicability of our method on fetal US data. Here, we refer to an anatomical standard plane classification task as and an acoustic shadow artifacts classification task as . We want to learn the corresponding disentangled features for all anatomical information, separated from containing only information about shadow artifacts. is the label for different anatomical standard planes while and are the labels of the shadow-free class and the shadow-containing class respectively.

Data set: The fetal US data set contains images sampled from 4120 2D US fetal anomaly screening examinations with gestational ages between 1822 weeks. These sequences consist of eight standard planes defined in the UK FASP handbook [18], including three vessel view (3VV), left ventricular outflow tract (LVOT), abdominal (Abd.), four chamber view (4CH), femur, kidneys, lips and right ventricular outflow tract (RVOT), and are classified by expert observers as shadow-containing (W S) or shadow-free (W/O S) (Fig. 1). We split the data as shown in Table. 1. Train, Validation and Test seen are separate data sets. Test seen contains the same entangled properties (but different images) as used for the training data set, while LVOT(W S) and Artifacts(OTHS) contain new combinations of entangled properties.

Train Validation Test seen LVOT(W S) Artifacts(OTHS)
3VV W/O S (W S) 180 (320) 50 (50) 334 (41) - (-) - (-)
LVOT W/O S (W S) 500 (-) 50 (-) 79 (-) - (418) - (-)
Abd. W/O S (W S) 125 (375) 50 (50) 190 (220) - (-) - (-)
Others W/O S (W S) - (-) - (-) - (-) - (-) 3159 (2211)
Table 1: Data split. “Others” contains standard planes 4CH, femur, kidneys, lips and RVOT. Test seen, LVOT(W S) and Artifacts(OTHS) are used for testing.

Evaluation approach: We refer to Std plane only as the networks for standard plane classification only (consists of and ), and Artifacts only as the networks for shadow artifacts classification only (consists of and ). refers to the proposed method without the adversarial regularization and Proposed is our method in Fig. 2.

The proposed method is implemented as outlined in Sec.2 choosing . contains three dense layers with hidden units while contains two dense layers with hidden units. We choose a bigger network capacity for by assuming that anatomies have more complex structures than shadows to be learned.

Table. 2 shows that our method improves the performance of standard plane classification by and on Test seen when compared with the Std plane only and the method (see in Col.5). It achieves minimal improvement (Artifacts only: and : classification accuracy) for shadow artifacts classification (see in Col.8).We also demonstrate the utility of the proposed method on images with previously unseen entangled properties. Table. 2 shows that the proposed method achieves accuracy of standard plane classification on LVOT(W S) ( higher than other comparison methods) while it performs similar to other methods on Artifacts(OTHS) for shadow artifacts classification.

Col.1 Col.2 Col.3 Col.4 Col.5 Col.6 Col.7 Col.8 Col.9 Col.10
Methods Test seen
LVOT
(W S)
Artifacts
(OTHS)
3VV LVOT Abd. W/O S W S
Std plane only 60.80 96.59 67.09 78.36 - - - 34.93 -
Artifacts only - - - - 77.94 80.46 78.70 - 69.26
63.73 97.80 78.48 81.25 78.77 77.78 78.74 37.56 69.50
Proposed 93.87 97.56 81.01 94.44 87.89 58.62 79.05 73.68 68.49
39.20 83.90 82.28 64.35 68.49 81.99 72.57 - -
Table 2: The classification accuracy () of different methods for the standard classification () and shadow artifacts classification () on Test seen data set and data sets with unseen entangled properties (LVOT(W S) and Artifacts(OTHS)). “Proposed” uses encoded features for relevant tasks, namely, and . “” uses encoded features for irrelevant tasks, namely, and . is the overall accuracy.
(a)
(b)
(c)
(d)
(a)
(b)
(c)
(d)
Figure 11: Visualization of the embedded data on the penultimate dense layer. The top row shows embedded synthetic test data while the bottom row shows embedded fetal US Test seen data. (a, c) are the results of using encoded features for relevant tasks, e.g. for and for ; separated clusters are desirable here. (b, d) are the results of using encoded features for irrelevant tasks, namely, for and for ; mixed clusters are desirable in this case.

We evaluate the performance of disentanglement by using the encoded features for the irrelevant task on Test seen, e.g. and . Here, and are encoded features of the proposed method. in Table. 2 indicates that contains much less anatomical information for standard plane classification ( in proposed vs. in ), while contains less shadow features information ( in proposed vs. in ). We additionally use PCA to show the embedded test data on the penultimate dense layer. The bottom row in Fig. 11 shows that encoded features are more capable of classifying class identities in the relevant task than the irrelevant task (e.g. (a) vs. (d)).

Discussion: Acoustic shadows are caused by anatomies which block the propagation of sound waves or by destructive interference. With this dependency between anatomy and artifacts, separating shadow features from anatomical features may lead to decreased performance of artifacts classification (Table.2, Col.7, Proposed). However, this separation enables feature generalization so that the model is less limited to certain image formation and able to tackle new combinations of entangled properties (Table.2, Col.9, Proposed). Generalization of supervised neural networks can also be achieved by extensive data collection across domains and in a limited way by artificial data augmentation. Here, we propose an alternative through feature disentanglement, which requires less data collection and training effort. Fig. 11

shows PCA plots for the penultimate dense layer. Observing entanglement in earlier layers reveals that disentanglement occurs in this very last layer. This is due to the definition of our loss functions and is partly influenced by the dense layers interpreting the latent representation for classification. Finally, perfect representation disentanglement is likely infeasible because image features are rarely totally isolated in reality. In this paper we have shown that even imperfect disentanglement is able to provide great benefits for artifact-prone image classification in medical image analysis.

4 Conclusion

In this paper, we propose a novel disentanglement method to extract generalizable features within a multi-task framework. In the proposed method, classification tasks lead to encoded features that are maximally informative with respect to these tasks while the adversarial regularization forces these features to be minimally informative about irrelevant tasks, which disentangles internal representations. Experimental results on synthetic and fetal US data show that our method outperforms baseline methods for multiple tasks, especially on images with entangled properties that are unseen during training. Future work will explore the extension of this framework to multiple tasks beyond classification.

Acknowledgments.

We thank the Wellcome Trust IEH Award [102431], Nvidia (GPU donations) and Intel.

References