Boosting segmentation with weak supervision from image-to-image translation

04/02/2019 ∙ by Eugene Vorontsov, et al. ∙ 24

In many cases, especially with medical images, it is prohibitively challenging to produce a sufficiently large training sample of pixel-level annotations to train deep neural networks for semantic image segmentation. On the other hand, some information is often known about the contents of images. We leverage information on whether an image presents the segmentation target or whether it is absent from the image to improve segmentation performance by augmenting the amount of data usable for model training. Specifically, we propose a semi-supervised framework that employs image-to-image translation between weak labels (e.g., presence vs. absence of cancer), in addition to fully supervised segmentation on some examples. We conjecture that this translation objective is well aligned with the segmentation objective as both require the same disentangling of image variations. Building on prior image-to-image translation work, we re-use the encoder and decoders for translating in either direction between two domains, employing a strategy of selectively decoding domain-specific variations. For presence vs. absence domains, the encoder produces variations that are common to both and those unique to the presence domain. Furthermore, we successfully re-use one of the decoders used in translation for segmentation. We validate the proposed method on synthetic tasks of varying difficulty as well as on the real task of brain tumor segmentation in magnetic resonance images, where we show significant improvements over standard semi-supervised training with autoencoding.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 3

page 6

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Semantic object segmentation from natural images is known to perform well with deep neural networks but these require a large quantity of pixel-level annotations. Obtaining a sufficient quantity of annotations is difficult and sometimes impractical; on the other hand, unlabeled or weakly categorized data is easier to obtain. This motivates the need for weakly or semi-supervised models, to leverage unlabeled or weakly labeled data.

Figure 1: Left: Images presenting digits transformed to images with only the background clutter, a residual image that isolates the digit, and a segmentation of the digit. Right: Images presenting cancer lesions in the brain are transformed to healthy images, a residual image that isolates the lesion, and a segmentation of the lesion.

Many works have explored the use of generative adversarial networks (GANs) to improve semantic segmentation of medical images. However, while these methods make better use of the training data by either improving the training objective [23, 9, 42, 30, 39, 47, 43, 35, 25] or performing data augmentation within the training set [14, 31, 20], they do not augment the training set to better cover the variations in the data population. On the other hand, some works have explored unsupervised anomaly localization using autoencoding [5] or GANs [37, 7] to learn a generative model of healthy cases. Another GAN based approach is to train an error model that could be used for updates on unlabeled data [45]. These approaches are approximate and do not make full use of available weak labels (healthy and sick domain labels). Making better use of available data, some recent approaches relied on image-to-image translation between sick and healthy cases [4, 2] but these were unsupervised and either approximate or not validated against baselines or on multiple tasks.

We focus on the common scenario in medical imaging, where a large number of images lack segmentation labels but are known to be either healthy or sick cases. This knowledge can be considered as weak proxy labels that identify whether there is something to be segmented in an image. For example, when segmenting cancerous lesions, images marked ‘healthy’ do not contain cancer while images marked ‘sick’ do. We argue that the objective of translating from sick to healthy images is a good unsupervised surrogate for segmentation. Consequently, we develop a semi-supervised segmentation method 111Code for model and experiments: [hidden during review period] with image-to-image translation, trained on unpaired images from sick and healthy domains.

Considering the sick domain as a superset of the variations in the healthy domain, we encode images into two latent codes: variations that are common to both and variations that are unique to the sick domain. This allows us to split decoding into two parts: (1) a ‘healthy’ image decoder that interprets the common latent code and (2) a residual decoder that additionally considers the unique code in order to compute a residual change to the ‘healthy’ output image, making it ‘sick’.

Because the output of the residual decoder is highly correlated with the segmentation output we can re-use the decoder for segmentation. In doing so, we maximize the proportion of model parameters that receive updates even when there are no pixel-level annotations available to guide image segmentation during training. Examples of these mappings, including both decoders and the segmentation output, are shown in Figure 1. Furthermore, whereas image-to-image translation models do not use long skip connections from the encoder to the decoder, we propose a long skip connection variant in our method. Long skip connections are common with supervised encoder-decoder models [10], where they help preserve spatial detail in the decoder even when the encoding is very deep. Overall, we summarize our contributions as follows:

  • We propose a semi-supervised segmentation method leveraging image-to-image translation.

  • We propose the use of (new) long skip connections for image-to-image translation, from encoder to decoder.

  • We propose a dual-function decoder (translation, segmentation), thus maximizing the number of parameters updated in the absence of pixel-level annotations.

  • We validate our method on challenging synthetic data and real brain tumor MR images, significantly improving over well-tuned baselines.

