Log In Sign Up

Clinically Plausible Pathology-Anatomy Disentanglement in Patient Brain MRI with Structured Variational Priors

We propose a hierarchically structured variational inference model for accurately disentangling observable evidence of disease (e.g. brain lesions or atrophy) from subject-specific anatomy in brain MRIs. With flexible, partially autoregressive priors, our model (1) addresses the subtle and fine-grained dependencies that typically exist between anatomical and pathological generating factors of an MRI to ensure the clinical validity of generated samples; (2) preserves and disentangles finer pathological details pertaining to a patient's disease state. Additionally, we experiment with an alternative training configuration where we provide supervision to a subset of latent units. It is shown that (1) a partially supervised latent space achieves a higher degree of disentanglement between evidence of disease and subject-specific anatomy; (2) when the prior is formulated with an autoregressive structure, knowledge from the supervision can propagate to the unsupervised latent units, resulting in more informative latent representations capable of modelling anatomy-pathology interdependencies.


page 2

page 6

page 11


An Open-Source Tool for Longitudinal Whole-Brain and White Matter Lesion Segmentation

In this paper we describe and validate a longitudinal method for whole-b...

Cognitive Subscore Trajectory Prediction in Alzheimer's Disease

Accurate diagnosis of Alzheimer's Disease (AD) entails clinical evaluati...

Unsupervised Brain Abnormality Detection Using High Fidelity Image Reconstruction Networks

Recent advances in deep learning have facilitated near-expert medical im...

1 Introduction

In medical image analysis, various representation learning methods have been introduced to disentangle particular anatomical or pathological structures from the the rest of the image under different observations (SotosDisentanglementTutorial; fragemann2022reviewDisentanglement). Although disentanglement of observable evidence of disease (e.g. brain lesions or atrophy) from subject-specific anatomy has been shown to be helpful in various downstream tasks (qin2019unsupervisedDisentanglementforRegesteration; DisentangledAAEAgePredictionMissingModality; zhou2021chestCOVIDdisentanglement; zuo2021informationUnsupervisedDomainAdaptationDisentanglement; liu2021semisupervisedDisentanglementdomain-generalised; liu2021unifiedBrainMRTranslationDisentanglement; D2-NetDisentanglementSegmentationMissingModality), maintaining the clinical plausibility of the learned representations is particularly challenging in the context of brain MRI analysis, for a number of reasons. Firstly, pathological and anatomical features in brain MRIs typically exhibits some degree of dependency, but the exact structure is generally unknown a priori or requires extensive domain knowledge to accurately describe. For example, multiple sclerosis (MS) is a chronic neurological disease characterized by T2 hyperintense lesions and gadolinium-enhancing (Gd+) lesions in the brain and spinal cord. MS lesions are typically found in periventricular white matter but cannot occur in certain brain regions (e.g. within the ventricles). It is often the case that the observable pathological features in the brain (e.g. hyper-intense MS lesions) cannot be easily disentangled from subject-specific anatomical structures (e.g. sulcal pattern, ventricular shape) in a clinically plausible way due to dependencies between the anatomical and pathological generating factors. Secondly, fine-grained pathological features in brain MRIs may have very high clinical significance. Although the omission of finer details does not preclude generative models from achieving a good approximation of the overall data distribution (i.e. a good test set log likelihood), such details could be highly meaningful in certain downstream tasks. This is, again, exemplified by MS lesions that can be as small as 3 mm but still represent a significant marker of disease activity (MSLesionSize). A robust and clinically useful representation in the context of brain MRI analysis must therefore: 1) faithfully model the spatial distribution of the lesions and their dependency on patients’ brain anatomical structures, and 2) accurately capture and disentangle lesions of all sizes, along with other potential imaging markers characterized by fine-grained details in the images (e.g. white matter texture).

