Multi-Instance Multi-Scale CNN for Medical Image Classification

07/04/2019 ∙ by Shaohua Li, et al. ∙ 0

Deep learning for medical image classification faces three major challenges: 1) the number of annotated medical images for training are usually small; 2) regions of interest (ROIs) are relatively small with unclear boundaries in the whole medical images, and may appear in arbitrary positions across the x,y (and also z in 3D images) dimensions. However often only labels of the whole images are annotated, and localized ROIs are unavailable; and 3) ROIs in medical images often appear in varying sizes (scales). We approach these three challenges with a Multi-Instance Multi-Scale (MIMS) CNN: 1) We propose a multi-scale convolutional layer, which extracts patterns of different receptive fields with a shared set of convolutional kernels, so that scale-invariant patterns are captured by this compact set of kernels. As this layer contains only a small number of parameters, training on small datasets becomes feasible; 2) We propose a "top-k pooling"" to aggregate the feature maps in varying scales from multiple spatial dimensions, allowing the model to be trained using weak annotations within the multiple instance learning (MIL) framework. Our method is shown to perform well on three classification tasks involving two 3D and two 2D medical image datasets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 11

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Training a convolutional neural network (CNN) from scratch demands a massive amount of training images. Limited medical images encourage people to do transfer learning, i.e., fine-tune 2D CNN models pretrained on natural images

[10]. A key difference between medical images and natural images is that, regions of interest (ROIs) are relatively small with unclear boundaries in the whole medical images, and ROIs may appear multiple times in arbitrary positions across the (and also in 3D images) dimensions. On the other hand, annotations for medical images are often “weak”, in that only image-level annotations are available, and there are no localized ROIs. In this setting, we can view each ROI as an instance in a bag of all image patches, and the image-level classification falls within the Multiple-Instance Learning (MIL) framework [11, 2, 6].

Another challenge with medical images is that ROIs are often scale-invariant, i.e., visually similar patterns often appear in varying sizes (scales). If approached with vanilla CNNs, an excess number of convolutional kernels with varying receptive fields would be required for full coverage of these patterns, which have more parameters and demand more training data. Some previous works have attempted to learn scale-invariant patterns, for example [8] adopted image pyramids, i.e. resizing input images into different scales, processing them with the same CNN and aggregating the outputs. However, our experiments show that image pyramids perform unstably across different datasets and consume much more computational resources than vanilla CNNs.

This paper aims to address all the challenges above in a holistic framework. We propose two novel components: 1) a multi-scale convolutional layer (MSConv) that further processes feature maps extracted from a pretrained CNN, aiming to capture scale-invariant patterns with a shared set of kernels; 2) a top- pooling scheme that extracts and aggregates the highest activations from feature maps in each convolutional channel (across multiple spatial dimensions in varying scales), so that the model is able to be trained with image-level labels only.

The MSConv layer consists of a few resizing operators (with different output resolutions), and a shared set of convolutional kernels. First a pretrained CNN extracts feature maps from input images. Then the MSConv layer resizes them to different scales, and processes each scale with the same set of convolutional kernels. Given the varying scales of the feature maps, the convolutional kernels effectively have varying receptive fields, and therefore are able to detect scale-invariant patterns. As feature maps are much smaller than input images, the computation and memory overhead of the MSConv layer is insignificant.

The MSConv layer is inspired by ROI-pooling [1], and is closely related to Trident Network [5]

. Trident Network uses shared convolutional kernels of different dilation rates to capture scale-invariant patterns. Its limitations include: 1) the receptive fields of dilated convolutions can only be integer multiples of the original receptive fields; 2) dilated convolutions may overlook prominent activations within a dilation interval. In contrast, the MSConv interpolates input feature maps to any desired sizes before convolution, so that the scales are more refined, and prominent activations are always retained for further convolution.

[3]

proposed a similar idea of resizing the input multiple times before convolution and aggregating the resulting feature maps by max-pooling. However we observed that empirically, activations in larger scales tend to dominate smaller scales and effectively mask smaller scales. MSConv incorporates a batchnorm layer and a learnable weight for each scale to eliminate such biases. In addition, MSConv adopts multiple kernel sizes to capture patterns in more varying scales.

A core operation in an MIL framework is to aggregate features or predictions from different instances (pattern occurrences). Intuitively, the most prominent patterns are usually also the most discriminative, and thus the highest activations could summarize a set of feature maps with the same semantics (i.e., in the same channel). In this regard, we propose a top- pooling scheme that selects the highest activations of a group of feature maps, and takes their weighted average as the aggregate feature for downstream processing. The top- pooling extends [9]

