Rethink Transfer Learning in Medical Image Classification

06/09/2021 ∙ by Le Peng, et al. ∙ University of Minnesota 0

Transfer learning (TL) with deep convolutional neural networks (DCNNs) has proved successful in medical image classification (MIC). However, the current practice is puzzling, as MIC typically relies only on low- and/or mid-level features that are learned in the bottom layers of DCNNs. Following this intuition, we question the current strategies of TL in MIC. In this paper, we perform careful experimental comparisons between shallow and deep networks for classification on two chest x-ray datasets, using different TL strategies. We find that deep models are not always favorable, and finetuning truncated deep models almost always yields the best performance, especially in data-poor regimes. Project webpage: https://sun-umn.github.io/Transfer-Learning-in-Medical-Imaging/ Keywords: Transfer learning, Medical image classification, Feature hierarchy, Medical imaging, Evaluation metrics, Imbalanced data

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 6

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Transfer learning (TL) has become the norm for medical image classification (MIC) and segmentation using deep learning. In these tasks, deep convolutional neural networks (DCNNs) pretrained on large-scale

source tasks

(e.g., image classification on ImageNet 

[5] and 3D medical segmentation and classification [4]) are often adopted and fine-tuned as backbone models for the target tasks; see [23, 23, 25, 1, 10, 2, 6, 12] for successes.

Figure 1:

(left) The feature hierarchy learned by typical DCNNs on computer vision datasets; (right) Examples of medical pathologies in a chest x-ray. While high-level shapes and patterns are needed to identify the dog on the left, only low-level textures and/or mid-level blobs are necessary for detecting the medical pathologies on the right. (the left visualized features adapted from

[29]; the right chest x-ray adapted from [17]).

The key of TL is feature reuse across the source and target tasks, which leads to practical benefits such as fast convergence in training, and good performance even if the training data in the target task are limited [28]. Pretrained DCNNs extract increasingly more complex visual features, from low-level textures and edges learned at bottom layers, to high-level shapes and patterns learned at top layers [29]—the latter are crucial for object recognition and segmentation (see Fig. 1 (left)). However, medical pathologies often take the form of abnormal textures or blobs, which correspond to low-level and/or at most mid-level visual features (see Fig. 1 (right)). So, intuitively, for most medical imaging tasks, we only need to finetune a reasonable number of bottom layers and do not need the top layers. Puzzlingly, the current common practice goes against this intuition: in most medical imaging tasks, all layers of pretrained DCNNs are retained and transferred. It is thus natural to ask how to bridge the gap between our intuition and the practice, and whether the full-transfer strategy is the best possible for MIC.

The pioneering work [18] partially addresses this on MIC (chest x-ray classification and diabetic retinopathy detection), and empirically shows that in terms of classification performance, i) shallow models can perform on par with deep models, and ii) finetuning deep and shallow models from pretrained weights is no better than training from scratch. However, there are serious limitations in their experimentation: 1) Poor evaluation metric. In their tasks, the validation sets have much smaller positive:negative ratios than ,111In R-DR identification, positives make up and of of the two validation sets, respectively [7]. In CheXpert, positive samples make up about only [11]. i.e., are highly imbalanced, and the standard AUC (i.e., area-under-ROC, AUROC) metric that they use is well known to be a poor indicator of performance in such scenarios; see Section 2 and [15, 21] and Section 5.1 of [16]; 2) Abundant training data. They focus on two data-rich tasks—each class has at least thousands of training data points, but most medical problems suffer from scarcity of data, when TL could have a unique performance edge [23, 20, 14, 13]; and 3) Rigid TL methods. They only evaluate the full-transfer (all layers finetuned from pretrained weights) and hybrid-transfer (i.e., Transfusion, bottom layers finetuned from pretrained weights while top layers finetuned from random initialization) strategies, but miss great alternatives, as we show below.