2 Related works

Image-to-image translation. Image to image translation was most prominently done with the CycleGAN [46] which does bidirectional translation between two domains. UNIT [26] proposed a similar approach but with a common latent space, shared by both domains, from which latent codes could be sampled. Augmented CycleGAN [1] and Multimodal UNIT [19] respectively extended both methods from one-to-one mappings to many-to-many.

Disentangling domain-specific variations. Both [19] and [24] present methods that learn shared and domain-specific latent codes. These differ from the proposed method in that they do not segment and do not assume (and benefit from) an absence domain as a subset of a presence domain. In addition, the domain-specific ”style” codes are encoded with a shallow network which may bias the model to indeed learn domain-specific styles; whereas, the proposed method uses deep encodings for all codes. Explicit disentangling of variations between these codes has recently been proposed in [13] by way of a gradient reversal layer [12].

Data Augmentation. GANs are used to augment liver lesion examples for classification in [11]. [22] synthesize data to cover uncommon cases such as peripheral nodules touching the lung boundary. [14] and [31] introduce a segmentation mask generator to augment small training datasets.

Anomaly localization. Generative models have been used to fit the distribution of healthy images in order to find anomalies. To localize lesions in brain MR images that are known to be either healthy or sick, [5] fit the healthy data distribution with an autoencoder. Given an image presenting a lesion, the lesion is localized via the residual of its reconstructed image which is likely to appear healthy. Similarly, [37] and [7] employ GAN to locate anomalies in retinal images and brain MR images, respectively. While these models require that weak ’sick’ or ’healthy’ labels are known, they are trained only on the latter. Furthermore, they allow only rough unsupervised localization.

Image-to-image translation for segmentation. By translating from sick to healthy images, [4] trains a network to localize Alzheimer’s derived brain morphological changes using the output residual. [2] further proposes a multi-modal variant of CycleGAN [46] to translate in both directions, applied to brain MR images with cancer. Sick images that are translated to healthy images are translated back to the original sick image via a residual inpainting of the lesions. Lesions are localized and segmented by predicting a minimal region to which to apply inpainting. Segmentation is unsupervised, with a prior that minimizes the inpainting region. This method has not been compared to other unsupervised methods, has been tested on a single dataset, and has not been extended to a weakly- or semi-supervised setting. Our work differs in that we develop a semi-supervised architecture that uses fewer parameters by reusing mappings, we skip information from the encoder to decoders, we propose a decoder that is trained with both translation and segmentation objectives, and we validate the method on multiple tasks.

Adversarial semi-supervised segmentation. A semi-supervised segmentation method for medical images was proposed by [45], where a discriminator learns a segmentation error signal on the annotated dataset which can be applied on unannotated data. This method may be limited in how well it could scale with the proportion of unannotated data since the discriminator’s behaviour may not generalize well beyond the annotated dataset on which it is trained. Because this method can be applied to the output of any segmentation model, we consider it complementary to our proposed method.

3 Methods

Segmentation labels are typically available for an insufficiently representative sample of data. We propose a semi-supervised method that extends supervised segmentation to weakly labeled data using a domain translation objective. In addition to a segmentation objective, the method attempts to translate between the distribution of images presenting the segmentation target (P) and the distribution of images where this target is absent (A).

3.1 Translation, segmentation, and autoencoding

Translating between images where the segmentation target object is present or absent requires a model to localize the target. It follows then that in order to add, remove, or modify the target in an image, the variations caused by it should be disentangled from everything else (Figure 2, left). We conjecture that segmentation relies on the same disentangling and that this is the most difficult part of both objectives. Thus, we identify domain translation as an unsupervised surrogate loss for segmentation. We propose an encoder-decoder model that extends segmentation with image-to-image translation. In addition, we leverage the similarity between these two objectives to employ a decoder that is shared by both.

