DeepAI
Log In Sign Up

Few-shot segmentation of medical images based on meta-learning with implicit gradients

Classical supervised methods commonly used often suffer from the requirement of an abudant number of training samples and are unable to generalize on unseen datasets. As a result, the broader application of any trained model is very limited in clinical settings. However, few-shot approaches can minimize the need for enormous reliable ground truth labels that are both labor intensive and expensive. To this end, we propose to exploit an optimization-based implicit model agnostic meta-learning iMAML algorithm in a few-shot setting for medical image segmentation. Our approach can leverage the learned weights from a diverse set of training samples and can be deployed on a new unseen dataset. We show that unlike classical few-shot learning approaches, our method has improved generalization capability. To our knowledge, this is the first work that exploits iMAML for medical image segmentation. Our quantitative results on publicly available skin and polyp datasets show that the proposed method outperforms the naive supervised baseline model and two recent few-shot segmentation approaches by large margins.

READ FULL TEXT VIEW PDF

page 3

page 6

03/18/2020

Semi-supervised few-shot learning for medical image segmentation

Recent years have witnessed the great progress of deep neural networks o...
02/19/2020

Meta Segmentation Network for Ultra-Resolution Medical Images

Despite recent progress on semantic segmentation, there still exist huge...
12/13/2019

Meta-Learning Initializations for Image Segmentation

While meta-learning approaches that utilize neural network representatio...
08/12/2021

Weakly Supervised Medical Image Segmentation

In this paper, we propose a novel approach for few-shot semantic segment...
11/19/2020

Bidirectional RNN-based Few Shot Learning for 3D Medical Image Segmentation

Segmentation of organs of interest in 3D medical images is necessary for...
03/16/2020

LT-Net: Label Transfer by Learning Reversible Voxel-wise Correspondence for One-shot Medical Image Segmentation

We introduce a one-shot segmentation method to alleviate the burden of m...
10/17/2020

Directed Variational Cross-encoder Network for Few-shot Multi-image Co-segmentation

In this paper, we propose a novel framework for multi-image co-segmentat...

1 Introduction

Automated lesion segmentation can help in accurate quantification, precise surgical removal or treatment, and planning of therapeutic procedures in the clinic. Unlike manual processes, which are usually subjective and sub-optimal, automatic methods can provide a more objective analysis of the lesions and their risks. There has been significant progress in the application of Deep learning (DL)-based methods for the semantic segmentation of clinically relevant anomalies in medical imaging data. However, the two main hindrances in transferring the success of a trained model from the lab to the real-world clinical settings are data scarcity that is related to the lack of high quality annotated training data and data mismatch which is associated with the inadequacy of the trained model to generalize well to clinical data during deployment. Data mismatch can arise due to population shift (varying demographics), acquisition shift (varying devices, protocols), prevalence shift (varying environmental factors) and selection bias(varying inclusion criteria for study)  (castro2020causality).

The state-of-the-art DL models require a large amount of high-quality, diverse datasets with pixel-wise ground truth masks that are difficult to generate and obtain freely in the medical imaging domain. With the available dataset, it is still possible to build a model by leveraging semi-supervised or few-shot learning methods (feyjie2020semi), but that may not cover all lesion categories or data from multiple sources, for example, rare disease cases, patient variability and multi-center data. Therefore, it is challenging to design a scalable system for clinical deployment. One possible solution to the dataset mismatch is domain adaptation (ganin2016domain) and domain generalization (li2017deeper)

. Domain adaptation utilizes labeled source training dataset and unlabeled target (test) dataset to develop a classifier that can be adapted to the target domain configuration. On the other hand, domain generalization capitalized on using multiple source training datasets to design a classifier that generalizes well on unseen target (test) datasets.

To mitigate the problem of data scarcity and domain generalization, meta-learning under few-shot settings has emerged as a potential solution (ravi2016optimization; finn2017model). Meta-learning is the notion of learning to learn by leveraging prior knowledge from various tasks (thrun2012learning). It has been popular in few-shot image classification (ali2020additive; mahajan2020meta) and now recently in few-shot image segmentation. Few-shot learning is a method that uses few annotated examples (support set) to make predictions on unlabeled examples (query set). Few-shot learning has been mostly explored in the natural images segmentation (zhang2020sg; zhang2019canet). Recently, it is also gaining attention in the medical image segmentation (khandelwal2020domain; feyjie2020semi; rutter2019convolutional; liu2020shape; zhang2021domain; khandelwal2020domain; xiao2021prior). Recent work by (feyjie2020semi) used a semi-supervised few-shot learning approach to perform skin lesion segmentation by feeding the learner with unlabeled surrogate tasks. Roy et al. (roy2020squeeze) applied a few-shot technique with a squeeze and excitep block architecture to perform volumetric segmentation of multiple organs in medical images. In the work proposed by Ouyang et al. (ouyang2020self), few-shot segmentation with a self-supervised method has been used to eliminate the need of having annotated medical images. They used an adaptive local pooling module in conjunction with prototypical networks to perform segmentation. Despite showing promising results within a few-shot setting, it has not demonstrated domain generalization capacity. Furthermore, the method does not incorporate the fact that during deployment, prior information of test data is generally unavailable.

