Semi-supervised Medical Image Classification with Global Latent Mixing

05/22/2020 ∙ by Prashnna Kumar Gyawali, et al. ∙ Rochester Institute of Technology 0

Computer-aided diagnosis via deep learning relies on large-scale annotated data sets, which can be costly when involving expert knowledge. Semi-supervised learning (SSL) mitigates this challenge by leveraging unlabeled data. One effective SSL approach is to regularize the local smoothness of neural functions via perturbations around single data points. In this work, we argue that regularizing the global smoothness of neural functions by filling the void in between data points can further improve SSL. We present a novel SSL approach that trains the neural network on linear mixing of labeled and unlabeled data, at both the input and latent space in order to regularize different portions of the network. We evaluated the presented model on two distinct medical image data sets for semi-supervised classification of thoracic disease and skin lesion, demonstrating its improved performance over SSL with local perturbations and SSL with global mixing but at the input space only. Our code is available at https://github.com/Prasanna1991/LatentMixing.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Medical image analysis via deep learning has achieved strong performance when supervised with a large labeled data set. Collecting such data sets is however costly in the medical domain since it involves expert knowledge. Semi-supervised learning (SSL) mitigates this challenge by leveraging unlabeled data.

An important goal in SSL is to avoid over-fitting the network function to small labeled data. A common inductive bias to guide this is the assumption of smoothness or consistency of the network function, i.e., nearby points and points of the same manifold should have the same label predictions. For instance, self-ensembling [6] penalizes inconsistent predictions of unlabeled data under local perturbations, and virtual adversarial training [8] maintains consistency by forcing predictions of different adversarially-perturbed inputs to be the same.

By considering perturbations around single data points, these approaches regularize only the local smoothness of the network function in the vicinity of available data points: no constraint is imposed on the global behavior of the network function in between data points [7]. To better exploit the structure of unlabeled data, we consider a strategy of mixup which was recently proposed to train a deep network on a linear mixing of pairs of input data and their corresponding labels [12]. By filling the void between input samples, this strategy regularizes the global smoothness of the function and was shown to improve the generalization of state-of-the-art neural architectures in both supervised [12] and semi-supervised learning [1]. This mixup strategy was recently extended to the latent space, showing improvement over mixing in the input space only, in a supervised setting.

We argue that the mixup strategy – training a network on linear mixing of input data and their labels – can be interpreted as regularizing the network to approximate a linear interpolation function in between data points. The gain of performance brought by mixing in the latent space, therefore, is partly owing to relaxing this linearity constraint to a portion of the network between the selected latent space and the output space. We also hypothesize that, since high-level representations in deep-networks encode important information for discriminative tasks, mixing at the latent space may provide novel training signals for SSL.

Therefore, we propose to extend this regularization, i.e., regularizing different portions of the network between the latent space and output space, for SSL and demonstrate its first application in medical image classification. In this approach, we perform linear mixing of pairs of labeled and unlabeled data – both in the input and latent space – along with their corresponding labels: for the latter, the label is guessed and continuously updated from an average of predictions of augmented samples for each unlabeled data point. We evaluate the presented SSL model on two distinct medical image classification tasks: multi-label classification of thoracic disease using Chexpert lung X-ray images [5], and skin disease classification using Skin Lesion images [3, 10]. We compare the performance of the presented method with both a supervised baseline, and several SSL methods including mixup at the input space [12], standard self-ensembling in the input space [6], and recently-introduced self-ensembling at the latent space [4]. We further provide ablation studies and analyze the effect of function smoothing achieved by the presented method.

2 Related Work

2.0.1 SSL in Medical Image Analysis:

Many recent semi-supervised works in medical image analysis have focused on explicitly regularizing the local smoothness of the neural function [2, 9, 4]. For instance, in [2], a siamese architecture for both labeled and unlabeled data points was proposed to encourage consistent segmentation under a given class of transformations. In [9], ensemble diversity was enforced with the use of adversarial samples to improve semi-supervised semantic image segmentation. In [4], the disentangled stochastic latent space was learned to improve self-ensembling for semi-supervised classification of chest X-ray images. In these works, each data point was subjected to local perturbations, e.g., elastic deformations [2], virtual adversarial direction [9], or sampling from latent posterior distributions [4], for local smoothness regularization.

In [7], the idea of promoting global smoothness in SSL was explored by constructing a teacher graph network. Similar approaches exploiting the global smoothness of neural functions, however, has not been studied in medical images.

2.0.2 Regularization with the Mixup Strategy:

The mixup strategy was first presented in [12] to improve generalization of supervised models by mixing the data pairs at the input space. It was recently extended in a semi-supervised setting where the mixing is considered for both labeled and unlabeled data points [1]

. In the meantime, a similar idea was also extended to the mixing of hidden representations

