Triplet Contrastive Learning for Brain Tumor Classification

08/08/2021 ∙ by Tian Yu Liu, et al. ∙ National University of Singapore 0

Brain tumor is a common and fatal form of cancer which affects both adults and children. The classification of brain tumors into different types is hence a crucial task, as it greatly influences the treatment that physicians will prescribe. In light of this, medical imaging techniques, especially those applying deep convolutional networks followed by a classification layer, have been developed to make possible computer-aided classification of brain tumor types. In this paper, we present a novel approach of directly learning deep embeddings for brain tumor types, which can be used for downstream tasks such as classification. Along with using triplet loss variants, our approach applies contrastive learning to performing unsupervised pre-training, combined with a rare-case data augmentation module to effectively ameliorate the lack of data problem in the brain tumor imaging analysis domain. We evaluate our method on an extensive brain tumor dataset which consists of 27 different tumor classes, out of which 13 are defined as rare. With a common encoder during all the experiments, we compare our approach with a baseline classification-layer based model, and the results well prove the effectiveness of our approach across all measured metrics.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Brain tumor is among the most common forms of cancer [1]

. It can be classified into various types 

[16]

such as Glioma and Meningiomas, and correctly classifying its type largely determines what treatment the physicians will prescribe. Conventionally, this task of brain tumor type classification is done by experienced physicians through manual inspection of MRI scans, which is tedious and prone to error. In the computer vision community, many medical imaging techniques have been developed to tackle this problem 

[6]. A general process involves segmentation of the tumor [10], followed by learning features for classification [1, 21, 18]. In our work, we focus on learning robust deep embeddings for brain tumor type classification given a segmented tumor image. To achieve this, we propose to leverage three techniques – contrastive learning for pre-training, data augmentation over rare cases, and triplet loss for learning efficient embeddings.

One of our key inspirations comes from the recent progress in learning efficient embeddings for face recognition and retrieval. In

[19], triplet loss is used as a training technique to directly learn optimized embeddings for face images, usually of lower dimensions, instead of an explicit classification layer. These embeddings can be used to implement tasks ranging from face recognition to clustering. Here we apply this technique in brain tumor type classification.

Unlike face recognition datasets such as [13], a unique challenge in brain tumor MRI datasets is the scarcity of labelled data [4], in terms of both ground truth segmentation masks and tumor diagnosis. The unlabeled MRI scans are generally more readily available due to the difficulty of producing the aforementioned supervision. We hence consider exploiting contrastive learning frameworks [7, 11] that are able to make use of unlabelled data for effective unsupervised model pre-training to address tumor type classification. Besides, we also leverage extensive data augmentation [18] to ameliorate the scarcity of annotated data in training brain tumor type classification models. In particular, we employ a data augmentation module for generating labelled, augmented examples. However, it has been shown generating new data through data augmentation can possibly be error-prone [9]. We then only generate augmented data for rare tumor cases with minimal labelled examples in the training dataset.

We integrate the above three new designs into a general deep 3D CNN-based encoder commonly used for processing brain tumor images [5], where the advantage of using deep-triplet loss based embeddings is highlighted compared to classical cross-entropy methods for classifying brain tumor images. Specifically, we first pre-train the model using a contrastive learning module adapted for MRI scans, effectively leveraging the more readily available unlabelled data. Next, we artificially increase the size of the labelled dataset by incorporating a rare-case data augmentation module to generate new data for rare tumor classes. Thirdly, we apply triplet loss for training the final model to learn efficient embeddings, which we then apply to the downstream brain tumor classification task.

2 Previous Work

2.0.1 Brain Tumor Classification

The classification of brain tumors into various types such as Glioma and Meningiomas is an important and active research area in the medical field. This problem is traditionally tackled through manual examination of MRI scans by physicians. Many medical image analysis works have been made, such as tumor segmentation [10] and tumor classification [1, 21, 18]

. Most of them adopt data preprocessing techniques and deep learning approaches 

[17]. For example, [3] and [2] address a binary classification problem to detect brain tumor given an MRI image. The recent work [18]

also proposes a deep Convolutional Neural Network based multi-grade brain tumor classification approach, which uses extensive data augmentation to generate new training data to relieve the data scarcity problem.

2.0.2 Learning Embeddings using Triplet Loss

Previous deep convolutional network approaches for brain tumor classification employ a classification layer trained from labelled tumor images [6]. For example, [15]