Many existing methods for learning disentangled representations are based on the variational auto-encoder (VAE) (Kingma2014VAE; burgess2018betaVAE; higgins2016betaVAE). However, a straight adoption of VAE for pathology-anatomy disentanglement in patient brain MRIs is often unsatisfactory due to the aforementioned challenges. There are two common failure modes. Firstly, mean-field variational inference poses the unlikely assumption that all generating factors are independent (VariationalReview), thereby failing to capture the inherent dependencies that exist between the anatomical and pathological generative factors in patient brain images. This may result in the synthesis of clinically implausible samples, as depicted in fig:IntroInvalidMSLesions, where lesions appear in clinically impossible regions (red arrow). Moreover, the independence assumption may exacerbates another pitfall of VAEs (VariationalReview), namely, the tendency to suffer from latent mode-covering or over-generalisation (hu2018onUnifying). This may lead to lower synthesis quality or an inability to preserve crucial finer details in medical images. fig:IntroBlurredMSLesions depicts such behaviour in a mean-field model where small lesions are obscured in the reconstructed images. Failure to capture fine details in the learned representation could lead to significantly poorer performance in downstream tasks.

fig:IntroFailureModes [][b] [angle=0]InvalidLesions [][b] [angle=0]BlurredLesionsWithArrows

Figure 1: Failure modes of mean-field VAE. (a) A sample drawn from a mean-field model with lesions in clinically invalid locations (red arrow). (b) Mean-field model leads to missed small lesions in the reconstructed image (yellow arrows). Our proposed structured model does not suffer from these issues

In this work, we propose to address these issues by using structured variational inference (StructuredVI) for fine-grained pathology-anatomy disentanglement in brain MRI. Specifically we model the dependencies that typically exist between pathological and anatomical features via multi-scale VAEs with a hierarchical latent structure (fig:GraphicalModel). We evaluate the effect of different priors in the reconstruction quality and the degree of disentanglement both quantitatively and qualitatively. We find that more expressive structured priors indeed lead to higher reconstruction quality and the preservation of important small pathological details. We verify that the model is capable of pathology disentanglement in an unsupervised setting. With an optional supervision objective, the model is shown to achieve a higher degree of disentanglement and to be capable of capturing latent dependency.

2 Model

Our model accepts image space observations sampled from a dataset as inputs. The model adheres to the customary setup of VAEs with a structured latent space consisting of disjoint variable groups (layers) that follow a hierarchical structure, as portrayed in fig:GraphicalModel. Each group or layer consists of spatial latent variables (SpatialGMVAE) at various resolutions scales, denoted as . The inference and generative models can therefore be expressed as follows:



Figure 2: The hierarchically structured latent space with disjoint variable groups (layers), encoder distribution parameters () and decoder distribution parameters (). Dashed lines are active during inference. Solid lines are active during generation and, if residual parameterisation is used, also during bidirectional inference.


Figure 3: Network Architecture

tab:ModelParametrizations Model ELBO VAE (M.1) nVAE (M.2) nVMP (M.3) nVMP+ (M.4)

Table 1: ELBOs for all four model parameterisation options examined in this work.

In this paper, we examine four ways to construct the ELBO objective, summerized in tab:ModelParametrizations.

(1) VAE (1): a vanilla multi-scale VAE with mean-field, parameter-free standard Gaussian priors at each layer. This model has a “hierarchy” in the sense of having latent representations at various resolution scales, but is not “hierarchical” in its distributional parametrization as there is no explicit inter-group dependency nor explicit information sharing between the encoder and the decoder. In this parametrization:



(2) nVAE (1): a hierarchical model with residual normal parameterisation proposed by NVAE and LadderRasmusVHBR15. The features that set this model apart from its vanilla counterpart are the explicit information sharing between the encoder and the decoder networks, as well as its partially auto-regressive nature. Firstly, unlike conventional VAEs, decoder parameters in nVAE not only characterize the generative distribution (3b) but are also a part of the inference model and hence play an important role in characterizing the posterior distribution (3a). For latent groups other than the topmost one , the inference model is bidirectional

. It estimates the