[11], demonstrating improvement over mixing at the input space, although only in supervised learning.

To our knowledge, this is the first semi-supervised classification network that employs the mixup strategy at the latent space, and the first time this type of approaches is applied to semi-supervised medical image classification.

Figure 1: Schematic diagram of the presented SSL method. During training, we continuously guess labels for the unlabeled data points (left) and then perform SSL via mixing at the input and latent space (bottom right). On the top right, we demonstrate the layers in the deep network where latent representations can be mixed.

3 Methodology

We consider a set of labeled training examples with the corresponding labels , and a set of unlabeled training samples . We aim to learn parameters for the mapping function , approximated via a deep neural network. Along the course of the training, we first guess and continuously update the labels for unlabeled data points (section 3.1). We then perform linear mixing between labeled and unlabeled data points, both in the input and latent space, along with their corresponding actual or guessed labels (section 3.2). Finally, the SSL model is trained on the mixed data sets using different losses depending on whether the mixed data point is closer to labeled or unlabeled data (section 3.3). Fig. 1 summarizes the key components of this semi-supervised learning process.

3.1 Guessing Labels

We guess the labels for unlabeled data by augmenting separate copies of data batch , and computing the average of the model’s prediction as:

(1)

The label guessing in this manner implicitly works as consistency regularization as the input transformations are assumed to leave class semantics unaffected. The guessed labels are continuously changed as the neural function is updated over the course of the training.

3.2 Input and Latent Mixup

Since the mapping function is approximated by deep neural network, we can decompose this function as = , where represents the part of the neural network that encodes the input data to some latent representation at layer , and denotes the part of neural network that decodes such latent representation to the output . Inspired by [11], we determine a set of eligible layers in the neural network from which we randomly select a layer and apply mixup in that layer (schematics in top-right; Fig. 1). For each batch, we combine and shuffle labeled and unlabeled data points to obtain a pair of random mini-batches (, ) and (, ). We pass these pairs to to obtain latent pairs (, ) and (, ), and then perform mixup at this latent layer to produce the mixed minibatch as (, ) as:

(2)

where

is the positive shape parameter of the Beta distribution, treated as hyperparameter in this work. Because the mixing could occur between labeled and unlabeled data, it is important to ensure that the mixed data fairly represent the distribution of both labeled and unlabeled data. Furthermore, as will be described in section

3.3, different losses will be used to reflect a different treatment of the actual and guess labels due to their difference in reliability. It is thus also important to know whether each mixed data point is closer to labeled or unlabeled data. To do so, we use instead of in equations (2) to ensure that is always closer to than to , allowing us to rely on the knowledge of to determine which loss to apply on the mixed data point.

Depending upon , we achieve different mixup strategies. For example, when , we only mix at the input space. When , we mix at the input space and latent layer 1. When , we mix only at the latent layer 1.

3.3 Supervised and Unsupervised Loss

To treat the actual and guessed labels differently because the latter are less reliable, we use different losses for data points that are closer to labeled versus unlabeled data. For data points in a batch that are closer to labeled data, the loss term is the cross-entropy loss:

(3)

For data points in

that are closer to unlabeled data, the loss function is defined as a

loss because it is considered to be less sensitive to incorrect predictions:

(4)

After obtaining mixed latent representation, the network is optimized by minimizing the sum of these two losses:

(5)

where is the weight term for the unsupervised loss.

4 Experiments

We first test the effectiveness of the presented SSL approach on two distinct benchmark data sets for medical image classifications, in comparison to a supervised baseline and alternative SSL models. We then analyze the effect of mixing at different latent layers, and perform ablation studies to assess the impact of different hyperparameters and the depth of latent mixing on the presented method. Finally, we discuss the effect of function smoothing achieved by the presented SSL strategy.

4.1 Data sets

We evaluate the presented model on two open-sourced large-scale medical dataset: Chexpert

[5] and ISIC 2018 Skin Lesion Analysis [3, 10].

4.1.1 Chexpert X-ray image classification:

Chexpert comprises of 224316 chest radiograph images from more than 60000 patients with labels for 14 different pathology categories. For pre-processing, we removed all uncertain and lateral-view samples from the data set, and re-sized the images to 128x128 in dimension. To ensure a fair comparison, we used the publicly available data splits for the labeled training set (ranging from 100 to 500 samples), unlabeled set, validation set, and test set [4]. For data augmentation, we rotated an image in the range of (-10, 10) and shifted (horizontal and vertical) it in the range of (0, 0.1) fraction of the image.

4.1.2 Skin image classification:

ISIC 2018 skin data set comprises of 10015 dermoscopic images with labels for seven different disease categories. Three sets of labeled training data (350, 600, and 1200) were created considering class balance. The same data re-sizing and data augmentation strategies as applied to X-ray images were applied here.