Although domain translation aligns well with segmentation, the canonical objective for unsupervised feature learning is autoencoding of the model input. A deep autoencoder may disentangle causal features of an image; that is, encoding the image may yield information about the features that produce it (Figure 2, right). When labels could be considered to cause the image, one would expect autoencoding to learn features that are useful for classification or segmentation [38]. Indeed, [32] recently won the Brain Tumor Segmentation challenge (2018) by augmenting a fully convolutional segmentation network with an autoencoding objective. This objective is easy to set up and train for. Unlike with domain translation, no knowledge about the images’ domain is required. On the other hand, information about presence (P) or absence (A) of the segmentation target in the image may guide a domain translation objective to more specifically isolate the variations that are important for segmentation [13].

Figure 2: Left: Translating images from a domain presenting the segmentation target object (Presence) to one in which it is absent (Absent) involves disentangling the object’s variations from the rest. The former is useful for segmentation, the latter for producing an image without the object. Right: Autoencoding may produce disentangled features (F) that are useful but not optimal for segmentation.

3.2 Our method

Figure 3: Framework overview of simultaneous segmentation, image translation and reconstruction. Images are transformed from the presence domain into the absence domain. Transformations are evaluated by a discriminator (not shown). The encoder and each decoder share skip connections for higher quality image generation.
Figure 4: Image-to-image translation from the absence domain to the presence

domain. The common code extracted by the encoder is used to reconstruct the input image. The unique code is sampled from a Normal distribution and concatenated to the common code to produce a residual image which, when added to the reconstructed image, yields a new image in the

presence domain. We cycle the image back through the encoder and the common decoder to ensure that the reconstructed image remains unchanged.

The proposed model builds on an encoder-decoder fully convolutional network (FCN) segmentation setup by introducing translation between a domain of images presenting the segmentation target (P) and a domain where it is absent (A), as in Figure 3. The encoder separates variations into those that are common to both A and P and those that are unique to P; essentially, P is a superset of the variations in A. For example, in the case of medical images of cancer, both A and P contain the same organs but P additionally contains cancerous lesions.

Latent code decomposition. Starting with images in domains A or P, the encoder () yields common () and unique () codes:

(1)

This decomposition of the latent codes is reminiscent of the style and content decomposition in [19] or the domain-specific codes in [24].

Presence to absence translation. Translation is achieved by selectively decoding from the latent codes and . A common decoder () uses only common variations, , to generate images in A:

(2)

where is essentially an autoencoding of , whereas is a translation of to the A domain where the segmentation target is removed. With this translation, the target variations can be recovered separately, by computing a residual change to that reconstructs as . This is done with a second residual decoder () which uses both common variations and those unique to P (see Figure 3):

(3)

The residual decoder requires all latent codes, , as its input because the manifestation of unique variations in the image space is dependent on the common variations. For example, the way cancer manifests in a brain scan depends on the location and structure of the brain in the scan. Note also that because the common decoder only uses the common latent code, the encoder must learn to disentangle common and unique variations.

Segmentation. The and codes or the residual contain sufficient information for segmentation. Indeed we reuse the residual decoder, used with , for segmentation. We parameterize a segmentation decoder in terms of the residual decoder , with segmentation specific per-layer instance normalization [40] parameters :

(4)

where is a pixelwise classification layer and is a subset of the

network that contains all but the last layer, using normalization parameters produced from the latent code by a multi-layer perceptron (MLP):

(5)

Absence to presence translation. Finally, we conclude the set of autoencoding and translation equations with and , where images in A are translated to images in P (Figure 4). We note that although these translations are not useful for segmentation, they are useful during training since they effectively augment the training updates that our encoders and decoders can receive. Since P contains additional variations to those found in A, we must either add these variations from an image in A or sample them from a prior distribution:

(6)

Here, requires a sample

from a zero-mean, unit variance prior over the unique variations,

. Note that unlike in a variational autoencoder, the encoder does not parameterize a conditional distribution over the unique variations but rather encodes a sample directly. We ensure that the distribution of encoded samples matches the prior by making match , as detailed further in the description of our training objective. The translation of to completes a cycle as in [46]. When must match , this ensures that the translations retain information about their source images, ensuring that the encoder and decoders do not learn trivial functions. As shall be seen below, this is already achieved by other objectives, making the cycle optional.