with learnable pooling weights (instead of being specified by a hyperparameter as in

[9]) and a learnable magnitude-normalization operator.

The MSConv layer and the top-

pooling comprise our Multi-Instance Multi-Scale (MIMS) CNN. To assess its performance, we evaluated 12 methods on three classification tasks: 1) classifying Diabetic Macular Edema (DME) on three Retinal Optical Coherence Tomography (OCT) datasets (two sets of 3D images); 2) classifying Myopic Macular Degeneration (MMD) on a 2D fundus image dataset; and 3) classifying Microsatellite Instable (MSI) against microsatellite stable (MSS) tumors of colorectal cancer (CRC) patients on histology images. In most cases, MIMS-CNN achieved better accuracy than five baselines and six ablated models. Our experiments also verified that both the MSConv layer and top-

pooling make important contributions.

Figure 1: The Multi-Instance Multi-Scale CNN on a 3D input image. For clarity, only the -th convolutional channel of the MSConv layer is shown.

2 Multi-Instance Multi-Scale CNN

The architecture of our Multi-Instance Multi-Scale CNN is illustrated in Fig. 1. It consists of: 1) a pretrained 2D CNN to extract primary feature maps, 2) a multi-scale convolutional (MSConv) layer to extract scale-invariant secondary feature maps, 3) a top- pooling operator to aggregate secondary feature maps, and 4) a classifier.

2.1 Multi-Scale Convolutional Layer

Due to limited training images, a common practice in medical image analysis is to extract image features using 2D CNNs pretrained on natural images. These features are referred as the primary feature maps. Due to the domain gap between natural images and medical images, feeding primary feature maps directly into a classifier does not always yield good results. To bridge this domain gap, we propose to use an extra convolutional layer to extract more relevant features from primary feature maps. This layer produces the secondary feature maps.

In order to capture scale-invariant ROIs, we resize the primary feature maps into different scales before convolution. Each scale corresponds to a separate pathway, and weights of the convolutional kernels in all pathways are tied. In effect, this convolutional layer has multiple receptive fields on the primary feature maps. We name this layer as a multi-scale convolutional (MSConv) layer.

More formally, let denote the primary feature maps, denote all the output channels of the MSConv layer111Each convolutional kernel yields multiple channels with different semantics, so output channels are indexed separately, regardless of whether they are from the same kernel., and denote the scale factors of the heights and widths (typically ) adopted by the resizing operators. The combination of the -th scale and the -th channel yields the -th secondary feature maps:

(1)

where in theory could adopt any type of interpolation, and our choice is bilinear interpolation.

For more flexibility, the convolutional kernels in MSConv could also have different kernel sizes. In a setting of resizing operators and different sizes of kernels, effectively the kernels have at most different receptive fields. The multiple resizing operators and varying sizes of kernels complement each other and equip the CNN with scale-invariance.

Among , feature maps in larger scales contain more elements and tend to have more top activations, hence dominate the aggregate feature and effectively mask out the feature maps in smaller scales. In order to remove such biases, the feature maps in different scales are passed through respective magnitude normalization operators. The magnitude normalization operator consists of a batchnorm operator and a learnable scalar multiplier . The scalar multiplier adjusts the importance of the -th channel in the -th scale, and is optimized with back-propagation.

The MSConv layer is illustrated in Fig. 1 and the left side of Fig. 2.

Figure 2: The MSConv and Top- pooling (on the -th channel only) in scales.

2.2 Top- Pooling

Multiple Instance Learning (MIL) views the whole image as a bag, and each ROI as an instance in the bag. Most existing MIL works [11, 6] were instance-based MIL, i.e., they aggregate label predictions on instances to yield a bag prediction. In contrast, [2] adopted embedding-based MIL, which aggregates features (embeddings) of instances to yield bag features, and then do classification on bag features. [2] showed that embedding-based MIL methods outperformed instance-based MIL baselines. Here we propose a simple but effective top- pooling scheme to aggregate the most prominent features across a few spatial dimensions, as a new embedding-based MIL aggregation scheme.

Top- pooling works as follows: given a set of feature maps with the same semantics, we find the top highest activation values, and take a weighted average of them as the aggregate feature value. Intuitively, higher activation values are more important than lower ones, and thus the pooling weight should decrease as the ranking goes lower. However it may be sub-optimal to specify the weights manually as did in [9]. Hence we adopt a data-driven approach to learn these weights automatically. More formally, given a set of feature maps , top- pooling aggregates them into a single value:

