Globally-Aware Multiple Instance Classifier for Breast Cancer Screening

06/07/2019 ∙ by Yiqiu Shen, et al. ∙ 5

Deep learning models designed for visual classification tasks on natural images have become prevalent in medical image analysis. However, medical images differ from typical natural images in many ways, such as significantly higher resolutions and smaller regions of interest. Moreover, both the global structure and local details play important roles in medical image analysis tasks. To address these unique properties of medical images, we propose a neural network that is able to classify breast cancer lesions utilizing information from both a global saliency map and multiple local patches. The proposed model outperforms the ResNet-based baseline and achieves radiologist-level performance in the interpretation of screening mammography. Although our model is trained only with image-level labels, it is able to generate pixel-level saliency maps that provide localization of possible malignant findings.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

page 7

page 10

page 11

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

As the second leading cause of cancer death among women in the US, breast cancer has been studied for decades. While studies have shown that screening mammography has significantly reduced breast cancer mortality, it is an imperfect tool [10]

. To address its limitations, convolutional neural networks (CNN) designed for computer vision tasks on natural images have been applied. For instance, VGGNet

[13]

, designed for object classification on ImageNet

[2], has been applied to breast density classification [16] and Faster R-CNN [11] has been adapted to localize suspicious findings in mammograms [12]. We refer the readers to [5]

for a comprehensive review of prior work on machine learning for mammography.

The compatibility between the models designed for natural images and the distinct properties of medical images remains an open question. Firstly, medical images are usually of a much higher resolution than natural images, so deep CNNs that work well for natural images may not be applicable to medical images due to GPU memory constraints. Moreover, for many applications, regions of interest (ROI) in medical images, such as lesions and calcifications, are proportionally smaller in size in relation to the entire image, compared to those in natural images. Fine details, often only a few pixels in size, along with global information such as the density of tissue determine the labels. In addition, while natural images can be aggressively downsampled and preserve the information necessary for classification, significant amounts of information could be lost from downsampling medical images, making the correct diagnosis unattainable.

Contributions.  In this work, we address the aforementioned issues by proposing a novel model for the classification of medical images. The proposed model preserves inputs in high resolution while being able to focus on small ROIs. Unlike existing approaches that rely on pixel-level lesion annotations [17, 12], our model only requires image-level supervision and is able to generate pixel-level saliency maps that highlight suspicious lesions. In addition, our model is equipped with an attention-based Multiple Instance Learning (MIL) network that enables it to select informative image patches, making the classification process interpretable. When trained and evaluated on a large dataset (1 million images) of high-resolution breast cancer screening exams, our model outperforms a ResNet-based baseline [17] and achieves radiologist-level performance.

Related Works.  Existing methods have approached the breast cancer detection problem using techniques such as MIL [20], generative adversarial networks [14], and 3D CNNs [15]. Our model is inspired by works on weakly supervised object detection. Recent progress demonstrates that CNN classifiers, trained with image-level labels, are able to perform semantic segmentation at the pixel level [3, 4, 19]. This is achieved in two steps. First, a backbone CNN converts the input image to a saliency map (SM) which highlights the discriminative regions. A global pooling operator then collapses the SM into scalar predictions which makes the entire model trainable end-to-end. To make an image-level prediction, most existing models rely on the SM which often neglects fine-grained details. In contrast, our model also leverages local information from ROI proposals using a dedicated patch-level classifier. In Section 3.2, we empirically demonstrate that the ability to focus on fine visual detail is important for classification.

Figure 1: Overall architecture of GMIC. The input image is annotated with true ROIs (red). The patch map indicates positions of ROI patches (blue squares) on the input.

2 Methods

We formulate our task as a multi-label classification. Given a grayscale high-resolution image , we would like to predict the label , where denotes whether class is present in . As shown in Figure 1, the Globally-Aware Multiple Instance Classifier (GMIC) consists of three modules: (i) The localization module processes to generate a SM, denoted by , which indicates approximate localizations of ROIs. (ii) The detection module uses to retrieve patches from as refined proposals for ROIs. (iii) We use an MIL framework to aggregate information from retrieved patches and generate the final prediction.

2.1 Localization Module

As illustrated in Figure 1, the localization module first uses a CNN to extract relevant features from . Due to memory constraints, input images are usually down-sampled before [19]. For mammograms, however, down-sampling distorts important visual details such as lesion margins and blurs small ROIs. In order to retain the original resolution, we parameterize as a ResNet-22 [17] and remove its global average pooling and fully connected layers. This model has fewer filters than the original ResNet architectures in each layer in order to process the image at the full resolution while keeping GPU memory consumption manageable. The feature maps obtained after the last residual block are transformed into the SM using convolution with sigmoid non-linearity. Each element of , , denotes a score that indicates the contribution of spatial location towards classifying the input as class .

