Tensorflow implementation of our paper: Few-shot 3D Multi-modal Medical Image Segmentation using Generative Adversarial Learning
We address the problem of segmenting 3D multi-modal medical images in scenarios where very few labeled examples are available for training. Leveraging the recent success of adversarial learning for semi-supervised segmentation, we propose a novel method based on Generative Adversarial Networks (GANs) to train a segmentation model with both labeled and unlabeled images. The proposed method prevents over-fitting by learning to discriminate between true and fake patches obtained by a generator network. Our work extends current adversarial learning approaches, which focus on 2D single-modality images, to the more challenging context of 3D volumes of multiple modalities. The proposed method is evaluated on the problem of segmenting brain MRI from the iSEG-2017 and MRBrainS 2013 datasets. Significant performance improvement is reported, compared to state-of-art segmentation networks trained in a fully-supervised manner. In addition, our work presents a comprehensive analysis of different GAN architectures for semi-supervised segmentation, showing recent techniques like feature matching to yield a higher performance than conventional adversarial training approaches. Our code is publicly available at https://github.com/arnab39/FewShot_GAN-Unet3DREAD FULL TEXT VIEW PDF
Tensorflow implementation of our paper: Few-shot 3D Multi-modal Medical Image Segmentation using Generative Adversarial Learning
Semantic segmentation is commonly used in medical imaging to identify the precise location and shape of structures in the body, and is essential to the proper assessment of medical disorders and their treatment. Although extensively studied, this problem remains quite challenging due to the noise and low-contrast in medical scans, as well as the high variability of anatomical structures. Recently, deep convolutional neural networks (CNNs) have led to substantial improvements for numerous computer vision tasks like object detection, image classification [2, 3] and semantic segmentation [4, 5], often achieving human-level performance. Yet, a major limitation of CNNs is their requirement for large amount of annotated data. This limitation is particularly important in medical image segmentation, where the annotation process is time-consuming, and prone to errors or intra-observer variability.
Semi-supervised learning approaches alleviate the need for large sets of labeled samples by exploiting available non-annotated data. In such approaches, only a limited number of samples with strong annotations are provided. A good generalization can however be achieved by considering unlabeled samples, or samples with weak annotations like image-level tags [6, 7, 8, 9], bounding boxes [10, 11] or scribbles [12, 8], during training. Recently, approaches based on adversarial training, and in particular Generative Adversarial Networks (GANs) , have shown great potential for improving semantic segmentation in a semi-supervised setting [14, 15]. Nevertheless, their application to the segmentation of 3D medical images with multiple modalities remains to this day limited [16, 17].
Our work addresses the problem of segmenting 3D multi-modal medical images from a few-shot learning perspective. We leverage the recent success of GANs to train a deep model with a highly-limited training set of labeled images, without sacrificing the performance of full supervision. The main contributions of this paper can be summarized as:
A first approach to apply GANs for the semi-supervised segmentation of 3D multi-modal images. We demonstrate this approach to significantly outperform state-of-art segmentation networks like 3D-UNet when very few training samples are available, and to achieve an accuracy close to that of full supervision.
A comprehensive analysis of different GAN architectures for semi-supervised segmentation, where we show more recent techniques like feature matching to have a higher performance than conventional adversarial training approaches.
The rest of this paper is organized as follows. In Section II, we give a brief overview of relevant work on semantic segmentation with a focus on semi-supervised learning. Section III then presents our 3D multi-modal segmentation approach based on adversarial learning, which is evaluated on the challenging task of brain segmentation in Section IV. Finally, we conclude with a summary of our main contributions and results.
Our method draws on recent successes of deep learning methods for semantic segmentation, in particular semi-supervised and few-shot learning approaches based on adversarial training.
Several semi-supervised deep learning methods have been proposed for image segmentation [18, 19, 20, 21]. A common strategy, based on the principle of self-training, involves updating network parameters and segmentation alternatively until convergence . However, if initial class priors given by the network are inaccurate, segmentation errors can occur and be propagated back to the network which then re-amplifies these errors. Various techniques can be used to alleviate this problem, including model-based  or data-based [23, 21] distillation, which aggregate the prediction of multiple teacher models or a single teacher trained with multiple transformed versions of the data to learn a student model, and employing attention modules . Yet, these approaches are relatively complex, as they require to train multiple networks, and are thus not suitable when very few training samples are available. Another popular approach consists in embedding the network’s output or internal representation in a manifold space, such that images having similar characteristics are near to each other . An important limitation of this approach is its requirement for an explicit matching function, which may be hard to define in practice.
Adversarial learning has also shown great promise for training deep segmentation models with few strongly-annotated images [14, 15, 24, 25]. An interesting approach to include unlabeled images during training is to add an adversarial network in the model, which must determine whether the output of the segmentation network corresponds to a labeled or unlabeled image [24, 25]. This encourages the segmentation network to have a similar distribution of outputs for images with and without annotations, thereby helping generalization. A potential issue with this approach is that the adversarial network can have a reverse effect, where the output for annotated images becomes growingly similar to the incorrect segmentations obtained for unlabeled images. A related strategy uses the discriminator to predict a confidence map for the segmentation, enforcing this output to be maximum for annotated images . For unlabeled images, areas of high confidence are used to update the segmentation network in a self-teaching manner. The main limitation of this approach is that a confidence threshold must be provided, the value of which can affect the performance. Up to date, only a single work has applied Generative Adversarial Networks (GANs) for semi-supervised segmentation . However, it focused on 2D natural images, whereas the current work targets 3D multi-modal medical volumes. Generating and segmenting 3D volumes brings additional challenges, such as computational complexity and over-fitting.
Few-shot learning methods seek good generalization on problems with a very limited labeled dataset, typically containing just a few training samples of the target classes [26, 27]. Even tough interest in such methods is increasing, most works focus on classification , where there is no need to generate a structured output. Shaban et al.  pioneered one-shot learning for semantic segmentation, which required only a single image and its corresponding pixel-level annotation per class, for learning. More recently, Rakelly et al. proposed to train guided  and conditional  networks with few support samples and sparse annotations, achieving similar performance to . A common feature of these methods is that a network pre-trained on a similar task is employed as initial model, which is then refined by the one/few shot support queries. However, adopting this strategy to segment medical images might be challenging as the gap between source and target tasks is often large. As shown in our experiments, the approach proposed in this work can achieve competitive performance using as few as two training samples, without the need for a pre-trained model.
, fully-annotated images are typically employed to train the network using a pixel-wise loss function like cross-entropy. As mentioned above, this is not possible in our case since the number of annotated images for training is highly limited. As in other semi-supervised segmentation approaches, we alleviate this problem by also incorporating unlabeled images in the training process. However, unlike these methods, we also make use of synthetic (i.e.,fake) images generated by a GAN.
To include labeled, unlabeled and fake images during training, we extend the classification approach of Dai et al.  to segmentation. In this GAN-based approach, a generator network is used to produce realistic fake examples and a discriminator to distinguish these fake examples from true data. Instead of predicting classes, as in standard methods, the model predicts (+1) classes, where the additional class corresponds to fake examples. However, as we show in following subsections, this formulation can be recast back into a -class problem using a simple re-parametrization trick. This overall strategy helps the model give plausible predictions for unlabeled true data by restricting its output for fake examples.
For adapting this model to the segmentation of 3D multi-modal images, several changes must be made. During training, 3D images must be processed in smaller sub-regions (i.e., patches) to deal with the much greater memory and computational requirements compared to 2D images. While training patches have a fixed size, test images may have arbitrary size. To address this issue, we make the segmentation network to be fully-convolutional . Another challenge comes from the generation of fake patches. Standard techniques for training GANs may lead to instability and poor results, especially in the case of semi-supervised learning 
. This problem is even more significant in the case of 3D multi-modal patches, whose distribution is harder to estimate with a parametric model. In addition, although generated patches must be realistic-looking, they should be sufficiently different from true unlabeled patches, otherwise the wrong information will be learned.
The following subsections provide a more detailed description of the proposed method. We start by giving a general formulation of generative adversarial networks (GANs). Then, we show how GANs can be used to include unlabeled and fake images in a semi-supervised segmentation setting. Finally, we explain how the standard GAN model is modified to fit our problem setting.
Standard GANs  model a min-max game between two parametrized neural nets: a generator and a discriminator . The idea is to train the generator to generate images from a learned data distribution , while simultaneously training the discriminator to differentiate between these generated images and true examples. Specifically, is trained to map a random (i.e., noise
) vectorto a synthetic image vector
. A common choice for sampling the noise is to use a uniform distribution, e.g.,. On the other hand, is trained to distinguish between real examples and synthetic examples . Here,
represents the probability that a samplebelongs to original data distribution. Networks and play a two player min-max game with value function :
Consider a standard CNN-based model for segmenting a 3D image into regions defined by . This model takes as input and outputs a
-dimensional vector of logits, where is the number of classes labels and is the index of image voxels. This output can be turned into class probabilities by applying the softmax function:
In a fully-supervised setting, the model is typically trained by minimizing a segmentation loss function, for instance, the cross-entropy between the true labels and the model’s predicted probabilities.
As shown in Fig. 1, the proposed model extends the standard full-supervision approach by incorporating unlabeled data and samples from the generator during training. Toward this goal, we label generated samples with a new class and thus increase the dimension of the segmentation model’s output to . In this new formulation, is the probability that voxel of input is fake. Moreover, to learn the basic structure of images from unlabeled data, we constraint the output to correspond to one of the classes of real data, which can be done by maximizing
With this, we can now define the loss functions used for training the discriminator and generator networks.
Suppose have a similar number of labeled, unlabeled and fake images, so that each type of images has equal importance in training. Our discriminator loss function can be defined as the sum of three terms:
The loss for labeled images is the same as in standard segmentation networks. In this work, we consider the mean cross-entropy:
In the case of unlabeled images, we maximize the term in Eq. (3), which is the same as minimizing
Finally, for generated images, we impose each pixel of an input patch to be predicted as fake, and define the loss as
In , it was shown that the optimal strategy for minimizing Eq. (4) is to have and , where is an undetermined scaling function for the -th pixel. It was also found that the having outputs is an over-parameterized formulation, since subtracting a general function from each logit does not change the output of the softmax. By using the logit of the fake class as substracted function, we obtain , and thus have only effective (i.e., non-zero) outputs. Employing these “normalized” logits in the softmax of Eq. (2) then leads to the following modified loss functions:
In summary, the idea is to plug a standard state-of-the-art segmentation model in the discriminator of the proposed network, where the labeled component of the loss remains unchanged (i.e, cross-entropy), and introduce two extra terms, the unlabeled term and the fake term , which are analogous to the two components of a discriminator loss in standard GANs.
The most common strategy for training the generator consists in maximizing the loss of Eq. (7). However, as demonstrated in , this can lead to instability and poor performance in the case of semi-supervised learning. Following these results, we instead adopt the Feature Matching (FM) loss for the generator, which is more suited to our problem. In FM, the goal of the generator is to match the expected value of features in an intermediate layer of the discriminator:
In this work, contains the activations of the second last layer of the encoding path in our model. In preliminary experiments, we found this choice to give slightly higher performance than using the encoder’s last layer.
In semi-supervised learning, having a good generator can actually deteriorate performances since in this case unlabeled and fake images cannot be separated. It is therefore desirable to have a generator that can generate samples outside the true data manifold, which is called a complement (or bad) generator .
The FM generator loss, described in the previous section, works better than standard training approaches in a semi-supervised setting because it performs distribution matching in a weak manner. However, it may still face two significant problems. First, since an FM-based generator can assign a significant amount of probability mass inside the support, an optimal discriminator will incorrectly predict samples in that region as fake. Secondly, as FM only matches first-order statistics, the generator might end up with a trivial solution, for example, it can collapse to mean of unlabeled features. The collapsed generator will then fail to cover some areas between manifolds. Since the discriminator is only well-defined on the union of the data supports of and , the prediction result in such gaps is under-determined.
The first problem is less likely in our case, since multi-modal 3D patches are complex structures to generate and, thus, it is more probable for the FM generator to sample images outside the true data manifold. To deal with the second problem, we can increase the entropy of the generated distribution by minimizing a modified loss for the generator:
In this formulation,
is defined as a diagonal Gaussian with bounded variance, i.e., with , where and are neural networks.
The overall architecture of the complement generator is illustrated in Fig. 1. As presented above, the FM loss uses features from the second layer of the discriminator (i.e., the U-Net segmentation network). Moreover, the fake image generator is paired with an encoder which learns a reverse mapping from generated images to corresponding noise vectors. All components the architecture are trained simultaneously in an end-to-end manner.
The proposed model is evaluated on the challenging tasks of segmenting infant and adult brain tissue from multi-modal 3D magnetic resonance images (MRI). The goal of our experiments is two-fold. First, we assess our GAN-based model in a few-shot learning scenario, where only a few training subjects are provided. Our objective is to provide performance similar to that of full-supervision, while using only 1 or 2 training subjects. Second, since the application of GANs to semi-supervised learning, and particularly to segmentation, is a new topic, we conduct experiments to measure to impact of various GAN techniques (e.g., feature matching, complementary generator, etc.) on segmentation accuracy. Before presenting results, we give details on the dataset, evaluation metrics and implementation used in the experiments.
We first used data from the iSEG-2017111See http://iseg2017.web.unc.edu/reference/ Challenge on infant brain MRI segmentation. The goal of this challenge is to compare (semi-) automatic algorithms for the segmentation of infant (6 months) T1- and T2-weighted brain MRI scans into three tissue classes: white matter (WM), gray matter (GM) and cerebrospinal fluid (CSF). This dataset was chosen to substantiate our proposed method: it contains the 3D multi-modal brain MRI data of only 10 labeled subjects, each one requiring about a week to annotate manually. Additionally, 13 unlabeled testing subjects are also provided. To further validate results, we also tested our method on segmenting adult brain tissues (i.e., WM, GM, CSF) from the MRBrains-2013222See http://mrbrains13.isi.uu.nl Challenge dataset, which contains the T1, T1-IR and FLAIR scans of 20 adult subjects. Ground truth labels are provided for only 5 training subjects, which form the training set. The test set contains the unlabeled scans of the 15 remaining subjects.
Segmentation accuracy is assessed using two well-known metrics, respectively measuring spatial overlap and surface distance :
Dice similarity coefficient (DSC): This widely-used metric compares segmented volumes based on their overlap. Given a reference segmentation , the DSC of a predicted segmentation is defined as
DSC values range between 0 and 1, with 1 corresponding to a perfect overlap.
Average Symmetric Surface Distance (ASD): This metric computes an average of distances from points on a surface to the nearest point on another surface. Let and be the reference and predicted segmentation boundaries, it can be defined as
The state-of-art 3D U-Net  model was chosen as segmentation network in our architecture. In order to use this model in the proposed GAN framework, the following changes were made:
As suggested in , ReLUs were changed to leaky ReLUs, allowing a small gradient for non-active units (i.e., units whose output is below zero).
Max pooling was replaced by average pooling, as it leads to sparse gradient which was shown to hamper GAN training.
These modifications to 3D U-Net have helped make the training more stable and improve the performance. Other elements of the discriminator’s architecture are the same as in the original U-Net.
For generating 3D patches, we chose the volume generator proposed by Wu et al. 
, which was shown to provide good results for various types of 3D objects. This model leverages the power of both general-adversarial modeling and volumetric convolutional networks to generate realistic 3D shapes. For implementing the encoder, we used a standard three-layer 3D CNN architecture, whose output vector is twice the size of the generator’s input noise vector. This network estimates the mean and standard deviation of the noise vector from which the given image is generated. It was found during preliminary experiments that using batch normalization in the generator and encoder gives best results. Therefore, this normalization setting was used for our GAN-based model.
To train the proposed GAN based model, the 10 labeled subjects data (i.e., examples) of the iSEG-2017 dataset were split into training (1 or 2 examples), validation (1 example) and testing (7 fixed examples). The 13 unlabeled examples of the testing dataset were instead used to train the GAN.
Similarly, for the MR Brains 2013 dataset, the 5 labeled examples were split into 1 training, 1 validation and 3 testing examples, respectively. As before, the 15 unlabeled subject data were used as unlabeled data for training the GAN.
As preprocessing, N4 bias field correction was applied to images, followed by intensity normalization. To train the model, 3232
32 patches were extracted from 3D scans with a step size of 8 voxels in each dimension. This serves two purposes: reduce computational requirements compared to employing whole 3D images, and increase the number and diversity of training examples. No other data augmentation was used, as our goal is to compare the performance of the two models in a few-shot learning scenario, not to achieve state-of-the-art performance on the tested datasets. The Adam optimizer was employed for mini-batch stochastic gradient descent (SGD), with a batch size of 30. For all networks (i.e., U-Net based discriminator, generator and encoder), we used a learning rate of 0.0001 and a momentum of 0.5.
|1 training image||2 training images|
|Ours (normal GAN)||0.66||1.75||0.62||1.91||0.81||0.62||0.71||0.96||0.72||0.89||0.82||0.51|
|Ours (FM GAN)||0.74||0.82||0.72||0.85||0.89||0.27||0.80||0.54||0.80||0.58||0.88||0.25|
To validate the proposed model in a few-shot learning scenario, we trained it end-to-end with only 1 or 2 training examples. The objective is to show that, when training with few labeled examples, our model outperforms U-Net and gives performance close to full-supervision without data augmentation. While training, the model is validated with a single labeled example, thus making the total number of labeled examples no greater than 3. To reduce bias while estimating performance, we repeated this process with 3 different combinations of training and validation examples, while keeping the 7 test examples fixed, and report the average result.
Table I gives the mean DSC and ASD obtained by the 3D U-Net modified as described in Section IV-B1 (Basic U-Net), and our proposed model with standard adversarial loss (Normal GAN), feature matching (FM GAN), or the complementary GAN model of Section III-C (bad-GAN). Results are reported for 1 and 2 labeled training examples. We see that the proposed GAN-based method significantly outperforms basic U-Net when a single labeled example is available, with DSC improvements of 5-8% for WM, 13-23% for GM, and 1-9% for CSF. Important improvements are also observed for 2 labeled examples, with a DSC increase of 3-12%, 10-18% and 1-6% for WM, GM and CSF, respectively. Similarly, we see a significant reduction in ASD for both cases.
Comparing the different GAN models, we find that feature matching (without entropy term) yields the best performance, for all tissue classes and test cases. Compared to bad-GAN, it provides DSC improvements of 5% for WM, 4% for GM and 3% for CSF, in the case of 1 labeled example, and improvements of 6% for WM, 4% for GM and 4% for CSF, when 2 labeled examples are employed. In the next section, we analyze in greater detail the behavior of these two GAN models to better understand these results.
Next, we evaluate the impact of supervision on the performance of 3D U-Net and FM GAN by increasing the number of labeled images in training from 1 to 5. Results of these experiments are plotted in Figure 2. In this experiment, we used a single validation example and a fixed set of 4 test examples. It can be seen that, compared to U-Net, FM GAN gives a higher or equal DSC in all cases, and that the accuracy of models is comparable for 5 labeled examples. Although 5 examples seems like a relatively small number, one should remember that networks are trained using patches sampled over these images, and thus these networks see thousands of training patches.
To visually appreciate the performance of the proposed model, Figure 3 shows the segmentation output of Basic U-Net and FM GAN for two different subjects, when training with 1, 2 or 5 labeled examples. If 1 or 2 labeled examples are used, standard U-Net gives poor results, showing the inability of this model to work in a few shot learning scenario. In contrast, FM GAN can better learn the structure of brain tissues by using unlabeled images. Moreover, following the results of Fig. 2, we see that the segmentation of FM GAN is visually similar to U-Net when 5 labeled images are employed in training.
Loss of U-Net and our FM Gan model at different training epochs, measured on a random subset of validation patches.
Results of the previous experiment showed the proposed model to outperform standard U-Net when very few labeled images are provided in training. In this section, we try to explain how the unlabeled and fake components of the loss function enable such improvements. Moreover, we analyze the tested GAN models to determine which elements contribute to having accurate segmentations.
Figure 4 plots the training losses of U-Net and our FM GAN model, at different training epochs, when using a single labeled example. For U-Net, we show the cross-entropy loss of Eq. (5) and validation error (i.e., mean percentage of incorrectly predicted voxels in randomly selected patches of validation images). In the case of FM GAN, we also report the unlabeled image of Eq. (6) and fake image loss of Eq. (7). These plots clearly show how U-Net, being a high-capacity model, quickly overfits the data. In contrast, our FM GAN model also learns from unlabeled and generated data and, hence, generalizes better the validation data.
|Method||Unlabeled loss||Fake loss|
|Ours (normal GAN)||0.0015||0.0060|
|Ours (FM GAN)||0.0014||0.0020|
To better asses the impact on segmentation of adding unlabeled and generated images, Table II gives the mean unlabeled and fake loss of the discriminator computed over test data. For this experiment, we extracted labeled patches from test images and generated an equal number of fake patches with the different GAN models. The high fake loss value of simple U-Net confirms that this model cannot discriminate between real and fake data. This limitation of U-Net can also been seen in Fig. 5, which gives the predicted probabilities of U-Net and our FM GAN model for a fake input patch. Unlike U-Net, the proposed model gives a fake class probability near to 1 (i.e., white color) for all voxels of the patch.
|Ours (FM GAN)||0.75||0.96||0.72||1.10||0.55||2.04|
Our results indicate that FM GAN outperforms the more complex bad-GAN model (see Table I and Fig. 6), which also adds an entropy term to have a more diverse distribution of generated examples. For bad-GAN, we only incorporated the variational inference (VI) loss, as the low density enforcement term was not relevant in our setting given the poor sample quality. It was found that adding the VI term  does not improve the performance of FM GAN for semi-supervised 3D image segmentation. One possible explanation for this is the poor sample quality, which is further aggravated when increasing the entropy.
plots the feature matching loss for both the FM GAN and bad-GAN models. It can be seen that the feature matching loss of FM GAN converges quickly and remains less than that of bad-GAN, indicating a better sample generation. Patches generated by bad-GAN have a higher chance of being far from the true distribution and, hence, we may fail to learn a discriminator with a tight boundary of the true manifold. For example, there might be generated patches which are outside the true manifold but classified as true by the discriminator. This can also be seen in TableII, where the average fake loss of bad-GAN is greater than that of FM GAN. Overall, the fake loss has an important contribution to performance in semi-supervised segmentation. It should produce samples that are different from true unlabeled images, while remaining close enough so that the discriminator learns useful information.
To validate our results, we also ran similar experiments on the MR Brains dataset using just 1 training example, the results of which are listed in Table III. As in previous experiments, we see that the proposed technique outperforms standard U-Net, with DSC improvements of 13.6% for WM, 7.5% for GM, and 34% for CSF. Likewise, our technique also yields a significant reduction in ASD: 46% for WM, 37% for GM, and 38% for CSF. These results suggest the usefulness of our method for across different 3D multi-modal segmentation tasks.
We presented a method for segmenting 3D multi-modal images, which can achieve performances comparable to full-supervision with only a few training samples. We showed how the method uses unlabeled data to prevent over-fitting, by learning to discriminate between true and generated fake patches. The proposed model can be employed to enhance any segmentation network in a low data setting, where the network fails to produce a good segmentation output. It also provides a new technique for few-shot learning, obviating the need for an initial pre-trained network by leveraging the semi-supervised learning ability of GANs. Moreover, results on the iSEG-2017 and MRBrains 2013 datasets showed our method’s potential for reducing the burden of acquiring annotated medical data.
Our experiments explored different generator losses and their impact on segmentation performance. We showed empirically that FM GAN performs better than bad-GAN for segmenting 3D multi-modal brain MRI images. Our method can be extended to other 3D multi-modal image segmentation tasks with any state-of-the-art segmentation network as discriminator.
Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
M. Rezaei, K. Harmuth, W. Gierke, T. Kellermeier, M. Fischer, H. Yang, and C. Meinel, “A conditional adversarial network for semantic segmentation of brain tumor,” inInternational MICCAI Brainlesion Workshop. Springer, 2017, pp. 241–252.