trains a classification network using a loss function combining interval loss and margin loss to increase the penalty upon misclassification.

[19] argues that this explicit classification approach is indirect and inefficient in generalizing beyond training data.

To tackle limitations of the classification layer in learning embeddings, [19] presents a comprehensive approach towards face recognition by learning unified embeddings using triplet loss. By structuring each input to the network as a triplet containing an anchor example, a positive example, and a negative example, the triplet loss function aims to minimize the embedding distance between the positive and anchor samples, while maximizing the distance between the negative and anchor samples.

The work [12] further enhances triplet loss approaches by showing that, contrary to the opinion at that time, triplet loss can yield state-of-the-art results. They proposed two triplet selection methods, Batch-Hard and Batch-All Triplet Loss, to produce better results.

2.0.3 Contrastive Learning

Contrastive learning is an effective technique for self-supervised learning 

[14]. Conclusive results and detailed approaches are explored by [7] in developing the SimCLR framework, and [11] in developing Momentum Contrast (MoCo).

With its application to self-supervised learning, contrastive learning can greatly benefit existing learning processes by enabling unlabelled data to be used for model pre-training. We find it especially useful in the field of medical imaging. In our case, i.e. classifying MRI brain tumor images, accurate labelled data in forms of diagnosis results and ground truth segmentation masks are less readily available. In this paper, we adapt the SimCLR approach for understanding medical imaging, and we further refine it for our triplet loss based process.

3 Approach

Our proposed approach consists of three main steps: 1) pre-training through contrastive learning, 2) generating new labelled data through an augmentation module, and 3) using triplet loss in the actual model training.

3.0.1 Loss

We use the Batch-Hard Triplet Loss and Batch-All Triplet Loss as implemented by [12]. For distinct tumor classes, and examples sampled from each class, the Batch-Hard Triplet Loss and Batch-All Triplet Loss are defined respectively as

(1)
(2)

where represents the embedding for the example in the class , and represents the margin parameter.

In the general case, training with triplet loss encourages the network to learn embeddings for example by ensuring that examples with the same label are close in embedding space, while examples with different labels are further apart. Batch-Hard Triplet Loss optimizes this process by selecting moderate triplets shown to be best for learning, while Batch-All Triplet Loss uses all possible triplet combinations from a single batch [12].

3.0.2 Data Augmentation Module

In order to increase the amount of training data for rare cases, we propose a rare-case data augmentation module. For each rare-case training example, we perform independent and random sequences of augmentation to obtain additional augmented examples. A single augmentation sequence consists of a random rotation, random flip (horizontal and vertical), Gaussian noise, and a random crop followed by resizing back to the original resolution.

3.0.3 Contrastive Learning

We use the contrastive learning approach to pre-train our models. We follow the SimCLR technique described by [7], but adapt it to medical images with three key differences.

Firstly, some of the main augmentation techniques suggested in SimCLR are unsuitable when applied on MRI images, such as color distortions, since these images do not contain standard RGB channels. Thus we use the data augmentation module described above for the random augmentation step.

Secondly, in order to effectively localize the target tumor sites and learn meaningful features associated with them, we generate a pseudo ground truth segmentation mask using a pre-trained tumor segmentation model. Our experiments demonstrate that despite lacking fully accurate ground truth data, pre-training using our contrastive learning approach still brings significant improvements to final results.

Lastly, in addition to using the NT-Xent (Normalized Temperature-scaled Cross Entropy) loss function described and used in [7] and [11], we also separately use two additional loss functions for contrastive learning, the Batch-All Triplet Loss and Batch-Hard Triplet Loss defined above. We further empirically determined, as demonstrated in our experiment results below, that using these triplet loss functions for our model pre-training generally yields better performance of the final models that are also trained using triplet loss approaches. We illustrate our process in Figure 1.

Figure 1: Contrastive pre-training process for a single target image in a minibatch. This process is repeated for all images in the minibatch.

4 Experiments

4.1 Dataset

We are unable to evaluate our results on public datasets due to lack of annotation. For reference of readers, the most relevant public dataset is [8], but it is unsuitable for our experiments as it only contains 3 classes - Glioma, Meningioma, and Pituitary Tumor.

