Categorical Relation-Preserving Contrastive Knowledge Distillation for Medical Image Classification

07/07/2021 ∙ by Xiaohan Xing, et al. ∙ City University of Hong Kong 0

The amount of medical images for training deep classification models is typically very scarce, making these deep models prone to overfit the training data. Studies showed that knowledge distillation (KD), especially the mean-teacher framework which is more robust to perturbations, can help mitigate the over-fitting effect. However, directly transferring KD from computer vision to medical image classification yields inferior performance as medical images suffer from higher intra-class variance and class imbalance. To address these issues, we propose a novel Categorical Relation-preserving Contrastive Knowledge Distillation (CRCKD) algorithm, which takes the commonly used mean-teacher model as the supervisor. Specifically, we propose a novel Class-guided Contrastive Distillation (CCD) module to pull closer positive image pairs from the same class in the teacher and student models, while pushing apart negative image pairs from different classes. With this regularization, the feature distribution of the student model shows higher intra-class similarity and inter-class variance. Besides, we propose a Categorical Relation Preserving (CRP) loss to distill the teacher's relational knowledge in a robust and class-balanced manner. With the contribution of the CCD and CRP, our CRCKD algorithm can distill the relational knowledge more comprehensively. Extensive experiments on the HAM10000 and APTOS datasets demonstrate the superiority of the proposed CRCKD method.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

With the recent progress of deep learning techniques, computer-aided diagnosis has shown human-level performance for some diseases and reduced the workload of human screening

[8]. However, the amount of training data for most diseases are limited, making the deep models prone to overfit the training data [26, 30]

. To tackle the over-fitting issue, many learning schemes have been proposed, such as transfer learning

[3, 14], dropout [15], and label-smoothing regularization [16, 10]. Another effective solution is knowledge distillation (KD), where a trained teacher model provides soft labels that supply secondary information to the student model, thus relieving the over-fitting problem [5].

Among existing KD frameworks, the self-ensembling mean-teacher [17]

is widely studied in medical image classification. Updated by the temporal moving average of the student model, the mean-teacher produces feature distribution and predictions that are robust to different perturbations, thus showing higher generalizability even with the limited amount of data. Therefore, to train a student with high accuracy and generalizability, it is crucial to maximally distill knowledge from the mean-teacher. Some researchers distilled the individual sample knowledge from the teacher, such as output logits

[18] and feature maps [23]. Recently, Liu et al. [9] took the relation among mini-batch samples as distilling targets and demonstrated its superiority over the individual KD counterparts.

However, most of the existing KD methods [18, 23, 22, 9, 1, 12, 6] are directly transferred from the computer vision field, without fully considering the following challenges in the medical domain. First, the intra-class variation and inter-class similarity in medical datasets are more severe than those in the natural domain. In specific, two types of diseases may exhibit extremely similar color, shape, and texture, making them less distinguishable than two classes of natural images (dogs vs. cats). Second, medical image datasets usually suffer from severe class imbalance since some diseases are common while others are rare. Due to this, the knowledge distilled by current KD may be biased towards the majority class and has insufficient representation for the minority classes.

To tackle the above-mentioned challenges, we propose a novel distillation approach, termed Categorical Relation-preserving Contrastive Knowledge Distillation (CRCKD), for medical image classification. Built upon the mean-teacher framework, we propose two novel KD paradigms, i.e., Class-guided Contrastive Distillation (CCD) and Categorical Relation Preserving (CRP), to distill the rich structural knowledge from the mean-teacher model. The main contributions are summarized as: (1) We propose the CCD module to pull closer positive image pairs from the same class in the teacher and student models, while pushing apart negative image pairs from different classes. With this regularization, the feature distribution of the student model exhibits higher intra-class similarity and inter-class variance. (2) To distill more robust and fine-grained relational knowledge, we propose the CRP loss that utilizes category centroids as anchors to regulate each sample’s relation with different categories. Compared with previous relational KD [21, 9] that adopts images in a mini-batch as anchors, the category centroids in our method serve as more reliable anchors and naturally mitigate the class imbalance problem. (3) Experimental results on HAM10000 and APTOS datasets demonstrated the efficacy of our proposed CRCKD method, as well as the superiority of the CCD and CRP over existing relational KD paradigms.

