Distill-to-Label: Weakly Supervised Instance Labeling Using Knowledge Distillation

07/26/2019 ∙ by Jayaraman J. Thiagarajan, et al. ∙ Lawrence Livermore National Laboratory ibm 0

Weakly supervised instance labeling using only image-level labels, in lieu of expensive fine-grained pixel annotations, is crucial in several applications including medical image analysis. In contrast to conventional instance segmentation scenarios in computer vision, the problems that we consider are characterized by a small number of training images and non-local patterns that lead to the diagnosis. In this paper, we explore the use of multiple instance learning (MIL) to design an instance label generator under this weakly supervised setting. Motivated by the observation that an MIL model can handle bags of varying sizes, we propose to repurpose an MIL model originally trained for bag-level classification to produce reliable predictions for single instances, i.e., bags of size 1. To this end, we introduce a novel regularization strategy based on virtual adversarial training for improving MIL training, and subsequently develop a knowledge distillation technique for repurposing the trained MIL model. Using empirical studies on colon cancer and breast cancer detection from histopathological images, we show that the proposed approach produces high-quality instance-level prediction and significantly outperforms state-of-the MIL methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

The success of supervisory algorithms, e.g. deep neural networks, in computer vision relies on the assumption that labels provide the required level of semantic description for the task at hand. For example, in object recognition 

[1], image-level labels are sufficient to identify generalizable patterns for each category, whereas in instance segmentation [2, 3], pixel-level annotations are required to produce detailed segmentation masks. However, in real-world applications, we are required to design effective models under weakly supervised settings. This is highly prevalent in medical image analysis [4, 5, 6], e.g. computational pathology, wherein an image is typically described by an overall diagnostic label, though patients with similar diagnosis can present vastly different signatures in their diagnostic images. Furthermore, annotating medical images is very expensive due to the required expertise of the clinicians, thus limiting the availability of labeled data [7].

Consequently, weakly supervised image segmentation methods have gained significant research interest [8]

, and the most successful approaches typically view filters in convolutional networks as object detectors and aggregate deep feature maps to infer class-aware visual evidence 

[9]. However, medical diagnosis presents unique challenges when compared to conventional object recognition tasks. More specifically, the factors leading to a diagnosis are not well-localized and are based on observing the co-occurrence of seemingly unrelated patterns in different parts of the image, thus making it challenging to infer generalizable features. As a result, recent approaches have explored the use of multiple instance learning (MIL) [10], wherein each sample is represented as a bag of instances with the assumption that only the bag-level labels are available. At their core, MIL methods attempt to aggregate features from the instances, while exploiting correlations between them, to produce an effective statistical descriptor. Given the complex nature of clinical decision process, more recent MIL methods have also included interpretability as an additional design objective. For example, Ilse et. al. include a learnable attention module in MIL to automatically identify key instances in a bag that are the most likely to trigger the observed bag-label [11].

Proposed Work. In this paper, we develop a novel approach to generate dense labels for clinical images using only weakly supervised data. To this end, we represent each image as a bag of patches and adopt an MIL formulation for predicting the bag-level diagnostic label. Given the strong dependence of MIL methods on the average number of instances in bags used for training, we propose to incorporate regularization based on virtual adversarial training (VAT). In general, VAT optimizes for the network to achieve information invariance under arbitrary perturbations to the input data [12, 13]. The form of perturbation we utilize is designed specifically for MIL – in addition to corrupting each of the instances, we employ perturbations to the bag, namely pruning instances and including uncorrelated noise instances. We demonstrate that the proposed VAT based regularization leads to highly robust predictive models when compared to existing approaches, particularly when the training data sizes are small. More importantly, we find this regularization to be crucial for improving instance-level predictions of an MIL model.

(a) Teacher model
(b) Student model - Instance Label Generator
Fig. 1: An illustration of the proposed approach for weakly supervised instance labeling in image classification tasks. Each image is represented as a bag of patches and a teacher model is trained to predict bag labels. Subsequently, we build a student model that jointly optimizes bag-level prediction with respect to the ground truth, and instance-level prediction based on knowledge distillation from the teacher.