4.2 Implementation details

In our experiments, we use the AlexNet-inspired network from [4]

to match their model implementation and training procedure closely. The network consists of five convolution blocks, followed by three fully-connected layers. All the models were trained up to 256 epochs with a learning rate of 1e-4 and decayed by a factor of 10 at the 50th and 125th epochs. For label guessing, we used

= 2 copies of unlabeled data. For Chexpert, unless mentioned otherwise, we used a set of eligible layers = {0, 2, 4}, mixing parameter = 1.0 for input mixup and = 2.0 for latent mixup, and = 75 for the weight on unsupervised loss. For the skin data set, we used a set of eligible layers = {0, 1}, mixing parameter = 1.0 for both input and latent mixup, and = 50 for the weight on unsupervised loss. We used the separately held out validation set to determine the best model along the course of the training, and report the results on the test set. The code used in the experiments will be made publicly available.

Model Chexpert () Skin ()
100 200 300 400 500 350 600 1200
Supervised baseline 0.5576 0.6166 0.6208 0.6343 0.6353 0.7707 0.7991 0.8538
Input Mixup 0.6491 0.6627 0.6731 0.6779 0.6823 0.8504 0.8609 0.9040
Latent Mixup 0.6523 0.6632 0.6747 0.6795 0.6836 0.8536 0.8736 0.9036
Input+Latent Mixup 0.6512 0.6641 0.6739 0.6796 0.6847 0.8666 0.8768 0.9073
Table 1: Mean AUROC of 14 categories in the Chexpert data and seven categories in the skin data. The reported values are the average of five random seeds runs.
Model Chexpert ()
100 200 300 400 500
Image-space self-ensembling (noise) 0.6012 0.6277 0.6444 0.6550 0.6626
Image-space self-ensembling (augmentation) 0.6089 0.6301 0.6423 0.6530 0.6617
Latent-space self-ensembling 0.6200 0.6386 0.6484 0.6637 0.6697
Input + Latent Mixup (ours) 0.6512 0.6641 0.6739 0.6796 0.6847
Table 2: Mean AUROC for classification for 14 categories in the Chexpert data. The average of five randomly-seeded runs is reported by the presented method, whereas the best result is reported for the other method based on [4].

4.3 Results

4.3.1 Comparison studies:

In both data sets, we first evaluate the SSL performance of the presented model in comparison with two baselines: a fully-supervised baseline where we train the network with a supervised cross-entropy loss without mixing, and input mixup where SSL is performed with mixing at the input space only. The results are presented in Table 1. For the presented approach, we present two versions: mixing only at the latent space (latent mixup SSL), and combining both input and latent mixing (input + latent mixup SSL). As shown, mixing in the latent space in general improved the SSL performance over the baseline methods. Among the alternatives involving latent mixup, combined input and latent mixing yielded the best performance in three out of five cases in the Chexpert data set, and in all cases in the skin dataset.

Using the Chexpert data set, we further compared the presented model with existing SSL methods that focused on regularizing local smoothness of the network function via perturbations around single data points: self-ensembling at the input space [6] using Gaussian noise perturbations (with std=0.15, image-space self-ensembling (noise)) or augmention with random translation and rotation (image-space self-ensembling (augmentation)), and ensembling at the disentangling latent space (latent-space self-ensembling) [4]. The results, as presented in 2, showed a clear improvement of the presented method, supporting the advantage of regularizing the global in addition to local smoothness of neural functions.

Ablation Latent mixup Input + Latent ( = 300) mixup Presented 0.6747 0.23 0.6739 0.20 Noise 0.6508 0.13 0.6512 0.06 = 1.0 0.6736 0.17 0.6743 0.11 0.6722 0.10 0.6719 0.20
Table 3: Effect of hyperparameters (left) and the latent depth for mixing (right).

4.3.2 Ablation studies:

We study the effect of different hyperparameters and elements in the presented SSL method, using a labeled dataset of size 300. The results are shown in Table 3 (left). While each had certain effect on the model performance, the most notable difference came from the data augmentation strategy used in the presented SSL method: replacing the presented data augmentation with image-level noises notably reduced the model performance, although still at a level higher than the ensembling baselines presented in Table 2.

In Table 3 (right), we show how the model performance was affected by the depth of latent space at which the mixing was performed, in comparison to a fixed baseline (green dashed) of mixing at the input space only. As shown, mixing at the deeper layers of the network appeared to be more beneficial in general. This implies that it may be more appropriate to apply the linearity constraint, considering its limited function capacity, to the later portion of a deep neural network. It may also suggest that, since higher-level representations are more task-related, mixing in such space could help in generalization.

4.4 The effect of function smoothing

