TandemNet: Distilling Knowledge from Medical Images Using Diagnostic Reports as Optional Semantic References

08/10/2017 ∙ by Zizhao Zhang, et al. ∙ 0

In this paper, we introduce the semantic knowledge of medical images from their diagnostic reports to provide an inspirational network training and an interpretable prediction mechanism with our proposed novel multimodal neural network, namely TandemNet. Inside TandemNet, a language model is used to represent report text, which cooperates with the image model in a tandem scheme. We propose a novel dual-attention model that facilitates high-level interactions between visual and semantic information and effectively distills useful features for prediction. In the testing stage, TandemNet can make accurate image prediction with an optional report text input. It also interprets its prediction by producing attention on the image and text informative feature pieces, and further generating diagnostic report paragraphs. Based on a pathological bladder cancer images and their diagnostic reports (BCIDR) dataset, sufficient experiments demonstrate that our method effectively learns and integrates knowledge from multimodalities and obtains significantly improved performance than comparing baselines.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In medical image understanding, convolutional neural networks (CNNs) gradually become the paradigm for various problems

[1]. Training CNNs to diagnose medical images primarily follows pure engineering trends in an end-to-end fashion. However, the principles of CNNs during training and testing is difficult to interpret and justify. In clinical practice, domain experts teach learners by explaining findings and observations to make a disease decision rather than leaving learners to find clues from images themselves.

Inspired by this fact, in this paper, we explore the usage of semantic knowledge of medical images from their diagnostic reports to provide explanatory supports for CNN-based image understanding. The proposed network learns to provide interpretable diagnostic predictions in the form of attention and natural language descriptions. The diagnostic report is a common type of medical record in clinics, which is comprised of semantic descriptions about the observations of biological features. Recently, we have witnessed rapid development in multimodal deep learning research

[2, 3]. We believe the joint study of multimodal data is essential towards intelligent computer-aided diagnosis. However, only a dearth of related work exists [4, 5].

To take advantage of the language modality, we propose a multimodal network that jointly learns from medical images and their diagnostic reports. Semantic information is interacted with visual information to improve the image understanding ability by teaching the network to distill informative features. We propose a novel dual-attention model to facilitate such high-level interaction. The training stage uses both images and texts. In the testing stage, our network can take an image and provide accurate prediction with an optional (i.e. with or without) text input. Therefore, the language and image models inside our network cooperate with one another in a tandem scheme to either single(images)- or double(image-text)-drive the prediction process. We refer to our proposed network as TandemNet. Figure 1 illustrates the overall framework.

To validate our method, we cooperate with a pathologist to collect the BCIDR dataset. Sufficient experimental studies on BCIDR demonstrate the advantages of TandemNet. Furthermore, by coupling visual features with the language model and fine-tuning the network using backpropagation through time (BPTT), TandemNet learns to automatically generate diagnostic reports. The rich outputs (i.e. attention and reports) of TandemNet have valuable meanings: providing explanations and justifications for its diagnostic prediction and making this process interpretable to pathologists.

Figure 1: The illustration of the TandemNet.

2 Method

CNN for image modeling We adopt the (new pre-activated) residual network (ResNet) [6] as our image model. The identity mapping in ResNet significantly improves the network generalization ability. There are many architecture variants of ResNet. We adopt the wide ResNet (WRN) [7] which has shown better performance and higher efficiency with much less layers. It also offers scalability of the network (number of parameters) by adjusting a widen factor (i.e. the channel of feature maps) and depth. We extract the output of the layer before average pooling as our image representation, denoted as . The input image size is , so . depends on the widen factor.

LSTM for language modeling

We adopt Long Short-Term Memory (LSTM)

[8]

to model diagnostic report sentences. LSTM improves vanilla recurrent neural networks (RNNs) for natural language processing and is also widely-used for multimodal applications such as image captioning

[9, 2]. It has a sophisticated unit design, which enables long-term dependency and greatly reduces the gradient vanishing problem in RNNs [10]. Given a sequence of words , LSTM reads the words one at a time and maintains a memory state and a hidden state . At each time step, LSTM updates them by

