Self-Training with Improved Regularization for Few-Shot Chest X-Ray Classification

05/03/2020 ∙ by Deepta Rajan, et al. ∙ 21

Automated diagnostic assistants in healthcare necessitate accurate AI models that can be trained with limited labeled data, can cope with severe class imbalances and can support simultaneous prediction of multiple disease conditions. To this end, we present a novel few-shot learning approach that utilizes a number of key components to enable robust modeling in such challenging scenarios. Using an important use-case in chest X-ray classification, we provide several key insights on the effective use of data augmentation, self-training via distillation and confidence tempering for few-shot learning in medical imaging. Our results show that using only  10 the labeled data, we can build predictive models that match the performance of classifiers trained in a large-scale data setting.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Motivation

An increasing need for automated diagnostic assistants in healthcare places a growing demand for developing accurate AI models, while being resilient to biases stemming from data sources and demographics [thiagarajan2018can]. In this paper, we consider an important class of diagnosis problems in medical imaging that is characterized by three crucial real-world challenges: (i) limited access to labeled data, (ii) severe class imbalance, and (iii) the need to associate each sample to multiple disease conditions (multi-label). Learning with limited labeled data, often referred to as few-shot learning, when combined with class imbalances, leads to severe overfitting in practice. Though the recent advances to few-shot learning can help with this challenge to an extent, e.g. novel augmentation techniques [chaitanya2019semi, eaton2018improving]

, customized loss functions 

[ge2018chest] and sophisticated regularization strategies [zhang2017mixup], the class imbalance and multi-label nature of diagnosis problems makes them insufficient in practice. Another popular approach to deal with small data problems is to leverage additional unlabeled datasets, if available, and build more robust models [berthelot2019mixmatch, berthelot2019remixmatch, arazo2019pseudo]. However, their effectiveness on the challenging few-shot, multi-label diagnosis problems has not been studied so far. Use-case.

To illustrate the aforementioned challenges, we consider a chest X-ray (CXR) classification problem, and show that existing deep learning-based solutions developed for large-scale data perform poorly with few-shot data 

[rajpurkar2018deep, pham2019interpreting]. More specifically, we use the public CXR repository developed by Stanford [irvin2019chexpert]. The choice of this use-case was driven by the prevalence of X-rays as a diagnostic modality [mettler2009radiologic], the impact of robustly detecting lung conditions [xu2020pathological] and the difficulty in obtaining expert annotations at scale [irvin2019chexpert]. Proposed Work. In this paper, we develop a novel learning approach, particularly suited for medical imaging problems, that enables the design of robust models in very low sample regimes. More specifically, our approach is comprised of crucial components: (i) weak image augmentation; (ii) mixup training; (iii) confidence tempering regularization; and (iv) self-training with a noisy student. While image augmentation is routinely used in several recent solutions for CXR classification [bressem2020comparing], we make a surprising finding that, with few-shot data, it is insufficient to achieve good generalization. Hence, we propose to employ mixup training, a recent approach for Vicinal Risk Minimiztion [zhang2017mixup], and a novel regularization strategy to handle the class imbalance challenge. Finally, we also explore the use of a self-training protocol to evolve a student model with improved generalization, without the need for any additional data. We extend this to the case where we can have access to additional unlabeled data. We make a crucial finding, similar to [xie2019self], that the student models should be noised while training. Our results show that a ResNet-18 model trained using less than of the labeled data outperforms another ResNet-18 trained on the full k labeled set. Furthermore, with less than of the labeled data, our approach achieves comparable performance to an over-parameterized DenseNet-121 architecture.

2 Problem Setup

In this work, we consider the problem of chest X-ray classification to find evidences for any combination of different diseases, namely: (a) Atelectasis (AT), (b) Cardiomegaly (CA), (c) Consolidation (CO), (d) Edema (ED), and (e) Pleural Effusion (PE). In our setup, we assume that we can only access few-shot labeled data and the label distribution is characterized by severe imbalance.

