Disentangle, align and fuse for multimodal and zero-shot image segmentation

11/11/2019 ∙ by Agisilaos Chartsias, et al. ∙ Cedars-Sinai 40

Magnetic resonance (MR) protocols rely on several sequences to properly assess pathology and organ status. Yet, despite advances in image analysis we tend to treat each sequence, here termed modality, in isolation. Taking advantage of the information shared between modalities (largely an organ's anatomy) is beneficial for multi-modality multi-input processing and learning. However, we must overcome inherent anatomical misregistrations and disparities in signal intensity across the modalities to claim this benefit. We present a method that offers improved segmentation accuracy of the modality of interest (over a single input model), by learning to leverage information present in other modalities, enabling semi-supervised and zero shot learning. Core to our method is learning a disentangled decomposition into anatomical and imaging factors. Shared anatomical factors from the different inputs are jointly processed and fused to extract more accurate segmentation masks. Image misregistrations are corrected with a Spatial Transformer Network, that non-linearly aligns the anatomical factors. The imaging factor captures signal intensity characteristics across different modality data, and is used for image reconstruction, enabling semi-supervised learning. Temporal and slice pairing between inputs are learned dynamically. We demonstrate applications in Late Gadolinium Enhanced (LGE) and Blood Oxygenation Level Dependent (BOLD) cardiac segmentation, as well as in T2 abdominal segmentation.



There are no comments yet.


page 1

page 3

page 7

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

In medical imaging multiple acquisitions for the same subject are often used to capture complementary information. Specifically, within Magnetic Resonance (MR), different pulse sequences attenuate different tissue characteristics for identifying anatomical and functional information, and produce images of different contrasts (modalities).

Automatic segmentation of such multimodal data remains important, yet most methods focus on single-modality images. We propose a method based on disentangled representations, designed to address challenges posed by multimodal data. These include differences in signal intensities, lack of annotated data and anatomical and temporal misalignments due to varying spatial resolutions or due to moving organs as in the case of dynamic imaging of the heart and the abdomen.

Multimodal learning permits the capture of information present in one modality (e.g. the anatomy) for use in another modality that has higher pathological contrast. As a motivating example, myocardial segmentation in LGE is challenging, since LGE mutes myocardial signal to accentuate signal originating from myocardial infarction. In fact, in clinical practice, analysis of LGE is typically combined with cine-MR [kim2009cardiovascular].

A naive way to propagate knowledge between modalities would be co-registration. This has been successful in the brain (see Section II-B). But precise multimodal registration remains challenging, due to the need for modality independent metrics [sotiras2013]. Critically, the brain remains static within an imaging session, whereas the heart is moving. Also, multimodal data are often inconsistent both in the number of images (different slices, cardiac phases, and perhaps more penalising resolution differences, e.g. slice thickness), as well as in the number of annotations. In addition, some sequences are static (LGE) and others dynamic (cine-MR). This necessitates solutions that alleviate misregistrations but also can pair input images.

Figure 1: MMSDNet schematic in a LGE segmentation exemplar task using LGE and cine-MR inputs. Firstly, disentangled anatomical factors of the LGE and cine-MR image are extracted. Then, they are aligned and combined to a fused anatomical factor, used to infer the final segmentation mask. Our approach can use multi-input (multimodal) data at training and inference. The latter is extremely useful not only in zero-shot learning

but also in removing outliers. We demonstrate this in several cardiac and abdominal datasets.

I-a Overview of the proposed approach

We propose a mechanism to represent data that is suitable for learning how to propagate knowledge for segmentation. We learn both with and without annotations using a reconstruction objective as self-supervision. More excitingly, our approach co-registers data within an anatomical representation space, becoming thus robust to variations in imaging contrast.

Our 2D approach, Multimodal Spatial Disentanglement Network (MMSDNet), see Figure 1, achieves the above by mapping multimodal images of the same subject into disentangled anatomical and modality factors.111

In computer vision these are typically referred to as content and style factors, respectively 

[huang2018multimodal]. However, as we detail in Section II-A in medical applications such disentanglement has more stringent requirements. Anatomical factors are represented as categorical feature maps. Each category corresponds to input pixels that are, ideally, spatially similar, and hence belong to the same anatomical part. This promotes semantic consistency, and also the learning of spatial correspondences between modalities. Furthermore, this semantic anatomical space is essential for zero-shot segmentation when no labels, but only images, are available for one of the modalities. As in transductive zero-shot learning [fu2015]

, the task is achieved by projecting images to a common semantic space. Modality factors encode pixel intensities in a smooth multivariate Gaussian manifold as per the Variational Autoencoder (VAE) 

[kingma2013auto]. Anatomical factors are used to obtain segmentation masks, whereas their re-entanglement with the modality factors can be used for image reconstruction. When learning with multiple modalities, anatomical factors obtained from multimodal images are co-registered with a Spatial Transformer Network (STN) [jaderberg2015spatial], fused with feature arithmetics, and also decoded in different modalities as defined by the modality factors. When input data are not paired (e.g. due to temporal or slice position differences) we introduce a loss term in the cost function that selects the most “informative” multimodal pairs by comparing anatomical factors.222We improve our preliminary work [chartsias2019multimodal] as follows: (1) we reduce model parameters, and encourage consistent multimodal anatomical representations, by employing weight sharing in the anatomy encoders; (2) we introduce a cost that reduces the need for expert pairing of multimodal inputs; (3) we design and evaluate another decoder; (4) we propose the use of distance correlation to assess disentanglement; (5) we generalise to non-cardiac datasets.

Our contributions are the following:

  1. We propose a 2D method for learning disentangled representations of anatomical and modality factors in multimodal medical images for segmentation.

  2. We demonstrate the importance of semantic anatomical factors, that is achieved through the model design, because they allow learning registration and fusion operators for the purpose of multimodal learning.

  3. A loss term in the cost function that learns to select the most “informative” multimodal pairs.

  4. We demonstrate our method’s robustness over other approaches with extensive experiments on several datasets, in cardiac MRI and abdominal segmentation.

  5. We show that our model can work both on single modal and multimodal inference, and that it outperforms other variants when trained with different amounts of annotations or in a zero-shot setting for one of the modalities.

  6. We discuss decoder design using FiLM [perez2017film] and SPADE [park2019semantic]

    respectively, and evaluate their disentanglement properties by estimating the dependence between the anatomical and modality factor with distance correlation.

Ii Related work

Multimodal machine learning is an active research area that involves diverse sources of information. While in computer vision modalities might refer to any heterogeneous source of information, such as text and images, here, as common in the medical domain, we restrict to different image acquisitions. We consider multimodal learning as combining information of different images, present at training and/or inference time.

333Using image translation or domain adaptation as augmentation strategy to reconcile lack of annotated data while related does not address this. For completeness we mention here few recent methods. Image translation has been proposed with cycle consistency [campello2019combining, huo2018synseg, zhang2018translating, tao2019segmentation], and disentanglement [chen2019unsupervised] losses. In domain adaptation, multimodal images are related with different augmentations [ly2019style], histogram matching [liu2019automatic] or adversarial losses [Chen2019SynergisticIA].

We review work on disentangled representations, the main focus of our method, and prior art on multimodal medical imaging, split in spatially registered or unregistered inputs. We highlight though that currently no work exists that is able to simultaneously achieve multimodal fusion from unregistered data for image segmentation, be robust to the number of training annotations, and be applied to single or multimodal inference. These are made possible by the careful design of disentangled and semantic anatomical representations.

Ii-a Disentangled representation learning

Our approach leverages learning disentangled anatomical and modality factors. Disentangling content from style for style transfer is gaining popularity in computer vision, with many examples such as in [huang2018multimodal]. In medical imaging, disentanglement has been used for semi-supervised cardiac segmentation [chartsias2018factorised, chartsias2019disentangled], multi-task learning [chartsias2019disentangled, meng2019representation], lung nodule synthesis [liu2018decompose], and registration [qin2019unsupervised]. Multiple modalities have been used for liver segmentation with domain adaptation [yang2019domain], albeit without information fusion. This has been proposed for brain tumour segmentation [chen2019robust], using though registered images.

For anatomical features to be useful in clinical tasks, they are required to be semantic and quantifiable [chartsias2019disentangled]. This is not guaranteed in disentanglement techniques used for style transfer [huang2018multimodal], or recent medical segmentation works [chen2019robust, yang2019domain] that do not impose restrictions on the content features. Semantic representations have recently been pursued in computer vision in the form of feature masks [ma2018exemplar] or by learning geometry with landmarks [wu2019transgaga]. Differently from others, we disentangle quantifiable anatomical features, such that they are useful for segmentation, whereas interpretability is promoted with explicit design constraints (Section III-A), which in addition enable registration and simple fusion operators.

Ii-B Multimodal learning with registered images

Early work on multimodal deep learning concatenated co-registered multimodal images in different input channels, in order to improve MR brain segmentation 

[havaei2017brain]. Robustness to missing modalities [joyce2017robust] was achieved with different encoders (per modality), that mapped images to modality invariant features. Common feature representations with multiple encoders has also been proposed for cross-modal classification [van2019learning].

Another aspect of multimodal learning is information fusion, used to combine complementary information. Most commonly, fusion is performed on the latent features [havaei2017brain, joyce2017robust], although, fusion at multiple levels can be achieved with densely connected layers [dolz2018hyperdense] to exploit multi-scale correlations. Furthermore, cross-modal convolutions are used as a way to weigh each modality’s contribution [tseng2017joint]. Finally, attention modules and residual learning focus on specific regions for brain MRI segmentation [chen2018mmfnet]. In contrast to the above, we take advantage of the strictly defined anatomical factors and use a fusion operator, which is able to combine all distinct features.

Figure 2: A cine and a LGE image are encoded to two anatomical factors. A STN deforms the cine anatomical factor to match the LGE anatomical factor. Their fusion produces the final LGE segmentation mask. Both anatomical factors (and their corresponding modality factors) further reconstruct the inputs.

Ii-C Multimodal learning with unregistered images

Misalignment though is common in multimodal data. In the brain, registration can be reliable but in the heart and other moving organs performance cannot be guaranteed. Correcting small misalignments is possible with a STN applied on features [joyce2017robust]. Alternatively, encoder-decoder setups can learn shared features among modalities. An exploration of different setups [valindria2018multi] showed that separate encoders and decoders that share the last and first layer achieve the highest performance.

In the heart the approaches are limited. Multiple inputs can be combined by directly adapting the segmentation masks with contour models [liu2019multi]. Alternatively, reducing the field of view (to the patch level), and ensembling (using and fusing results from several atlases) can alleviate the effect of committed errors [zhuang2016multi]. A recent work [zhuang2018multivariate]

proposes simultaneous segmentation and registration of multimodal cardiac MR by modelling the joint distribution with Multivariate Mixture Models.

Multimodal images can be used as different samples of the same data distribution to form an expanded dataset [wang2019skunet]. Finally, multimodal registration, although susceptible to errors, can create “noisy” labels [roth2019cardiac].

Our method is different: it is the first to jointly learn suitable representation, co-registration, and information fusion for segmentation but in a semi-supervised setting.

Iii Methodology

Here, we describe MMSDNet, a multi-component 2D model for multimodal and semi-supervised learning that is robust to input misalignments. Training consists of three stages, shown in Figure 2. Firstly, encoders map images to anatomical and modality factors. Then, the anatomical factors are spatially aligned and fused to produce segmentations. Finally, the two factors are decoded to reconstruct the input. We now detail the individual components, as well as the employed cost functions. Supplement shows network architectures (Figures S1-S5).

Iii-a Encoding

Given modality with samples , where is the set of images, and and are the height and width respectively, the encoding process achieves a disentanglement of anatomical and modality factors.

Anatomical factors

are tensors produced by encoders dedicated to each modality

: , where are the encoder parameters. The encoders are modelled after the U-Net [ronneberger2015u] (see Figure S1a). To reduce model parameters, and encourage a common anatomical representation among the multimodal data, we employ weight sharing in the decoder of each U-Net. Thus, the parameters are split into the unique parameters of the encoding path, and the shared parameters of the decoding path: .

An anatomical factor is represented as a binary tensor, that is also a one-hot encoding at the channel dimension. A binary anatomical factor is discouraged from storing imaging information and promotes the factorisation process. A one-hot encoding of the anatomical factors enforces that a particular image region uniquely appears in a single channel. More formally,

, s.t. . Two anatomical factors produced by a cine and an LGE image can be seen in Figure 3.

Divergence loss : The modality factors

are vectors produced by a single stochastic encoder, that, given an image sample

and its anatomy factor

, learns a probability distribution

. In order to encourage a smooth space, this posterior distribution is optimised to follow a multivariate Gaussian prior, , by minimising the divergence with the re-parameterisation trick [kingma2013auto]:

The modality encoder is shown in Figure S1b. It predicts the mean and standard deviation of a Gaussian, that are used to draw the random sample vector


Figure 3: Anatomical factors from a cine and a LGE. Observe how the same anatomical regions appear in the same channels.

Iii-B Alignment and fusion of the anatomical factors

Two anatomical factors and of modalities and

respectively, are aligned using non-linear registration with thin plate spline interpolation. Given

and , a STN (architecture in Figure S2) produces a matrix of control points, which defines the interpolated surface passing through the control points of and registers it with . The result of the alignment step, , is a deformed anatomical factor corresponding to , and vice versa ( corresponds to ).

During inference, the deformed anatomies are combined, to produce a fused representation containing all unique and shared features, that are present in the constituent anatomical factors. Since they are spatially aligned, a pixel wise operation such as the pixel-wise max is able to preserve all encoded features. More formally, and . One benefit of max-fusion is that it is invariant to the number of inputs, and is therefore directly applicable in cases with more than two modalities and enables inference when only one modality is present.

Iii-C Segmentation

Given an anatomical factor

, a simple neural network (architecture in Figure S3) infers a corresponding segmentation mask

, s.t. , where is the set of masks of modality and is the number of segmentation classes. The segmentation network is common for all modalities, and is also applied to the deformed and fused anatomies of Section III-B.

Supervised loss : Given a set of images paired with masks , a supervised cost is defined as a weighted sum of the differentiable Dice loss and Cross Entropy (CE):

where and control the contribution of each loss. The cross entropy and differentiable Dice are respectively defined as:

where , , and refer to the height, width and channel, and is the probability for a pixel belonging to class .

Adversarial loss : An unsupervised segmentation cost is defined with a mask discriminator , modelled after LS-GAN [mao2017effectiveness]. The adversarial objective given real masks sampled from all modalities is:

where the discriminator is adversarially trained against the segmentation network. The discriminator’s architecture consists of 4 convolutional layers followed by LeakyReLU and a final single neuron layer, and uses Spectral Normalisation

[miyato2018spectral] to stabilise training. In both segmentation costs, the anatomical factors come either from the input images directly, or are the result of the alignment step of a secondary modality: . In the latter case, the gradients produced by the segmentation cost are back-propagated to the STN module to learn its parameters.444We omit the use of

as input to the segmentation network to avoid backpropagating gradients both to the STN and the

anatomy encoder, which might result in the STN not achieving a good convergence.

Iii-D Decoding

The anatomical factors are further decoded into an output image of a style dictated by a modality factor : . This entanglement can be performed with different decoders, which indirectly influence the type of disentanglement, or in other words the type of information captured by the anatomical and modality factors. We investigate two decoder architectures based on FiLM [perez2017film] and SPADE [park2019semantic].

The input of the FiLM-based decoder (Figure S4a) is the anatomical factors, which, after a series of convolutions, are conditioned by samples. These are used to predict a scale and an offset parameter and , which modulate each intermediate feature map , where , and are the height, width and number of channels respectively: .

We also consider a SPADE-based decoder (Figure S4b), which has been demonstrated to generate texture details on synthetic images given segmentation masks. The input to this decoder is a sample, that is processed by a series of convolutional layers, conditioned by the anatomical factor, defining the output “shape”. An Instance Normalisation layer with parameters and , is firstly applied to a feature map , which is then modulated by tensors and (same size as ) .

Self-supervised costs and : The decoders are trained to reconstruct the input with the loss:

In addition, realism of synthetic images is further encouraged with an image discriminator for a modality , modelled after LS-GAN. This defines an adversarial loss:

As in the segmentation case, is an encoding of image or a deformed encoding of another image : . When , the model acts as an auto-encoder. This is critical to allow the use of non-annotated images, and thus enable semi-supervised learning. In the case, where , the backpropagated gradients are used to train the STN module, and also aid the factorisation process since the modality factor learns to encode the “style” of the input modalities.

Iii-E Reconstruction of the modality factor loss

In order to encourage disentanglement, and also avoid posterior collapse of the modality factor, we maximise the mutual information between the synthetic images and the corresponding -factors: . This prevents the decoder from ignoring the -factors and instead only use the anatomy factors. We minimise the reconstruction of the modality factor:

where is a random sample from a unit Gaussian and is the synthetic image produced by this sample. The mutual information between the synthetic images and the modality factors is further encouraged by cross-reconstructing the same anatomy in different modalities with different -factors.

Iii-F Non-expert pairing

Better multimodal fusion and STN registration will be achieved by multimodal image pairs that are more similar in terms of their spatial and temporal positions. In cases where the multimodal images are not expertly paired, MMSDNet can automatically measure anatomical similarities with an optional cost, that directly compares the anatomical factors, and “selects” only the most informative image pairs.

During training, and given an image and a set of candidate images from modality : , the multimodal segmentation and reconstruction losses for a sample are weighted accordingly by weights, s.t. :

where , and

. By weighting the loss functions, the STN module does not need to learn deformations for all pairs, as well as prevents it from trying to match slices with different anatomical content.

Due to the semantics of the anatomical factors, and the fact they are categorical, we can be directly evaluate their overlap in terms of the Dice score. The Dice for each pair, becomes the input to a small neural network of two fully connected layers that outputs the weights, and is similar to the temperature scaling technique proposed for calibrating classification outputs [guo2017calibration]. At inference time, the most accurate segmentation is produced from the weighted sum of the fusion with different slices :

This optional weighting of the cost function is only used in unpaired data, and as shown in experiment V-C converges to the same result as manual pairing.

Iv Experimental setup

Iv-a Training details

The model is trained with a multi-component loss function, We set . Higher weight on encourages separation of segmentation classes. A reduced weight prevents posterior collapse, in which the factor is ignored by the decoder; however, an even lower , would not promote a Gaussian prior approximation, leading to a non-smooth intensity manifold. Number of channels and dimensions are set to and respectively, as in [chartsias2019disentangled].

The code is written in Keras (

https://keras.io) and will be made available online upon acceptance. We train with Adam (learning rate of ), and evaluate using Stochastic Weight Averaging [izmailov2018averaging] to reliably compare between different methods. Quantitative evaluation is performed on 3-fold cross-validation, where the training, validation and test sets correspond to the 70%, 15% and 15% of the data volumes.

Iv-B Data

Experiments use three multimodal datasets of a source and a target modality, that have been rescaled to .

  1. For LGE segmentation, we use cine-MR and LGE data of 28 patients [stirrat2017ferumoxytol], acquired at Edinburgh Royal Infirmary (ERI), with spatial resolution mm/pixel, and slice thickness . End diastolic myocardial contours are provided. The image size is pixels. The number of segmented images is 358 (for each of cine-MR, LGE).

  2. To evaluate robustness on different medical data, we use abdominal T1-dual inphase and T2-SPIR data, shorthanded as CHAOS,555https://chaos.grand-challenge.org/Results_CHAOS/ for T2 segmentation. Images of 20 subjects with liver, kidneys and spleen segmentations are acquired by a 1.5T Philips MRI scanner, which produces 12 bit DICOM images of resolution. We resample to an x-y spacing of 1.89mm, and crop to pixels. In total, there are 1594 images.

  3. Finally, we evaluate BOLD segmentation with a dataset (shorthand BOLD) of cine-MR and CP-BOLD images of 10 canines, with an in-plane resolution of , acquired at baseline and severe ischemia (inflicted as controllable stenosis of the left-anterior descending coronary artery (LAD) on a 1.5T Espree (Siemens Healthcare). The image acquisition is at short axis view, covering the mid-ventricle, and is performed using cine-MR and a flow and motion compensated CP-BOLD acquisition. The pixel resolution is [tsaftaris2013detecting]. In total there are 129 cine-MR and 264 CP-BOLD images with expert segmentations from all cardiac phases.

Iv-C Baseline and benchmark methods

The following baseline and benchmark methods are used:

  1. A lower bound computes the Dice score between real masks of two modalities, and is also a measure of misalignment of the multimodal data. This is referred as “copy”, and can be used for segmenting a target modality without annotations from the source modality.

  2. This lower bound can be improved after registering the multimodal images and applying the registration field to the source masks. The deformation field is calculated by affine registration using mutual information, followed by symmetric diffeomorphic using cross-correlation [Tustison2015]. This is referred as “register” baseline, and can also be used without annotations of the source modality. “Copy” and “register” are common in clinical evaluation.

  3. As a supervised benchmark, we train a UNet on annotated data of the target modality, and refer to it as UNet-single.

  4. We further re-train a UNet on mixed training data of all modalities to evaluate its capability of concurrently handling multimodal data, and refer to it as UNet-both.

  5. Finally, we implement DualStream [valindria2018multi], the most recent Deep Learning based method for handling multimodal data which does not require registered data.

V Experiments and discussion

Sections V-A and V-B present segmentation results of a target modality, assuming a source modality that always contains annotations during training. The source modality is cine-MR for ERI and BOLD datasets, and T1 for CHAOS. The target modality is LGE, BOLD and T2 for ERI, BOLD and CHAOS, respectively. Unless explicitly specified, MMSDNet uses a FiLM-based decoder. We evaluate the effects of: input pairing in Section V-C; registration in Section V-D; and a SPADE-based decoder in Section V-E. Section V-F evaluates disentanglement

of each decoder design. Where appropriate, we use bold font for the best (on average) method and * to denote statistical significance of paired t-tests (

assessed via permutations) comparing with the second best (to avoid multiple comparisons).

V-a Multimodal segmentation and zero-shot learning

A primal contribution of our work is to perform zero-shot learning. But first we should show that we can learn and infer in a multimodal setting. Thus, we first demonstrate that multiple inputs at training and inference time benefit segmentation. Table I presents test Dice scores on three datasets for MMSDNet and the benchmarks of Section IV-C. Two setups are evaluated, assuming either that annotations are available for the target modality or not (zero-shot learning).

In the 100% case, training with multiple inputs always improves accuracy, even when multimodal data simply constitute an augmented dataset. The usage of multiple inputs at inference time by MMSDNet, obtains a similar mean Dice as other benchmarks, but considerably reduces the standard deviation, such as in the CHAOS case, where it is reduced from 15% to 5%.

In the 0% case (zero-shot), the (learned) benchmark methods fail to produce accurate segmentations for all datasets. MMSDNet on the other hand, is able to consistently maintain a better average and smaller variance by leveraging information from the source modality. This is due to the aligning of the multimodal representations in the anatomy space, which allows learning the target modality segmentation with “zero” supervised examples.

Methods Train Inference 100% target annotations
copy multi
register multi
UNet single single
UNet multi single
DualStream multi single
MMSDNet multi single
MMSDNet multi multi
Method Train Inference 0% target annotations
copy multi
register multi
UNet single single
UNet multi single
DualStream multi single
MMSDNet multi single
MMSDNet multi multi
Table I: Segmentation results on three datasets with full (100%) and zero (0%) supervision on the target modality.

V-B Semi-supervised segmentation

Here we evaluate the sensitivity of all methods on different amounts of ground truth annotations available during training. Table II presents the average (across all labels) cross-validation test set Dice score. Exemplar test results are shown in Figure 4. The number of images for both source and target modalities are fixed, but the amount of target annotations varies. Sampling the amount of annotations is performed on a subject-level, to avoid having a mixture of annotated and non-annotated images of the same subject in the training set. The MMSDNet results correspond to using multiple inputs at inference time.

Average Dice for all methods is comparable when the number of annotations is high, although MMSDNet achieves the lowest variance. With a reducing number of annotations, the performance of the competing methods also reduces with a simultaneous increase in the variance. MMSDNet maintains good results and robustness to edge cases, as evidenced by the small variance achieved throughout all setups.

Method       50%       25%       12.5%
Method 50% 25% 12.5%
CHAOS: T2 Dice
Method 50% 25% 12.5%
Table II: Segmentations of LGE, BOLD and T2, when training with a varying amount of annotations for ERI, BOLD, and CHAOS datasets respectively.
Figure 4: Panel of LGE segmentation examples from ERI dataset, obtained with different amount of LGE annotations.

V-C Effect of pair matching

The results of Sections V-A and V-B correspond to expertly paired multimodal inputs. Here, we evaluate the sensitivity of MMSDNet on unpaired multimodal images, as well as the effect of the automated pairing cost proposed in Section III-F.

We randomly shuffle the multimodal pairs by two positions, with the shuffled pairs differing up to two spatial slices within a 3D volume.666Similar results can be obtained by shuffling the different cardiac phases in the cine-MR temporal stack. We measure the LGE segmentation Dice score on ERI data when using 100% and 0% LGE annotations. We thus compare our automated method with expert pairing (upper bound) and a random shuffle (lower bound). Table III presents the results of copy method, as well as of MMSDNet evaluated with both cine-MR and LGE inputs.

Shuffling the multimodal pairs decreases the copy performance considerably. In both cases, automated matching of candidate pairs based on the semantics of the anatomical factors proves effective in ignoring distant slices (in the volume), with results very closely approaching the ones achieved by expert pairing. To show how our model learns appropriate weights, the evolution of weights across training epochs is shown in Figure 

5, in which corresponds to the closest pair converged to a probability of one early on in training.

During inference, a “soft” segmentation mask is produced as a weighted sum between each weight with its corresponding mask. However, this converges to using the prediction of the “closest” pair, as evidenced by Figure 5.

Pair matching   copy   MMSDNet 0%   MMSDNet 100%
automated n/a
Table III: LGE segmentation results when the multimodal images are not expertly paired.
Figure 5: Evolution of weights across epochs. Weights are used as a measure of similarity between each candidate multimodal pair. For more details see text.

V-D Effect of STN

We assess the need for a registration module with an ablated model. We compare the accuracy of a fused segmentation that is obtained with and without the STN module. Two MMSDNet models are compared, trained on ERI data with 100% and 0% LGE annotations. The mean Dice without the STN is measured to be and respectively. This is lower than the Dice of MMSDNet with STN that is and . Furthermore, in the 100% case the difference is statistically significant at the 1% level. Thus, clearly registration helps.

V-E Effect of decoder design on segmentation accuracy

The modular design of MMSDNet permits incorporation of components with different designs. We evaluate segmentation accuracy achieved by two decoder architectures: FiLM and SPADE. Specifically, we train a SPADE-based MMSDNet on ERI and CHAOS and compare with the FiLM-based MMSDNet for 100% and 0% annotations.

With 100% annotations, the SPADE-based MMSDNet achieves % and on ERI and CHAOS respectively, identical to the Dice achieved by FiLM. With 0% annotations, the SPADE-based MMSDNet achieves and , whereas FiLM-based results are and respectively on ERI and CHAOS.

We can conclude that the regularising effect that the reconstruction process has on extracting segmentations, is similar with both decoder variants. However, different decoder designs influence the way the anatomical and modality factors interact to produce a synthetic image. We explore this next.

Figure 6: Reconstructions with two decoders. The FiLM synthetic image is more flat and lacks texture, in contrast to the SPADE synthetic image. Images taken from CHAOS dataset.

V-F Evaluating disentanglement

Even though FiLM and SPADE decoders do not result in evident differences in segmentation accuracy, they produce synthetic images of different quality (Figure 6). Since the anatomical factors contain flat regions, FiLM-based conditioning with scalar parameters tends to produce images with less texture details than SPADE-based conditioning.

Here, we aim to assess the information retained in the modality factors, and characterise the achieved disentanglement. This is a challenging problem not addressed in existing literature: all assume vector latent variables (e.g. BetaVAE score [kingma2013auto]

). In MMSDNet, and typically in content/style disentanglement, the factors of variations are not of the same dimensionality, with the anatomy being spatial. For the experiments below, we use models trained on CHAOS with 100% T2 annotations to assess (dis)entanglement using classification tests, factor arithmetics, and a proposed metric of independence of random variables.

Figure 7: (a) FiLM and (b) SPADE based reconstructions. Images per row correspond to interpolating a single dimension. Last two columns (correlation, and difference image ), indicate regions mostly affected by each dimension.

V-F1 Modality classification

On the premise that the common modality encoder correctly extracts modality features, a classifier should detect the modality type, given just the

-factor. We assess this hypothesis, by training a logistic regression classifier to predict whether different

-factors correspond to T1 or T2 images. The classifier’s accuracy is 99% and 97% for FiLM and SPADE, respectively, on a test set of three subjects.

We further evaluate whether specific dimensions in capture the modality type by repeating the experiment, for each dimension. In the FiLM model, the 2nd dimension achieves 100% accuracy, whereas the rest vary between 54% and 64%. Similarly in the SPADE model, the 7th dimension achieves 97% accuracy vs. 42% and 63% of the others.

V-F2 Modality factor arithmetics

We qualitatively examine the information retained in each dimension in vector with latent space arithmetics. The likelihood of the modality factor approximates a Gaussian prior, and therefore interpolating in the range covers the probability space. Figure 7 shows synthetic images arranged in a grid; images of each row are produced by interpolating the values of a single dimension of , with the remaining ones fixed. The final two columns highlight affected regions by calculating the per-pixel Pearson correlation, as well as the difference, , between the synthetic images at extreme values and , respectively.

Both decoders have one -dimension that has a global image effect ( and respectively) and controls the “modality” type. This finding is inline with the classification results above. Furthermore, some dimensions of the FiLM decoder, e.g. the 8th, appear to be focused on specific anatomical regions, whereas the dimensions of the SPADE decoder produce more diffused correlation images. The latter is likely related to the fact that SPADE can generate texture better.

V-F3 Disentanglement metric

We propose the use of distance correlation [szekely2007measuring], as a metric of factor independence (and disentanglement), which is invariant to the input variable dimensionality, and can also detect linear and non-linear associations. While distance correlation has been used before for reducing data leakage [vepakomma2019reducing], we use it here for measuring (dis)entanglement. Distance correlation is defined as,

where is the distance covariance of and , and is the distance variance respectively. Given random samples and with , the distance covariance is the product of two distance matrices (one for each variable) averaged by , where each distance matrix is double centred by subtracting the mean row, the mean column and the overall mean from each element: . The distance variance is then , and .

The distance correlation between and values from a FiLM-based model is , whereas the equivalent for a SPADE-based model is . This suggests that the anatomical and modality factors obtained by a FiLM decoder are more independent, and therefore the FiLM-based model is more disentangled. Although distance correlation cannot explicitly evaluate the type of information in each variable, this result can be explained intuitively by the decoder design. The SPADE decoder allows more flexibility to the factors, and this is evident both in the synthetic images, which contain more texture, and also in the diffused correlation images of Figure 7b, implying a higher anatomical correlation (and higher entanglement) between the and factors.

Vi Conclusion

We have presented a method for multimodal learning, and specifically multimodal segmentation, that is robust to the requirement for registered and paired input images. This has been made possible by disentangling images into semantic anatomical factors, that are consistently represented across modalities, and modality factors that model the intensity variability of the multimodal inputs into a smooth latent space.

We have proposed MMSDNet, which, to the best of our knowledge, is the first work that enables multimodal and zero-shot segmentation by aligning disentangled anatomical representations, and can be trained with zero annotations for one of the modalities. We presented the benefit of multimodal (over single-modal) learning in cardiac and abdominal segmentation, where we achieve high accuracy and low variance through the fusion of anatomical information of different modalities. We further demonstrated robustness to misalignments in the multimodal data (achieved by a spatial transformer network), and robustness to the quality of the multimodal pair matching (with an optional weighting of the multimodal pairs), both made possible by comparing the semantic anatomical factors. Finally, we made a first step in evaluating the quality of the content/style disentanglement using the distance correlation.

The significance of our work lies in the potential for the use of disentangled representations in other challenging problems of medical research. Future directions include the learning of further factorisations suitable for medical data, for instance to capture pathological information and specific artefacts, as well as a theoretical characterisation of the disentangling process and precise quantification of the type of information that is captured by each factor, which admittedly is more complex in content/style disentanglement than in vectorised latent spaces for which metrics have been recently suggested [do2019theory].