(1)

where

is an input word, which is computed by firstly encoding it as a one-hot vector and then multiplied by a learned word embedding matrix.

The hidden state is a vector encoding of sentences. The treatment of it varies from problems. For example, in image captioning, a multilayer perceptron (MLP) is used to decode it as a predicted word at each time step. In machine translation

[11], all hidden states could be used. A medical report is more formal than a natural image caption. It usually describes multiple types of biological features structured by a series of sentences. It is important to represent all feature descriptions but maintain the variety and independence among them. To this end, we extract the hidden state of every feature description (in our implementation, it is achieved by adding a special token at the end of each sentence beforehand and extracting the hidden states at all the placed tokens). In this way, we obtain a text representation matrix for types of feature descriptions. This strategy has more advantages: it enables the network to adaptively select useful semantic features and determine respective feature importance to disease labels (as shown in experiments).

Dual-attention model The attention mechanism [12, 11]

is an active topic in both computer vision and natural language communities. Briefly, it gives networks the ability to generate attention on parts of the inputs (like visual attention in the brain cortex), which is achieved by computing a context vector with attended information preserved.

Different from most existing approaches that study attention on images or text, given the image representation and the report representation 111The two matrices are firstly embedded through a convolutional layer with Tanh., our dual-attention model can generate attention on important image regions and sentence parts simultaneously. Specifically, we define the attention function to compute a piece-wise weight vector as

(2)

where has individual weights for visual and semantic features (i.e. and ). is specifically defined as follows:

(3)

where and are parameters to be learned to compute , and . and are vectors with all elements to be one. denotes the global average-pooling operator on the last dimension of and . denotes the concatenation operator. Finally, we obtain a context vector by

(4)

In our formulation, the computation of image and text attention is mutually dependent and conducts high-level interactions. The image attention is conditioned on the global text vector and the text attention is conditioned on the global image vector . When computing the weight vector , both information contributes through . We also consider extra configurations: computing two by two , and then concatenate them to compute with one softmax or compute two with two softmax functions. Both configurations underperform ours. We conclude that our configuration is optimal for the visual and semantic information to interact with each other.

Intuitively, our dual-attention mechanism encourages better alignment of visual information with semantic information piecewise, which thereby improves the ability of TandemNet to discriminate useful features for attention computation. We will validate this experimentally.

Prediction module To improve the model generalization, we propose two effective techniques for the prediction module of the dual-attention model.

1) Visual skip-connection

The probability of a disease label

is computed as

(5)

The image feature skips the dual-attention model and is directly added onto (see Figure 1). During backpropagation, this skip-connection directly passes gradients for the loss layer to the CNN, which prevents possible gradient vanishing in the dual-attention model from obstructing CNN training.

2) Stochastic modality adaptation We propose to stochastically “abandon” text information during training. This strategy generalizes TandemNet to make accurate prediction with absent text. Our proposed strategy is inspired by Dropout and the stochastic depth network [13], which are effective for model generalization. Specifically, we define a drop rate as the probability to remove (zero-out) the text part during the entire network training stage. Thus, based to the principle of Dropout, will be scaled by if text is given in testing.

The effects of these two techniques are discussed in experiments.

Method Accuracy () w/o text w/ text WRN16-4 75.4 - ResNet18-TL 79.4 - TandemNet-WVS 79.4 85.6 TandemNet 82.4 89.9 TandemNet-TL 84.9 88.6 Table 1: The quantitative evaluation (averaged on 3 trials). The first block shows standard CNNs so text is irrevelent. Figure 2: The confusion matrices of two compared methods ResNet18-TL and TandemNet-TL (w/o text) in Table 1.

3 Experiments