Total loss. The training objective consists of a segmentation loss combined with four translation losses, each weighted by some scalar :

(7)

Segmentation loss. We use Dice loss for segmentation, as in [10, 28], which measures the overlap between the predicted segmentation and reference segmentation :

(8)

Reconstruction losses. To ensure that the encoder and decoders can cover the distribution of images, we reconstruct input images:

(9)

Similarly, we reconstruct the latent codes so as to ensure that their distributions match across domains A and P, or in the case of unique codes, match the prior:

(10)

We define a cycle consistency loss for the APA cycle:

(11)

Note that there is no PAP cycle since in the proposed method this is equivalent to PP reconstruction, as can be seen in Figure 3. Because both images and the their latent codes are reconstructed, the cycle consistency loss is optional.

We use the distance for all reconstruction losses.

Adversarial loss. Finally, we use the hinge loss for the adversarial objective, together with spectral norm on the encoder and decoders as in [44]:

(12)

where, for each domain , is the generator network for some generated image and is a discriminator network which discriminates between real data and generated data .

3.3 Baseline methods

The proposed method is compared against two baseline approaches: a fully supervised, fully convolutional network (FCN) and another one augmented with a reconstruction objective for semi-supervised training. To ease comparison, all models (baselines and proposed) share the same encoder and decoder architectures. The fully supervised method uses an encoder with a single decoder (“Segmentation only” in Table 1). This is equivalent to the proposed method with only the segmentation loss, using only the residual decoder. The semi-supervised method (“AE baseline” in Table 1), adds an additional decoder that reconstructs the input.

3.4 Compressed long skip connections

Figure 5: Compressed skip connection as a way to limit information bypass while preserving spatial detail.

In all models, including baselines, every decoder accepts long skip connections from the encoder, as in [10]. These connections skip features from each layer in the encoder to the corresponding layer in the decoder, except for the first and the last layers. Because long skip connections make autoencoding trivial, they are not used with the reconstruction decoder in the semi-supervised baseline method, however skip connections are used between the encoder and segmentation decoder.

Typically, feature maps from the encoder are either directly summed with [10] or concatenated to [36] those in the decoder. We proposed a modified variant of long skip connections where any stack of feature maps is first compressed (via convolution) to a single map before concatenation (see Figure 5). We note that concatenating all feature maps is costly computationally and appears to increase training time for image translation whereas summing feature maps makes the image translation task very difficult to learn. To further stabilize training, all features skipped from the encoder are normalized with instance normalization. We find that these long skip connections help train the model faster and help produce higher quality image outputs even with a deep encoder.

Figure 6: Examples of images from the synthetic MNIST datasets. Samples from the presence domain and corresponding ground truth segmentations are in the first and second rows; unrelated samples from the absence domain are in the third row.
Figure 7: Example of image translation and segmentation for cluttered MNIST.

4 Experiments

In this section, we evaluate our proposed semi-supervised segmentation method on both synthetic and real data. We confirm that this method outperforms the fully supervised and semi-supervised segmentation baselines. We present examples of image translation results to illustrate the correlation between segmentation and image translation tasks. We begin with a synthetic MNIST-based task that simulates the common situation, where some images are known to present the segmentation target (P), while in others it is known to be absent (A). With this data, we experiment with increasing the difficulty of the segmentation task. Then, we proceed to evaluate the proposed method on brain tumour segmentation on real MRI data (BraTS).

MNIST BraTS
4848 simple 4848 hard 128128 240120
Only segmentation 0.61 (0.01) 0.36 (0.01) 0.15 (0.01) 0.69 (0.04)
AE baseline 0.75 (0.01) 0.49 (0.02) 0.57 (0.02) 0.73 (0.02)
Proposed 0.79 (0.01) 0.57 (0.00) 0.65 (0.01) 0.79 (0.02)
Table 1:

Segmentation Dice scores for synthetic MNIST and real BraTS segmentation tasks: mean (standard deviation).

Figure 8: Example of image segmentation and translation from Presence to Absence domains for BraTS. Different MRI sequences (image channels) are arranged in columns.