2.2 Detection Module

Due to its limited width, is only able to provide coarse localization. We propose using image patches as ROI proposals to complement the localization module with fine-grained detail. We designed a greedy algorithm (Algorithm 1) to retrieve proposals for ROIs, , from the input . In our experiments, we set , and . The reset rule in line 12 explicitly ensures that extracted ROI proposals do not significantly overlap with each other.

0:  , ,
0:  
1:  
2:  for each class  do
3:     
4:  end for
5:  
6:   denotes an arbitrary rectangular patch on
7:  
8:  for each  do
9:     
10:      position of in
11:     
12:     
13:  end for
14:  return  
Algorithm 1 Retrieve the ROIs

2.3 Multiple Instance Learning Module

Since ROI patches are retrieved using a coarse saliency map, the information relevant for classification carried in each patch varies significantly. To address this, we apply an MIL framework to aggregate information from ROI patches. A detection network is first applied on every instance

and converts them into feature vectors

. We use in all experiments. We parameterize as a ResNet-18 [7] (pretrained on ImageNet [2]). Since not all ROI patches are relevant to the prediction, we use the Gated Attention Mechanism [8] to let the model select informative patches. The selection process yields an attention-weighted representation , where attention score indicates the relevance of each patch . The representation is then passed to a fully connected layer with sigmoid activation to generate a prediction , where are learnable parameters.

2.4 Training

It is difficult to make this model trainable end-to-end. Since the detection module is not differentiable, the gradient from the training loss will not flow into the localization module. Inspired by [3], we circumvent this problem with a scheme that simultaneously trains the localization module and the MIL module. An aggregation function is designed to map the SM for each class into a prediction . The design of has been extensively studied [4]. Global Average Pooling (GAP) would dilute the prediction as most of the spatial locations in

correspond to background and provide little training signal. On the other hand, Global Max Pooling (GMP) only backpropagates gradient into a single spatial location which makes the learning process slow and unstable. In our work, we use a soft balance between GAP and GMP :

, where denotes the set containing locations of top values in , and is a hyper-parameter. The prediction

is a valid probability as

. To fine-tune the SM and prevent the localization module from highlighting irrelevant areas, we impose the following regularization on : , where

is a hyper-parameter. In summary, the loss function used to train the entire model is:

(1)

where is the binary cross-entropy and is a hyper-parameter. In the inference stage, the prediction is computed as .

3 Experiments

The proposed model is evaluated on the task of predicting whether any benign or malignant findings are present in a mammography exam. The dataset includes 229,426 exams (1,001,093 images) [18]. Across the entire data set, malignant findings were present in 985 breasts and benign findings in 5,556 breasts. As shown in Figure 2, each exam contains four grayscale images () representing two standard views (CC and MLO) for both left and right breasts. A label is associated with each breast where () denotes the presence or absence of a benign/malignant finding in a breast. All findings are confirmed by a biopsy. In each exam, two views on the same breast share the same label. A small fraction () of the data are associated with pixel-level segmentation where if pixel belongs to the findings of class . In all experiments, segmentations are only used for evaluation.

R-CC L-CC
R-MLO L-MLO
Figure 2: Example exam for a patient. Benign findings are highlighted in green.

3.1 Experimental Set-up and Evaluation Metrics

We adopt the same pre-processing as [17]. The dataset is divided into disjoint training (186,816), validation (28,462) and test (14,148) sets. In each iteration, we train the model using all exams that contain at least one benign or malignant finding and an equal number of randomly sampled negative exams. All images are cropped to pixels and normalized. The training loss is optimized using Adam [9]. We optimize the hyper-parameters using random search [1]. Specifically, we search on a logarithmic scale for the learning rate , the regularization weight , the regularization exponent , and the pooling threshold

. We train 100 separate models, each for 40 epochs. We pretrain

on BI-RADS labels as described in [6]. For classification performance, we report the area under the ROC curve (AUC) on the breast-level. As our model generates a prediction for each image and each breast is associated with two images (CC and MLO), we define breast-level predictions as the average of the two image-level predictions. To quantitatively evaluate our model’s localization ability, we use the continuous F1 score, where precision (P) and recall (R) are defined as: and , and denotes the segmentation label and is the SM for class . On the test set, these metrics are averaged over images for which segmentation labels are available.

Figure 3: Reader study

Table 1: AUCs of the baseline model and a few variations of GMIC
Model Malignant Benign
ResNet-22 [17] 0.827 0.731

 