(2)

where are the highest activations within , and are nonnegative pooling weights to be learned, subject to a normalization constraint . In practice, is initialized with exponentially decayed values, and then optimized with back-propagation.

An important design choice in MIL is to choose the spatial dimensions to be pooled. Similar patterns, regardless of where they appear, contain similar information for classification. Correspondingly, features in the same channel could be pooled together. On 2D images, we choose to pool activations across the -axes of the secondary feature maps, and on 3D images we choose to pool across the and (slices) axes. In addition, feature maps in the same channel but different scales (i.e., through different and the same ) encode the same semantics and should be pooled together. Eventually, all feature maps in the -th channel, are pooled into a single value . Then following an -channel MSConv layer, all feature maps will be pooled into an

-dimensional feature vector to represent the whole image. As typically

, the downstream FC layer doing classification over this feature vector has only a small number of parameters and less prone to overfitting.

Fig.2 illustrates the top- pooling being applied to the -th channel feature maps in scales.

3 Experiments

3.1 Datasets

Three classification tasks involving four datasets were used for evaluation.

DME classification on OCT images. The following two 3D datasets acquired by Singapore Eye Research Institute (SERI) were used:

1) Cirrus dataset: 339 3D OCT images (239 normal, 100 DME). Each image has 128 slices in 512*1024. A 67-33% training/test split was used;

2) Spectralis dataset: 197 3D OCT images (60 normal, 137 DME). Each image has slices in 497*768. A 50-50% training/test split was used;

MMD classification on fundus images:

3) MMD dataset (acquired by SERI): 19,272 2D images (11,924 healthy, 631 MMD) in 900*600. A 70-30% training/test split was used.

MSI/MSS classification on CRC histology images:

4) CRC-MSI dataset [4]: 93,408 2D training images (46,704 MSS, 46,704 MSI) in 224*224. 98,904 test images (70,569 MSS, 28,335 MSI) also in 224*224.

3.2 Compared Methods

MIMS-CNN, 5 baselines and 6 ablated models were compared. Unless specified, all methods used the ResNet-101 model (without FC) pretrained on ImageNet for feature extraction, and top-

pooling () for feature aggregation.

MI-Pre. The ResNet feature maps are pooled by top- pooling and classified.

Pyramid MI-Pre. Input images are scaled to of original sizes, before being fed into the MI-Pre model.

MI-Pre-Conv. The ResNet feature maps are processed by an extra convolutional layer, and aggregated by top- pooling before classification. It is almost the same as the model in [6], except that [6] does patch-level classification and aggregates patch predictions to obtain image-level classification.

MIMS. The MSConv layer has 3 resizing operators that resize the primary feature maps to the following scales: . Two groups of kernels of different sizes were used.

MIMS-NoResizing. It is an ablated MIMS-CNN with all resizing operators removed. This is to evaluate the contribution of the resizing operators.

Pyramid MIMS. It is an ablated MIMS-CNN with all resizing operators removed, and the multi-scaledness is pursued with input image pyramids of scales . The MSConv kernels is configured identically as above.

MI-Pre-Trident [5]. It extends MI-Pre-Conv with dilation factors .

SI-CNN [3]. It is an ablated MIMS-CNN with the batchnorms and scalar multipliers removed from the MSConv layer.

FeatPyra-4,5. It is a feature pyramid network [7] that extracts features from conv4_x and conv5_x in ResNet-101, processes each set of features with a respective convolutional layer, and classifies the aggregate features.

ResNet34-scratch. It is a ResNet-34 model trained from scratch.

MIMS-patchcls and MI-Pre-Conv-patchcls. They are ablated MIMS and MI-Pre-Conv, respectively, evaluated on 3D OCT datasets. They classify each slice, and average slice predictions to obtain image-level classification.

3.3 Results

Table 1 lists the AUROC scores (averaged over three independent runs) of the 12 methods on the four datasets. All methods with an extra convolutional layer on top of a pretrained model performed well. The benefits of using pretrained models are confirmed by the performance gap between ResNet34-scratch and others. The two image pyramid methods performed significantly worse on some datasets, although they consumed twice as much computational time and GPU memory as other methods. MIMS-CNN almost always outperformed other methods.