In this paper, we challenge the conclusions of [18], and depict a picture of TL for MIC that is consistent with our intuition favoring only transferring low- to mid-level features. Specifically, measuring performance using both AUROC and AURPC (area-under-precision-recall-curve, which reflects performance well even under class imbalance; see Section 2), we find the following:

  • For any model, shallow or deep, TL most often outperforms training from scratch, especially when the training set is small, reaffirming the performance benefit of TL;

  • The best TL strategy is finetuning truncated pretrained models (i.e., top layers removed) up to layers commensurate with the level of features needed. This not only confirms that high-level features are probably not relevant for typical MIC, but also leads to compact models that are favorable for both training and inference.

2 Highlights of Key Technical Points

2.1 Evaluation metrics: AUPRC vs. AUROC

Figure 2:

(left) ROC and PRC curves of two classifiers on the same binary classification problem. While the ROC curves are uniformly high and very close, their PRC curves are far apart. (right) Confusion table and definition of relevant terms for ROC and PRC curves.

For binary classification with a dominant negative class—common for MIC, the precision-recall curve (PRC) is well known to be more informative and indicative of the true performance than the ROC [15, 21]. Fig. 2(right) provides a quick reminder of the relevant notations. Even for mediocre classifiers that can only score the positives above the majority of the negatives, the TPRs quickly climb to high values for low FPRs—dictated by large TNs in early cut-off points, leading to uniformly high and close AUROCs. So AUROC cannot tell good classifiers from mediocre classifiers. In contrast, PRC captures the granularity of performance by including precision that is sensitive to the ranking of positives vs. all negatives, rendering AUPRC a favorable metric in the presence of class imbalance.

To quickly see the gap, consider a dataset consisting of positives and negatives for a rare disease. Assume Classifier A (CA) scores the positives uniformly over the top , and Classifier B (CB) scores the positives uniformly over the top . Intuitively, CA performs much better than CB, as they detect TP case at the cost of and FP cases, respectively—this is captured by precision. Figure 2 (left) presents the ROC and PRC curves for CA and CB: both CA and CB perform almost perfectly told from ROCs, but PRCs reveal the substantial performance gap clearly—while CA is good, CB is probably unacceptable in medical diagnosis (CA: AUROC, AUPRC; CB:AUROC, AUPRC).

In our experiments, the medical datasets are imbalanced, if not significantly, and hence we report both AUROC and AUPRC, while the prior work [18] only reports AUROC. As we show in Section 3, the use of AUPRC helps us to reaffirm the performance benefit of TL, against the conclusion made in [18].

2.2 Networks models and training strategies

We study three types of network models: (1) Shallow networks

: i.e., the CBR (convolution-batchnorm-relu) family proposed in 

[18]. These are basic convolutional networks with (CBR-LargeW, CBR-Small, CBR-Tiny) or (CBR-LargeW) convolution layers composed of or filters; details are in Appendix A of [18]; (2) Deep networks: ResNet50[8] and DenseNet121[9], both are popular in TL for MIC [3, 11, 19]; (3) Truncated networks: ResNet50 [8] and DenseNet121[9] truncated at different levels, including i) Res-T1, Res-T2, Res-T3, Res-T4, Res-T5, which are ResNet50 truncated at the first 25(50%)222All the percentages here indicate the percentages of retained layers out of all layers in the respective deep models., 37(74%), 43(86%), 46(92%), and 49(98%) convolutional layers, respectively; and ii) Dens-T1, Dens-T2, and Dens-T3, which are DenseNet121 truncated at the first 1(12%), 2(32%), and 3(88%) dense blocks, respectively. For all these models, the final fully-connected layers are adjusted or appended whenever necessary.

We compare two training strategies: (1) Random initialization (RI): The network is trained with all weights randomly initialized; (2) Transfer learning (TL)

: Pretrained network weights are finetuned. All the pretrained models are trained on ImageNet, a standard image classification dataset in computer vision 

[5].

2.3 Datasets and setups

The prior work [18] performs their experiments on CheXpert [11] and Retina [7], but Retina is private. We evaluate our models and training strategies on CheXpert [11], and a private COVID-19 chest x-ray dataset.