relative variational posteriors (3a

) that characterize the deviation from priors obtained from preceeding layers of the decoder. With this design, KL optimization is expected to be simpler than when posteriors predict the absolute mean and variances at each layer.



Furthermore, nVAE is considered to be partially auto-regressive and hence a more expressive prior than the standard mean-field parametrization. While the prior for each group is dependent on those of the preceding layers , each element within the same latent group still adhere to the independence assumption as . We can hence calculate the relative KLD loss for each element with a simple analytic expression:


We propose two extensions to nVAE by integrating a VamPrior (VamPriorTomczakW17) into the hierarchical VAE setup for extra flexibility and expressiveness in the hierarchical latent structure. By incorporating encoder parameters in the trainable prior, we expect the following two models to achieve a greater extent of “coupling” or “collaboration” between the priors and the posteriors.

(3) nVMP (1) replaces the standard-Gaussian prior of the topmost layer of nVAE with a -component multimodal VamPrior characterized by trainable pseudo-inputs and encoder parameters . The subsequent layers still adhere to the hierarchical residual parameterisation of nVAE (we retain the residual Gaussian parameterisation for subsequent layers). The implication is that only is multi-modal whereas the distributions of all lower level “residual deviations” are assumed to be Gaussian.

(4) nVMP+ (1) extends nVMP with one more KL term between the encoder-driven VamPriors and the decoder priors is imposed for the entire hierarchy. In this case, the decoder “priors” are regarded as “intermediate posteriors” and encouraged to imitate the encoder-driven multi-modal distribution throughout the hierarchy. We postulate that this configuration adds an extra layer of information sharing between the encoder and the decoder networks which can potentially lead to further improvement in representation quality.

3 Experiments and Results

We validate our approach on two brain MRI datasets: the publically available Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset (mueller2005adni) (), and a proprietary MS dataset from a MS clinical trial (). The central 16 2-D slices of T1-weighted sequences were used for the AD experiments, while the central 24 2-D slices of Fluid Attenuated Inverse Recovery (FLAIR) sequence were used for the MS experiments. Expert T2 lesion segmentation labels for the MS experiments were provided. Both datasets were divided into non-overlapping training (60 %), validation (20 %) and testing (20 %) sets. Additional acquisition and pre-processing details are described in Appendix A.

We first train the model under an unsurpervised setting and evaluate the effect of incorporating additional prior structures on synthesis quality. As shown in Table LABEL:tab:MainResults, VAEs with more expressive structured priors indeed outperform their mean-field counterpart at the same model capacity in terms of image reconstruction fidelity.

tab:MainResults [MS] Model LL PSNR SSIM FID VAE 2758 25.1 0.72 0.058 nVAE 2458 25.7 0.75 0.023 nVMP 2374 25.9 0.75 0.031 nVMP+ 1953 26.6 0.79 0.035 [AD] Model LL PSNR SSIM FID VAE 2386 24.6 0.70 0.030 nVAE 2105 25.2 0.73 0.013 nVMP 1863 25.5 0.75 0.011 nVMP+ 842 26.8 0.80 0.007

Table 2: Reconstruction quality metrics. (wang2004SSIM; FID)

We additionally examine model behaviours in a supervised learning setting depicted in the bottom-left (purple) block in fig:GraphicalModel and fig:ArchDiagram. In this setting, we supplement the MS model with a lesion segmentation objective between a chosen “pathological” latent subset

and expert pathology (lesion) segmentation labels . The rest of the latent space (that remains unsupervised) are regarded as the anatomical latent subsets, denoted as .

Firstly, as one might expect, supervision is shown to enhance latent disentanglement as one may anticipate. Disease-related features in the synthesized images are noticeably more sensitive (VAEPCA) to perturbations in comapred to , as shown in fig:AttributeSensitivity, Appendix B. Such a disparity in attribute sensitivity is appreciable in unsupervised models, but is made much more pronounced by the selective latent supervision.