Dataset To collect the BCIDR dataset, whole-slide images were taken using a 20X objective from hematoxylin and eosin (HE) stained sections of bladder tissue extracted from a cohort of 32 patients at risk of a papillary urothelial neoplasm. From these slides, 1,000 RGB images were extracted randomly close to urothelial regions (each patient’s slide yields a slightly different number of images). For each of these images, the pathologist then provided a paragraph describing the disease state. Each paragraph addresses five types of cell appearance features, namely the state of nuclear pleomorphism, cell crowding, cell polarity, mitosis, and prominence of nucleoli (thus ). Then a conclusion is decided for each image-text pair, which is comprised of four classes, i.e. normal tissue, low-grade (papillary urothelial neoplasm of low malignant potential) carcinoma, high-grade carcinoma, and insufficient information. Following the same procedure, four doctors (not experts in the bladder cancer) wrote additional four descriptions for each image. They also refer to the pathologist’s description to make sure their annotation accuracy. Thus there are five ground-truth reports per image and image-text pairs in total. Each report varies in length between 30 and 59 words. We randomly split (6/32) of patients including samples as the testing set and the remaining of patients including samples ( as the validation set for model selection) for training. We subtract the data RGB mean and augment through clip, mirror and rotation.

Implementation details Our implementation is based on Torch7. We use a small WRN with and (denoted as WRN16-4), resulting in M parameters and . We use dropout with after each convolution. We use for LSTM, , and . We use SGD with a learning rate for the CNN (used likewise for standard CNN training for comparison) and Adam with for the dual-attention model, which are multiplied by

per epoch. We also limit the gradient magnitude of the dual-attention model to

by normalization [10].

Diagnostic prediction evaluation Table 1 and Figure 2

show the quantitative evaluation of TandemNet. For comparison with CNNs, we train a WRN16-4 and also a ResNet18 (has 11M parameters) pre-trained on ImageNet

222Provided by https://github.com/facebook/fb.resnet.torch

. We found transfer learning is beneficial. To test this effect in TandemNet, we replace WRN16-4 with a pre-trained ResNet18 (TandemNet-TL). As can be observed, TandemNet and TandemNet-TL significantly improve WRN16-4 and ResNet18-TL when only images are provided. We observe TandemNet-TL slightly underperforms TandemNet when text is provided with multiple trails. We hypothesize that it is because fine-tuning a model pre-trained on a complete different natural image domain is relatively hard to get aligned with medical reports in the dual-attention model. From Figure

2, high grade (label id 3) is more likely to be misclassified as low grade (2) and some insufficient information (4) is confused with normal (1).

Figure 3: Left: The accuracy with varying drop rates. Right: The averaged text attention per feature type (and overall) to each disease label. The feature type is specified in the text of dataset introduction (in order).
Figure 4: The t-SNE visualization of the MLP input. Each point is a test sample. The embeddings with text (right) results in better distribution.

We analyze the text drop rate in Figure 3 (left). When the drop rate is low, the model obsessively uses text information, so it achieves low accuracy without text. When the drop rate is high, the text can not be well adapted, resulting in decreased accuracy with or without text. The drop rate of performs best and thereby is used in this paper. As illustrated in Figure 3

, we found that the classification of text is easier than images, therefore its accuracy is much higher. However, please note that the primary aim of this paper is to use text information only at the training stage. While at the testing stage, the goal is to accurately classify images without text.

In Eq. (5), one question that may arise is that, when testing without text, whether it is merely from the CNN that produces useful features rather than from the dual-attention model (since the removal (zero-out) of could possibly destroy the attention ability). To validate the actual role of , we remove the visual skip-connection and train the model (denoted as TandemNet-WVS in Table 1) and it improves ResNet16-4 by without text. The qualitative evaluation below also validates the effectiveness of the dual-attention model. Additionally, we use the (t-distributed Stochastic Neighbor Embedding) t-SNE dimensionality reduction technique to examine the input of MLP in Figure 4.

Figure 5: From left to right: Test images (the bottom shows disease labels), pathologist’s annotations, visual attention w/o text. visual attention and corresponding text attention (the bottom shows text inputs). Best viewed in color.