2 Method

Figure 1: Overview of the proposed Categorical Relation-preserving Contrastive Knowledge Distillation (CRCKD) framework. The student model is supervised by the weighted cross-entropy loss and knowledge distillation losses (, , and ). The dashed orange lines indicate the back-propagation paths of the gradients. [Best viewed in color]

Fig. 1

illustrates our proposed CRCKD framework. It consists of a student model and a mean-teacher model. The student model is optimized by stochastic gradient descent while the teacher weights

are updated by the exponential moving average (EMA) of the student weights . Given an image , it is augmented twice by adding different perturbations (i.e., random flipping and affine transformation) and produces two different images and . Taking the corresponding augmented image as an input, the student model and teacher model extract feature representations and

, and predict output probabilities

and , respectively. The student’s prediction is supervised by the weighted cross-entropy loss and the KL divergence with . To constrain the consistency between the student’s and teacher’s structural information, we propose the loss that pulls the positive pairs from the same class while pushing way the negative feature pairs from different classes. Furthermore, in each model, we construct a novel relation graph () between the sample feature and the category centroids. The loss is proposed to regularize the consistency between the relation graphs of the teacher and student models.

2.1 Class-guided Contrastive Distillation (CCD)

Recently, Contrastive Representation Distillation (CRD) has achieved impressive distillation performance via incorporating contrastive learning into the conventional KD paradigm. Despite its appealing results, one major shortcoming of CRD is that it will mistakenly push apart images from the same class in the feature space, thus unavoidably enlarging the intra-class variance.

To tackle this dilemma, we propose a novel Class-guided Contrastive Distillation (CCD) which utilizes the class-label information to guide the CRD. Specifically, CCD regards samples from the same class as positive pairs and pulls their representations closer, while taking images from different classes as negative pairs and pushing their representations apart. As depicted in Fig. 1, two different augmentations of an image are processed by the teacher and student models to generate feature embeddings and . Then, the embeddings are projected to and through and , where and

denote the trainable parameters in the projection layers (which are instantiated as linear transformation). The projected embeddings

and are mapped to the unit hypersphere through normalization, thus their similarity can be measured by the inner product. Inspired by [19], for each sample in the student model, we define the CCD loss as

(1)

where is the temperature that controls the concentration level, and denote the number of positive samples and negative samples, respectively. is the cardinality of the dataset. By minimizing , the student model is optimized to produce feature representations that are more similar with the positive pairs while differing from the negative samples in the teacher model. Similarly, the CCD loss for the teacher model is defined as

(2)

It is noteworthy that the loss merely updates the projection head of the teacher model. The CCD loss regularizes the consistency of teacher and student’s inter-sample structural knowledge by enlarging intra-class similarity and inter-class divergence between these two models, thus yielding performance gains.

As suggested in [19, 13], a large number of negative samples is required to ensure the performance of contrastive learning. To get access to a large number of negative samples and avoid large batch size, we follow Wu et al.[24] to construct a memory bank that stores the -dimensional embeddings of all training images. We denote the memory bank for the student (teacher) model as (). As shown in Fig. 1, in each forward propagation, only the features of the query samples in the mini-batch are updated while all other samples retain their embeddings at the last step. For each query sample in the student model, the positive and negative samples in Eq. 1 are randomly selected from the teacher’s memory . Similarly, the Eq. 2 is computed in a similar manner.

2.2 Categorical Relation Preserving (CRP)

Although the proposed CCD can regularize the structural consistency of the teacher and student, the regularization is relatively coarse since each sample pushes apart negative image pairs from different classes without differentiation. However, some categories of diseases are much more similar than other categories, thus their distributions should be closer in the embedding space. To capture fine-grained relational knowledge, [11, 21, 9] proposed to distill the pair-wise relations between data samples in a mini-batch. However, for the dataset with severe class imbalance, most samples in a mini-batch belong to the majority class, thus the constructed relation graphs may suffer from class bias.