Instead, we make evaluations on a labelled dataset acquired by (Anonymous Company), which contains 27 different classes of T2-weighted brain MRI scans. The tumor types include the aforementioned ones in [8], along with more diverse types such as Neurilemmoma. In total, we have 4,962 different MRI scans, split to 70% training, 10% validation, 20% testing. For evaluation purposes on our proposed metrics, we also classify images that represent of the dataset as rare cases, corresponding to 13/27 of our class labels.

Each MRI scan example in our labelled dataset is resized to 128x128 pixels, with a depth component of size 12. A labelled example also contains its corresponding tumor type and ground-truth segmentation mask. Table 1 shows the statistics of the labelled dataset. In addition to this labelled dataset, we also use a separate unlabelled dataset (lacking both ground truth tumor classes and segmentation masks) for evaluating our contrastive pre-training approach. It consists of around 22K randomly selected MRI tumor images with unknown labels. We use a separate pre-trained model to generate pseudo segmentation masks for these labels.

Type Train Val Test
0 662 94 189
1* 37 5 10
2 162 23 46
3 342 48 97
4 163 23 46
5 231 32 65
6 140 19 39
7* 32 4 8
8* 7 1 2
Type Train Val Test
9 120 17 34
10* 38 5 10
11* 32 4 8
12* 32 4 9
13* 21 3 6
14 61 8 17
15* 8 1 2
16* 5 1 1
17 101 14 28
Type Train Val Test
18* 35 5 10
19* 27 3 7
20* 29 3 7
21* 16 2 4
22 64 8 17
23 642 91 183
24 161 23 46
25 201 28 57
26 124 17 35
Total 3493 486 983
Table 1: Breakdown of train, validation, and test dataset, with each tumor type representing a distinct class
(*) refers to classes defined as rare

4.2 Implementation

In all our experiments, we use a common convolutional encoder architecture containing around 6.9M trainable parameters, followed by one dense layer of size 128, and a final dense embedding layer of size . is set to 27 (number of classes) for the cross entropy models, and empirically set to 6 for the triplet models. We also apply softmax as an output function to the embedding layer for the cross entropy model, and L2-normalization for the triplet loss model.

The works [7] and [11]

observe larger batch sizes and training for longer epochs produce better results for contrastive learning. However, due to memory limitations, we use a small batch size of 15 for contrastive learning. Hence, instead of the LARS optimizer, we use the SGD optimizer for the self-supervised pre-training step. In our case, as opposed to that observed by 

[7] and [11], we discover that training for longer epochs does not necessarily improve results (Figures 3 and 3). Hence, we only pre-train models up to 20 epochs in our experiments.

Figure 2: Contrastive learning pre-train: Effect of training time (in epochs) on Recall for various contrastive loss functions
Figure 3: Contrastive learning pre-train: Effect of training time (in epochs) on Recall for various contrastive loss functions

We train all pre-trained and non-pretrained models using mini-batches consisting of 30 examples selected using stratified random sampling. Each final model is trained for 400 epochs, which is found sufficient for convergence. During evaluation, we combine our training and validation sets. We classify each test image using K-Nearest Neighbors on this combined set, with .

Lastly, the performance of each model is measured using the following metrics: micro-averaged recall (Recall), macro-averaged recall (Recall), macro-averaged recall for rare classes (Recall), and Rank-5 accuracy (as used by [12, 22]). Micro and macro-averaged recall are computed as defined in [20].

4.3 Results

Contrastive Loss Augment Training Loss Recall Recall Recall Rank-5 Acc
- - CrossEntropy 0.435 0.241 0.0588 0.561 0.440
- - Interval [15] 0.458 0.233 0.022 0.590 0.463
NT-Xent - CrossEntropy 0.511 0.291 0.0849 0.610 0.505
- Yes CrossEntropy 0.391 0.244 0.0901 0.649 0.398
NT-Xent Yes CrossEntropy 0.502 0.284 0.0821 0.574 0.501
BATriplet - CrossEntropy 0.433 0.245 0.0495 0.555 0.417
BHTriplet - CrossEntropy 0.498 0.269 0.103 0.583 0.502
- - BHTriplet 0.414 0.260 0.101 0.667 -
- Yes BHTriplet 0.491 0.338 0.168 0.705 -
NT-Xent - BHTriplet 0.177 0.0562 0.011 0.427 -
BATriplet - BHTriplet 0.465 0.304 0.142 0.681 -
BHTriplet - BHTriplet 0.465 0.341 0.203 0.696 -
BATriplet Yes BHTriplet 0.484 0.295 0.0978 0.704 -
BHTriplet Yes BHTriplet 0.478 0.299 0.118 0.695 -
Table 2: Evaluation of various methods on unseen test set with nearest neighbor classifier. Note that classification accuracy is only included for methods using an explicit classification layer.