Attention analysis We visualize the attention weights to show how TandemNet captures image and text information to support its prediction (the image attention map is computed by upsampling the weights of to the image space). To validate the visual attention, without notifying our results beforehand, we ask the pathologist to highlight regions of some test images they think are important. Figure 5 illustrates the performance. Our attention maps show surprisingly high consistency with pathologist’s annotations. The attention without text is also fairly promising, although it is less accurate than the results with text. Therefore, we can conclude that TandemNet effectively uses semantic information to improve visual attention and substantially maintains such attention capability though the semantic information is not provided. The text attention is shown in the last column of Figure 5. We can see that our text attention result is quite selective in only picking up useful semantic features.

Furthermore, the text attention statistics over the dataset provides particular insights into the pathologists’ diagnosis. We can investigate which feature contributes the most to which disease label (see Figure 3 (right)). For example, nuclear pleomorphism (feature type 1) shows small effects on the low-grade disease label. cell crowding (2) has large effects on high-grade. We can justify the reason of text attention by closely looking at images of Figure 5: high grade images have obvious high cell crowding degree. Moreover, this result strongly demonstrates the successful image-text alignment of our dual-attention model.

Image report generation We fine-tune TandemNet using BPTT as an extra supervision and use the visual feature as the input of LSTM at the first time step333We freeze the CNN for the whole training and the dual-attention model for the first epochs, and then fine-tune with a smaller learning rate, . . We direct readers to [9] about detailed LSTM training for image captioning. Figure 6 shows our promising results compared with pathologist’s descriptions. We leave the full report generation task as a future study [5].

Figure 6: The pathologist’s annotations are in black and the automatic results of TandemNet are in green, which accurately describe the semantic concepts.

4 Conclusion

This paper proposes a novel multimodal network, TandemNet, which can jointly learn from medical images and diagnostic reports and predict in an interpretable scheme through a novel dual-attention mechanism. Sufficient and comprehensive experiments on BCIDR demonstrate that TandemNet is favorable for more intelligent computer-aided medical image diagnosis.

References

  • [1] Greenspan, H., van Ginneken, B., Summers, R.M.: Guest editorial deep learning in medical imaging: Overview and future promise of an exciting new technique. TMI 35(5) (2016) 1153–1159
  • [2] Vinyals, O., Toshev, A., Bengio, S., Erhan, D.: Show and tell: A neural image caption generator. In: CVPR. (2015) 3156–3164
  • [3] Xu, T., Zhang, H., Huang, X., Zhang, S., Metaxas, D.N.: Multimodal deep learning for cervical dysplasia diagnosis. In: MICCAI. (2016) 115–123
  • [4] Shin, H.C., Roberts, K., Lu, L., Demner-Fushman, D., Yao, J., Summers, R.M.: Learning to read chest x-rays: Recurrent neural cascade model for automated image annotation. In: CVPR. (2016) 2497–2506
  • [5] Zhang, Z., Xie, Y., Xing, F., Mcgough, M., Yang, L.: Mdnet: A semantically and visually interpretable medical image diagnosis network. In: CVPR. (2017)
  • [6] He, K., Zhang, X., Ren, S., Sun, J.: Identity mappings in deep residual networks. In: ECCV. (2016) 630–645
  • [7] Zagoruyko, S., Komodakis, N.: Wide residual networks. In: BMVC. (2016)
  • [8] Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural computation 9(8) (1997) 1735–1780
  • [9] Karpathy, A., Fei-Fei, L.: Deep visual-semantic alignments for generating image descriptions. In: CVPR. (2015) 3128–3137
  • [10] Pascanu, R., Mikolov, T., Bengio, Y.: On the difficulty of training recurrent neural networks. In: ICML. (2013) 1310–1318
  • [11] Luong, M.T., Pham, H., Manning, C.D.:

    Effective approaches to attention-based neural machine translation.

    EMNLP (2015) 1412––1421
  • [12] Xu, K., Ba, J., Kiros, R., Cho, K., Courville, A., Salakhutdinov, R., Zemel, R.S., Bengio, Y.: Show, attend and tell: Neural image caption generation with visual attention. In: ICML. (2015) 2048–2057
  • [13] Huang, G., Sun, Y., Liu, Z., Sedra, D., Weinberger, K.: Deep networks with stochastic depth. In: ECCV. (2016) 646–661