In medical image analysis, various representation learning methods have been introduced to disentangle particular anatomical or pathological structures from the the rest of the image under different observations (SotosDisentanglementTutorial; fragemann2022reviewDisentanglement). Although disentanglement of observable evidence of disease (e.g. brain lesions or atrophy) from subject-specific anatomy has been shown to be helpful in various downstream tasks (qin2019unsupervisedDisentanglementforRegesteration; DisentangledAAEAgePredictionMissingModality; zhou2021chestCOVIDdisentanglement; zuo2021informationUnsupervisedDomainAdaptationDisentanglement; liu2021semisupervisedDisentanglementdomain-generalised; liu2021unifiedBrainMRTranslationDisentanglement; D2-NetDisentanglementSegmentationMissingModality), maintaining the clinical plausibility of the learned representations is particularly challenging in the context of brain MRI analysis, for a number of reasons. Firstly, pathological and anatomical features in brain MRIs typically exhibits some degree of dependency, but the exact structure is generally unknown a priori or requires extensive domain knowledge to accurately describe. For example, multiple sclerosis (MS) is a chronic neurological disease characterized by T2 hyperintense lesions and gadolinium-enhancing (Gd+) lesions in the brain and spinal cord. MS lesions are typically found in periventricular white matter but cannot occur in certain brain regions (e.g. within the ventricles). It is often the case that the observable pathological features in the brain (e.g. hyper-intense MS lesions) cannot be easily disentangled from subject-specific anatomical structures (e.g. sulcal pattern, ventricular shape) in a clinically plausible way due to dependencies between the anatomical and pathological generating factors. Secondly, fine-grained pathological features in brain MRIs may have very high clinical significance. Although the omission of finer details does not preclude generative models from achieving a good approximation of the overall data distribution (i.e. a good test set log likelihood), such details could be highly meaningful in certain downstream tasks. This is, again, exemplified by MS lesions that can be as small as 3 mm but still represent a significant marker of disease activity (MSLesionSize). A robust and clinically useful representation in the context of brain MRI analysis must therefore: 1) faithfully model the spatial distribution of the lesions and their dependency on patients’ brain anatomical structures, and 2) accurately capture and disentangle lesions of all sizes, along with other potential imaging markers characterized by fine-grained details in the images (e.g. white matter texture).
Many existing methods for learning disentangled representations are based on the variational auto-encoder (VAE) (Kingma2014VAE; burgess2018betaVAE; higgins2016betaVAE). However, a straight adoption of VAE for pathology-anatomy disentanglement in patient brain MRIs is often unsatisfactory due to the aforementioned challenges. There are two common failure modes. Firstly, mean-field variational inference poses the unlikely assumption that all generating factors are independent (VariationalReview), thereby failing to capture the inherent dependencies that exist between the anatomical and pathological generative factors in patient brain images. This may result in the synthesis of clinically implausible samples, as depicted in fig:IntroInvalidMSLesions, where lesions appear in clinically impossible regions (red arrow). Moreover, the independence assumption may exacerbates another pitfall of VAEs (VariationalReview), namely, the tendency to suffer from latent mode-covering or over-generalisation (hu2018onUnifying). This may lead to lower synthesis quality or an inability to preserve crucial finer details in medical images. fig:IntroBlurredMSLesions depicts such behaviour in a mean-field model where small lesions are obscured in the reconstructed images. Failure to capture fine details in the learned representation could lead to significantly poorer performance in downstream tasks.
In this work, we propose to address these issues by using structured variational inference (StructuredVI) for fine-grained pathology-anatomy disentanglement in brain MRI. Specifically we model the dependencies that typically exist between pathological and anatomical features via multi-scale VAEs with a hierarchical latent structure (fig:GraphicalModel). We evaluate the effect of different priors in the reconstruction quality and the degree of disentanglement both quantitatively and qualitatively. We find that more expressive structured priors indeed lead to higher reconstruction quality and the preservation of important small pathological details. We verify that the model is capable of pathology disentanglement in an unsupervised setting. With an optional supervision objective, the model is shown to achieve a higher degree of disentanglement and to be capable of capturing latent dependency.
Our model accepts image space observations sampled from a dataset as inputs. The model adheres to the customary setup of VAEs with a structured latent space consisting of disjoint variable groups (layers) that follow a hierarchical structure, as portrayed in fig:GraphicalModel. Each group or layer consists of spatial latent variables (SpatialGMVAE) at various resolutions scales, denoted as . The inference and generative models can therefore be expressed as follows:
In this paper, we examine four ways to construct the ELBO objective, summerized in tab:ModelParametrizations.
(1) VAE (1): a vanilla multi-scale VAE with mean-field, parameter-free standard Gaussian priors at each layer. This model has a “hierarchy” in the sense of having latent representations at various resolution scales, but is not “hierarchical” in its distributional parametrization as there is no explicit inter-group dependency nor explicit information sharing between the encoder and the decoder. In this parametrization:
(2) nVAE (1): a hierarchical model with residual normal parameterisation proposed by NVAE and LadderRasmusVHBR15. The features that set this model apart from its vanilla counterpart are the explicit information sharing between the encoder and the decoder networks, as well as its partially auto-regressive nature. Firstly, unlike conventional VAEs, decoder parameters in nVAE not only characterize the generative distribution (3b) but are also a part of the inference model and hence play an important role in characterizing the posterior distribution (3a). For latent groups other than the topmost one , the inference model is bidirectional
. It estimates therelative variational posteriors (3a
) that characterize the deviation from priors obtained from preceeding layers of the decoder. With this design, KL optimization is expected to be simpler than when posteriors predict the absolute mean and variances at each layer.
Furthermore, nVAE is considered to be partially auto-regressive and hence a more expressive prior than the standard mean-field parametrization. While the prior for each group is dependent on those of the preceding layers , each element within the same latent group still adhere to the independence assumption as . We can hence calculate the relative KLD loss for each element with a simple analytic expression:
We propose two extensions to nVAE by integrating a VamPrior (VamPriorTomczakW17) into the hierarchical VAE setup for extra flexibility and expressiveness in the hierarchical latent structure. By incorporating encoder parameters in the trainable prior, we expect the following two models to achieve a greater extent of “coupling” or “collaboration” between the priors and the posteriors.
(3) nVMP (1) replaces the standard-Gaussian prior of the topmost layer of nVAE with a -component multimodal VamPrior characterized by trainable pseudo-inputs and encoder parameters . The subsequent layers still adhere to the hierarchical residual parameterisation of nVAE (we retain the residual Gaussian parameterisation for subsequent layers). The implication is that only is multi-modal whereas the distributions of all lower level “residual deviations” are assumed to be Gaussian.
(4) nVMP+ (1) extends nVMP with one more KL term between the encoder-driven VamPriors and the decoder priors is imposed for the entire hierarchy. In this case, the decoder “priors” are regarded as “intermediate posteriors” and encouraged to imitate the encoder-driven multi-modal distribution throughout the hierarchy. We postulate that this configuration adds an extra layer of information sharing between the encoder and the decoder networks which can potentially lead to further improvement in representation quality.
3 Experiments and Results
We validate our approach on two brain MRI datasets: the publically available Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset (mueller2005adni) (), and a proprietary MS dataset from a MS clinical trial (). The central 16 2-D slices of T1-weighted sequences were used for the AD experiments, while the central 24 2-D slices of Fluid Attenuated Inverse Recovery (FLAIR) sequence were used for the MS experiments. Expert T2 lesion segmentation labels for the MS experiments were provided. Both datasets were divided into non-overlapping training (60 %), validation (20 %) and testing (20 %) sets. Additional acquisition and pre-processing details are described in Appendix A.
We first train the model under an unsurpervised setting and evaluate the effect of incorporating additional prior structures on synthesis quality. As shown in Table LABEL:tab:MainResults, VAEs with more expressive structured priors indeed outperform their mean-field counterpart at the same model capacity in terms of image reconstruction fidelity.
We additionally examine model behaviours in a supervised learning setting depicted in the bottom-left (purple) block in fig:GraphicalModel and fig:ArchDiagram. In this setting, we supplement the MS model with a lesion segmentation objective between a chosen “pathological” latent subsetand expert pathology (lesion) segmentation labels . The rest of the latent space (that remains unsupervised) are regarded as the anatomical latent subsets, denoted as .
Firstly, as one might expect, supervision is shown to enhance latent disentanglement as one may anticipate. Disease-related features in the synthesized images are noticeably more sensitive (VAEPCA) to perturbations in comapred to , as shown in fig:AttributeSensitivity, Appendix B. Such a disparity in attribute sensitivity is appreciable in unsupervised models, but is made much more pronounced by the selective latent supervision.
Secondly and more importantly, supervision helps to verify that the model is indeed actively using the latent structures. In models with autoregressive structures (nVAE, nVMP, nVMP+), knowledge from the supervision is propagated to the unsupervised “anatomical” latent units
, as in, those unsupervised latent units attain a higher linear predictability (Lasso regressionscores (eastwood2018frameworkDisentanglement), Table LABEL:tab:Informativeness) with respect to lesion volume. This is in contrast to the behaviour of the baseline mean-field VAE, where information from the supervision task is constrained within the supervised group . This observation shows that the model is indeed taking advantage of the extra structures brought by the autoregressive priors and the residual parameterisation and hence, indeed capable of modelling the dependencies between anatomical and pathological generating factors.
We can qualitatively evaluate pathology-anatomy disentanglement by swapping anatomical and pathological latent features between a pair of subjects in a manner similar to “style-mixing” (StyleGAN). As shown in fig:StyleMixing for representative examples, brain atrophy in AD patients (left), and T2 lesions in MS patients (right), are disentangled from the subject’s anatomical particularities (such as sulcal pattern), thus enabling the mixing the pathology of one patient with the anatomy of the other.
We may also leverage conditional distributions learned by the model to examine subject-specific pathology distributions. For example, based on learned representations of Subject B1 in fig:StyleMixing, we may visualise many possible disease states given this subject’s anatomy (fig:ConditionalResampling, top row) or explore how this subject’s lesions would manifest on other subjects’ brain anatomies (fig:ConditionalResampling, bottom row).
We propose hierarchical VAEs with structured priors for learning pathology-anatomy disentangled representations of brain MRIs. Our model can faithfully capture imaging features, including fine-grained details, while accounting for pathology-anatomy dependencies to ensure sample validity. We additionally examine model bevaviours in a supervised learning setting. Supervision is shown to (1) further enhance latent disentanglement; and (2) enable the inspection of information propagation between latent groups for modelling pathology-anatomy interdependencies. Our model allows for robust and controllable brain MRI synthesis rich in high-frequency and pathologically-sound details, which could be meaningful for various downstream tasks.
The authors are grateful to the International Progressive MS Alliance for supporting this work (grant number: PA-1412-02420), and to the companies who generously provided the clinical trial data that made it possible: Biogen, BioMS, MedDay, Novartis, Roche / Genentech, and Teva. Funding was also provided by the Natural Sciences and Engineering Research Council of Canada, the Canadian Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs program, and a technology transfer grant from Mila - Quebec AI Institute. S.A. Tsaftaris acknowledges the support of Canon Medical and the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme (grant RCSRF1819 / 8 / 25). Supplementary computational resources and technical support were provided by Calcul Québec and the Digital Research Alliance of Canada. Falet, J.-P., was supported by an end MS Personnel Award from the Multiple Sclerosis Society of Canada, by a Canada Graduate Scholarship-Masters Award from the Canadian Institutes of Health Research, and by the Fonds de recherche du Québec - Santé / Ministère de la Santé et des Services sociaux training program for specialty medicine residents with an interest in pursuing a research career, Phase 1. This work was made possible by the end-to-end deep learning experimental pipeline developed in collaboration with our colleagues Justin Szeto, Eric Zimmerman, and Kirill Vasilevski. Additionally, the authors would like to thank Louis Collins and Mahsa Dadar for preprocessing the MRI data.
Appendix A Data Acquisition, Implementation and Training Details
All MRI sequences were acquired at a resolution of 1 mm 1 mm 1 mm, Each 2-D slice was downsampled to a resolution of 2 mm 2 mm. These were standardized to have zero-mean and unit variance.
We compare the four parameterisations in Table tab:ModelParametrizations with a 5-layer model () the exact same capacity. For each dataset, the latent space capacity is set to . We use the Adam optimizer (kingma2014adam) with a learning rate of 5e-5 and s weight decay of 1e-8.
Two loss re-weighting mechanisms are used in our training procedure: (1) We use a linear annealing schedule (fu_cyclical_2019_naacl) for KLD losses with a cycle length of 10000 iterations. The initial KLD learning rate is set to 2e-7. (2) To avoid posterior collapse, we use a KL Balancing trick suggested by (NVAE). We re-scale each KL term of the hierarchy with a coefficient proportional to the size of each latent layer as well as the KLD value of that layer. This mechanism encourages more balanced information attribution to each latent layer (vahdat2018dvae++LayerScalingKL; chen2016variationalLayerScalingKL).
Appendix B Additional Results
As discussed in Section 3, we evaluate layer-wise latent pathology informativeness of MS models by examining each layer’s linear predictability of a salient pathological attribute, T2 lesion volume. To quantify linear predictability, we train Lasso regressors () with latent representations obtained from each individual latent layer of each model and compute each Lasso regressors’s scores with respect to T2 lesion volume based on expert segmentation labels.
In tab:Informativeness, rows 1-4 show the scores of the unsupervised models, which are generally poor; rows 5-8 show the same metrics for supervised models where supervision is provided to () as an additional lesion segmentation objective. Models with autoregressive structures (nVAE, nVMP, nVMP+) benefit more from the supervision - knowledge from the supervision is propagated to the unsupervised “anatomical” latent units, resulting in higher scores even in the unsupervised latent subsets. This shows that the model is indeed actively using the latent structures.
Furthermore, by separately scaling each latent subgroup and seeing the changes in the generated images, we can examine the features captured by each individual latent group. Latent disentanglement, as indicated by the remarkable disparity in layer-wise pathological attribute sensitivity to scaling, is made evident with such visualisation.
In this particular example (fig:AttributeSensitivity), the appearance of the hyper-intense MS lesions in the synthesised images is relatively insensitive to multiplicative perturbation in all but one latent layer, . The layer with the highest pathological attribute sensitivity, , is hence considered to be a disentangled “pathological” latent subset .
We note that even in the unsupervised setting, disease-related features in the synthesized images are noticeably more sensitive to changes in a small subset of latent variables than the rest, which allows us to identify such a subset as and the rest as (anatomical latent subsets) in a post-hoc manner. Such disparity in pathological attribute sensitivity is much more pronounced in the “selective supervision” setting (bottom-left purple block in fig:GraphicalModel and fig:ArchDiagram), where the additional supervision is given to a chosen layer .