Since we assume no access to the instance-level labels, we propose to utilize knowledge distillation

to leverage the pre-trained MIL model to design an effective instance-level label generator. In general, knowledge distillation (KD) involves transferring from one machine learning model (

teacher) to another (student[14], such that performance on the task under consideration improves, or the number of parameters required for solving a task reduces. In its classical formulation, the teacher is a high-capacity model with high performance, while the student is designed to be more compact. However, several recent works have demonstrated KD to be very effective in transferring to high-capacity student models [15]. In this work, we view knowledge distillation from a different perspective by identifying what parts of the teacher’s expertise (bag-level prediction) can be trusted while training the student model that can make robust predictions for single instances, i.e., bags of size . We introduce a novel formulation that jointly distills from bag-level and instance-level predictions of the teacher to produce highly effective instance label generators, using only weakly supervised data. To the best of our knowledge, this is the first known approach for using distillation to repurpose MIL models for instance labeling.

Using empirical studies on colon cancer and breast cancer detection from histopathological images, we demonstrate significant improvements to instance-level labeling quality, when compared to the state-of-the-art MIL technique in [11]. Interestingly, we find that the proposed VAT regularization leads to superior results on these challenging benchmark problems over existing MIL approaches, particularly in terms of instance-level prediction. Furthermore, incorporating our knowledge distillation strategy resulted in additional performance gains. Notably, our method provides improvements of and on the instance labeling accuracy for colon and breast cancer datasets respectively. In summary, we show that distillation can be an effective strategy for automatically generating dense labels, which can make models interpretable for clinical practice, and can support subsequent training of dense segmentation models without the need for dense annotations.

Ii Problem Setup

Following the formulation of multiple instance learning, we consider each sample to be a bag of instances, , with no specific dependency or ordering between them. Note that, in our setup, each instance corresponds to an image patch and the number of instances can vary between bags. The target variable corresponds to the label for a bag, whereas each denotes the label for each of the instance. Since our approach is weakly supervised, we assume there is no access to the instance-level labels during training. The goal is to build a teacher model based on MIL to perform bag-level prediction, and subsequently employ knowledge distillation to train an identical student model that can produce predictions for both and . Since the predictions in MIL should be invariant to the permutations of the instance order in the bags, we consider prediction functions of the following form.

Theorem 1.

A prediction function for a set of instances , is permutation-invariant to the elements in , iff, it can be expressed as follows:

where and are appropriate transformations.

Note that, this universal decomposition can be replaced by approximations including the max operation in lieu of the summation in the above theorem. Since we employ neural networks to process the raw data, the above theorem can be modified to operate with the latent features instead of , where denotes the mapping learned by the network. In that case, the transformation is simplified to identity and the transformation maps from the dimensional latent space to the target .

Iii Proposed Approach

In this section, we describe the proposed methodology for building an effective instance label generator using only weakly supervised data. An overview of our approach can be found in Figure 1.

Iii-a Teacher Model – Multiple Instance Learning

The teacher model is designed to predict the bag-level target using the set of instances using MIL. Our MIL formulation is similar to the one found in [11], wherein an embedding-level aggregation strategy is used to perform bag classification. As shown in Figure 1

(a), each of the instances in the input bag is processed using a convolutional neural network. The resulting features for each of the instances are then aggregated using a permutation-invariant pooling strategy before invoking the classifier module. Though the pooling can be carried out using simple strategies such as

instance average or instance max, they are not directly trainable. Instead, we perform a weighted average of instance features, where the weights are determined using an attention module implemented as a simple neural network. Further, we include a constraint that the weights must sum to , in order to ensure that the aggregation can be applied to bags of varying sizes.

Denoting the bag of features by , the aggregation strategy can be formally stated as

(1)

where the computation of weights is parameterized using a gated attention mechanism:

(2)

Here, , , are learnable parameters of the attention module. Note that, the attention computation is posed as a softmax function such that the constraint of weights summing to is satisfied. Finally, by using a combination of tanh and sigm activations enables both the inference of non-linear relations between instances and more importantly removes the troublesome linearity in the function. Note, the teacher model is trained using the cross entropy loss and optimized to maximize prediction performance on target variable .

Iii-B Virtual Adversarial Training for MIL

Though attention based aggregation can be superior to non-trainable pooling strategies, it is more prone to overfitting, particularly when the number of bags , or the number of instances in each bag are low. This limitation can be particularly severe when we utilize a pre-trained MIL model to produce instance-level labels by inputting bags with a single instance. Hence, we propose to incorporate virtual adversarial training

based regularization (VAT) to promote robustness with respect to arbitrary perturbations of the bags. VAT strategies have been successfully utilized in supervised and semi-supervised learning formulations 

[12], wherein locally-Lipschitz constraint is incorporated by enforcing classifier consistency around the norm-ball neighborhood of each sample.

In the context of MIL, we propose to enforce three consistency conditions for the teacher model: (i) must produce consistent predictions even when uncorrelated noise instances are included; (ii) must produce consistent predictions when each instance in a bag is perturbed around its norm-ball neighborhood; and (iii) the conditional entropy of predictions from

must be low even when a random subset of instances are arbitrarily excluded from a bag. Hence the overall loss function for the proposed MIL can be written as:

(3)

Here, is a training bag with instances. Now, denotes a modified bag with instances that includes uncorrelated noise instances, i.e.,

Here, corresponds to an uncorrelated noise instance drawn from a uniform random distribution defined in the same range as the instances in the training data. Next, in eq. (3), is obtained by perturbing each instance in within a norm ball of radius :

where denotes a random perturbation within the norm ball around . Finally , containing , instances is obtained by excluding a random subset of instances from :

where denotes a randomly selected subset of indices of instances from the bag .

While the three loss functions and are implemented using the binary cross entropy function, the final term

uses the conditional entropy of the output probabilities. Note, minimizing the conditional entropy ensures that the output probabilities from the model is concentrated around one of the classes. The choice of the hyper-parameters for this optimization are discussed in the experiments section.

(a) 10 instances per bag on average
(b) 20 instances per bag on average
Fig. 2: Effect of the proposed virtual adversarial training (VAT) on the performance of the teacher model. This experiment is carried out with the MNIST digits dataset, with varying number of bags and average number of instances per bag .

Effect of VAT. In order to understand the usefulness of including VAT into the MIL training process, we setup an empirical study with the MNIST handwritten digits data, similar to [11]. A bag is composed of a random number of images from the MNIST dataset, and is assigned a positive label if it contains one or more images of digit

. The number of instances in a bag is Gaussian-distributed with mean

(closest integer value is taken) and variance

. We study the effect of VAT by varying the number of bags and the average number of instances per bag respectively. We used a fixed set of

bags for validating the resulting models from each of the cases. All experiments were run with the LeNet5 architecture, trained using the RMSprop optimizer and the area under the ROC (receiver operating characteristic) curve is used as the evaluation metric. Figure

2 shows the bag-level prediction performance on the validation set for the different cases. As it can be observed, the proposed VAT regularization leads to significant performance improvements for varying bag sizes, particularly when is low, thus resulting in highly consistent predictive models.

Iii-C Student Model – Instance Label Generation

Since we do not have access to ground truth labels for training the instance-level classifier, we propose to leverage knowledge from the teacher model, i.e. bag-level classifier, using a knowledge distillation approach. More specifically, we build a student model that has an architecture similar to the teacher, but with the goal of improving its prediction quality for the instances. An interesting aspect of the proposed formulation is that we do not need require a separate predictor for the instance-level, since a single instance can be represented as a bag of size . Following our notations defined earlier, for a given bag comprising the instances , we denote the features obtained using the teacher and student models as and respectively. Given the latent representations both the teacher and student models, we perform aggregation using the attention module and employ an output linear

layer to produce the logits, followed by the

softmax function to obtain the probability of assignment for each of the classes.

Fig. 3: Colon cancer dataset - Convergence behavior of the student model training process. Note, the teacher was pre-trained using MIL along with the proposed VAT regularization. The total loss showed here is obtained using the objective in eq. (5)

.

(a) Colon cancer dataset
(b) Breast cancer dataset
Fig. 4: Illustration of instances from the two histopathological datasets used in our experiments.
Fold Accuracy F1 Score AUROC
MIL [11]
Teacher
Student
MIL [11]
Teacher
Student
MIL [11]
Teacher
Student
0 0.80 0.83 0.88 0.47 0.61 0.75 0.78 0.82 0.87
1 0.79 0.84 0.90 0.71 0.79 0.81 0.89 0.92 0.92
2 0.69 0.79 0.83 0.48 0.64 0.75 0.87 0.90 0.91
3 0.84 0.90 0.92 0.68 0.87 0.90 0.89 0.92 0.92
4 0.83 0.88 0.90 0.71 0.75 0.78 0.88 0.91 0.92
Average 0.79 0.85 0.89 0.61 0.73 0.80 0.86 0.89 0.91
TABLE I: Instance-level label prediction performance for the colon cancer dataset. We show results from fold cross validation obtained using the baseline MIL [11], the proposed teacher and student models. The average performance across the folds is also included for each case.
Fold Accuracy F1 Score AUROC
MIL [11]
Teacher
Student
MIL [11]
Teacher
Student
MIL [11]
Teacher
Student
0 0.29 0.59 0.64 0.19 0.47 0.53 0.51 0.71 0.77
1 0.21 0.54 0.60 0.23 0.51 0.55 0.61 0.72 0.76
2 0.32 0.63 0.66 0.29 0.62 0.64 0.63 0.77 0.78
3 0.27 0.58 0.65 0.31 0.63 0.68 0.67 0.73 0.74
4 0.28 0.61 0.63 0.27 0.58 0.60 0.59 0.64 0.70
Average 0.27 0.59 0.64 0.26 0.56 0.60 0.60 0.71 0.75
TABLE II: Instance-level label prediction performance for the breast cancer dataset. We show results from fold cross validation obtained using the baseline MIL [11], the proposed teacher and student models. The average performance across the folds is also included for each case.

Formally, for an input , the teacher model produces the output , where the aggregated representation is obtained using eq. (1). Similarly, the student produces the output . In the original formulation of knowledge distillation [14], the student is trained such that its output is similar to the teacher’s output , as well as to the true labels, in order to improve the bag-level classification performance. In practice, the output probability of the teacher is smoothed with a temperature

to soften the estimate, and provide additional information about the challenges during training.

(4)

However, in our setup, the goal is not improve the bag-level classification, but instead the instance-level prediction, without accessing ground truth labels entirely. To this end, we propose this optimization objective:

(5)

Here, denotes the output from for the instance in , and is the KL-divergence between two Bernoulli variables:

(6)

In the above objective, the first term denotes the binary cross entropy between the predicted bag-level labels and the ground truth, while the second term performs distillation from to using the softened softmax probabilities. The last two terms are based on the instance-level predictions obtained by passing each instance () independently through the teacher and the student models. Similar to the formulation in eq. (3), denotes the conditional entropy. Through this controlled knowledge distillation at bag and instance levels, we show that our approach achieves significant gains in instance-level prediction even without accessing supervisory labels.

Iv Experiments

In this section, we describe the two applications considered, discuss the experiment setup, and present a detailed performance evaluation. All experiments were carried based on a standard convolutional neural network comprising a stack of convolution layers with ReLU activation followed by max pooling, and

fully connected layers, while the classifier contains a single fully connected layer. For comparison, we consider the state-of-the-art baseline [11], that employs attention-based aggregation to perform MIL111Implementation for attention-based MIL was obtained from a. https://github.com/utayao/Atten_Deep_MIL and b. https://github.com/AMLab-Amsterdam/AttentionDeepMIL.

. All models were implemented using PyTorch 

[16].

Iv-a Dataset Description

We tested our algorithm on cancer histopathology datasets of the breast and colon. Histopathology images are especially relevant because the nucleus identification of diseased cells is a highly specialized and time-consuming task. Consequently, it would be hugely beneficial if the relevant instances of tumorous patches could be identified accurately, which would in turn greatly reduce the burden on clinicians.

Iv-A1 Color Cancer Dataset

This dataset is comprised of Hematoxylin and eosin stain (H&E) histopathology ( pixels) images of colorectal cancer [17]. Out of the 29,756 nuclei centers identified, 22,444 nuclei were given an associated class label, i.e. epithelial, inflammatory, fibroblast, and miscellaneous. To simplify our problem for the multiple instance learning we chose to focus on the binary problem of detecting the epithelial cell nucleus. This is motivated by the findings in [18] that mutations expressed in the epithelial cells of the colon are often the starting point for cancerous mutation. In this case, each whole slide image (bag) was divided into image patches (instances) (See Fig. 4 for examples). Every bag containing epithelial cells were marked as positive. Further, the ground truth labels for instances, i.e. patches, were formed by studying if the instance contained epithelial nuclei. Note, these instance labels were used only for evaluation and not utilized during training.

Iv-A2 Breast Cancer Dataset

This dataset consists of H&E stained histopathology images ( pixels) [19] with each of the cells marked as benign or malignant. Further, each whole slide cell image (bag) is also assigned a positive label if it contained malignant cancer cells (instances). Similar to the previous case, the images were divided into patches to form the instances (See Fig. 4 for examples of both). Note, with those instances with 75% or more of white pixels were considered as background and discarded from our study.

Both experiments were carried out on a train/valid/test split of with fold cross validation.The number of convolution layers in the feature extractor part of the network was fixed at for the colon cancer dataset and at for the breast cancer dataset. We used the Adam optimizer with learning rate set to and used batch size . The hyper-parameters for the student model training in eq. (5) were set at the following values for both cases: . On the other hand the hyper-parameters for teacher model training in eq. (3) were tuned specifically for the two datasets. This was because the breast cancer dataset presented a severe class imbalance (benign vs malignant) at the instance-level. More specifically, for the colon cancer dataset, we used the values , whereas for the latter . Furthermore, for the breast cancer case, it was beneficial to include multiple random realizations (set to ) of for the VAT regularization in eq. (3). In other words the loss term was constructed by averaging the conditional entropy obtained using bags containing different subsets of instances.

Iv-B Results

Tables I and II present the cross validated results for the two datasets on the instance-level label generation. As mentioned earlier, we report the results for the state-of-the-art MIL method in [11], the teacher and student models from the proposed approach. For evaluation, we adopted the widely adopted metrics, accuracy, F1 score and area under the receiver operating curves (AUROC). The first striking observation is that the proposed approach consistently produces significant performance gains, in terms of all three metrics, over the baseline method. For example, on the colon cancer dataset, our approach provides improvements of in accuracy and in F1 score. More interestingly, with the challenging breast cancer dataset, we find that the baseline MIL performs poorly than even random predictions. In contrast, the proposed strategies lead to more reliable predictions, producing a boost of in the prediction accuracy score. This clearly evidences the effectiveness of the VAT regularization and our novel distillation formulation, in the context of multiple instance learning, towards generating dense labels from weak supervision.

V Conclusions

In this paper we presented a method for producing instance-level labels with only weak supervision, i.e. image-level labels, in medical image analysis. Our method relies on using a novel virtual adversarial training regularization to MIL and repurposing a pre-trained MIL model for instance classification. Through empirical validation on two very challenging histopathology cancers of the colon and the breast, we showed that the proposed method consistently outperformed the state-of-the-art MIL. This presents a huge opportunity to save exhaustive efforts in annotating clinical data, and to more importantly enable weakly supervised data augmentation for data-driven inferencing. Future work will include extending this methodology to multi-class predictions, and developing scalable techniques to produce pixel-level segmentation.

References

  • [1] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton,

    Imagenet classification with deep convolutional neural networks,”

    in Advances in neural information processing systems, 2012, pp. 1097–1105.
  • [2] Anna Khoreva, Rodrigo Benenson, Jan Hosang, Matthias Hein, and Bernt Schiele, “Simple does it: Weakly supervised instance and semantic segmentation,” in

    Proceedings of the IEEE conference on computer vision and pattern recognition

    , 2017, pp. 876–885.
  • [3] Guanbin Li, Yuan Xie, Liang Lin, and Yizhou Yu, “Instance-level salient object segmentation,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017, pp. 2386–2395.
  • [4] Melih Kandemir and Fred A Hamprecht, “Computer-aided diagnosis from weak supervision: A benchmarking study,” Computerized medical imaging and graphics, vol. 42, pp. 44–50, 2015.
  • [5] Yan Xu, Jun-Yan Zhu, I Eric, Chao Chang, Maode Lai, and Zhuowen Tu, “Weakly supervised histopathology cancer image segmentation and classification,” Medical image analysis, vol. 18, no. 3, pp. 591–604, 2014.
  • [6] Rushil Anirudh, Jayaraman J Thiagarajan, Timo Bremer, and Hyojin Kim, “Lung nodule detection using 3d convolutional neural networks trained on weakly labeled data,” in Medical Imaging 2016: Computer-Aided Diagnosis. International Society for Optics and Photonics, 2016, vol. 9785, p. 978532.
  • [7] Alexander Ratner, Stephen H Bach, Henry Ehrenberg, Jason Fries, Sen Wu, and Christopher Ré, “Snorkel: Rapid training data creation with weak supervision,” Proceedings of the VLDB Endowment, vol. 11, no. 3, pp. 269–282, 2017.
  • [8] Yanzhao Zhou, Yi Zhu, Qixiang Ye, Qiang Qiu, and Jianbin Jiao, “Weakly supervised instance segmentation using class peak response,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2018, pp. 3791–3800.
  • [9] Kunpeng Li, Ziyan Wu, Kuan-Chuan Peng, Jan Ernst, and Yun Fu, “Tell me where to look: Guided attention inference network,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2018, pp. 9215–9223.
  • [10] Yan Xu, Tao Mo, Qiwei Feng, Peilin Zhong, Maode Lai, I Eric, and Chao Chang,

    Deep learning of feature representation with multiple instance learning for medical image analysis,”

    in 2014 IEEE international conference on acoustics, speech and signal processing (ICASSP). IEEE, 2014, pp. 1626–1630.
  • [11] Maximilian Ilse, Jakub Tomczak, and Max Welling, “Attention-based deep multiple instance learning,” in International Conference on Machine Learning, 2018, pp. 2132–2141.
  • [12] Rui Shu, Hung H Bui, Hirokazu Narui, and Stefano Ermon, “A dirt-t approach to unsupervised domain adaptation,” in Proc. 6th International Conference on Learning Representations, 2018.
  • [13] Xu Ji, João F Henriques, and Andrea Vedaldi, “Invariant information distillation for unsupervised image segmentation and clustering,” arXiv preprint arXiv:1807.06653, 2018.
  • [14] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, “Distilling the knowledge in a neural network,” arXiv preprint arXiv:1503.02531, 2015.
  • [15] Tommaso Furlanello, Zachary Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar, “Born-again neural networks,” in International Conference on Machine Learning, 2018, pp. 1602–1611.
  • [16] Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer, “Automatic differentiation in pytorch,” 2017.
  • [17] Korsuk Sirinukunwattana, Shan e Ahmed Raza, Yee-Wah Tsang, David RJ Snead, Ian A Cree, and Nasir M Rajpoot, “Locality sensitive deep learning for detection and classification of nuclei in routine colon cancer histology images.,” IEEE Trans. Med. Imaging, vol. 35, no. 5, pp. 1196–1206, 2016.
  • [18] Lucia Ricci-Vitiani, Dario G Lombardi, Emanuela Pilozzi, Mauro Biffoni, Matilde Todaro, Cesare Peschle, and Ruggero De Maria, “Identification and expansion of human colon-cancer-initiating cells,” Nature, vol. 445, no. 7123, pp. 111, 2007.
  • [19] Elisa Drelie Gelasca, Jiyun Byun, Boguslaw Obara, and B.S. Manjunath, “Evaluation and benchmark for biological image segmentation,” in IEEE International Conference on Image Processing, Oct 2008.