4.1 Cluttered MNIST

We construct a synthetic task for digit segmentation using MNIST digits, similar to the cluttered MNIST dataset in [21]. Each image in P contains a complete randomly positioned digit placed on a background of clutter. The clutter is produced from randomly cropped digits within the same data fold (training, validation, or test set). In all experiments, we used crops of 10x10 pixels. We dither regions where MNIST digits or clutter components overlap, so as to prevent models from identifying these boundaries.

We tested the proposed model and the baseline methods on three variants of the cluttered MNIST task, at two resolutions: 4848 simple, with 8 pieces of clutter; 4848 hard, with 24 pieces of clutter; and 128128, with 80 pieces of clutter. Samples from these generated datasets are shown in Figure 6; all datasets were generated prior to training. In all experiments, we provided reference segmentations for 1% of the training examples. In addition, to mimic the issue of small training datasets where the training set fails to cover all modes of variation of the data population, we trained on reference segmentations only for the digit 9.

As shown in Table 1, the proposed model significantly outperforms both the semi-supervised and the fully supervised baselines. The improvement is greater for the harder variants of the task: 4848 hard and 128128. We suspect that the greater improvement on the 4848 hard task may be due to the greater difficulty in separating digits from clutter as compared to 4848 simple, in which case the translation objective that seeks to specifically disentangle digit variations from clutter variations should be particularly helpful.

Examples of image translation and segmentation are shown in Fig. 7. We first discuss presence to absence translation. For the MNIST 4848 simple case, almost the entire digit is removed from the image during translation and the residual is similar to the segmentation result which supports our conjecture that translation is a good surrogate task for the segmentation task. The MNIST 4848 hard dataset has very challenging images, in which localizing the true digit is very difficult even for people. The model learned to partially remove the digit in order to fulfill the GAN objective. Therefore, the residual does not contain the entire digit, however it attends to the correct location in the image which may guide segmentation. We note that a digit does not need to be completely removed in order for the image to appear to contain only clutter because any remaining digit parts could appear as clutter. During absence to presence translation, the model learns the distribution of correct digits and is able to insert them into the image, as shown for MNIST 4848 simple. With more clutter (MNIST 4848 hard) it becomes challenging; generated residuals have less variety and many look like variations of the digit 0.

These experiments demonstrate that semi-supervised segmentation benefits from image-to-image translation. We observed significant improvements over supervised and semi-supervised segmentation baselines.

4.2 BraTS

Moving beyond synthetic data, we evaluated the proposed method on brain tumour segmentation challenge 2017 data (BraTS). Because this dataset contains only magnetic resonance imaging (MRI) volumes presenting cancer, we artificially split the data along 2D axial slices into P and A domains as a proof of concept. Because lesions tend to be closer to the center of the brain than the top or bottom, slices from the center of the brain, containing more brain matter, tend to be over-represented in P as compared to A. We require that the P and A domains differ only in that P contains the segmentation target; therefore, in order to better balance the slice distributions between the two domains, we additionally split the brains into hemispheres and select only those half-slices that contain at least 25% brain pixels. In P, we also limit the minimal number of lesion pixels to 1% of brain pixels. We pre-process every volume by mean-centering the brain pixels (ignoring background) and dividing them by their standard deviation. Finally, we use half-slices extracted from the processed volumes as model inputs. Each input has four channels, corresponding to four registered MRI sequences: T1, T2, T1C, and FLAIR.

We trained the proposed model and baselines with reference labels available for 1% of the training data. As shown in Table 1, the proposed model achieves a Dice score, significantly outperforming both the segmentation baseline, , and the semi-supervised autoencoding baseline, . Image translation and segmentation examples are shown in Figure 8. As evident in the figure, lesions were well removed by image-to-image translation. Unlike with the cluttered MNIST data, some of the sequences (T1, T1c) result in fairly complicated residuals that are nonetheless correctly reinterpreted as segmentations via the residual decoder.

The first column in Figure 8 reveals an artifact of distribution imbalance where a rare truncated input slice is transformed into a common non-truncated slice. Artifacts of this sort are particularly common when there is an imbalance in the distribution of slice sizes between P and A (which we try to avoid). Ideally, entire brain volumes would be used as inputs instead of slices.