Methods Cirrus Spectralis MMD CRC-MSI Avg.
MI-Pre 0.574 0.906 0.956 0.880 0.829
Pyramid MI-Pre 0.638 0.371 0.965 0.855 0.707
MI-Pre-Conv 0.972 0.990 0.961 0.870 0.948
MIMS-NoResizing 0.956 0.975 0.961 0.879 0.942
Pyramid MIMS 0.848 0.881 0.966 0.673 0.842
MI-Pre-Trident 0.930 1.000 0.966 0.897 0.948
SI-CNN 0.983 1.000 0.972 0.880 0.959
FeatPyra-4,5 0.959 0.991 0.970 0.888 0.952
ResNet34-scratch 0.699 0.734 0.824 0.667 0.731
MIMS 0.986 1.000 0.972 0.901 0.965
MIMS-patchcls 0.874 0.722 / / /
MI-Pre-Conv-patchcls 0.764 0.227 / / /
Table 1: Performance (in AUROC) of 12 methods on four image datasets.

The inferior performance of the two -patchcls models demonstrated the advantages of top- pooling for MIL. To further investigate its effectiveness, we trained MIMS-CNN on Cirrus with six MIL aggregation schemes: average-pooling (mean), max-pooling (max), top- pooling with , and an instance-based MIL scheme: max-pooling over slice predictions (max-inst).

As can be seen in Table 2, the other three aggregation schemes fell behind all top- schemes, and when increases, the model tends to perform slightly better. It confirms that embedding-based MIL outperforms instance-based MIL.

Methods mean max max-inst
AUROC on Cirrus 0.829 0.960 0.975 0.980 0.980 0.986 0.986
Table 2: Performance of seven MIL aggregation schemes on the Cirrus dataset.

4 Conclusions

Applying CNNs on medical images faces three challenges: datasets are of small sizes, annotations are often weak and ROIs are in varying scales. We proposed a framework to address all these challenges. This framework consists of two novel components: 1) a multi-scale convolutional layer on top of a pretrained CNN to capture scale-invariant patterns, which contains only a small number of parameters, 2) a top- pooling operator to aggregate feature maps in varying scales across multiple spatial dimensions to facilitate training with weak annotations within the Multiple Instance Learning framework. Our method has been validated on three classification tasks involving four image datasets.

References

  • [1]

    Girshick, R.: Fast R-CNN. In: Proceedings of the 2015 IEEE International Conference on Computer Vision (ICCV). pp. 1440–1448. ICCV ’15 (2015)

  • [2]

    Ilse, M., Tomczak, J.M., Welling, M.: Attention-based deep multiple instance learning. In: ICML. pp. 2132–2141. Proceedings of the 35th International Conference on Machine Learning, ICML 2018 (2018)

  • [3] Kanazawa, A., Sharma, A., Jacobs, D.W.: Locally scale-invariant convolutional neural networks. NIPS Workshop on Deep Learning and Representation Learning (2014)
  • [4] Kather, J.N.: Histological images for MSI vs. MSS classification in gastrointestinal cancer, FFPE samples, https://doi.org/10.5281/zenodo.2530835
  • [5] Li, Y., Chen, Y., Wang, N., Zhang, Z.: Scale-Aware Trident Networks for Object Detection. arXiv e-prints arXiv:1901.01892 (Jan 2019)
  • [6]

    Li, Z., Wang, C., Han, M., Xue, Y., Wei, W., Li, L.J., Fei-Fei, L.: Thoracic disease identification and localization with limited supervision. In: The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (June 2018)

  • [7] Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature pyramid networks for object detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 2117–2125 (2017)
  • [8] Rasti, R., Rabbani, H., Mehridehnavi, A., Hajizadeh, F.: Macular OCT classification using a multi-scale convolutional neural network ensemble. IEEE Transactions on Medical Imaging 37(4), 1024–1034 (April 2018)
  • [9] Shi, Z., Ye, Y., Wu, Y.: Rank-based pooling for deep convolutional neural networks. Neural Networks 83, 21 – 31 (2016)
  • [10] Tajbakhsh, N., Shin, J.Y., Gurudu, S.R., Hurst, R.T., Kendall, C.B., Gotway, M.B., Liang, J.: Convolutional neural networks for medical image analysis: Full training or fine tuning? IEEE Transactions on Medical Imaging 35(5), 1299–1312 (May 2016)
  • [11] Zhu, W., Lou, Q., Vang, Y.S., Xie, X.: Deep multi-instance networks with sparse label assignment for whole mammogram classification. In: Medical Image Computing and Computer Assisted Intervention - MICCAI 2017. pp. 603–611 (2017)