To settle the above issues and capture class-balanced relational knowledge, we propose a novel Categorical Relation Preserving (CRP) loss that utilizes category centroids to construct the relation graph. Specifically, in the student (teacher) model, we compute the centroid of the -th category by averaging the features of all samples in the -th class (retrieved from the memory bank ()):

(3)

where denotes the number of samples in the -th class. Then, for each query sample (

) in the mini-batch, we compute its cosine similarity with all category centroids in the student (teacher) model. After softmax over all classes, we obtain the categorical relation between the sample

() and the -th category:

(4)

where is the total number of classes in the dataset, and denote the sample’s representations extracted by the student and teacher, respectively. Then, we propose the CRP loss to minimize the KL divergence between the teacher and student’s categorical relation graphs:

(5)

Compared with existing relational KD [11, 21, 9], using category centroids as anchors in our methods has two advantages: 1) One anchor (class centroid) is utilized to represent each class (regardless of the number of images in that class), thus the constructed relation graphs naturally mitigate the bias caused by class imbalance. 2) The anchors in the CRP retrieved from the memory bank (momentum aggregation of temporal steps) are more robust than the anchors that rely on the current step. To this end, the CRP loss regularizes the reliable categorical relation graphs built with more robust and representative anchors, thus is expected to better mimic the relational knowledge in the teacher model.

2.3 Training and Testing

In the proposed framework, the weights of the student model and

of the projection layers are optimized by the loss function defined as:

(6)

where denotes the weighted cross-entropy loss supervised by the ground-truth labels. The weight for each class in is inverse proportional to the number of images in that class. , , are used to distill the individual and structural knowledge from the teacher model. In the first epochs, the trade-off weights and would gradually ramp-up from to according to a Gaussian warming up function , while is set as . After epochs, we fix the value of and as , and set as . The teacher weights are updated as the EMA of the student weights.

At the testing stage, we discard the mean teacher and the projection heads, so the inference time is the same as the vanilla student model.

3 Experiments

3.1 Dataset and Implementation Details

3.1.1 Dataset:

We evaluated our proposed CRCKD framework on the HAM10000 [20, 4] and APTOS datasets [2]. The HAM10000 consists of 10015 dermoscopy images labeled by 7 types of skin lesions. In APTOS, there are 3662 fundus images for grading diabetic retinopathy into five categories. These two datasets both suffer from severe class imbalance. A detailed description of these two datasets is provided in the supplementary material. For both datasets, we performed five-fold cross-validation and reported the average testing performance over the five folds. We evaluated the classification performance by overall accuracy (), average precision (), balanced multi-class accuracy (), and score. Due to the class imbalance, is considered the most important metric in this task.

3.1.2 Implementation:

Our method was implemented in Python with the Pytorch library. We employed the pre-trained DenseNet121

[7] as the backbone of the teacher and student model. The network was trained with two P40 GPUs in parallel and the batch size was set to 64. Adam with and was used for network optimization. We trained the network for 80 epochs with ramp-up epoch set as 30. The initial learning rate was set to and decayed by the one-cycle schedule. The temperature in Eq. 1 and Eq. 2 is set as . For each query sample, the number of positive pairs and negative pairs was empirically set as 20 and 4096, respectively.

3.2 Experimental Results

 

Methods HAM10000 APTOS
ACC AP BMA F1 ACC AP BMA F1
DenseNet121 (B1) 84.30 74.16 72.19 72.53 83.83 71.85 67.51 69.14
B1 + MT (B2) 85.01 74.19 76.07 74.38 83.77 71.66 68.79 69.89
B2 + CCD 85.52 74.87 77.64 75.45 84.42 72.79 70.42 71.23
B2 + CRP 85.32 75.06 77.06 75.37 84.47 73.07 69.87 71.07
Our method 85.66 76.35 78.07 76.45 84.87 73.18 71.90 72.22
B2 + CRD [19] 85.33 74.41 76.44 74.77 84.09 71.82 69.38 70.27
B2 + SP [21] 85.13 74.92 76.06 74.48 83.16 70.51 69.15 69.54
B2 + FitNet [27] 84.13 72.85 76.38 73.98 83.76 71.97 69.88 70.52

 

