Implementation for Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation
Deep learning for medical imaging suffers from temporal and privacy-related restrictions on data availability. To still obtain viable models, continual learning aims to train in sequential order, as and when data is available. The main challenge that continual learning methods face is to prevent catastrophic forgetting, i.e., a decrease in performance on the data encountered earlier. This issue makes continuous training of segmentation models for medical applications extremely difficult. Yet, often, data from at least two different domains is available which we can exploit to train the model in a way that it disregards domain-specific information. We propose an architecture that leverages the simultaneous availability of two or more datasets to learn a disentanglement between the content and domain in an adversarial fashion. The domain-invariant content representation then lays the base for continual semantic segmentation. Our approach takes inspiration from domain adaptation and combines it with continual learning for hippocampal segmentation in brain MRI. We showcase that our method reduces catastrophic forgetting and outperforms state-of-the-art continual learning methods.READ FULL TEXT VIEW PDF
Implementation for Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation
Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation
In medical imaging, privacy regulations and temporal restrictions limit access to data . These limitations inhibit the application of traditional supervised deep learning methods for medical imaging tasks, which require the simultaneous availability of all data during training. Continual learning reframes the problem into a sequential training process, where not all datasets are available at each time step. However, when we evaluate continual learning models, they still experience a significant drop in performance, caused by catastrophic forgetting, i.e., the model adapting too strongly to particularities of the last training batch .
The leading cause of catastrophic forgetting in medical imaging is multi-domain data originating from different domains . These domains result from diverse disease patterns among the examined subjects and divergent technologies and standards used during the acquisition process. In magnetic resonance imaging (MRI) it is common practice for institutions to operate scanners from various vendors and employ disparate protocols . Additionally, MRI datasets frequently contain subjects that are either healthy or suffer from various pathological conditions. To sufficiently solve a task with data from multiple domains, models have to adapt and learn in a domain-invariant fashion.
The limited availability of multi-domain data makes developing a general-purpose model for continual hippocampal segmentation difficult. Commonly, at least two datasets from different domains are accessible simultaneously due to open access, relaxed restrictions, or access to historical data acquired with older scanners and protocols within the institution. This serendipity allows for learning a disentanglement between the domains and the content needed for segmentation, which continual learning methods do not yet exploit. Inspired by image-to-image translation (I2I), we utilize adversarial training to learn a disentanglement between a domain-invariant content representation sufficient for segmentation and a dataset-specific domain representation[10, 12, 18]. We train an encoder for each representation and share the content encoder with our segmentation module. Finally, we extend our approach to continual learning. In Fig. 1, we describe how our architecture could react to common dataset availability scenarios. We perform experiments on a subset of those hypothetical scenarios.
We contribute the Adversarial Continual Segmenter (ACS) for continual semantic segmentation of multi-domain data through adversarial disentanglement and latent space regularization that reduces catastrophic forgetting in hippocampal segmentation of brain MRIs.
Several Generative Adversarial Networks (GANs) disentangle the feature space to improve interpretability. Chen et al.  take a mutual-information-based approach while Karras et al.  directly modify the generator to achieve automatic separation of high-level attributes. Adversarial disentanglement shows promising results when applied to segmentation in a multi-domain  and multi-modal setting . In domain adaptation, Kamnitsas et al.  utilize domain-invariant features for a segmentation task  and learn those with adversarial regularization of numerous layer outputs.
Cross-domain disentanglement: I2I translation extends cross-domain feature disentanglement by splitting the latent space into a content and style encoding to achieve better translation results . The content encoding is assumed only to capture task-specific information. The style encoding holds the domain-specific information. Huang et al.  and Lee et al.  further assume that the complexity of the content outweighs the domain, which they reflect in different encoder complexities.
Continual learning: The main problem that continual learning methods face is catastrophic forgetting. For this purpose, regularization-based approaches constrain important parameters from changing [3, 28, 16, 19]
. With a similar goal, Memory Aware Synapses (MAS) learn importance weights for each parameter and use those to penalize parameter changes. Knowledge distillation methods try to preserve specific model outputs to retain performance on old data [7, 22]. Keeping a subset of training data is also widely used, e.g., in dynamic memory  and rehearsal methods . However, keeping parts of the old data is not feasible in most medical imaging scenarios due to privacy concerns .
Whereas existing continual learning methods focus solely on a sequential learning process and do not consider the simultaneous availability of datasets and their divergent domains, we specifically exploit these circumstances through adversarial disentanglement.
We first describe how we disentangle input images into content and domain representation and follow up by introducing the adversarial approach. Our domain representation models the heterogeneity of the acquisition modality, e.g., varying protocols and machine vendors, as well as different disease patterns. To learn the domain-invariant content representation, we initially train on two datasets simultaneously, as is common in I2I translation. This representation then acts as a basis for our U-Net segmenter.
Variational autoencoder loss:
We model the domain encoder as a variational autoencoder (VAE), which encodes the domain as a distribution
parameterized by varianceand mean . Because the complexity of the domain is assumed to be lower than the content complexity, we limit the dimensionality of the domain representation to one float value. We use a combination of a reconstruction loss between the input image and the generator output
and a Kullback–Leibler regularization term weighted by hyperparameter
to draw the encoding close to a normally distributed Gaussian prior of.
Latent regression loss: To prepare the domain information for the generator input, we first sample from the learned domain distribution and then pass the samples into the latent scale layer as proposed by Alharbi et al.  to produce a latent domain scale . To give additional information about the domain, we inject a domain code using central biasing instance normalization [11, 29]
. We model this domain code with a one-hot vector. As proposed by Jiang and Veeraraghavan and based on Lee et al.  and Huang et al. , we also define a latent code regression loss , which constrains the generator to produce unique mappings for a latent code .
GAN loss: We interpret the encoder of the U-Net as content encoder with the output of the bottom layer treated as content representation . Both and the scaled domain sample are fed into generator to reconstruct the input sample . To now disentangle the content and domain, we introduce adversarial training. We deploy a domain discriminator that regularizes , , and by discriminating whether an input image is part of a given domain . trains on a combination of real and generated images as described in the corresponding discriminator loss in Eq. 2. generates these images from content of and random domain . , , and counter the discriminator by minimizing a negative binary cross-entropy loss shown in Eq. 3. The training forces to produce a domain-invariant output, which we utilize for the segmentation task.
Content adversarial loss: The information about the domain can still flow from to via the skip-connections of the U-Net. To prevent this, we introduce a content discriminator inspired by Jiang et al. and Kamnitsas et al. [11, 12]. regularizes as well as the skip-connections because we want to not leak any domain information to the segmenter . The discriminator design is similar to a reversed U-Net decoder, i.e., it takes as input and the skip-connections of at each corresponding layer. We train both and using a multi-class cross-entropy loss and respectively. In the case of the generated images, we represent the domain code as a placeholder class. To ensure that the adversarial training is stable, we train the discriminators and the remaining architecture at separate steps .
Segmentation: To produce the segmentation mask , we encode input into through . We then pass and the skip-connections of into to compute . We train the U-Net for semantic segmentation through a pixel-wise combination of a Dice and a binary cross-entropy loss between the target mask and the prediction . After initially training the architecture on two or more datasets, the model has sufficiently learned the disentanglement between the content and the domain. To train on a new dataset, we only fine-tune the last four convolutional layers of .
Datasets: All datasets are from different domains and contain T1-weighted MRIs. The first dataset was released as part of the 2018 Medical Segmentation Decathlon challenge  and consists of 195 subjects in total, with 90 healthy and 105 with non-affective psychotic disorder. The scans were collected using a Philips Achieva scanner, and the mean size of the volumes is . The second dataset was published in Scientific Data  and has a T1-weighted dataset with 25 healthy subjects. All scans were acquired using MRI systems with 3 Tesla units. The mean standard resolution is . Finally, the third dataset is provided by the Alzheimer’s Disease Neuroimaging Initiative  and consists of 68 subjects that are either part of the control group or suffer from either mild cognitive impairment or Alzheimer’s disease. The images were acquired with scanners from Siemens, GE, and Philips with 23, 24, and 21 scans, respectively. The mean volume size is . All three datasets provide reference segmentation masks for the hippocampus. The masks were annotated manually with the protocols defined in the respective publications.We evaluate our architecture on all three datasets, which we will refer to as A, B, and C, respectively.
We split each dataset into 70% train, 20% test, and 10% validation and use the latter to select the hyperparameters. We train slice-by-slice and upsample via bilinear interpolation to achieve uniform slices. We compare ACS with the following baselines. First, just the U-Net block of ACS (U-Net-b) shown in Fig.2, and second a standard U-Net. Furthermore, we extend the U-Net by knowledge distillation on the output layer (OL-KD) as proposed by Michieli and Zanuttigh , and Memory Aware Synapses adapted to brain segmentation (BS-MAS) by Özgün et al. 
. As suggested in BS-MAS, we divide the surrogate loss by the number of network parameters and normalize the resulting importance values between zero and one. We report the Intersection over Union (IoU) and Dice coefficient on the hippocampus class of the test set. We use a batch size of 40 and train on four Tesla V100 SXM3 GPUs. Each method receives training over 60 epochs. After 30 epochs, the training only continues with the third dataset. We repeat training for every combination of the three datasets (AB-C, AC-B, BC-A), e.g., initial training on datasets A and B, then on C (AB-C). Additionally, we jointly train ACS and the U-Net on all datasets simultaneously (ABC). To justify the necessity for all mechanisms in our method, as described in Sec. 3, we conduct an ablation study in Tab. 3. Implementation details and qualitative results including the disentanglement can be found in the supplementary material and code on github.com/MECLabTUDA/ACS.
To assess the continual learning performance, we evaluate the results after stages 1 and 2 corresponding to Fig. 0(a). An ideal algorithm should perform equally or better on the initial training datasets from epoch 30 to 60 while it should improve on the third dataset added after 30 epochs.
Stage 1: Tab. 1 shows the results for all methods after 30 epochs on the initial two datasets. All baselines observe the same score because they apply the regularization in the second training stage, whereas ACS performs disentanglement during the initial training phase. ACS outperforms them by a Dice of (IoU ) averaged over all combinations and datasets.
|Dataset A||Dataset B||Dataset C||Average|
Stage 2: To measure overall continual learning performance, i.e., the combination of learning and forgetting, we inspect the average scores over all datasets after 60 epochs in Tab. 2. While the comparison methods’ results fluctuate, our approach achieves a consistently higher performance across all combinations and datasets. This observation manifests in an increase of the average Dice score by over the U-Net, over the U-Net-b, over BS-MAS, and over OL-KD. On combination AB-C, the U-Net drops by an IoU of (Dice ) on dataset A and by (Dice ) on dataset B. The remaining methods, including ACS, show a significantly lower decline and effectively reduce catastrophic forgetting.
|Dataset A||Dataset B||Dataset C||Average|
Combination AC-B shows the clear advantage of our approach. Dataset A contains four types of disorders recorded by a single scanner, while dataset C holds three disease patterns recorded by three different scanners. The baselines struggle with the diversity of these domains, and our model outperforms them by an IoU of (Dice ). These observations show that our model learns a sufficient content representation that can deal with diverse cognitive impairments and scans acquired by scanners of various vendors.
We trace back the low performance on dataset A in combination BC-A to A outnumbering B and C in its variability and number of subjects. The high performance of the U-Net thereby originates from overfitting on A which through its high variability still allows it to perform well on B and C. Because ACS is only fine-tuned on A, it cannot fully exploit this anomaly, but still shows competitive results.
Ablation Study: The conducted ablation study in Tab. 3 verifies that all losses contribute to the performance of ACS. Only on AC-B, the combination of all losses underperforms slightly, but remains competitive. For more detailed numbers we direct the reader to the supplementary material.
The results demonstrate that leveraging the availability of multiple datasets increases multi-domain segmentation performance by sufficiently learning a domain-invariant representation. This assumption is further supported by the joint training results in Tab. 2 showing the superior capability of ACS in comparison to the U-Net. Additionally, our method outperforms the state-of-the-art on most continual learning setups and effectively reduces catastrophic forgetting.
We propose ACS, an architecture for continual semantic segmentation of multi-domain data that leverages the simultaneous availability of datasets. In real clinical practice, multiple datasets are available at the beginning of the continual training process through, among other sources, public or accessible historical data. Unlike current methods, we leverage this serendipity to disentangle MRI images into content and domain representations through adversarial training. We then perform multi-domain hippocampal segmentation directly on the domain-invariant content representation. We demonstrate drastic improvements through domain disentanglement of multi-domain data in the first training stage. In the second training stage, the benefits of our proposal for continual learning become clear by showcasing that using all available data reduces catastrophic forgetting and outperforms current state-of-the-art methods. Our method pushes continual learning closer towards a clinical application where various degrees of variability such as disease patterns, scan vendors, and acquisition protocols exist and further enables the continual usage of deep learning models in clinical practice.
Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences 114 (13), pp. 3521–3526. External Links: Cited by: §1, §2.
Proceedings of Machine Learning Research, Vol. 121, pp. 413–423. Cited by: §2.
Denoising scanner effects from multimodal mri data using linked independent component analysis. NeuroImage 208, pp. 116388. External Links: Cited by: §1.