A recent study also suggests that the supervised transfer learning method with fine-tuning handled the data mismatch better than the semi-supervised methods 

(oliver2018realistic). Thus, to overcome the shortcomings of few-shot learning in domain generalization, meta-learning is adopted for domain generalization. Recent work by  (dou2019domain) uses gradient-based meta-learning algorithm MAML, where the idea is to operate in semantic feature space and learn semantically invariant features across training domains. They evaluated their method with brain Magnetic Resonance Imaging (MRI) images from different datasets that inherited domain shifts. They showed consistent results across all the datasets. However, the approach has not been tested under few-shot settings. Also, the training and test set contains instances from the same anatomy. However, the Model Agnostic Meta Learning (MAML) algorithm has some caveats related to computation and memory efficiency and becomes difficult to scale when it requires many optimization steps (rajeswaran2019meta). Thus, for quick and better optimization during meta-learning, we use the Implicit Model Agnostic Meta Learning (iMAML) algorithm (rajeswaran2019meta).

This work addresses the generalization problem of supervised learning methods on unseen datasets by exploiting the bi-level optimization procedure in a meta-learning framework under a few-shot setting. To this end, we are the first to explore the efficacy of the iMAML algorithm for medical image segmentation. Our contribution includes: (i) Incorporation of attention-UNet

(oktay2018attention) mechanism for inner optimization of the weights using segmentation tasks on two different datasets during episodic meta-training, (ii) utilizing analytical solution (conjugate gradient) for computing meta-gradients to achieve optimized weights, and (iii) comprehensive analysis of the efficacy of method on publicly available skin and polyp datasets in different unseen settings.

2 Methodology

Dataset # of Images Input Size Imaging Type
Kvasir-SEG (jha2020kvasir) 1000 Variable Colonoscopy
CVC-ClinicDB (bernal2015wm) 612 Colonoscopy
ISIC-2018 (codella2019skin; tschandl2018ham10000) 2596 Dermoscopy
PH2 (mendoncya2013dermoscopic) 200 Dermoscopy
Table 1: Publicly available medical imaging datasets used in our experiments.

2.1 Dataset.

We use four widely used publicly available datasets, namely ISIC-2018 (codella2019skin; tschandl2018ham10000), PH2 (mendoncya2013dermoscopic), Kvasir-SEG (jha2020kvasir) and CVC-612 (bernal2015wm). A combination of these datasets have been used for the meta-training stage and tested on a holdout dataset to evaluate our proposed iMAML segmentation approach. Table 1 presents information of each dataset used. ISIC-2018 and PH2 datasets include benign and malignant skin lesion images, while Kvasir-SEG and CVC-612 contain protruded polyp images often seen as precancerous precursors in colorectal endoscopy. Each dataset contains acquired images with their corresponding expert annotated masks of the lesion.

2.2 Implicit Model Agnostic Meta Learning (iMAML) algorithm

Figure 1: Meta-learning with implicit gradient optimization on medical imaging datasets: Meta training is done as episodic tasks on two public datasets. In the first stage, a few-shot learning framework for each task is used for the support set, and validation is done on the query set. During the meta-testing stage, an unseen task from the third dataset is provided with the optimized weights obtained from the first stage, #1 and the gradient of the computed loss is used to readjust the final weights on only a few samples of this dataset. Finally, the fine-tuned weight is used for the inference of the test samples. In all these setting, we use attention U-Net (oktay2018attention) to achieve segmentation maps.

In general, MAML approaches are trained through a meta-learning objective function (finn2017model). However, due to the requirement of back-propagation during model training with high-order meta-gradients, MAML can suffer from vanishing gradients. In order to eliminate this problem, Rajeswaran et al. (rajeswaran2019meta) suggested to use a bi-level optimization, where: 1) inner optimization is focused on computing weights through the CNN model and 2) analytic solution is used for the outer

meta-gradient estimation (see Eq. (

1)).

(1)