Appendix 0.A Model Configurations and Training Settings

For MIMS-CNN, on the two OCT datasets Cirrus and Spectralis, the convolutional kernels were specified as in “kernel size (number of output channels)” pairs. On the MMD dataset, the convolutional kernels were specified as . On the CRC-MSI dataset, as the images are of smaller resolution, smaller convolutional kernels were adopted.

In all models, the underlying ResNet layers were fine-tuned to reduce domain gaps between ImageNet and the training data. The learning rate of the underlying layers was set as half of the top layers to reduce overfitting.

To evaluate on OCT datasets, all models were first trained on Cirrus for 4500 iterations. Then on Spectralis, the trained models were first fine-tuned on the training set for 200 iterations, then evaluated on the test sets. When training on the Cirrus and Spectralis datasets, to increase data diversity, in each iteration slices were randomly chosen to form a batch from the 30 central slices of the input image.

On the CRC-MSI dataset, there is significant domain gap between the training and test images. Hence 2% of the original test images were moved to the training set (these 2domain gap. In particular, all models were trained on the training set for one epoch (LR=0.01), and then fine-tuned on the tuning set for two epochs (LR=0.01, 0.004).

When working on 3D images such as Cirrus and Spectralis, as the MSConv layer only involves 2D convolutions, all slices in a 3D image can be conveniently arranged as a 2D batch for faster processing.

Appendix 0.B A Possible Explanation of Why Image Pyramid Failed

We randomly selected 100 images from the CRC-MSI dataset msi, scaled them with three scale factors , and compared the respective ResNet-101 features with those of the original images. Their average Pearson correlations are listed in Table 1.

Scale Pearson correlation
2 0.261
0.75 0.451
0.5 0.257

Table 3: Pearson correlations of ResNet-101 features of the scaled images with the original image features.

One can see that the Pearson correlation dropped to around 0.26 when the images were resized to half or double of their original sizes. It indicates that ResNet captures very different features of them, although they are the same images in different scales. This feature decorrelation may prevent the downstream convolutional layer from learning scale-invariant patterns with the same kernels. Instead, the downstream convolutional layer needs to utilize more kernels to memorize the different features at each scale, and thus more training data is demanded. This observation may explain why image pyramid performed inferiorly in our experiments.

Appendix 0.C Gradient-based Localization and Visualization

In practice, medical doctors are keen to understand how a Computer Aided Diagnosis system reaches a certain diagnosis, so that they can make reliable decisions. To this end, we adopt a variant of a gradient-based visualization method grad-input,deepexplain,taylor, making it reliably locate and visualize suspicious slices and regions for doctors to examine.

Suppose we have trained a classifier that classifies medical images. To perform the visualization, we manually assign an interested class to have loss , and other classes have loss 0, then do a backward pass from the classifier to get the gradient at a pre-specified feature map layer. When the input image is 3D, the classifier will first determine which slices belong to the target class. For each of these slices, the back-propagation based visualization method is applied to generate a heatmap that quantifies the contributions of each pixel in the input 2D slice, and overlay the heatmap on the 2D slice as the output visualization image.

Denote the input slice as . The heatmap algorithm proceeds as follows:

  1. Collect the gradient tensor at the

    -th layer, denoted as , where indexes the feature channel, and indexes the height and width, respectively. The gradient tensor are of the same size of the input feature tensor to the -th layer. The input feature tensor to the -th layer is denoted as . Then perform an element-wise multiplication of the gradient tensor and the input feature tensor, and sum out the feature channel dimension, so as to get the input contribution matrix , where the element at the -th row and the -th column is:

  2. Set all the negative values in the input contribution matrix to 0, and scale to the values into the range and round to integer values. Accordingly, is converted to a non-negative contribution matrix, denoted as . is interpolated to a heatmap , which is of the same size as of the original image. highlights some image areas, and the highlight color intensities are proportional to the contributions of these areas to the final classification confidence of class .

  3. A weighted sum of and is taken as the visualization of the basis on which the classifier decides that the input image is in class :

The data flow of the visualization algorithm is illustrated in Fig.3. A heatmap of an OCT slice is presented in Fig.4. We can see that the Diabetic Macular Edema (DME) cyst is precisely localized in the OCT slice.

Figure 3: The flowchart of the gradient-based visualization algorithm.
Figure 4: Visualization result of an OCT slice with DME.

splncs04 references