Precision medicine involves choosing a treatment that best balances efficacy against side effects/personal preference for the individual. In many clinical contexts, delays in finding an effective treatment can lead to significant morbidity and irreversible disability accrual. Such is the case for multiple sclerosis, a chronic neurological disease of the central nervous system. Although numerous treatments are available, each has a different efficacy and risk profile, complicating the task of choosing the optimal treatment for a particular patient. One hallmark of MS is the appearance of lesions visible on T2-weighted MRI sequences of the brain and spinal cord [RRMSactivity]. The appearance of new or enlarging, NE-T2, lesions on sequential MRI indicates new disease activity. Suppression of NE-T2 lesions constitutes a surrogate outcome used to measure treatment efficacy. Predicting the future effect of a treatments on NE-T2 lesions counts using brain MRI prior to treatment initiation would therefore have the potential to be an early and non-invasive mechanism to significantly improve patient outcomes.
Predicting future treatment effects first requires accurate prognostic models for future disease evolution. Deep learning has been used to predict prognostic outcomes in a variety of medical imaging domains[SmokingPrognosis, StrokePrognosis, AlzeheimersPrognosis, BRATSSurvivalprog]. In the context of MS, research has mainly focused on the related tasks of lesion segmentation [VALVERDE2017159MSseg, MSSEGFULLYCONV, TanyaUncMSseg, nichyporuk2021cohort] and NE-T2 lesion detection [DoyleLesionDetection, nazactivitypred]. Recently, deep learning models have been developed for the binary prediction of future disability progression [pmlr-v102-tousignant19a] and the binary prediction of future lesion activity [NazFirstActivityPaper]
, as defined by the presence of more than one NE-T2 or Gadolinium enhancing lesions. The prediction of more granular outcomes, such as future NE-T2 lesion counts, remains an open research topic. Furthermore, models are typically built as prognostic models for untreated patients. Predicting prognosis on treatment requires addressing the additional challenge of learning the effect each treatment will have on a particular patient based on their MRI, and thus potentially subtle MRI markers predictive of future treatment response. Machine learning models that have been devised to predict treatment response when it is directly measurable on the image (e.g. shrinking tumour)[XuLung58, TreatmentEffectBreastTumor], are insufficient for the context of MS and for other diseases where treatment response must be evaluated relative to placebo or other treatments. Previous work by [Doyle2017PredictingFD] examined the ability of classical machine learning models to perform binary activity prediction for patients on MS treatments and identify potential treatment responders.
Several machine learning methods have been developed to estimate treatment effects for single treatment-control comparisons [louizos2017causal, shalit2017TARNET, shi2019adaptingDRAGONNET], with extensions to multiple treatments [ZhaoCTS, zhao2020mutli-treatment]. zhao2020mutli-treatment also integrate the notion of value and cost (or risk) associated with a treatment, crucial elements for making sound recommendations, particularly when higher efficacy medications may be associated with more severe side effects. However, applications to precision medicine have largely focused on using clinical data as input DeepSurv; Fotso2018DeepNN; Ching2018CoxnnetAA; Jaroszewicz2014UpliftMW. Existing MS models MPScoringCoxPH; RioScore; OneYearMSINF-B are also limited to clinical features (e.g. demographics), and established group-level MRI-derived features (e.g. contrast-enhancing lesion counts, brain volume). Deep learning models would permit learning individual, data-driven features of treatment effect directly from MRI sequences and should provide improvement on existing strategies.
This paper introduces the first image-based treatment recommendation framework for MS that combines prognosis prediction, treatment effect estimation, and treatment-associated risk (fig:block_diagram) evaluation. Our models takes multi-sequence MRI at baseline, along with available clinical information, as input to a multi-head deep neural network that learns shared latent features in a common ResNet encoder he2015deepResNet. It then learns treatment-specific latent features in each output head for predicting future potential outcomes on multiple treatments. Predictions, effect estimates, and treatment risk are then supplied to a Clinical Decision Support Tool that outputs a treatment recommendation. This framework is evaluated on a proprietary multi-trial, multi-scanner dataset of MS patients exposed to five different treatment options. The multi-head model not only accurately predicts, from baseline, future NE-T2 lesion counts that will develop 1-2 years ahead on all treatments, but it is able to reliably identify subgroups with heterogeneous treatment effects (groups for which the treatment is more or less effective) as measured by causal inference metrics. Finally, this framework shows that improved lesion suppression can be achieved using the support tool, especially when treatment risk is being considered.
2.1 Estimating Treatment Effect
Let be the input features (multi-sequence MRI and available clinical data), be the outcome of interest, and be the treatment allocation in the case where is a control (e.g. placebo) and the remaining are treatment options. Given an observational dataset , the individual treatment effect (ITE) for patient can be defined using the Neyman/Rubin Potential Outcome Framework Rubin1974EstimatingCE as , where and represents potential outcomes on treatment and control, respectively. The ITE is therefore a fundamentally unobservable causal quantity because only one of these potential outcomes is realized. Treatment effect estimation in machine learning therefore relies on a related causal estimand, the conditional average treatment effect (CATE)
The causal expectations can be recovered from the observational data as follows
which can be estimated in an unbiased fashion using randomized control trial data (as in our case), where Gutierrez2017. Further assumptions are needed in the context of non-randomized data PropensityScoreUsage.
2.2 Network Architecture
Our network is based on TARNET shalit2017TARNET and its multi-treatment extension zhao2020mutli-treatment. Specifically, we employ a single multi-head neural network composed of different CATE estimators,
is parametrized by a neural network trained on the corresponding treatment distribution, and all share parameters in the earlier layers. A ResNet encoder is used as the shared trunk, and after a global max pooling layer, the encoded features are concatenated with any available clinical information before being processed by treatment-specific multilayer perceptrons (MLPs). The model architecture is depicted in fig:NetworkDiagram.
During training, mini-batches are randomly sampled from and fed through the network, outputting a prediction for each treatment head. Losses are computed at each head for the set of prediction-target pairs where ground truth is available for that treatment, . Shared parameters are learned in the common layers, which receive gradients for each sample irrespective of treatment allocation, while treatment-specific parameters are learned in the treatment heads from samples allocated to the corresponding treatment. At inference, predictions from all output heads are used for every patient. Full implementation details can be seen in sec:ImpDets.
The tasks of regression and classification are examined. Regressing future NE-T2 lesion counts offers the most intuitive interpretation of treatment effect
(i.e. differences in lesion count), but is sensitive to outliers in the count distribution (e.g. patients with 50 lesions). On the other hand, MS guidelinesMSguidelines report a cutoff of () new/enlarging T2 lesions after which a treatment should be changed to a more effective one. We therefore also consider the binary classification task of predicting minimal evidence of disease activity on future T2 sequences, referred to as MEDA-T2, as having NE-T2 lesions. Unfortunately, the treatment effect at the binary scale would not capture the true range of effects, and using the softmax outputs to compute has a less informative interpretation as compared to regressed counts. For the regression loss, we use Mean Squared Error (MSE) on the log-transformed count, , to reduce the weight of outliers. For the classification loss, we use binary cross entropy (BCE) on the binary MEDA-T2 outcome, , where is the indicator function.
2.3 Clinical Decision Support Tool
Based on zhao2020mutli-treatment, we define to be the risk associated with treatment . This can be set by a clinician and patient based on their experience/preference, or could be extrapolated from long-term drug safety data. In the case of MS, drugs can be grouped into lower efficacy (LE), moderate efficacy (ME), and high efficacy (HE). An escalation strategy (starting with LE and escalating if necessary) is often used to avoid unnecessarily exposing patients to side effects attributed to higher efficacy drugs msescalationguidelines. We therefore set , where is the constant incremental risk associated with moving up the ladder of efficacy (which is set by the user). takes on a value of for placebo, for LE, for ME, and for HE. We define risk-adjusted CATE, as
Assuming negative CATE indicates benefit, here a reduction in NE-T2 lesions, the tool then recommends treatment such that
3 Experiments and Results
The dataset is composed of patients from four randomized clinical trials: BRAVO BRAVO, OPERA 1 OPERA, OPERA 2 OPERA, and DEFINE DEFINE. Each trial enrolled patients with relapsing-remitting MS (the most common form) and had similar recruitment criteria. We excluded patients who did not complete all required MRI timepoints, or were missing MRI sequences/clinical features at baseline, resulting in a dataset with . Treatments for these trials are categorized based on their efficacy at the group level: placebo (), no efficacy (NE, ), lower efficacy (LE, ), moderate efficacy (ME, ), and high efficacy (HE, ) with each level representing one treatment. Pre-trial statistics and treatment distributions can be seen in Appendix F.
All trials acquired MRIs at 1 x 1 x 3 mm resolution at the following timepoints: baseline (prior to treatment initiation), one year, and two years. Each contains 5 sequences: T1-weighted, T1-weighted with gadolinium contrast agent, T2-weighted, Fluid Attenuated Inverse Recovery, and Proton Density weighted. In addition, expert-annotated gadolinium-enhancing (Gad) lesion masks and T2 lesion labels are provided. The baseline MRIs and lesion masks were used as input to our model, while the NE-T2 lesion counts occurring between year one and two were used to compute count target and the binarized MEDA-T2 outcome. Patient’s who did not complete all the required MRIs were excluded as they would not have a NE-T2 count. Percentage of MEDA-T2 in our dataset for placebo, NE, LE, ME, and HE are is 45.7%, 54.4%, 63.8%, 77.4%, 99.6%, respectively. In addition, baseline age, sex, and Expanded Disabillity Status Scale Kurtzke1444EDSS, a clinical disability score, were used as additional clinical features as inputs to our model. The dataset was divided into a 4x4 nested cross validation scheme for model evaluation NestedCrossVal. Following Soltys2014EnsembleMF’s use of ensembling, the 4 inner-fold models are used as members of an ensemble whose prediction on the outer fold’s test set is the average of its members.
3.2 Predicting Future Lesion Suppression
We conduct three experiments to determine the best performing framework for predicting the observed future MEDA-T2 given different combinations of inputs, targets, and loss functions. The first compares the performance of the proposed single multi-head architecture with the performance ofindependently trained networks. The second assesses the benefit of using both imaging and clinical features. The third compares binary classification of MEDA-T2 with binarization of the output of a regression model trained directly on the NE-T2 lesion counts.
Model performance is evaluated using average precision (AP) due to class imbalances in some of the treatment arms, particularly on HE. The random baseline reflects the positive MEDA-T2 label fraction on each arm. For an improved estimate of the generalization error, metrics are computed from the aggregated outer fold test set predictions. Results are shown in tab:factualPR. The multi-head architecture improves APs across most treatment arms, and the concatenation of clinical features provides an additional boost in performance. Finally, the multi-head binarized regression model with clinical data concatenation outperformed the binary classification equivalent.
Given its strong performance, we performed the following evaluations using the regression model. We evaluated the MSE on the non-binarized output of the regression model (the log-lesion count), which demonstrates an improvement over the random baseline (mean log lesion count) for all treatments except HE (see tab:factualRegressionTarget). The failure to regress lesion counts on HE can be explained by the extremely small variance in the target distribution, with only 5% of all test patients havingfuture NE-T2 lesion counts.
3.3 Estimating Treatment Effects
Given that the regression model outperforms alternatives on MEDA-T2 classification, and because it provides added granularity and a more intuitive interpretation, we used this model for CATE estimation. CATE estimates are computed for each treatment arm relative to placebo.
To evaluate the quality of the CATE estimation, we report uplift bins AscarzeQuartileAD at three thresholds of predicted effect. Response () values are binned into tertiles, and the average difference between the ground truth lesion count for patients who factually received the treatment and those who factually received placebo is computed for each treatment . The result, shown in subfig:countAD, demonstrates individuals predicted to respond most (top 33%) have a significantly greater reduction in lesion count over the entire group, and the ones predicted to respond least (bottom 33%) have a smaller reduction than the entire group. This suggests the model correctly identifies heterogeneous treatment effects. Uplift bins at different resolutions can be seen in Appendix D.
3.4 Clinical Decision Support Tool In Action
We now illustrate how the tool could be used in practice. Assuming each drug is associated with a different risk profile (see sec:ptp), fig:Examplepatients illustrates examples of potential outcomes for two patients. Patient (a) might opt for either a HE efficacy option if they are not worried about greater risk of side effects or cost, or might select a ME option if they are more risk-averse. Patient (b), in turn, might opt for a drug that is NE at the group level but that is predicted to be of comparable efficacy to other options in their particular case.
Individual potential outcome predictions cannot be evaluated due to the lack of ground truth, but we can evaluate the group outcomes for those who received the recommended treatment. To do so, we adjust the ground-truth future NE-T2 lesion count for each individual who received the recommended treatment by adding the risk associated with that treatment, , and compare their average risk-adjusted lesion count to the group who received a non-recommended treatment (subfig:countCDST). Patients who were factually assigned treatment based on the system’s recommendation had a reduction in expected adjusted lesion count for any value of the incremental cost (varied along the -axis) which indicates the tool provides better treatment recommendations when minimizing treatment-associated risk.
In this paper, we introduce the first medical imaging-based deep learning model for recommending optimal treatments in MS. The model predicts future NE-T2 counts and MEDA-T2 with high precision on 5 different treatments, and finds sub-groups with heterogeneous treatment effects. However, highly effective suppression of new lesion formation may have only a modest effect on long term disability progression. Current work is focused on predicting stronger markers of disability progression, so as to improve the value of the decision support tool. Additionally, the model’s recommendations have the potential to balance efficacy against treatment associated risks and patient preference. However, our current support tool uses linear scaling of risk between treatments. A comprehensive risk adjustment model that incorporates patient preferences, side effects, cost and other inconveniences would provide a more holistic clinical support tool but is beyond the scope of this paper. Future improvements could also be made by estimating treatment effect uncertainty JessonCausalFailure and explicitly optimizing adjusted CATE zhao2020mutli-treatment.
This investigation was supported (in part) by an award from the International Progressive Multiple Sclerosis Alliance (award reference number PA-1412-02420), the Canada Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs program (Arbel), the Natural Sciences and Engineering Research Council of Canada (Arbel), an end MS Personnel Award from the Multiple Sclerosis Society of Canada (Falet), a Canada Graduate Scholarship-Masters Award from the Canadian Institutes of Health Research (Falet), and the Fonds de recherche Santé / Ministère de la Santé et des Services sociaux training program for specialty medicine residents with an interest in pursuing a research career, Phase 1 (Falet). Supplementary computational resources and technical support were provided by Calcul Québec, WestGrid, and Compute Canada. Additionally, the authors would like to thank Louis Collins and Mahsa Dadar for preprocessing the MRI data, Zografos Caramanos, Alfredo Morales Pinzon, Charles Guttmann and István Mórocz for collating the clinical data, Sridar Narayanan. Maria-Pia Sormani for their MS expertise, and Behrooz Mahasseni for many helpful discussions during model development.
Appendix A Implementation Details
The MRI sequences are first clipped between standard deviations and then normalized to per sequence. The MRI sequences are then resampled to 2x2x2 resolution and cropped for a final dimension of 72x76x52. The clinical data is normalized to .
As mentioned in the Network Architecture section, the trunk of the model consists of three ResNet blocks followed by several MLPs. Each ResNet block contains two convolutional blocks followed by a residual addition. Each convolutional block contains a convolution (kernel size 3, stride 1), Instance Normalizationulyanov2017instance, a dropout layer JMLR:v15:Dropout with , and a LeakyReLU activation Maas2013RectifierNI
. Each ResNet block, with the exception of the last, is followed by an max pooling operation with kernel size 2. In the three ResNet blocks, the number of kernels for each convolution is [32, 64, 128] respectively. After the three ResNet blocks, the latents are flattened using a global average pool before concatenating the features with the clinical information and inputting the combined latent space to the MLPs. Each of the 5 MLPs in the network consist of three hidden layers which have dimensions [128,32,16] and use ReLU activationsRelu with no dropout. For training, we used the AdamW optimizerloshchilov2019AdamW with a learning rate of .0001 and a batch size of 8.
For models using imaging data and clinical data, the clinical data included age, gender and baseline EDSS. For the models using clinical data only, the clinical data included age, gender, baseline EDSS, baseline T2 lesion volume, and baseline Gad lesion count.
Appendix B Lesion Counts
Appendix C Treatment Effect Analysis with the binary MEDA-T2 outcome
Appendix D Additional Uplift Bins
Appendix E Additional Results
Appendix F Pretrial Patient Statistics.
Appendix G Significance Values
Appendix H MRI Preprocessing
Scans were first denoised Denoising, corrected for intensity heterogeneity intensitynonuniformcorrection, and normalized into the range 0-100. Second, for each patient, the T2w, PD, and FLAIR scans were co-registered to the structural T1w scan using a 6-parameter rigid registration and a mutual information objective function collinsregistration. The T1w scans were then registered to an average template defining stereotaxic space collinsANIMAL; stxspace. All volumes are resampled onto a 1 mm isotropic grid using the T1-to-stx space transformation (for the T1w data) or the transformation that results from concatenating the contrast-to-T1 and T1-to-stx transformation (for the other contrasts).