Finally, we explore the effect of function smoothing brought by the presented SSL method. Starting with a two-moon toy data set, we observed in Fig. 2 that mixing in the latent space increases the smoothness of the decision boundary in comparison to mixing at the input space only, an observation similar to [11] for supervised learning. In addition, it also provided a broader range of uncertainty (broader region of low confidence) compared to mixing in input space only.

Figure 2: Decision boundary of SSL learning on two-moon toy data, where yellow dots represent the labeled data and the rest are unlabeled data.
Figure 3:

Reliability diagram of the networks on classifying two class labels from X-ray images, trained with

= 300 labeled data. Perfect calibration is indicated by the diagonal line representing identity function.

While it is not feasible to visualizing the decision boundary for the deep neural network in the presented medical image classification, we instead investigated the effect of a more smoothed confidence measure as observed in the toy data. To do so, we consider the calibration of the model via the reliability diagram. Fig. 3 shows examples of the network in classifying two class labels: as shown, in general, the mixup strategy improves the calibration of the network compared to a supervised baseline, while mixing at the latent space tends to further marginally improve the calibration compared to mixing at the input space alone.

5 Conclusion

We presented a novel semi-supervised learning method that regularizes the global smoothness of neural functions under the combination of input and latent mixing of labeled and unlabeled data. The evaluation on public chest X-ray data and skin disease data showed that the presented method improved the classification performance over SSL focusing on local smoothness of neural functions, as well as SSL regularizing global smoothness of the entire network between the input and output space. In future work, we are interested in extending the presented method for semi-supervised medical image segmentation.

Acknowledgement. This work is supported by NSF CAREER ACI-1350374 and NIH NHLBI R15HL140500

References

  • [1] D. Berthelot, N. Carlini, I. Goodfellow, N. Papernot, A. Oliver, and C. A. Raffel (2019) Mixmatch: a holistic approach to semi-supervised learning. In Advances in Neural Information Processing Systems, pp. 5050–5060. Cited by: §1, §2.0.2.
  • [2] G. Bortsova, F. Dubost, L. Hogeweg, I. Katramados, and M. de Bruijne (2019) Semi-supervised medical image segmentation via learning consistency under transformations. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 810–818. Cited by: §2.0.1.
  • [3] N. Codella, V. Rotemberg, P. Tschandl, M. E. Celebi, S. Dusza, D. Gutman, B. Helba, A. Kalloo, K. Liopyris, M. Marchetti, et al. (2019) Skin lesion analysis toward melanoma detection 2018: a challenge hosted by the international skin imaging collaboration (isic). arXiv preprint arXiv:1902.03368. Cited by: §1, §4.1.
  • [4] P. K. Gyawali, Z. Li, S. Ghimire, and L. Wang (2019) Semi-supervised learning by disentangling and self-ensembling over stochastic latent space. In International Conference on Medical Image Computing and Computer-Assisted Intervention, Cited by: §1, §2.0.1, §4.1.1, §4.2, §4.3.1, Table 2.
  • [5] J. Irvin, P. Rajpurkar, M. Ko, Y. Yu, S. Ciurea-Ilcus, C. Chute, H. Marklund, B. Haghgoo, R. Ball, K. Shpanskaya, et al. (2019) Chexpert: a large chest radiograph dataset with uncertainty labels and expert comparison. In AAAI, Cited by: §1, §4.1.
  • [6] S. Laine and T. Aila (2017) Temporal ensembling for semi-supervised learning. In ICLR, Cited by: §1, §1, §4.3.1.
  • [7] Y. Luo, J. Zhu, M. Li, Y. Ren, and B. Zhang (2018) Smooth neighbors on teacher graphs for semi-supervised learning. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 8896–8905. Cited by: §1, §2.0.1.
  • [8] T. Miyato, S. Maeda, M. Koyama, and S. Ishii (2018) Virtual adversarial training: a regularization method for supervised and semi-supervised learning. IEEE transactions on pattern analysis and machine intelligence 41 (8), pp. 1979–1993. Cited by: §1.
  • [9] J. Peng, G. Estrada, M. Pedersoli, and C. Desrosiers (2020) Deep co-training for semi-supervised image segmentation. Pattern Recognition, pp. 107269. Cited by: §2.0.1.
  • [10] P. Tschandl, C. Rosendahl, and H. Kittler (2018) The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific data 5, pp. 180161. Cited by: §1, §4.1.
  • [11] V. Verma, A. Lamb, C. Beckham, A. Najafi, I. Mitliagkas, D. Lopez-Paz, and Y. Bengio (2019) Manifold mixup: better representations by interpolating hidden states. In

    International Conference on Machine Learning

    ,
    pp. 6438–6447. Cited by: §2.0.2, §3.2, §4.4.
  • [12] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz (2017) Mixup: beyond empirical risk minimization. arXiv preprint arXiv:1710.09412. Cited by: §1, §1, §2.0.2.