Secondly and more importantly, supervision helps to verify that the model is indeed actively using the latent structures. In models with autoregressive structures (nVAE, nVMP, nVMP+), knowledge from the supervision is propagated to the unsupervised “anatomical” latent units

, as in, those unsupervised latent units attain a higher linear predictability (Lasso regression

scores (eastwood2018frameworkDisentanglement), Table LABEL:tab:Informativeness) with respect to lesion volume. This is in contrast to the behaviour of the baseline mean-field VAE, where information from the supervision task is constrained within the supervised group . This observation shows that the model is indeed taking advantage of the extra structures brought by the autoregressive priors and the residual parameterisation and hence, indeed capable of modelling the dependencies between anatomical and pathological generating factors.

fig:StyleMixing [angle=0]StyleMixing

Figure 4: Pathology-anatomy disentanglement visualised by “style-mixing” between pairs of Alzheimer’s Disease (AD) and Multiple Sclerosis (MS) patient brain MRIs.

We can qualitatively evaluate pathology-anatomy disentanglement by swapping anatomical and pathological latent features between a pair of subjects in a manner similar to “style-mixing” (StyleGAN). As shown in fig:StyleMixing for representative examples, brain atrophy in AD patients (left), and T2 lesions in MS patients (right), are disentangled from the subject’s anatomical particularities (such as sulcal pattern), thus enabling the mixing the pathology of one patient with the anatomy of the other.

We may also leverage conditional distributions learned by the model to examine subject-specific pathology distributions. For example, based on learned representations of Subject B1 in fig:StyleMixing, we may visualise many possible disease states given this subject’s anatomy (fig:ConditionalResampling, top row) or explore how this subject’s lesions would manifest on other subjects’ brain anatomies (fig:ConditionalResampling, bottom row).


Figure 5: Visualising conditional distributions. Images in the top row are generated by fixing to that of Subject B1 and resampling the layer corresponding to . Images in the bottom row are generated by obtaining from other real samples and fixing to that of Subject B1.

4 Conclusions

We propose hierarchical VAEs with structured priors for learning pathology-anatomy disentangled representations of brain MRIs. Our model can faithfully capture imaging features, including fine-grained details, while accounting for pathology-anatomy dependencies to ensure sample validity. We additionally examine model bevaviours in a supervised learning setting. Supervision is shown to (1) further enhance latent disentanglement; and (2) enable the inspection of information propagation between latent groups for modelling pathology-anatomy interdependencies. Our model allows for robust and controllable brain MRI synthesis rich in high-frequency and pathologically-sound details, which could be meaningful for various downstream tasks.

The authors are grateful to the International Progressive MS Alliance for supporting this work (grant number: PA-1412-02420), and to the companies who generously provided the clinical trial data that made it possible: Biogen, BioMS, MedDay, Novartis, Roche / Genentech, and Teva. Funding was also provided by the Natural Sciences and Engineering Research Council of Canada, the Canadian Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs program, and a technology transfer grant from Mila - Quebec AI Institute. S.A. Tsaftaris acknowledges the support of Canon Medical and the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme (grant RCSRF1819 / 8 / 25). Supplementary computational resources and technical support were provided by Calcul Québec and the Digital Research Alliance of Canada. Falet, J.-P., was supported by an end MS Personnel Award from the Multiple Sclerosis Society of Canada, by a Canada Graduate Scholarship-Masters Award from the Canadian Institutes of Health Research, and by the Fonds de recherche du Québec - Santé / Ministère de la Santé et des Services sociaux training program for specialty medicine residents with an interest in pursuing a research career, Phase 1. This work was made possible by the end-to-end deep learning experimental pipeline developed in collaboration with our colleagues Justin Szeto, Eric Zimmerman, and Kirill Vasilevski. Additionally, the authors would like to thank Louis Collins and Mahsa Dadar for preprocessing the MRI data.


Appendix A Data Acquisition, Implementation and Training Details