Table 1: Five-fold cross-validation results on HAM10000 and APTOS datasets. The highest rankings are highlighted in bold

. Our method: B2 + CCD + CRP. Detailed performance on the mean and standard deviation of each algorithm is provided in the supplementary material.

3.2.1 Quantitative Results:

Table1 summarizes the performance of our CRCKD and baseline algorithms on the HAM10000 and APTOS datasets. Compared with the vanilla student model “DenseNet121 (B1)”, “B1 + MT (B2)” achieves much better classification performance, showing the superiority of introducing the mean-teacher guidance. What’s more, Table1 () indicates the effectiveness of the proposed CCD and CRP, because involving either of them leads to relatively better performance than “B2”. We conjecture that performance gains brought by CCD and CRP are attributed to the distillation of structural knowledge. Further, the combination of CCD and CRP in our method achieves the best performance with of ( higher than “B2”) on the HAM10000 dataset and of ( higher than “B2”) on the APTOS dataset, demonstrating the efficacy of the proposed method.

To further validate the effectiveness of the proposed CCD and CRP, we performed a comparison with the other three KD paradigms (see the last three rows in Table1). “B2 + CCD” achieves better performance than “B2 + CRD”[19], indicating the contribution of introducing the class-label guidance, which is consistent with our analysis in Section 2.1. Besides, the superiority over “B2 + SP” [21] suggests that our proposed CRP can better distill relational knowledge by utilizing more robust and representative class centroids as anchors. Finally, both “B2 + CCD” and “B2 + CRP” outperform “B2 + FitNet” [27] (a classical KD method that distills the intermediate features of individual samples), suggesting the necessity and superiority of relation-preserving KD studied in this work.

Figure 2: Visualization of the relation matrices between mini-batch samples (batch size=256) in the HAM10000 dataset. Input samples are grouped by ground truth class along each axis. The color from dark blue to light green indicates increased similarity.

3.2.2 Qualitative Analysis:

Taking the HAM10000 dataset as an example, we visualize inter-sample relations produced by the features of different methods in Fig. 2. The diagonal blue blocks in Fig. 2 (a) denote intra-class similarity while other parts represent inter-class relation. As shown in Fig. 2 (b), “B2 + CRD” exhibits low intra-class similarity (especially for the 1, 2, and 5 block). In contrast, “B2 + SP” and “B2 + FitNet”(Fig. 2 (c, d)) yield high inter-class relations. With the regularization of the proposed CCD and CRP, our method (Fig. 2 (e)) exhibits lower inter-class similarity and higher intra-class similarity. To quantitatively compare the relation matrices, we resort to , where denotes the average pair-wise similarity between samples from the same class, is the average sample similarity between different classes. The larger value of indicates higher intra-class similarity and lower inter-class similarity. The value of our method is , outperforming the “B2 + CRD” (1.36), “B2 + SP” (1.35), and “B2 + FitNet” (1.22). These results demonstrate that our proposed method can effectively alleviate the issue of high intra-class variance and inter-class similarity in the medical domain.

Methods ACC (%) AP (%) BMA (%) F1 (%)
Yan et al. [25] 77.160.80 62.711.66 73.371.24 66.701.33
Zhang et al. [28] 82.340.73 72.321.38 74.011.76 72.281.04
Zhang et al. [29] 81.610.94 71.441.22 73.341.86 71.650.85
Liu et al. [9] 84.731.00 73.881.24 76.551.32 74.630.99
Our method 85.660.97 76.350.99 78.071.28 76.450.66
Table 2: Comparison with state-of-the-art methods on the HAM10000 dataset.

3.2.3 Comparison with Contemporary Methods:

On the HAM10000 dataset, we further evaluated the performance of our proposed method with four state-of-the-art methods for skin lesion classification: attention-based methods [25, 28], synergic deep learning [29], and KD method based on sample relations [9]. For a fair comparison, we changed [9] to full supervision, used the implementations of [25, 28, 29, 9] suggested by the authors, and evaluated the performance using the same dataset as our method. As shown in Table 2, the proposed method outperforms existing methods with an improvement of , , , in , further validating the effectiveness of our proposed method.