In Eq. (1), and represents training (support set) and validation (query set) in the meta-training phase for the task. The task-specific parameters in the inner optimization level are represented by while the optimized weights after meta-training, i.e., the meta-parameters, are represented by . The final optimized meta-parameters are represented as . In order to avoid overfitting and help anchor the task parameter to the meta-parameter , an L2-regularization is used for the model training .

The meta-training and meta-testing stages are shown in Fig. 1. During the meta-training stage, tasks are generated. The tasks contain support set (train) and query set (validation) with few-shot instances. This means that only a few samples are chosen, such as 5 for 5-shot and 10 for 10-shot. We then initialize our attention U-Net segmentation model with random weights for the task. We then computed the loss between the predicted mask and the ground truth mask in the support set with -regularization. Validation loss on the query data completes the task for which the optimized is fed to the meta-learner where meta-gradients are analytically computed and updated as in Eq. (2

). This is then fed to the model weights of the attention U-Net architecture for further backpropagation and optimization. Such a two-level optimization scheme is iterative and done for two different datasets in our case (see Fig. 

1, top). The meta-training stage is completed once the set number of tasks are completed to obtain the final meta-learned parameters .

(2)

The second stage consists of a simple fine-tuning step on the unseen data where optimized weight , say

for simplicity, is used to optimize the loss function

in few-shot setting. The final achieved weights are then used in the final inference for direct segmentation map prediction as shown in Fig. 1 (bottom).

2.3 Loss function

A compound loss was used during training which comprises of both log-cosh-dice loss and binary cross entropy loss. The final loss function is devised as:

(3)

and have usual meanings for dice loss and binary cross-entropy loss classically used in segmentation approaches (cosh). Unlike, classical dice loss, is the Lovàsz extension (Lovsz) that tackles the non-convex nature of dice loss by smoothing it and making the function tractable and easy to differentiate. Additionally, we have added a weight decay function as an regularization with as regularization hyper-parameter and is the model weight. This allows to encapsulate better generalizability on test samples.

2.4 Network Architecture

Our proposed model architecture is shown in Fig. 1. The network consists of a sampler for creating support and query set for few-shot setting of our experiment and for specific tasks. This is then fed to an attention U-Net (oktay2018attention) architecture for the inner-level parameter optimization for each task. Finally, we have a meta-gradient optimizer for computing the optimized weights fed to the attention U-Net.

3 Experiments and Results

3.1 Setup

Experimental design.

All experiments in this work use few-shot supervised settings for which N-way, K-shot tasks are randomly generated from two publicly available datasets. In this context, N refers to the number of classes and K refers to samples from each class. The number of classes N corresponds to the number of different data pools, making our experiments a 2-way K-shot task. Finally, for the meta-testing, the learned parameters were fine-tuned over an entirely new task drawn from the held-out data pool. We present three sets of experiments: (i) tasks that comprised of samples exclusively from the Kvasir-SEG (polyp) or from the PH2 (skin) dataset, (ii) tasks that are comprised of mixed samples, and (iii) tasks trained on the same class datasets and tested on a completely different class, such as meta-training on skin datasets and meta-testing on polyp dataset.

Implementation details.

The meta-parameters were initialized with pre-trained weights from U-Net trained on brain MRI scans (pedano2016radiology). The meta-gradient is computed by applying conjugate gradient (CG) and the meta-parameters are updated using the Adam optimizer (Adam) with a learning rate of and a weight decay of . For the regularization of the computed learned weights, we fixed . The images and their corresponding ground truth were normalized in the range of [-1, 1] and resized to

. All implementations were done using the PyTorch framework, and experiments were conducted on NVIDIA Tesla V100-SXM3.

3.2 Results

Algorithm K-shots # Tasks Target Dataset DSC
Naive Baseline 1000 - ISIC 58.10
Semi-supv. Baseline (feyjie2020semi) 5 - ISIC 61.38
10 - ISIC 61.40
20 - ISIC 60.79
PMG. Baseline (xiao2021prior) 5 - ISIC 67.00
Meta-learned 5 50 ISIC 77.39
10 50 ISIC 79.17
20 50 ISIC 83.26
Table 2: Qualitative results for first experimental setup. Episodic training is done independently, first on PH2 (skin) and then on Kvasir-SEG (polyp).
Algorithm K-shots # Tasks Target Dataset DSC
Naive Baseline 1000 - ISIC 58.10
Semi-supv. Baseline (feyjie2020semi) 5 - ISIC 61.38
10 - ISIC 61.40
20 - ISIC 60.79
PMG. Baseline (xiao2021prior) 5 - ISIC 67.00
Meta-learned 5 50 ISIC 70.15
10 50 ISIC 71.69
20 50 ISIC 72.48
Table 3: Results from the second experimental setup. Episodic training on tasks comprised of both PH2 (skin) and Kvasir-SEG (polyp) instances.
Algorithm K-shots # Tasks Target Dataset DSC
Naive Baseline 1000 - ISIC 58.10
Semi-supv. Baseline (feyjie2020semi) 5 - ISIC 61.38
10 - ISIC 61.40
20 - ISIC 60.79
PMG. Baseline (xiao2021prior) 5 - ISIC 67.00
Meta-learned 5 50 ISIC 63.56
10 50 ISIC 65.09
20 50 ISIC 66.71
Table 4: Quantitative results from the third experimental setup. Episodic training on CVC-612 (polyp) and Kvasir-SEG (polyp) dataset.