CheXpert consists of chest x-rays of patients, with binary labels for each of the pathologies for each x-ray333We map all “Uncertainty” labels into “0” (negative, i.e., the U-Zeros model in [11]). We omit the “Support Devices” class as it is different from typical medical pathologies. : No Finding (NF), Enlarged Cardiome (EC), Cardiomegaly (Ca.), Lung Opacity (LO), Lung Lesion (LL), Edema (Ed.), Consolidation (Co.), Pneumonia (Pn.), Atelectasis (At.), Pneumothorax (Pt.), Pleural Effusion (PE), Pleural Other (PO), Fracture (Fr.). We randomly divide the dataset into training (), validation (), and test () sets. The setting is different from that of [18] and we do not directly compare with their results [18], as: 1) they test on the official validation set, which only consists of x-rays and also misses several pathologies. We prefer the much larger test set constructed above for stable and comprehensive evaluation; 2) our purpose is not to obtain or surpass the state-of-the-art results on the dataset, but to evaluate the models and training strategies that we listed above in a consistent manner; and 3) they do not provide sufficient training details or codes for easy reproducibility.

Our COVID-19 dataset comprises COVID-positive and COVID-negative frontal chest x-rays collected from M Health Fairview, Minnesota, USA; details can be found in [24]. We focus on prospective evaluation: training and validation are performed on x-rays dated before June 30th 2020, and test is performed on x-rays dated after. By this splitting strategy, our training set has positives and negatives (imbalance ratio :), validation set has positives and negatives, and the test set has positives and negatives (imbalance ratio :)—the drifting imbalance ratio is a salient feature of pandemic data such as COVID-19.

To account for randomness, each experiment is independently repeated three times, and the mean and standard deviation of all performance scores are reported.

Figure 3: TL vs. RI on deep and shallow models for CheXpert. For almost all cases, RI performs on par with TL, but in a number of cases, TL outperforms RI by visible gaps, especially when measured by AUPRC.

3 Transfer Learning (TL) vs. Random Initialization (RI)

In this section, we compare TL and RI on deep and shallow models in data-rich and data-poor regimes, respectively. Both CheXpert and COVID-19 contain over samples per class, and this is relatively high for medical tasks. So we consider experiments on the original datasets to be in the data-rich regime. To simulate data-poor regimes, we focus on the COVID-19 dataset and subsample and of the original training set.

3.1 Data-rich regime

The scores on CheXpert are presented in Fig. 3 (standard deviation: ). While TL and RI perform comparably on most pathologies and models, in a number of cases on deep models, TL outperforms RI by significant gaps, e.g., ResNet50 on Pt., TL is above RI, as measured by AUPRC. Overall, on deep models, TL outperforms RI marginally on most pathologies, and considerably on a number of pathologies, measured by both AUROC and AUPRC. On shallow models, TL and RI are very close by both metrics.

The scores on COVID-19 can be read off from Fig. 4 (i.e., “w. data”; standard deviation ). We observe a similar trend to that of CheXpert. TL leads to substantial performance gains on DenseNet121, and marginal gains or occasional losses on shallow models.

Overall, on both datasets and in all settings, the best-performing combination is most often TL, not RI, coupled with a certain model. AUPRC is a crucial tie-breaker when AUROCs are close.

3.2 Data-poor regime

Figure 4: TL vs. RI on COVID-19. With the full dataset, TL wins over RI on DenseNet121, and they are close in performance on shallow models. With and subsampled data only, TL outperforms RI on all but CBR-LargeT and CBR-LargeW.

As we alluded to Section 1, TL is expected to benefit data-limited learning. To verify this, we simulate two small COVID-19 datasets: 1) Data, which contains 1,539 cases (88 positives and 1,451 negatives); 2) Data, which consists of 3,080 cases (177 positives and 2903 negatives). Our training protocol and test data are exactly the same as Section 3.1. The results are included in Fig. 4. We have two observations: 1) On all models, except for CBR-LargeT and CBR-LargeW, TL outperforms RI measured in both AUROC and AUPRC; 2) CBR-Small+TL is a clear winner for data, whereas for data, AUROC and AUPRC point to CBR-LargeW and DenseNet121, respectively. This disparity highlights the need for reporting and comparing both metrics.