GMIC-loc 0.885 0.777
GMIC-mil 0.878 0.766
GMIC-noattn 0.823 0.726
GMIC-random 0.757 0.692
GMIC-loc-random 0.889 0.776

 

GMIC 0.900 0.784

3.2 Classification Performance

In this section, we report the average test performance of the 5 models from the hyper-parameter search that achieved the highest validation AUC on malignant classification (referred to as top-5). In order to understand the impact of each module, we evaluate GMIC under a number of settings. GMIC-loc uses as its predictions and GMIC-mil uses . As shown in Table 1, both variants of GMIC outperform the baseline, especially in predicting malignancy. The full model, GMIC, using the aggregated prediction , attains higher AUC than GMIC-loc and GMIC-mil. We attribute this improvement to the synergy of local and global information. To empirically validate this conjecture, we test three additional models: GMIC-noattn assigns equal attentions on each ROI patch; GMIC-random outputs prediction by applying MIL module on patches randomly selected from the input image; GMIC-loc-random combines the predictions from GMIC-loc and GMIC-random . As Table 1 shows, GMIC-noattn is less accurate than GMIC-mil, suggesting that the attention mechanism in MIL module is essential for classification. Moreover, GMIC-random is weaker than GMIC-mil and GMIC-loc-random does not demonstrate any performance gain on top of GMIC-loc. These observations confirm our hypothesis that applying the MIL module on high-resolution ROI patches supplements the global information extracted by SM and refines predictions.

To evaluate the clinical value of our model, we compare the performance of GMIC with radiologists using data from the reader study described in [17]

. This reader study includes 14 radiologists, each providing a probability estimate of malignancy for 720 screening exams (1440 breasts). The radiologists were only shown images for each exam with no other data. To further improve our predictions, we ensemble the predictions of the

top-5 models. As shown in Figure 3, the ensemble GMIC model achieves higher AUC (0.876) than the average (0.778) and the most accurate (0.860) among the 14 readers. GMIC obtains a marginally worse performance in the reader study than in the test set because the reader study contains a much larger portion of positive samples.

We also assess the efficacy of a human-machine hybrid, whose predictions are simply the average of predictions from the radiologists and the model. The human-machine hybrid achieves an AUC of 0.883. These results suggest that our model captures different aspects of the task compared to radiologists and can be used as a tool to assist in interpreting breast cancer screening exams.

Figure 4: Visualization of three examples. Input images are annotated with segmentation labels (green=benign, red=malignant). ROI patches are shown with their attention scores.

3.3 Localization Performance

We select the model with the highest validation F1 for malignancy localization. At the inference stage, we upsample SMs using nearest neighbour interpolation to match the resolution of the segmentation labels. The average continuous F1/precision/recall on test set is 0.207/0.288/0.254 for malignant and 0.133/0.135/0.224 for benign. In addition, the best localization model also achieves a classification AUC of 0.886/0.78 for malignant/benign classes.

To better understand our model’s behavior, we visualize SMs of three samples selected from the test set in Figure 4. In the first two examples, the SMs are highly activated on the true lesions, suggesting that our model is able to detect suspicious lesions without pixel-level supervision. Moreover, the attention is highly concentrated on ROI patches that overlap with the annotated lesions. In the third example, the malignant SM only highlights parts of a large malignant lesion. This behavior is related to the design of : a fixed pooling threshold cannot be optimal for all sizes of ROI. Furthermore, this observation also illustrates that while human experts are asked to annotate the entire lesion, CNNs tend to emphasize only the most informative part.

4 Conclusion

We present a novel model for breast cancer screening exam classification. The proposed method uses the input in its original resolution while being able to focus on fine details. Moreover, our model also generates saliency maps that provide additional interpretability. Evaluated on a large mammography dataset, GMIC outperforms the ResNet-based baseline and generates predictions that are as accurate as radiologists. Given its generic design, the proposed model is widely applicable to other image classification tasks. Our future research will focus on designing joint training mechanisms that would enable GMIC to improve its localization using error signals from the MIL module.

Acknowledgments

The authors would like to thank Catriona C. Geras for correcting earlier versions of this manuscript and Joe Katsnelson and Mario Videna for supporting our computing environment. We also gratefully acknowledge the support of Nvidia Corporation with the donation of some of the GPUs used in this research. This work was supported in part by grants from the National Institutes of Health (R21CA225175 and P41EB017183).

References

Appendix 0.A Additional Visualizations

Figure 5: Additional visualizations of benign examples. Input images are annotated with segmentation labels (green=benign, red=malignant). ROI patches are shown with their attention scores.
Figure 6: Additional visualizations of malignant examples. Input images are annotated with segmentation labels (green=benign, red=malignant). ROI patches are shown with their attention scores.