We present results for three different experimental setups to illustrate the model efficacy compared to naive supervised attention UNet and two recent SOTA few-shot methods used for medical image segmentation.

1. Meta-training with samples drawn exclusively from two unique datasets and unique categories:

Table 2 presents the episodic training of our meta-learning approach on PH2 and Kvasir-SEG datasets consisting of skin and polyp categories, respectively. It can be observed that on the unseen ISIC dataset, our proposed iMAML-based segmentation outperformed the naive baseline U-Net by a very large margin and by nearly 23% and 16% on the dice coefficient compared to the baseline semi-supervised method and recent mask guided few-shot segmentation approach, respectively. The qualitative results (Figure 2) also provide insight that our method is good at segmenting different skin lesion types. The proposed meta-learning-based segmentation obtained the highest dice coefficient for different -shots (for 5, 10, and 20 samples in the meta-training stage).

2. Tasks comprising mixed samples of two unique datasets:

Table 3 shows a different setting where the samples are mixed from two datasets. Clearly, there is evidence of a performance drop in our meta-learning method. Still, the proposed algorithm outperformed similar baseline methods. The best dice score of 72.48% is obtained on the ISIC (skin) dataset under 2-way 20-shot setting.

Figure 2: Qualitative results of the proposed method on ISIC-2018 (codella2019skin; tschandl2018ham10000) from Table 2 and Kvasir-SEG(jha2019resunet++) from the supplementary material Table  1.

3. Tasks comprising samples from two unique datasets of the same class:

Table 4 and Supplementary Table 1 represents meta-training on two unique datasets but with the same categories and tested on a different class dataset. It can be observed that for episodic training conducted on polyp datasets (CVC-612 and Kvasir-SEG) and tested on skin dataset (see Table 4), our method is still able to generalize better than the naive baseline approach trained on 1,000 samples and the recent semi-supervised approach. Similar observations can be found when the method is trained on skin (ISIC-2018 and PH2) datasets and tested on the Kvasir-SEG polyp segmentation dataset (see Figure 2 and Supplementary Table 1).

Additionally, as an ablation study, we have investigated the effect of Lovász extension and standard dice loss function in a meta-learning setting. Based on the experimental results ( see Supplementary Table 2) Lovàsz extension was chosen.

4 Conclusion

We proposed a novel model-agnostic meta-learning segmentation method in a few-shot setting that uses implicit gradient-based optimization technique for improved model parameter estimation. The proposed method showed improved performance and generalization capability compared to both naive supervised techniques and recent few-shot segmentation approaches. Such a method allows the exploitation of available medical imaging datasets for training and can be effectively used on the unseen dataset without requiring ample ground truth labels. Thus, our method eliminates the dependency of requiring abundant data for each specialized medical imaging category. The proposed approach can revolutionize the clinical usability of deep learning-based techniques. In our future work, we will explore the model’s generalization capability across diverse tasks without catastrophic forgetting.

References

Supplementary material

Algorithm K-shots # Tasks Target Dataset DSC
Naive Baseline 1000 - Kvasir-SEG 60.53
Meta-learned 5 50 Kvasir-SEG 62.00
10 50 Kvasir-SEG 65.10
20 50 Kvasir-SEG 66.58
Table 1: Episodic meta-training on ISIC (skin) and PH2 (skin) dataset from experimental setup 3. Meta-testing is done on Kvasir-SEG dataset that consists of polyp class.
Algorithm K-shots # Tasks Target Dataset DSC
Dice Loss 5 20 ISIC 73.90
Log(cosh(Dice Loss)) 5 20 ISIC 76.85
Table 2: Effect of Lovász extension compared to the standard dice loss function in meta-learning setting. Meta-training was done on CVC-612 and PH2 and tested on ISIC.