All MRI sequences were acquired at a resolution of 1 mm 1 mm 1 mm, Each 2-D slice was downsampled to a resolution of 2 mm 2 mm. These were standardized to have zero-mean and unit variance.

We compare the four parameterisations in Table tab:ModelParametrizations with a 5-layer model () the exact same capacity. For each dataset, the latent space capacity is set to . We use the Adam optimizer (kingma2014adam) with a learning rate of 5e-5 and s weight decay of 1e-8.

Two loss re-weighting mechanisms are used in our training procedure: (1) We use a linear annealing schedule (fu_cyclical_2019_naacl) for KLD losses with a cycle length of 10000 iterations. The initial KLD learning rate is set to 2e-7. (2) To avoid posterior collapse, we use a KL Balancing trick suggested by (NVAE). We re-scale each KL term of the hierarchy with a coefficient proportional to the size of each latent layer as well as the KLD value of that layer. This mechanism encourages more balanced information attribution to each latent layer (vahdat2018dvae++LayerScalingKL; chen2016variationalLayerScalingKL).

Appendix B Additional Results

As discussed in Section 3, we evaluate layer-wise latent pathology informativeness of MS models by examining each layer’s linear predictability of a salient pathological attribute, T2 lesion volume. To quantify linear predictability, we train Lasso regressors () with latent representations obtained from each individual latent layer of each model and compute each Lasso regressors’s scores with respect to T2 lesion volume based on expert segmentation labels.

tab:Informativeness Model VAE 0.20 0.08 0.28 0.01 0.00 nVAE 0.11 0.12 0.00 0.00 0.00 nVMP 0.06 0.23 0.01 0.01 0.00 nVMP+ 0.03 0.00 0.20 0.00 0.00 psVAE 0.00 0.00 0.62 0.08 0.01 psNVAE 0.31 0.20 0.65 0.02 0.00 psNVMP 0.54 0.23 0.63 0.02 0.01 psNVMP+ 0.21 0.37 0.56 0.09 0.00

Table 3: (MS) Layer-wise latent informativeness with respect to T2 lesion volume. Models with prefix “ps-” have partially supervised latent spaces (i.e. are -supervised).

In tab:Informativeness, rows 1-4 show the scores of the unsupervised models, which are generally poor; rows 5-8 show the same metrics for supervised models where supervision is provided to () as an additional lesion segmentation objective. Models with autoregressive structures (nVAE, nVMP, nVMP+) benefit more from the supervision - knowledge from the supervision is propagated to the unsupervised “anatomical” latent units, resulting in higher scores even in the unsupervised latent subsets. This shows that the model is indeed actively using the latent structures.

Furthermore, by separately scaling each latent subgroup and seeing the changes in the generated images, we can examine the features captured by each individual latent group. Latent disentanglement, as indicated by the remarkable disparity in layer-wise pathological attribute sensitivity to scaling, is made evident with such visualisation.


Figure 6: (MS) Layer-wise pathological attribute sensitivity visualised by individually scaling each layer in the latent hierarchy, from (top row) to (bottom row).

In this particular example (fig:AttributeSensitivity), the appearance of the hyper-intense MS lesions in the synthesised images is relatively insensitive to multiplicative perturbation in all but one latent layer, . The layer with the highest pathological attribute sensitivity, , is hence considered to be a disentangled “pathological” latent subset .
We note that even in the unsupervised setting, disease-related features in the synthesized images are noticeably more sensitive to changes in a small subset of latent variables than the rest, which allows us to identify such a subset as and the rest as (anatomical latent subsets) in a post-hoc manner. Such disparity in pathological attribute sensitivity is much more pronounced in the “selective supervision” setting (bottom-left purple block in fig:GraphicalModel and fig:ArchDiagram), where the additional supervision is given to a chosen layer .


Figure 7: (AD) Variations captured by each layer of the model. Images at the top row are fully resampled at each level of the hierarchy. On each subsequent row , we show the residual variation of layer by fixing latent codes at the top layers.


Figure 8: (MS) Clusters discovered by VamPrior.