5 Extensions and applications

Although we present work on two domains, P and A, we note that the proposed method can be easily extended to any greater number of domains. For example, if different types of pathology are known to be present in a medical image dataset, a domain-specific code (with a corresponding residual decoder) could be encoded for each pathology in addition to a neutral code with any pathology absent. Most interestingly, our image-to-image translation approach would allow any number of pathologies to be present in an image at a time, unlike for example the StarGAN multi-domain image-to-image translation architecture [8].

Finally, we note that there are many different data outside of medical imaging that can be split into P and A domains. For example, any material fault analysis, such as rust detection, microchip defects, or the decay of building facades can be expressed that way. Another interesting application may be the surveying of flood damage by learning the difference between pre-flood and post-flood urban aerial images. Extending the proposed method to more than two domains, one could explore such multi-domain problems as shadow segmentation where different times of day constitute different domains (with noon in A).

6 Conclusion

We propose a semi-supervised segmentation method that makes use of image-to-image translation in order to leverage unsegmented training data with cases presenting the object (P) of interest and cases in which it is absent (A). We argue that this objective is a good unsupervised surrogate for segmentation because it should similarly rely on disentangling of object variations from other variations. Indeed, we validate our method on both synthetic cluttered MNIST segmentation tasks and brain tumour segmentation in MR images, where we achieve significant improvement over supervised segmentation and a semi-supervised baseline.

References

Appendix

Appendix A Model and training details

This section details model architectures, parameter initializations, and optimization hyperparameters. Network layers are described in Tables 

2-9. Here, conv block refers to a residual layer (as in the ResNet [16]

) that chains together a normalization operation, an activation function, and a convolution, with a

short skip connection from the input to the output as shown in in Figure 9. For all experiments we use PyTorch [33].

Figure 9: The conv block

chains a normalization operation (norm), a rectified linear unit (ReLU), and a convolution (conv). When used in a decoder, 2

upsampling is performed prior to convolution by simple repetition of pixel rows and columns. The input is summed to the output via a short skip connection.

General model structure

The proposed model has one encoder and two decoders: common and residual. Additionally, it uses two discriminators, one for each direction of translation. The autoencoding baseline has one encoder and two decoders: segmentation and reconstruction. The segmentation baseline has one encoder and one segmentation decoder.

Reusing encoders and decoders.

To compare the effect of different training objectives, we try to reduce the confounding effect of differing architectures between the proposed model and baseline models. For each task, we use the same encoder for all models; likewise, the common decoder in the proposed model and all decoders in the baseline models are the same. The residual decoder in the proposed model is similar, differing in that it lacks short skip connections and uses slightly larger convolution kernels. All encoders and decoders are initialized with the Kaiming Normal approach [15]

. Convolutions are applied to inputs with reflection padding. All activation functions are rectified linear units (ReLU).

Skip connections.

We use long skip connections from the encoder to every decoder except the reconstruction decoder of the autoencoding baseline. Long skip connections bridge representations of the same resolution (these have the same number of channels). Specifically, the representation in the encoder is compressed to a single channel with a 11 convolution and then concatenated to the corresponding decoder representation. The encoder and all decoders have short skip connections (as in the ResNet), except for the residual decoders of the proposed model.

Latent code split.

All latent bottleneck representations of every model have 512 channels. In the proposed model, 128 of these channels are specified as the residual latent code and the rest as the common latent code.

Normalization.

All encoders use instance normalization [40]. All decoders use layer normalization [3]. The residual decoder of the proposed model performs segmentation by adopting a segmentation-specific optimization approach that differs from the layer normalization used with translation.

Segmentation via residual decoder.

In the proposed method, the residual decoder is used both in translation and in segmentation. For segmentation, all but the last layer are used and a classification layer is appended: 11 convolution with channels, where is the number of classes. In order to adapt the features learned via translation to the segmentation task, inference is modified by using a different normalization approach during segmentation than during translation. For the MNIST tasks, a four layer multi-layer perceptron with 256 units per layer is used to map the latent code (both common and unique) to parameters for adaptive instance normalization [18] in the residual decoder. For BraTS segmentation, the residual decoder uses separate layer normalization parameters for segmentation.