In order to evaluate the effectiveness of contrastive pre-training, rare-case augmentation, and triplet loss training, we experiment starting with a baseline model using traditional cross-entropy for training. We incrementally add modifications of contrastive pre-training (using NT-Xent, Batch-All Triplet, and Batch-Hard Triplet), rare-case data augmentation, as well as swapping out cross-entropy loss with Batch-Hard Triplet Loss. Our results are shown in Table 2. For all our experiments, we choose the best performing model based on validation performance on the sensitivity (Recall) metric. To prevent ambiguity, we will refer to our models using the following format:

ContrastivePretrain-Augment-Loss

From our results, it can be observed that contrastive pre-training brings large performance improvements. For instance, BHTriplet-None-BHTriplet achieves around 100% increase in , and 30% increase in Recall compared to its non-pretrained counterpart None-None-BHTriplet. However, we note applying the right contrastive learning approach is important to obtaining good performance. In NTXent-None-BHTriplet, using NTXent-based contrastive pre-training results in extremely poor performance for the triplet loss model. In general, the best performance is obtained when triplet models are pre-trained with triplet-based contrastive learning approaches, and when cross-entropy models are pre-trained with NTXent.

Our results further reveal that triplet loss based final models generally perform better on the Recall metrics and Recall, and significantly outperform cross-entropy models on Rank-5 accuracy. This highlights the effectiveness of using a triplet loss approach for learning more meaningful deep relative embeddings of MRI images. In terms of Recall, cross-entropy generally continues to yield performance gains. We reason that this shows cross-entropy models focus and perform better on majority classes in which there are sufficient training data, as opposed to rare data classes.

When we include the rare-case data augmentation module, significantly stronger performance is attained across all metrics for the triplet loss model. However, when combined with contrastive pre-trainining, apart from a slight gain when evaluating on Recall, we can see obvious decrease in performance (BHTriplet-None-BHTriplet vs. BHTriplet-Aug-BHTriplet). This is an interesting observation, and the reason behind it might require further exploration. Possibly, more epochs are required for full convergence of such models, or these two approaches are not as orthogonal as they intuitively seem.

Note, while search accuracy performs on par with classification accuracy in our experiments, such results can also be extremely dataset dependent since search retrieval and classification are different problems. Hence, this part of the results is inconclusive for its ability to generalize to other similar datasets, and only included for completeness.

5 Conclusion

We conclude from our results that training based on triplet loss can effectively learn deep embeddings from a small, imbalanced brain tumor dataset. In addition, using a contrastive pre-training approach and/or a rare-case data augmentation module can significantly improve final results by ameliorating the lack of data problem in this domain. We also highlight that our approach learns efficient embeddings for brain tumor images rather than classifications, which can be applied more generally to other downstream tasks. Despite this, we are still able to produce results which are comparable to, and even outperform, more direct classification approaches. Furthermore, based on the huge improvement in Rank-5 accuracy, one possible medical application of our approach is image retrieval in which the top

labelled matches of a target image can be retrieved from a database to assist physicians in classification. Also, since our choice of encoder was arbitrary, future work can be done to improve performance through using suitable state-of-the-art encoders instead. Lastly, due to the generality of our approach, we believe this method can be applied to other areas of medical imaging with similarly structured datasets.

References

A Supplementary Material

a.1 Dimension of Embedding Layer

Figure A.1: We note that the model performance of triplet loss is highly dependent on the size of the embedding space. Here, we use Batch-Hard triplet loss with a margin of and plot Recall against embedding space dimension. Experiments for parameter selections in this section are performed with a pre-trained set of weights obtained from a training using a much larger supervised dataset, in order to ensure better convergence and more accurate parameter choices.

a.2 Triplet Loss Margin Size

Figure A.2: We investigate the impact of the margin hyper-parameter on Recall (left) and Recall (right). The results demonstrate that while Batch-Hard triplet loss performed poorly at lower margin levels, it was able to outperform Batch-All significantly in terms of Recall at larger margins. Our result goes in contradiction to that found by [12], in which the authors found that the Batch-Hard variant consistently outperforms Batch-All across all margins.