4 Conclusion

In this paper, we present a novel Categorical Relation-preserving Contrastive Knowledge Distillation (CRCKD) framework for medical image classification. Against the unique challenges of high inter-class similarity and class-imbalance in the medical domain, we propose two novel KD paradigms, i.e., CCD and CRP, to distill rich structural knowledge from the mean-teacher model. Experimental results on the HAM10000 and APTOS datasets demonstrate the effectiveness of the proposed CCD and CRP over other KD paradigms. On the HAM10000 dataset, experiments show that our CRCKD method outperforms many state-of-the-art methods.

4.0.1 Acknowledgements.

The work described in this paper was supported by National Key R&D program of China with Grant No. 2019YFB1312400, Hong Kong RGC CRF grant C4063-18G, and Hong Kong RGC GRF grant #14211420.

References

  • [1] S. Abbasi, M. Hajabdollahi, P. Khadivi, N. Karimi, R. Roshandel, S. Shirani, and S. Samavi (2020) Classification of diabetic retinopathy using unlabeled data and knowledge distillation. arXiv preprint arXiv:2009.00982. Cited by: §1.
  • [2] APTOS 2019 blindness detection. Note: https://www.kaggle.com/c/aptos2019-blindness-detection/data Cited by: §3.1.1.
  • [3] V. Cheplygina, M. de Bruijne, and J. P. Pluim (2019) Not-so-supervised: a survey of semi-supervised, multi-instance, and transfer learning in medical image analysis. Medical image analysis 54, pp. 280–296. Cited by: §1.
  • [4] N. C. Codella, D. Gutman, M. E. Celebi, B. Helba, M. A. Marchetti, S. W. Dusza, A. Kalloo, K. Liopyris, N. Mishra, H. Kittler, and A. Halpern (2018) Skin lesion analysis toward melanoma detection: a challenge at the 2017 international symposium on biomedical imaging (isbi), hosted by the international skin imaging collaboration (isic). In Proc. ISBI, pp. 168–172. Cited by: §3.1.1.
  • [5] G. Hinton, O. Vinyals, and J. Dean (2015)

    Distilling the knowledge in a neural network

    .
    arXiv preprint arXiv:1503.02531. Cited by: §1.
  • [6] Y. Hou, Z. Ma, C. Liu, and C. C. Loy (2019) Learning lightweight lane detection cnns by self attention distillation. In Proc. ICCV, pp. 1013–1021. Cited by: §1.
  • [7] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q. Weinberger (2017) Densely connected convolutional networks. In Proc. CVPR, pp. 4700–4708. Cited by: §3.1.2.
  • [8] G. Litjens, T. Kooi, B. E. Bejnordi, A. A. A. Setio, F. Ciompi, M. Ghafoorian, J. A. Van Der Laak, B. Van Ginneken, and C. I. Sánchez (2017) A survey on deep learning in medical image analysis. Medical image analysis 42, pp. 60–88. Cited by: §1.
  • [9] Q. Liu, L. Yu, L. Luo, Q. Dou, and P. A. Heng (2020) Semi-supervised medical image classification with relation-driven self-ensembling model. IEEE Trans. Med. Imaging. Cited by: §1, §1, §1, §2.2, §2.2, §3.2.3, Table 2.
  • [10] R. Müller, S. Kornblith, and G. Hinton (2019) When does label smoothing help?. arXiv preprint arXiv:1906.02629. Cited by: §1.
  • [11] W. Park, D. Kim, Y. Lu, and M. Cho (2019) Relational knowledge distillation. In Proc. CVPR, pp. 3967–3976. Cited by: §2.2, §2.2.
  • [12] A. Patra, Y. Cai, P. Chatelain, H. Sharma, L. Drukker, A. T. Papageorghiou, and J. A. Noble (2019) Efficient ultrasound image analysis models with sonographer gaze assisted distillation. In Proc. MICCAI, pp. 394–402. Cited by: §1.
  • [13] N. Saunshi, O. Plevrakis, S. Arora, M. Khodak, and H. Khandeparkar (2019) A theoretical analysis of contrastive unsupervised representation learning. In

    International Conference on Machine Learning

    ,
    pp. 5628–5637. Cited by: §2.1.
  • [14] H. Shang, Z. Sun, W. Yang, X. Fu, H. Zheng, J. Chang, and J. Huang (2019)

    Leveraging other datasets for medical imaging classification: evaluation of transfer, multi-task and semi-supervised learning

    .
    In Proc. MICCAI, pp. 431–439. Cited by: §1.
  • [15] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research 15 (1), pp. 1929–1958. Cited by: §1.
  • [16] C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna (2016) Rethinking the inception architecture for computer vision. In Proc. CVPR, pp. 2818–2826. Cited by: §1.
  • [17] A. Tarvainen and H. Valpola (2017) Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. In Adv. Neural Inf. Process. Syst., pp. 1195–1204. Cited by: §1.
  • [18] J. J. Thiagarajan, S. Kashyap, and A. Karargyris (2019) Distill-to-label: weakly supervised instance labeling using knowledge distillation. In 2019 18th IEEE International Conference On Machine Learning And Applications (ICMLA), pp. 902–907. Cited by: §1, §1.
  • [19] Y. Tian, D. Krishnan, and P. Isola (2019) Contrastive representation distillation. arXiv preprint arXiv:1910.10699. Cited by: §2.1, §2.1, §3.2.1, Table 1.
  • [20] P. Tschandl, C. Rosendahl, and H. Kittler (2018) The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific data 5, pp. 180161. Cited by: §3.1.1.
  • [21] F. Tung and G. Mori (2019) Similarity-preserving knowledge distillation. In Proc. ICCV, pp. 1365–1374. Cited by: §1, §2.2, §2.2, §3.2.1, Table 1.
  • [22] B. Unnikrishnan, C. M. Nguyen, S. Balaram, C. S. Foo, and P. Krishnaswamy (2020) Semi-supervised classification of diagnostic radiographs with noteacher: a teacher that is not mean. In Proc. MICCAI, pp. 624–634. Cited by: §1.
  • [23] J. Wu, S. Yu, W. Chen, K. Ma, R. Fu, H. Liu, X. Di, and Y. Zheng (2020) Leveraging undiagnosed data for glaucoma classification with teacher-student learning. In Proc. MICCAI, pp. 731–740. Cited by: §1, §1.
  • [24] Z. Wu, Y. Xiong, S. X. Yu, and D. Lin (2018) Unsupervised feature learning via non-parametric instance discrimination. In Proc. CVPR, pp. 3733–3742. Cited by: §2.1.
  • [25] Y. Yan, J. Kawahara, and G. Hamarneh (2019) Melanoma recognition via visual attention. In Inf Process Med Imaging, pp. 793–804. Cited by: §3.2.3, Table 2.
  • [26] C. Yang, L. Xie, C. Su, and A. L. Yuille (2019) Snapshot distillation: teacher-student optimization in one generation. In Proc. CVPR, pp. 2859–2868. Cited by: §1.
  • [27] J. Yim, D. Joo, J. Bae, and J. Kim (2017) A gift from knowledge distillation: fast optimization, network minimization and transfer learning. In Proc. CVPR, pp. 4133–4141. Cited by: §3.2.1, Table 1.
  • [28] J. Zhang, Y. Xie, Q. Wu, and Y. Xia (2018) Skin lesion classification in dermoscopy images using synergic deep learning. In Proc. MICCAI, pp. 12–20. Cited by: §3.2.3, Table 2.
  • [29] J. Zhang, Y. Xie, Y. Xia, and C. Shen (2019) Attention residual learning for skin lesion classification. IEEE Trans. Med. Imaging 38 (9), pp. 2092–2103. Cited by: §3.2.3, Table 2.
  • [30] J. Zhuang, J. Cai, R. Wang, J. Zhang, and W. Zheng (2020)

    Deep knn for medical image classification

    .
    In Proc. MICCAI, pp. 127–136. Cited by: §1.