Discriminator.

Two discriminators are used with the proposed method, one for each direction of translation. We use multi-scale discriminators as proposed in [41, 26]. The discriminator architectures shown in Table 8 and Table 9 describe the network that is applied at each of three scales. At some scales, discriminators output a map of values per image instead of a single value. First, all pixels in this map are averaged and second, the resulting discriminator values are averaged across all scales. All discriminators use leaky ReLU [27] with a slope of .

Optimization.

For all experiments, we used the AMSGrad optimizer [34] with and . We used a learning rate of for all networks except discriminators which were trained with a learning rate of , following [17]

. We used a batch size of 20 images. For MNIST experiments, we ran training for 300 epochs; for BraTS, 500 epochs. In the proposed method, we used the hinge loss for the adversarial objective, with spectral normalization 

[29] applied to all networks, as in [44, 6].

Weighted objectives.

We found that the following objective weights yielded the best overall performance: , , , , . (AE: ).

Data augmentation

We applied data augmentation on the fly during training for BraTS but not for MNIST tasks since a large amount of data is generated for the latter. Data augmentation involved random rotations up to 3 degrees, random zooms up to 10% in or out, random intensity shifts up to 10%, random horizontal and/or vertical flips, and spline warping. Spline warping used a 33 grid of control points with each point placed according to a Normal distribution with variance . In those cases where data augmentation created new pixels along image edges or corners, these were filled by reflecting the image outward toward the edges and corners.

Encoder (MNIST 4848)
Layer Channels Kernel Stride
Convolution 32 33 1
Conv block 64 33 2
Conv block 128 33 2
Conv block 256 33 2
Conv block 512 33 2
NormReLU
Table 2: The encoder used for all models with MNIST 4848.
Decoder (MNIST 4848)
Layer Channels Kernel Stride
Convolution 256 33 1
Conv block 128 33 1
Conv block 64 33 1
Conv block 32 33 1
NormReLUConv 1 33 1
Table 3: The decoder used for all models (common but not residual decoder in the proposed method) with MNIST 4848.
Residual decoder (MNIST 4848)
Layer Channels Kernel Stride
Convolution 256 55 1
Conv block (no short skip) 128 55 1
Conv block (no short skip) 64 55 1
Conv block (no short skip) 32 55 1
NormReLUConv 1 55 1
Table 4: The residual decoder used in the proposed method with MNIST 4848.
Encoder (MNIST 128128 and BraTS)
Layer Channels Kernel Stride
Convolution 16 33 1
Conv block 32 33 2
Conv block 64 33 2
Conv block 128 33 2
Conv block 256 33 2
Conv block 512 33 2
NormReLU
Table 5: The encoder used for all models with MNIST 128128 and BraTS.
Decoder (MNIST 128128 and BraTS)
Layer Channels Kernel Stride
Convolution 256 33 1
Conv block 128 33 1
Conv block 64 33 1
Conv block 32 33 1
Conv block 16 33 1
NormReLUConv 1 33 1
Table 6: The decoder used for all models (common but not residual decoder in the proposed method) with MNIST 128128 and BraTS.
Residual decoder (MNIST 128128 and BraTS)
Layer Channels Kernel Stride
Convolution 256 55 1
Conv block (no short skip) 128 55 1
Conv block (no short skip) 64 55 1
Conv block (no short skip) 32 55 1
Conv block (no short skip) 16 55 1
NormReLUConv 1 55 1
Table 7: The residual decoder used in the proposed method with MNIST 128128 and BraTS.
Discriminator (MNIST)
Layer Channels Kernel Stride
Convolution 128 44 1
NormReLUConv 128 44 2
NormReLUConv 256 44 2
NormReLUConv 512 44 2
Convolution 1 11 1
Table 8: The discriminator used in the proposed method with MNIST 4848 and 128128.
Discriminator (BraTS)
Layer Channels Kernel Stride
Convolution 64 44 1
NormReLUConv 64 44 2
NormReLUConv 128 44 2
NormReLUConv 256 44 2
NormReLUConv 512 44 2
Convolution 1 11 1
Table 9: The discriminator used in the proposed method with BraTS.