Dataset Patients Images CA ED CO AT PE
Train 43,393 138,655 17,572 36,983 10,040 23,810 58,141
Validation 10,000 20,674 1,849 3,543 1,016 3,337 5,632
Test 200 234 68 45 33 80 67
Table 1: Description of the chest X-ray classification dataset used in our study.

Dataset Description. We use CheXpert [irvin2019chexpert], a large public dataset for chest radiograph interpretation. The images were curated by Stanford from both in-patient and out-patient centers between October 2002 and July 2017. It consists of X-rays from patients, where images can correspond to Frontal, Lateral, Anteroposterior or Posteroanterior views. In our study, we used the subset of train set that contained an actual prediction for the classes that we considered (some of the samples have the label uncertain). Subsequently, we randomly split the dataset into train and validation sets with no patient overlap among them and the test set was designed using the additional -patient set released publicly by Stanford for evaluation. The sample sizes used in our experiments along with their class distributions are summarized in Table 1. Setup. We denote a labeled dataset by the tuple, , which is a collection of examples (also referred as shots) and a label matrix of size , where indicates the total number of classes (set to ). We denote an unlabeled dataset as , which does not have the corresponding annotations. In our experiments, we randomly draw both labeled and unlabeled sets from the k train set (see Table 1) with no overlap between them. Note, we assume that the marginal distributions of the classes in the few-shot dataset is same as the original k training set. We expect the classification task to be significantly more challenging as the number of labeled examples

becomes smaller. In order to use models pre-trained on ImageNet for initialization, we pre-processed the raw gray-scale images by resizing them to

using linear interpolation while maintaining the aspect ratio with border padding. The images were then normalized using a pixel mean of

and standard deviation of

in addition being contrast adjusted using histogram equalization.

3 Method

In this section, we present the proposed methodology for building accurate classifiers using few-shot data in healthcare problems. Our approach is comprised of three crucial components, namely (i) weak image augmentation, (ii) mixup training, (iii) confidence tempering, and (iv) self-training with noisy students, to produce highly effective models. (i) Weak image augmentation. In accordance with one of the CheXpert [irvin2019chexpert] competition’s top-ranked submission [jfhealthcare], we perform weak augmentation on X-rays to improve robustness of the trained models. In particular, we apply random affine transformations namely rotation ( to ), horizontal/ vertical translations ( to ) and scaling ( to ). Though weak data augmentation is widely adopted to avoid overfitting, we find it to be insufficient in problems with limited training data, corroborating with the results in [chaitanya2019semi], where other augmentation techniques were also explored. In the rest of the paper, we refer to the augmented images by the notation . (ii) Mixup training.

Mixup is a recent technique for training deep neural networks 

[zhang2017mixup], wherein we generate additional samples by convexly combining random pairs of input images and their corresponding labels. It is based on the principle of Vicinal Risk Minimization [chapelle2001vicinal], where the goal is to train classifiers not only on the training samples, but also in the vicinity of each sample. It has also been found in [thulasidasan2019mixup] that mixup training also leads to networks whose confidences are well-calibrated, i.e., the predictive scores are actually indicative of the actual likelihood of correctness. Hence, in our approach, we utilize mixup training to improve the robustness of classifiers. For mixup, we create virtual image-label pairs by convexly interpolating between two random samples ,

(1)

and enforce the consistency that predictions for should agree with the interpolated labels . The amount of interpolation is controlled by the parameter , where in , and

denotes the beta distribution. In practice, given the predictions from a model

with parameters , we define the loss function for mixup training:

(2)

where denotes the binary cross entropy loss between the predictions and the true labels , and the summation is over multiple random pairs.

Input: Labeled data , and unlabeled data , Mixup parameter , confidence tempering constants and , sharpening parameter , hyper-parameters .
Output: Teacher model with parameters and Student model with parameters
Initialization: initialize model parameters ;
for epochs do Perform weak image augmentation to labeled data to obtain ;
Generate mixup parameter ;
Mixup training: Convexly combine random sample pairs in using eqn. (1) and compute using eqn. (2);
Confidence tempering: For each class