4 Truncated Transfer Learning

Figure 5: TL AUROC and AUPRC on COVID-19. Here we include all shallow models, and Dens-F is DenseNet121. Dens-T1 and Dens-T2 are among the best models.
Figure 6: TL AUROC and AUPRC on CheXpert. Here CBR is the CBR-Small, which performs the best among all shallow models, and Dens-F is DenseNet121. Dens-T3 performs the best on all pathologies.

We have confirmed the advantage of TL over RI, especially on deep models. It is natural to ask if we could have better ways of performing TL. Since we probably do not need high-level features, it is puzzling why in TL on deep models we need to keep the top layers at all. This motivates the truncated networks that we proposed in Section 2.2. In this section, we compare the performance of TL on these truncated models and their parent deep models, as well as selected CBR shallow models. We omit Transfusion proposed in [18], as according to the evaluation there, it does not lead to performance improvement compared to TL on full models, but just better convergence speed—which is not the focus of the current paper. The training and test protocols follow exactly those of Section 3.

Fig. 5 summarizes the results on our COVID-19 dataset. It is evident that Dens-T1 and Dens-T2, which are heavily truncated (up to of the original depth) versions of DenseNet121, are the top two performant models when combined with TL. In contrast, Dens-T3 and Dens-F with less aggressive truncations can be substantially worse, and sometimes even worse than the shallow models. From COVID-19 medical studies, it is known that salient radiomic features for COVID-19 are opacities and consolidation in the lung area that only concern low-level textures and perhaps also mid-level blobs [27, 22]. This is a strong confirmation that only a reasonable number of bottom layers are needed for efficient TL.

Fig. 6 presents the results on CheXpert. Dens-T3 is the best model, although in most cases it is comparable to Dens-F. Note that compared to Dens-T2 or Dens-T1 that exceled on the COVID-19 dataset, Dens-T3 is far deeper ( vs. of the original depth). This disparity can again be explained by feature hierarchy. In CheXpert, pathologies such as atelectasis and pneumothorax need relatively high-level features as they start to concern shapes, in contrast to the low- and mid-level features used in COVID-19. Another observation is that on EC, LL, PO, Fr., the AUPRCs are very low (below ) although the corresponding AUROCs are all above or even mostly . These are rare diseases in CheXpert with high imbalance ratios between positives and negatives (EC: 1:19.8, LL: 1:26.4, PO: 1:75.9, Fr.: 1:24.8). Even for the best models here, the AUROCs may be considered decent, but their actual performance, when precision is taken into account via AUPRC, is very poor444However, note that it does not make sense to cross compare AUPRCs across different ratios; see discussions in Section 5.1 of [16] and [26].. This reinforces our claim that AUPRC needs to be reported when evaluating classifiers on data with class imbalance.

5 Discussion

In this paper, we revisit transfer learning (TL) for medical imaging classification (MIC) on chest X-rays, taking into account characteristics of typical medical datasets such as class imbalance and small data. By evaluating different TL strategies on a number of shallow and deep convolutional neural network models, we find that 1) TL does benefit the classification performance, especially in data-poor scenarios; 2) only transferring truncated deep pretrained models up to layers commensurate with the level of features needed for the classification task leads to superior performance compared to conventional TL on deep and shallow models. During our experimentation, we have also highlighted the role of AUPRC in distinguishing good classifiers from mediocre ones under class imbalance. Our results support that low- and mid-level visual features are probably sufficient for typical MIC, and high-level features are needed only occasionally.

Potential future directions include: 1) experimenting with other image modalities, such as CT and MRI images—if similar conclusion holds, the truncated TL strategy can lead to a profound saving of computing resources for model training and inference on 3D medical data; 2) validating the conclusion on segmentation, another major family of imaging tasks, and other diseases; 3) investigating TL on models directly trained on medical datasets, e.g., [4] and the pretrained models released in the Nvidia Clara Imaging package555https://developer.nvidia.com/clara, rather than computer vision datasets.

Acknowledgements