, estimate

using eqn. (3);
Compute loss function ;
Update parameters ;
end for
/*Self-training*/
Initialize a student model with parameters ;
for m epochs do Perform weak image augmentation to labeled and unlabeled data to obtain and ;
Estimate pseudo labels for unlabeled data ;
Perform sharpening of with using eqn. (4);
Generate mixup parameter ;
Mixup training: Convexly combine random pairs in using eqn. (1) and compute using eqn. (2);
Confidence tempering: For each , estimate using eqn. (3) for ;
Compute loss function ;
Update parameters ;
end for
return ;
Algorithm 1 Proposed Approach with few-shot labeled data and an additional unlabeled set.

(ii) Confidence tempering regularization.

Though mixup training helps in avoiding model overfitting, the inherent class imbalance can make it ineffective, particularly for classes with lesser number of examples. A naïve way to handle this is to alter the probability distribution with which we choose the random pairs in mixup (i.e. uniform distribution), however, it is not clear how to estimate marginal distributions using limited data that effectively reflects the unseen test cases. Hence, we propose a novel regularization strategy, referred as

confidence tempering (CT). A common observation in imbalanced multi-label problems is that a model compounds more evidence for assigning every image to the most prevalent class, while providing little to no likelihood for classes with very few examples. We avoid this by introducing a regularization term for every class :

(3)

Here, indicates the likelihood of assigning sample to class and is the average evidence for class . In practice, we evaluate this for each mini-batch. The hyper-parameters (set to ) and (set to ) are low and high thresholds for tempering confidences. In other words, this regularization penalizes a model that assigns overwhelmingly high evidences for any class or that does not provide any non-trivial evidence for any class. As we will show in our results, this regularization provides significant performance gains for classes with very few examples in the train set. (iv) Self-training with noisy students. Finally, we propose to employ a self-training protocol, wherein we distill knowledge from a trained model (Teacher) to evolve a Student model with parameters that can achieve an improved generalization. This can be carried out using only the labeled data or with an additional unlabeled set . In either of the settings, we follow the empirical evidence in the recent work by Xie et al. [xie2019self] and use a student model that is noised during training (via mixup). We will now explain the protocol for the case where we have access to an additional unlabeled dataset , which we refer to as Self-Training (Unlabeled) or in short ST(U). We can also derive the protocol for the case where we do not have an additional unlabeled set, i.e. Self-Training (Labeled) or ST(L), by setting . Given the teacher model , trained with mixup and confidence tempering, we first estimate pseudo-labels for the weakly augmented unlabeled data, . Similar to [xie2019self], to reduce the effect of uncertainties in the teacher model, we perform sharpening of the predictions as follows:

(4)

where denotes the indicator function and is a hyper-parameter. In practice, to make it differentiable, we implement the indicator function as . This sharpening pushes the prediction probabilities for each of the labels closer to when it is greater than , and closer to when it is less than . This formulation for multi-label predictions is akin to temperature scaling for multi-class problems, and we set in our experiments. Using the true-labels for the labeled set and the pseudo labels for the set , we update the student model parameters . More specifically, we use the following loss function:

(5)

While we use the standard binary cross entropy loss on without mixup, we make the student noised by performing mixup on the unlabeled set . The second term uses a mixup loss similar to eqn. (2), with the key difference that is replaced with KL-divergence, since the pseudo-labels are soft. Note that, we perform confidence tempering only for the unlabeled set during student training. A detailed listing of our approach can be found in Algorithm 1.

4 Results and Findings

Method
W-AUC W-PRC W-AUC W-PRC W-AUC W-PRC
Baseline (W-Aug.) 0.724 0.478 0.814 0.615 0.831 0.670
W-Aug. + Mixup 0.733 0.502 0.819 0.640 0.842 0.684
W-Aug. + Mixup + CT 0.738 0.507 0.833 0.673 0.841 0.691
W-Aug. + Mixup + CT + ST(L) 0.741 0.542 0.839 0.670 0.838 0.688
W-Aug. + Mixup + CT + ST(U) 0.75 0.538 0.844 0.684 0.846 0.688
Table 2: Performance comparison between ResNet-18 models trained using methods: Weak Augmentation (W-Aug.), Mixup, Confidence Tempering (CT), Self-Training with labeled (ST(L)) and additional unlabeled (ST(U)) data.

Model Design. Since the release of CheXpert [irvin2019chexpert], a plethora of approaches have been published for X-ray classification including CheXNext [rajpurkar2018deep]. While most successful solutions use over-parameterized, deep models such as DenseNet-121 [pham2019interpreting], more recently, even shallow network architectures have been shown to produce comparable performances [bressem2020comparing]. Given our few-shot learning setup, we find ResNet-18 to be effective in avoiding overfitting without trading-off performance [ge2018chest]. We refer to the case where we fine-tune ResNet-18 with only weak augmentation (W-Aug.) as the baseline solution. For the proposed approach, we create variants by ablating different components in Algorithm 1. Training.

All models in our study were implemented using Pytorch and trained for

epochs using the following hyperparameters: learning rate

reduced by a factor of , batch size , the Adam optimizer with weight decay and momentum . For in Eq. (1), we chose the best values in the range , while a higher works better for . We set to in Eq. (5), and chose the best values between and for and . Note, we varied and the corresponding unlabeled sets were of size respectively. We plan to release our codes after the review process. Evaluation Metrics. To evaluate performance, we use the widely-adopted metrics, namely area under ROC curve (AUC) and precision-recall curve (PRC). Due to the inherent class imbalance, we used weighted averages of the metrics using class-specific weights, which we refer to as W-AUC and W-PRC respectively.

(a) and
(b) and
Figure 1: Class-specific AUC achieved using different approaches for two few-shot scenarios. W-Aug.: weak augmentation, CT: confidence tempering, ST(L): self-training with labeled data, ST(U): self-training with additional unlabeled data.

4.1 Key Findings

Mixup leads to better models with few-shot data. As showed in Table 2, adding mixup regularization leads to better performance (in both metrics) over the baseline at different . Though mixup helps avoid overfitting in prevalent classes, it is less effective in tackling class-imbalance at lower . For example, in Fig. 1(a), the AUC scores for Cardiomegaly and Consolidation are lower than the baseline. However, it gets better with larger (Fig. 1(b)).

Figure 2: Class-activation maps for two test cases: true positive and true negative.
Figure 3: Our ResNet-18 model trained with less than of the labeled set matches the over-parameterized models trained on the full k data.

Confidence tempering provides significant gains. We also find that the CT regularization, when combined with mixup, produces crucial performance gains (see Table 2). For example, when =, W-AUC increases from to and W-PRC from to . More importantly, CT improves the AUC for classes with low support while not compromising on those with high support. This is evidenced by the improvements for Caridomelagy and Consolidation classes in Figure 1 over plain mixup, while also performing well on the more prevalent Pleural Effusion and Edema. From the saliency maps (generated using Gradcam [selvaraju2017grad]) for detecting different conditions, we find that the CT regularization leads to non-trivial probabilities even for negative findings, however, the evidences are from irrelevant parts of the image (e.g. organ boundary, background pixels) thereby avoiding spurious correlations. Self-training with few-shots matches full-shot training Finally, including the self-training strategy either with only labeled data (ST(L)) or with additional unlabeled data (ST(U)) boosts the performance even further. From Table 2, the best performing are variants that include self-training. Surprisingly, using less than of the labeled data, our approach outperforms ResNet-18 trained on the full set (Fig. 3). Further, the best performing ResNet-18 model obtained at ( of the total labeled data) is comparable to the over-parameterized DenseNet-121 model trained on the full data, which clearly emphasizes the effectiveness of our approach.

References