The authors acknowledge the Minnesota Supercomputing Institute (MSI) at the University of Minnesota, and Microsoft Azure for providing computing resources. LP was partly supported by an OVPR COVID19 Rapid Response Grant of University of Minnesota. We also thank M Health Fairview of Minnesota for providing the private COVID19 data.

References

  • [1] Anthimopoulos, M., Christodoulidis, S., Ebner, L., Christe, A., Mougiakakou, S.: Lung pattern classification for interstitial lung diseases using a deep convolutional neural network. IEEE transactions on medical imaging 35(5), 1207–1216 (2016)
  • [2]

    Antropova, N., Huynh, B.Q., Giger, M.L.: A deep feature fusion methodology for breast cancer diagnosis demonstrated on three imaging modality datasets. Medical physics

    44(10), 5162–5171 (2017)
  • [3] Bressem, K.K., Adams, L.C., Erxleben, C., Hamm, B., Niehues, S.M., Vahldiek, J.L.: Comparing different deep learning architectures for classification of chest radiographs. Scientific reports 10(1), 1–16 (2020)
  • [4] Chen, S., Ma, K., Zheng, Y.: Med3d: Transfer learning for 3d medical image analysis. arXiv preprint arXiv:1904.00625 (2019)
  • [5]

    Deng, J., Dong, W., Socher, R., Li, L., Kai Li, Li Fei-Fei: Imagenet: A large-scale hierarchical image database. In: 2009 IEEE Conference on Computer Vision and Pattern Recognition. pp. 248–255 (2009).

    https://doi.org/10.1109/CVPR.2009.5206848
  • [6] Ghafoorian, M., Mehrtash, A., Kapur, T., Karssemeijer, N., Marchiori, E., Pesteie, M., Guttmann, C.R., de Leeuw, F.E., Tempany, C.M., Van Ginneken, B., et al.: Transfer learning for domain adaptation in mri: Application in brain lesion segmentation. In: International conference on medical image computing and computer-assisted intervention. pp. 516–524. Springer (2017)
  • [7] Gulshan, V., Peng, L., Coram, M., Stumpe, M.C., Wu, D., Narayanaswamy, A., Venugopalan, S., Widner, K., Madams, T., Cuadros, J., Kim, R., Raman, R., Nelson, P.C., Mega, J.L., Webster, D.R.: Development and validation of a deep learning algorithm for detection of diabetic retinopathy in retinal fundus photographs. JAMA 316(22),  2402 (dec 2016). https://doi.org/10.1001/jama.2016.17216
  • [8] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 770–778 (2016)
  • [9] Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 4700–4708 (2017)
  • [10] Huynh, B.Q., Li, H., Giger, M.L.: Digital mammographic tumor classification using transfer learning from deep convolutional neural networks. Journal of Medical Imaging 3(3), 034501 (2016)
  • [11]

    Irvin, J., Rajpurkar, P., Ko, M., Yu, Y., Ciurea-Ilcus, S., Chute, C., Marklund, H., Haghgoo, B., Ball, R., Shpanskaya, K., et al.: Chexpert: A large chest radiograph dataset with uncertainty labels and expert comparison. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 33, pp. 590–597 (2019)

  • [12] Jiang, Z., Zhang, H., Wang, Y., Ko, S.B.: Retinal blood vessel segmentation using fully convolutional network with transfer learning. Computerized Medical Imaging and Graphics 68, 1–15 (2018)
  • [13] Karimi, D., Warfield, S.K., Gholipour, A.: Critical assessment of transfer learning for medical image segmentation with fully convolutional neural networks. arXiv preprint arXiv:2006.00356 (2020)
  • [14] Khalifa, N.E.M., Loey, M., Taha, M.H.N., Mohamed, H.N.E.T.: Deep transfer learning models for medical diabetic retinopathy detection. Acta Informatica Medica 27(5),  327 (2019)
  • [15] Lever, J., Krzywinski, M., Altman, N.: Classification evaluation. Nature Methods 13(8), 603–604 (jul 2016). https://doi.org/10.1038/nmeth.3945
  • [16]

    Murphy, K.P.: Probabilistic Machine Learning: An introduction. MIT Press (2021),

    probml.ai
  • [17] Nguyen, H.Q., Pham, H.H., Nguyen, N.T., Nguyen, D.B., Dao, M., Vu, V., Lam, K., Le, L.T.: Vinbigdata chest x-ray abnormalities detection. url=https://www.kaggle.com/c/vinbigdata-chest-xray-abnormalities-detection (2021)
  • [18] Raghu, M., Zhang, C., Kleinberg, J., Bengio, S.: Transfusion: Understanding transfer learning for medical imaging. In: Advances in Neural Information Processing Systems. vol. 32, pp. 3347–3357. Curran Associates, Inc. (2019)
  • [19] Rajpurkar, P., Irvin, J., Zhu, K., Yang, B., Mehta, H., Duan, T., Ding, D., Bagul, A., Langlotz, C., Shpanskaya, K., et al.: Chexnet: Radiologist-level pneumonia detection on chest x-rays with deep learning. arXiv preprint arXiv:1711.05225 (2017)
  • [20] Ravishankar, H., Sudhakar, P., Venkataramani, R., Thiruvenkadam, S., Annangi, P., Babu, N., Vaidya, V.: Understanding the mechanisms of deep transfer learning for medical images. In: Deep learning and data labeling for medical applications, pp. 188–196. Springer (2016)
  • [21] Saito, T., Rehmsmeier, M.: The precision-recall plot is more informative than the roc plot when evaluating binary classifiers on imbalanced datasets. PloS one 10(3), e0118432 (2015)
  • [22] Shi, H., Han, X., Jiang, N., Cao, Y., Alwalid, O., Gu, J., Fan, Y., Zheng, C.: Radiological findings from 81 patients with covid-19 pneumonia in wuhan, china: a descriptive study. The Lancet infectious diseases 20(4), 425–434 (2020)
  • [23] Shin, H.C., Roth, H.R., Gao, M., Lu, L., Xu, Z., Nogues, I., Yao, J., Mollura, D., Summers, R.M.: Deep convolutional neural networks for computer-aided detection: Cnn architectures, dataset characteristics and transfer learning. IEEE transactions on medical imaging 35(5), 1285–1298 (2016)
  • [24] Sun, J., Peng, L., Li, T., Adila, D., Zaiman, Z., Melton, G.B., Ingraham, N., Murray, E., Boley, D., Switzer, S., Burns, J.L., Huang, K., Allen, T., Steenburg, S.D., Gichoya, J.W., Kummerfeld, E., Tignanelli, C.: A prospective observational study to investigate performance of a chest x-ray artificial intelligence diagnostic support tool across 12 u.s. hospitals. medRxiv (2021). https://doi.org/10.1101/2021.06.04.21258316, https://www.medrxiv.org/content/early/2021/06/04/2021.06.04.21258316
  • [25]

    van Tulder, G., de Bruijne, M.: Combining generative and discriminative representation learning for lung ct analysis with convolutional restricted boltzmann machines. IEEE transactions on medical imaging

    35(5), 1262–1272 (2016)
  • [26] Williams, C.K.: The effect of class imbalance on precision-recall curves. Neural Computation 33(4), 853–857 (2021)
  • [27] Wong, H.Y.F., Lam, H.Y.S., Fong, A.H.T., Leung, S.T., Chin, T.W.Y., Lo, C.S.Y., Lui, M.M.S., Lee, J.C.Y., Chiu, K.W.H., Chung, T.W.H., et al.: Frequency and distribution of chest radiographic findings in patients positive for covid-19. Radiology 296(2), E72–E78 (2020)
  • [28] Yosinski, J., Clune, J., Bengio, Y., Lipson, H.: How transferable are features in deep neural networks? In: Proceedings of the 27th International Conference on Neural Information Processing Systems-Volume 2. pp. 3320–3328 (2014)
  • [29] Zeiler, M.D., Fergus, R.: Visualizing and understanding convolutional networks. In: European conference on computer vision. pp. 818–833. Springer (2014)