MATCH-Net: Dynamic Prediction in Survival Analysis using Convolutional Neural Networks

11/26/2018 ∙ by Daniel Jarrett, et al. ∙ University of Oxford 0

Accurate prediction of disease trajectories is critical for early identification and timely treatment of patients at risk. Conventional methods in survival analysis are often constrained by strong parametric assumptions and limited in their ability to learn from high-dimensional data, while existing neural network models are not readily-adapted to the longitudinal setting. This paper develops a novel convolutional approach that addresses these drawbacks. We present MATCH-Net: a Missingness-Aware Temporal Convolutional Hitting-time Network, designed to capture temporal dependencies and heterogeneous interactions in covariate trajectories and patterns of missingness. To the best of our knowledge, this is the first investigation of temporal convolutions in the context of dynamic prediction for personalized risk prognosis. Using real-world data from the Alzheimer's Disease Neuroimaging Initiative, we demonstrate state-of-the-art performance without making any assumptions regarding underlying longitudinal or time-to-event processes attesting to the model's potential utility in clinical decision support.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 11

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In Alzheimer’s disease—the annual cost of which exceeds $800 billion globally Marinescu et al. (2018)—the effectiveness of therapeutic treatments is often limited by the challenge of identifying patients at early enough stages of progression for treatments to be of use. As a result, accurate and personalized prognosis in earlier stages of cognitive decline is critical for effective intervention and subject selection in clinical trials. Conventional statistical methods in survival analysis often begin by choosing explicit functions to model the underlying stochastic process Doksum and Hbyland (1992); Kleinbaum and Klein (2005); Rodríguez (2005); Lee and Whitmore (2010); Fernández et al. (2016); Alaa and van der Schaar (2017). However, the constraints—such as linearity and proportionality Singh and Mukhopadhyay (2011); Cox (1992) in the popular Cox model—may not be valid or verifiable in practice.

In spite of active research, a conclusive understanding of Alzheimer’s disease progression remains elusive, owing to heterogeneous biological pathways Beckett et al. (2015); Hardy and Selkoe (2002), complex temporal patterns Jedynak et al. (2012); Donohue et al. (2014), and diverse interactions Frölich et al. (2017); Pascoal et al. (2017)

. Hence Alzheimer’s data is a prime venue for leveraging the potential advantages of deep learning models for survival. Neural networks offer versatile alternatives by virtue of their capacity as general-purpose function approximators, able to learn—without restrictive assumptions—the latent structure between an individual’s prognostic factors and odds of survival.

Contributions. Our goal is to establish a novel convolutional model for survival prediction, using Alzheimer’s disease as a case study for experimental validation. Primary contributions are threefold: First, we formulate a generalized framework for longitudinal survival prediction, laying the foundation for effective cross-model comparison. Second, our proposal is uniquely designed to capitalize on longitudinal data to issue dynamically updated survival predictions, accommodating potentially informative patterns of data missingness, and combining human input with model predictions. Third, we propose methods for deriving clinically meaningful insight into the model’s inference process. Finally, we demonstrate state-of-the-art results in comparison with comprehensive benchmarks.

2 Problem formulation

Let there be patients in a study, indexed . Time is treated as a discrete dimension of fixed resolution . Each longitudinal datum consists of the tuple , where

is the vector of covariates recorded at time

, and

is the binary survival indicator. Let random variable

denote the time-to-event, the time of right-censoring, and . Per convention, we assume that censoring is not correlated with the eventual outcome Tsiatis and Davidian (2004); Zheng and Heagerty (2005); van Houwelingen and Putter (2011). Let the complete longitudinal dataset be given by . Then we can define

(1)

to be the set of observations for patient extending from time into a width- window of the past, where parameter depends on the model under consideration. Given longitudinal measurements in , our goal is to issue risk predictions corresponding to length- horizons into the future. Formally, given a backward-looking historical window , we are interested in the failure function for forward-looking prediction intervals

; that is, we want to estimate the probability

(2)

of event occurrence within each prediction interval. Observe that parameterizing the width of the historical window results in a generalized framework—for instance, a Cox model only utilizes the most recent measurement; that is, . At the other extreme, recurrent models may consume the entire history; that is, . As we shall see, the best performance is in fact obtained via a flexible intermediate approach—that is, by learning the optimal width of a sliding window of history.

3 MATCH-Net

Figure 1: (a) The Match-Net architecture, with . (b) Illustration of temporal convolutions acting over feature channels. (c) The longitudinal context within which Match-Net operates, as well as the network’s prediction targets in association with the sliding window input mechanism.

We propose Match-Net: a Missingness-Aware Temporal Convolutional Hitting-time Network, innovating on current approaches in three respects. Dynamic prediction: Existing deep learning models issue prognoses on the basis of information from a single time point Faraggi and Simon (1995); Buckley and James (1979); Katzman et al. (2016); Xiang et al. (2000); Mariani et al. (1997); Liestbl et al. (1994); Biganzoli et al. (2009); Yu et al. (2011); Luck et al. (2017); Fotso (2018); Lee et al. (2018), potentially discarding valuable information in the presence of longitudinal data. We investigate temporal convolutions in capturing heterogeneous representations of temporal dependencies among observed covariates, enabling truly dynamic predictions on the basis of historical information. Informative missingness: Current survival methods rely on the common assumption that the timing and frequency of covariate measurements is uninformative van Houwelingen and Putter (2011); Rizopoulos (2012). By contrast, our model is missingness-aware—that is, we explicitly account for informative missingness by learning correlations between patterns of missingness and disease progression. Human input: Instead of issuing predictions solely on the basis of quantitative clinical measurements, we optionally incorporate clinicians’ most recent diagnoses of patient state into model estimates to examine the incremental informativeness of subjective input. Each innovation is a source of gain in performance (see Section 4; further details in the appendix).

Match-Net accepts as input a sliding-window of observed covariates , as well as a parallel binary mask of missing-value indicators taking on values of one to denote missing covariate measurements. In addition, the network optionally accepts a one-hot vector describing the most recent clinician diagnosis of Alzheimer’s disease progression. Starting from the base of the network, the convolutional block first learns representations of longitudinal covariate trajectories by extracting local features from temporal patterns. After each layer, filter activations from the auxiliary branch are concatenated with those in the main branch. Then, the fully-connected block captures more global relationships by combining local information. Finally, the output layers produce failure estimates

(3)

for pre-specified prediction intervals, where is the maximal horizon desired. This convolutional dual-stream architecture explicitly captures representations of temporal dependencies within each stream, as well as between covariate trajectories and missingness patterns in association with disease progression. This accounts for the potential informativeness of both irregular sampling (i.e. intervals between consecutive clinical visits may vary) and asynchronous sampling (i.e. not all features are measured at the same time) Rubin (1976); Bang et al.

. This also encourages the network to distinguish between actual measurements and imputed values, reducing sensitivity to specific imputation methods chosen. The final architecture uses convolutions of length

, with more filters per layer in the main branch.

Loss function. With the preceding notation, the negative log-likelihood of a single empirical result and model estimate in relation to some input window is given by

(4)

where denotes the parameters of the network. The total loss function is then computed to simultaneously take into account the quality of predictions for all prediction horizons , all times available along each patient’s longitudinal trajectory, and all patients in the survival dataset:

(5)

where accounts for failure or right-censoring. This is a natural generalization of the log-likelihood in Liestbl et al. (1994) to accommodate longitudinal survival. Weight function allows trading off the relative importance of different patients, time steps, and prediction horizons. First, this allows standardizing patient contributions with , thereby counteracting the bias against patients with shorter survival durations. Second, in the context of heavily imbalanced classes, this allows up-weighting positive instances—that is, input windows that correspond to actual failure.

4 Experiments

The Alzheimer’s Disease Neuroimaging Initiative (ADNI) study data is a longitudinal survival dataset of per-visit measurements for 1,737 patients Marinescu et al. (2018). The data tracks disease progression through clinical measurements at -year intervals, including quantitative biomarkers, cognitive tests, demographics, and risk factors. Our objective is to predict the first stable occurrence of Alzheimer’s disease for each patient. Further information on the dataset, preparation, and training can be found in the appendix.

Benchmarks. We evaluate Match

-Net against both traditional longitudinal methods in survival analysis and recent deep learning approaches; the former includes Cox landmarking and joint modeling methods, and the latter includes static and dynamic multilayer perceptrons and recurrent neural network models. Performance is evaluated on the basis of the area under the receiver operating characteristic curve (AUROC) and the area under the precision-recall curve (AUPRC), both computed with respect to prediction horizons

. Five-fold cross validation metrics are reported in Table 1.

Performance. Match-Net produces state-of-the-art results, consistently outperforming both conventional statistical and neural network benchmarks. Gains are especially apparent in AUPRC scores—improving on the MLP by an average of 15% and on joint models by 16% across all horizons, and by 27% and 26% for one-step-ahead predictions. To understand the sources of improvement, we observe a 4% gain in AUPRC from introducing the sliding window mechanism (MLP to S-MLP), a 9% gain from incorporating temporal convolutions (S-MLP to S-TCN), and a further 2% gain from accommodating informative missingness (S-TCN to Match-Net). In addition, including the most recent clinician diagnosis results in a further 17% gain (further details located in the appendix).

Match-Net S-TCN S-MLP FCN D-Atlas RNN MLP JM LM
auroc 0.5 0.962 0.961 0.959 0.954 0.959 0.949* 0.948* 0.913* 0.909*
1.0 0.942 0.941 0.932 0.930 0.929 0.930 0.930 0.917* 0.914*
1.5 0.902 0.902 0.897 0.895 0.892 0.891 0.890 0.881 0.878
2.0 0.909 0.908 0.904 0.903 0.896 0.901 0.895 0.894 0.890
2.5 0.886 0.884 0.881 0.883 0.884 0.883 0.874 0.883 0.878
auprc 0.5 0.594 0.580 0.500 0.536 0.517 0.464* 0.469* 0.473* 0.469*
1.0 0.513 0.505 0.447 0.453 0.423 0.410* 0.435 0.415* 0.412*
1.5 0.373 0.367 0.354 0.357 0.364 0.340 0.340 0.319 0.325
2.0 0.390 0.380 0.364 0.375 0.352 0.355 0.359 0.362 0.367
2.5 0.384 0.381 0.371 0.365 0.360 0.365 0.356 0.366 0.363
Table 1: Cross validation performance for and years: Match-Net without clinician input, sliding-window temporal convolutional networks (S-TCN) and multilayer perceptrons (S-MLP). Benchmarks include fully-convolutional networks (FCN) Long et al. (2015) adapted for sequence-based survival prediction, Disease Atlas (D-Atlas) Lim and van der Schaar (2018), baseline recurrent neural networks (RNN) including GRUs and LSTMs, static multilayer perceptrons (MLP), as well as conventional statistical methods for survival analysis—including joint modeling (JM) and Cox landmarking (LM). Bold values indicate best performance, and asterisks next to benchmark results indicate statistically significant difference (-value < 0.05) from Match-Net result. More detailed breakdown of gains are found the appendix.
Figure 2: Average saliency map indicating feature and temporal influence, computed using slopes of partial dependence on sample of numerical features across sliding window (; years).

Visualization. From the preceding, we observe the largest gains by introducing convolutions. While this is consistent with our motivating hypothesis that convolutions are better able to capture temporal patterns, clinicians often desire a degree of transparency into the prediction process Ribeiro et al. (2016); Avati et al. (2017). We adopt the partial dependence approach in Friedman (2001) to understand the input-output relationship, as well as examining the utility of convolutions. For each observed covariate , we want to approximate how the estimated failure function varies based on the value of . We define the dependence

(6)

where . By evaluating Equation 6 on the values present in the data, the influence of each covariate can be measured by estimating its slope. For a global picture of what impact each feature and time step has on the model’s predictions, we compute the influence for all features to produce an average saliency map Zeiler and Fergus (2014); Simonyan et al. (2013) highlighting the effect of convolutional layers. All else equal, absent convolutions we see that having worse covariate values any time step almost invariably has an upward impact on risk (i.e. negative impact on survival). On the other hand, with temporal convolutions we see that having worse covariate values at earlier time steps may result in a downward impact on risk (i.e. positive impact on survival), suggesting that convolutions may better facilitate modeling relative movements (e.g. sudden declines) than simply paying attention to levels. Further visualizations for added perspective on input-output relationships are found in the appendix.

References

  • Marinescu et al. [2018] Razvan V Marinescu, Neil P Oxtoby, Alexandra L Young, et al. Tadpole challenge: Prediction of longitudinal evolution in alzheimer’s disease. arXiv preprint arXiv:1805.03909, 2018.
  • Doksum and Hbyland [1992] Kjell A Doksum and Arnljot Hbyland.

    Models for variable-stress accelerated life testing experiments based on wiener processes and the inverse gaussian distribution.

    Technometrics, 34(1):74–82, 1992.
  • Kleinbaum and Klein [2005] D Kleinbaum and M Klein. Survival analysis statistics for biology and health. Survival, 510:91665, 2005.
  • Rodríguez [2005] Germán Rodríguez. Parametric survival models. Lectures Notes, Princeton University, 2005.
  • Lee and Whitmore [2010] Mei-Ling Ting Lee and GA Whitmore. Proportional hazards and threshold regression: their theoretical and practical connections. Lifetime data analysis, 16(2):196–214, 2010.
  • Fernández et al. [2016] Tamara Fernández, Nicolás Rivera, and Yee Whye Teh. Gaussian processes for survival analysis. In Advances in Neural Information Processing Systems, pages 5021–5029, 2016.
  • Alaa and van der Schaar [2017] Ahmed M Alaa and Mihaela van der Schaar. Deep multi-task gaussian processes for survival analysis with competing risks. In Proceedings of the 30th Conference on Neural Information Processing Systems, 2017.
  • Singh and Mukhopadhyay [2011] Ritesh Singh and Keshab Mukhopadhyay. Survival analysis in clinical trials: Basics and must know areas. Perspectives in clinical research, 2(4):145, 2011.
  • Cox [1992] David R Cox. Regression models & life-tables. In Breakthroughs in statistics. Springer, 1992.
  • Beckett et al. [2015] Laurel A Beckett, Michael C Donohue, Cathy Wang, et al. The alzheimer’s disease neuroimaging initiative phase 2: Increasing the length, breadth, and depth of our understanding. Alzheimer’s & Dementia, 11(7):823–831, 2015.
  • Hardy and Selkoe [2002] John Hardy and Dennis J Selkoe. The amyloid hypothesis of alzheimer’s disease: progress and problems on the road to therapeutics. science, 297(5580):353–356, 2002.
  • Jedynak et al. [2012] Bruno M Jedynak, Andrew Lang, Bo Liu, et al. A computational neurodegenerative disease progression score: method and results with the alzheimer’s disease neuroimaging initiative cohort. Neuroimage, 63(3):1478–1486, 2012.
  • Donohue et al. [2014] Michael C Donohue, Hélène Jacqmin-Gadda, Mélanie Le Goff, et al. Estimating long-term multivariate progression from short-term data. Alzheimer’s & Dementia, 10(5):S400–S410, 2014.
  • Frölich et al. [2017] Lutz Frölich, Oliver Peters, Piotr Lewczuk, et al. Incremental value of biomarker combinations to predict progression of mild cognitive impairment to alzheimer’s dementia. Alzheimer’s research & therapy, 9(1):84, 2017.
  • Pascoal et al. [2017] Tharick A Pascoal, Sulantha Mathotaarachchi, Monica Shin, et al. Synergistic interaction between amyloid and tau predicts the progression to dementia. Alzheimer’s & Dementia, 13(6):644–653, 2017.
  • Tsiatis and Davidian [2004] Anastasios A Tsiatis and Marie Davidian. Joint modeling of longitudinal and time-to-event data: an overview. Statistica Sinica, pages 809–834, 2004.
  • Zheng and Heagerty [2005] Yingye Zheng and Patrick J Heagerty. Partly conditional survival models for longitudinal data. Biometrics, 61(2):379–391, 2005.
  • van Houwelingen and Putter [2011] Hans van Houwelingen and Hein Putter. Dynamic prediction in clinical survival analysis. CRC, 2011.
  • Faraggi and Simon [1995] David Faraggi and Richard Simon. A neural network model for survival data. Statistics in medicine, 14(1):73–82, 1995.
  • Buckley and James [1979] Jonathan Buckley and Ian James. Linear regression with censored data. Biometrika, 66(3):429–436, 1979.
  • Katzman et al. [2016] Jared L Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. Deep survival: A deep cox proportional hazards network. stat, 1050:2, 2016.
  • Xiang et al. [2000] Anny Xiang, Pablo Lapuerta, Alex Ryutov, Jonathan Buckley, and Stanley Azen. Comparison of the performance of neural network methods and cox regression for censored survival data. Computational statistics & data analysis, 34(2):243–257, 2000.
  • Mariani et al. [1997] L Mariani, D Coradini, E Biganzoli, et al. Prognostic factors for metachronous contralateral breast cancer: a comparison of the linear cox regression model and its artificial neural network extension. Breast cancer research and treatment, 44(2):167–178, 1997.
  • Liestbl et al. [1994] Knut Liestbl, Per Kragh Andersen, and Ulrich Andersen. Survival analysis and neural nets. Statistics in medicine, 13(12):1189–1200, 1994.
  • Biganzoli et al. [2009] Elia M Biganzoli, Federico Ambrogi, and Patrizia Boracchi. Partial logistic artificial neural networks (plann) for flexible modeling of censored survival data. In Neural Networks, 2009. IJCNN 2009. International Joint Conference on, pages 340–346. IEEE, 2009.
  • Yu et al. [2011] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos. Learning patient-specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems, pages 1845–1853, 2011.
  • Luck et al. [2017] Margaux Luck, Tristan Sylvain, Héloïse Cardinal, Andrea Lodi, and Yoshua Bengio. Deep learning for patient-specific kidney graft survival analysis. arXiv preprint arXiv:1705.10245, 2017.
  • Fotso [2018] Stephane Fotso. Deep neural networks for survival analysis based on a multi-task framework. arXiv preprint arXiv:1801.05512, 2018.
  • Lee et al. [2018] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning approach to survival analysis with competing risks. AAAI, 2018.
  • Rizopoulos [2012] Dimitris Rizopoulos. Joint models for longitudinal and time-to-event data: With applications in R. Chapman and Hall/CRC, 2012.
  • Rubin [1976] Donald B Rubin. Inference and missing data. Biometrika, 63(3):581–592, 1976.
  • [32] Seo-Jin Bang, Yuchuan Wang, and Yang Yang. Phased-lstm based predictive model for longitudinal ehr data with missing values.
  • Long et al. [2015] Jonathan Long, Evan Shelhamer, and Trevor Darrell. Fully convolutional networks for semantic segmentation. In

    IEEE conference on computer vision and pattern recognition

    , pages 3431–3440, 2015.
  • Lim and van der Schaar [2018] Bryan Lim and Mihaela van der Schaar. Forecasting disease trajectories in alzheimer’s disease using deep learning. arXiv preprint arXiv:1807.03159, 2018.
  • Ribeiro et al. [2016] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin.

    Why should i trust you?: Explaining the predictions of any classifier.

    In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, pages 1135–1144. ACM, 2016.
  • Avati et al. [2017] Anand Avati, Kenneth Jung, Stephanie Harman, Lance Downing, Andrew Ng, and Nigam H Shah. Improving palliative care with deep learning. In Bioinformatics and Biomedicine (BIBM), 2017 IEEE International Conference on, pages 311–316. IEEE, 2017.
  • Friedman [2001] Jerome H Friedman.

    Greedy function approximation: a gradient boosting machine.

    Annals of statistics, pages 1189–1232, 2001.
  • Zeiler and Fergus [2014] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In European conference on computer vision, pages 818–833. Springer, 2014.
  • Simonyan et al. [2013] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034, 2013.
  • Parisot et al. [2018] Sarah Parisot, Sofia Ira Ktena, Enzo Ferrante, et al. Disease prediction using graph convolutional networks: Application to autism spectrum disorder and alzheimer’s disease. Medical image analysis, 2018.
  • Li and Luo [2017] Kan Li and Sheng Luo. Functional joint model for longitudinal and time-to-event data: an application to alzheimer’s disease. Statistics in medicine, 36(22):3560–3572, 2017.
  • Liu et al. [2017] Ke Liu, Kewei Chen, et al.

    Prediction of mild cognitive impairment conversion using a combination of independent component analysis and the cox model.

    Frontiers in human neuroscience, 11:33, 2017.
  • Vos et al. [2015] Stephanie JB Vos, Frans Verhey, Lutz Frölich, et al. Prevalence and prognosis of alzheimer’s disease at the mild cognitive impairment stage. Brain, 138(5):1327–1338, 2015.
  • Therneau [2018] Terry M Therneau. A package for survival analysis in s. R package version 2.42, 2018.
  • Hu [2018] FC Hu.

    Stepwise variable selection procedures for regression analysis.

    R package version 0.1.0, 2018.
  • Wu [2009] Lang Wu. Mixed effects models for complex data. Chapman and Hall/CRC, 2009.
  • Gal and Ghahramani [2016] Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In

    international conference on machine learning

    , pages 1050–1059, 2016.
  • Saito and Rehmsmeier [2015] Takaya Saito and Marc Rehmsmeier. The precision-recall plot is more informative than the roc plot when evaluating binary classifiers on imbalanced datasets. PloS one, 10(3):e0118432, 2015.

Appendix

Example use

While various settings may benefit from Match-Net as a matter of clinical decision support, we illustrate one possible application within the context of personalized screening. Figure 3 shows the historical risk trajectory and forward risk estimates for a randomly selected ADNI patient. During the first seven years of bi-annual visits, the patient exhibits cognitively normal behavior, and the Match-Net risk estimates—computed via measured biomarkers and neuropsychological tests—reflect this steady clinical state (see historical trajectory in blue). In fact, as of precisely seven years of follow-up, the predicted 30-month forward risk remains less than 10% (see predicted trajectory in orange).

Figure 3: Example application of Match-Net for personalized risk scoring.

However, two clinical visits later, when the patient returns for regular checkup and clinical measurements, the projected 30-month forward risk jumps to over 50% (see predicted trajectory in red). In this situation, the clinician is immediately alerted to the sudden increase in risk of dementia, and may decide to advise more frequent checkups, or to administer a wider range of tests and biomarker measurements in the immediate term to better assess the overall risk in light of the recent downturn. In fact, as it turns out in this case, the patient is indeed diagnosed with Alzheimer’s disease at years, shedding light on Match-Net’s potential as an early warning and subject selection system.

Related work

Model
Non-
Linearity
Deep
Learning
Direct-to-
Probability
Time-
Variance
Dynamic
Prediction
Cox (1972) N/A N/A
Faraggi & Simon (1995)
Katzman et al. (2016)
Luck et al. (2017)
Lee et al. (2018)
(This study) (2018)
Table 2: Summary of primary improvements by related work

The first study to investigate neural networks formally in the context of time-to-event analysis was done by Faraggi and Simon [1995]. By swapping out the linear functional in the Cox model for the topology of a hidden layer, their nonlinear proportional hazards approach was extended to other models for censored data, such as Rodríguez [2005] and Buckley and James [1979]. In 2016, Katzman et al. [2016]

were the first to apply modern techniques in deep learning to survival, in particular without prior feature selection or domain expertise. While previous studies following

Faraggi and Simon [1995]’s model generally produced mixed results Xiang et al. [2000], Mariani et al. [1997], Katzman et al. [2016] demonstrated comparable or superior performance of multilayer perceptrons in relation to conventional statistical methods.

Instead of predicting the hazard function as an intermediate objective, Liestbl et al. [1994] first proposed—and Biganzoli et al. [2009] further developed—an alternative approach to predict survival directly for grouped time intervals. In 2017, Luck et al. [2017] combined the use of the Cox partial likelihood with the goal of predicting probabilities for pre-specified time intervals. Inspired by the work of Yu et al. [2011]

on multi-task logistic regression models for survival, they generalized the idea to deep learning via multi-task multilayer perceptrons.

Recently, Lee et al. [2018] and Fotso [2018] in 2018 proposed learning the distribution of survival times directly, making no assumptions regarding the underlying stochastic processes—in particular with respect to the time-invariance of hazards. By being process-agnostic, Lee et al. [2018] demonstrated significant improvements over existing statistical, machine learning, and neural network survival models on multiple real and synthetic datasets. In the context of Alzheimer’s disease, Parisot et al. [2018] studied the use of medical image sequences for classifying patient progression. More pertinently, Lim and van der Schaar [2018]

used recurrent networks to forecast disease trajectories with ADNI data, but relied on the explicit assumption of exponential distributions. Finally, building on these developments, one of the main contributions of this study is the use of temporal convolutions for dynamic survival prediction—while making no assumptions, and allowing the associations between covariates and risks to evolve over time (see Table

2).

Details on dataset

We are primarily interested in the clinical status for each patient at any given time. An official diagnosis is recorded at each patient’s visit, and consists of two attributes. First, each diagnosis may be either stable or transitive. The former consists of stable diagnoses of normal brain functioning (“NL”), mild cognitive impairment (“MCI”), or Alzheimer’s disease (“AD”), and the latter consists of preliminary diagnoses indicating transitions between these categories, which may take the form of either conversions or reversions. Conversions indicate a forward progression in the disease trajectory, and reversions indicate a regression back towards an earlier stage of the disease.

Figure 4: State space of clinical diagnoses.

Patients may remain in stable or transition diagnosis states for any duration at a time. The average patient who receives a transition diagnosis is observed to persist in that state for one year, while some patients do not exit this state until almost 5 years have elapsed. Patients who receive a transition diagnosis may not actually be confirmed with a subsequent stable diagnosis; in fact, less than half of the transition diagnoses for dementia were confirmed by a stable diagnosis at the next step, and almost one quarter are never followed by a stable diagnosis at any point until right-censoring. In addition, patients often actually undergo reversion transitions back towards earlier stages of the disease; in fact, over 5% of the study population receive reversion diagnoses at some point in time.

Type Min Max Mean S.D. Missing
Event (AD) Categorical - - - - 30.1%
Static
Age Numeric 5.4E+01 9.1E+01 7.4E+01 7.2E+00 0.0%
APOE4 (Risk) Numeric 0.0E+00 2.0E+00 5.4E-01 6.6E-01 0.1%
Education Level Numeric 4.0E+00 2.0E+01 1.6E+01 2.9E+00 0.0%
Ethnicity Categorical - - - - 0.0%
Gender Categorical - - - - 0.0%
Marital Status Categorical - - - - 0.0%
Race Categorical - - - - 0.0%
Biomarker
Entorhinal Numeric 1.0E+03 6.7E+03 3.4E+03 8.1E+02 49.2%
Fusiform Numeric 7.7E+03 3.0E+04 1.7E+04 2.8E+03 49.2%
Hippocampus Numeric 2.2E+03 1.1E+04 6.7E+03 1.2E+03 46.6%
Intracranial Numeric 2.9E+02 2.1E+06 1.5E+06 1.7E+05 37.6%
Mid Temp Numeric 8.0E+03 3.2E+04 1.9E+04 3.1E+03 49.2%
Ventricles Numeric 5.7E+03 1.6E+05 4.2E+04 2.3E+04 41.6%
Whole Brain Numeric 6.5E+05 1.5E+06 1.0E+06 1.1E+05 39.7%
Cognitive
ADAS (11-item) Numeric 0.0E+00 7.0E+01 1.1E+01 8.6E+00 30.1%
ADAS (13-item) Numeric 0.0E+00 8.5E+01 1.8E+01 1.2E+01 30.7%
CRD Sum of Boxes Numeric 0.0E+00 1.8E+01 2.2E+00 2.8E+00 29.7%
Mini Mental State Numeric 0.0E+00 3.0E+01 2.7E+01 4.0E+00 29.9%
RAVLT Forgetting Numeric -1.2E+01 1.5E+01 4.2E+00 2.5E+00 30.9%
RAVLT Immediate Numeric 0.0E+00 7.5E+01 3.5E+01 1.4E+01 30.7%
RAVLT Learning Numeric -5.0E+00 1.4E+01 4.0E+00 2.8E+00 30.7%
RAVLT Percent Numeric -5.0E+02 1.0E+02 6.0E+01 3.8E+01 31.4%
Table 3: Summary and description of variables used in ADNI dataset.

Note on class imbalance. The per-patient failure rate is 14% (243 patients out of the total 1,737). However, given the online nature of the sliding window mechanism in training and testing, the effective fraction of observations with positive event labels for any prediction horizon is around 2%.

Data preparation

Since the ADNI dataset is an amalgamation from multiple related studies, most features are sparsely populated. Features with less than half of the entries missing are retained, leaving 18 numeric and 4 categorical features (see Table 3

); the latter are represented by one-hot encoding, resulting in 16 binary features. Consistent with existing Alzheimer’s studies, patients are aligned according to time elapsed since baseline measurements

Li and Luo [2017], Liu et al. [2017], Vos et al. [2015]. Timestamps are mapped onto a discrete axis with a fixed resolution of -year intervals; where multiple measurements qualify for the same destination, the most recent measurement per feature takes precedence. Original measurements were made at roughly -year intervals, so we observe that the average absolute deviation between original values and final timestamps amounts to an insignificant 4 days (i.e. less than 2% of each interval).

Where measurements are missing, values are reconstructed using zero-order hold interpolation. In addition, due to the fixed-width nature of the sliding window, the input tensor

for initial prediction times correspond to left-truncated information ; feature values are therefore extrapolated backwards for all intervals of the form . Note that regardless of the imputation mechanism, information on original patterns of missingness—due to truncation, irregular sampling, and asynchronous sampling alike Rubin [1976], Bang et al. —is preserved in the missing-value mask

provided in parallel to the network. Finally, to improve numerical conditioning, all features are normalized with empirical means and standard deviations from the training set data.

Training procedure

Training begins with input tuple , where , , and , and terminates with a set of calibrated network weights

. The network is trained until convergence, up to a maximum of 50 epochs. Training loss is only computed for event labels corresponding to actual recorded clinical visits (

i.e. timestamps with recorded covariate values); neither imputed nor forward-filled labels are included. Analogous to our total loss function definition, the convergence metric is the sum of performance scores across all prediction tasks,

(7)

where , optionally allow trading off the relative importance between the two measures and different horizons. In this study we simply use the unweighted sum, although any convex combination would be valid. Empirically, results are not meaningfully improved by favoring one metric over the other. Elastic net regularization is used, and validation performance is computed every 10 iterations. For early stopping, validation scores serve as proxies for generalization error. Positive instances are oversampled to counteract class imbalance. To augment training data, artificial labels for dummy events of the form are generated, and positive labels are filled forward for all horizons where corresponds to the first failure for that example. Patients diagnosed with Alzheimer’s at baseline (20%) are excluded from testing and validation, since survival is undefined in those cases.

Hyperparameter Selection Range
Connected Layers 1, 2, 3, 4, 5
Convolutional Layers 1, 2, 3, 4, 5
Dropout Rate 0.1, 0.2, 0.3, 0.4, 0.5
Epochs for Convergence 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
Learning Rate 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2
L1-Regularlisation None, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1
L2-Regularization None, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1
Minibatch Size 32, 64, 128, 256, 512
Number of Filters (Covariates) 32, 64, 128, 256, 512
Number of Filters (Masks) 8, 16, 32, 64, 128
Oversample Ratio None, 1, 2, 3, 5, 10
Recurrent Unit State Size 1, 2, 3, 4, 5
Width of Connected Layers 32, 64, 128, 256, 512
Width of Convolutional Filters 3, 4, 5, 6, 7, 8, 9, 10
Width of Sliding Window 3, 4, 5, 6, 7, 8, 9, 10
Table 4: Hyperparameter selection ranges for random search.

For Cox landmarking, we use the implementation in Therneau [2018] for interval-censored data, fitting a sequence of proportional hazards regression models for observation groups. Optimal groupings are determined by exhaustive search in -year increments. Preliminary feature-selection is performed by stepwise regression for the best candidate model, using the implementation of Hu [2018]. For joint modeling, we adopt the common two-stage method in Wu [2009], first fitting linear mixed effects sub-models for significant variables, then fitting landmarking models based on mean estimates from the sub-models. In both cases, consistent with literature, time is defined as years since initial follow-up Li and Luo [2017], Liu et al. [2017], Vos et al. [2015].

For all neural network models, hyperparameter optimization is carried out via 100 iterations of random search(see Table 4

). In addition, activation functions are also searched over for gating units in recurrent network cells. Model selection is performed on the basis of final composite scores—as in Equation

7—for each candidate. We use 5-fold cross validation to evaluate performance, stratified at the patient level—that is, patients are randomly selected into datasets for training (60%), validation (20%), and testing (20%), with the ratio of positive patients (i.e. those for which at least one sliding window contains at least one outcome of failure) to negative patients is kept uniform across folds.

Further analysis

In addition, we utilize output scatters to aid our understanding of the input-output relationship. We use MC dropout to generate average responses by explicitly varying one input feature at a time Gal and Ghahramani [2016] in equation 6. This is done by either (1) varying only the final value of the measurement within the sliding window, or by (2) varying all values of the measurement the same time. While the latter gives insight into how the response changes with respect to a difference in levels, the former sheds light on how the response is affected by a sudden change instead. Interestingly in the case of temporal convolutions, for multiple features (Figure 5 shows the MMSE and CDRSB features as examples) the response is actually stronger to final value changes. This effect is not present in any feature for multilayer perceptrons. In light of the performance of convolutional models, this is consistent with our hypothesis that they appear able to learn more complex temporal patterns than multilayer perceptrons; in this case, these results suggest that a patient who experiences a sudden decline in certain features is expected to fare worse than one who has always scored just as poorly all along.

Figure 5: Comparison of output scatters for changes in all values within sliding window (red) versus changes in the final value only (blue). Convolutional models show stronger responses in the latter.

Sources of gain

While the advantage of using multilayer perceptrons over traditional statistical survival models has been studied (see Lee et al. [2018] for example), here we account for the additional sources of gain from our design choices. Table 5 (a) shows the initial benefit from incorporating longitudinal histories of covariate measurements. Recurrent networks are used as a reasonable starting point; improvements—where positive—are marginal at best. (b) However, experiments on differenced data—capturing relative movements without absolute levels—indicate substantial informativeness (see df-RNN trained on differences vs. joint models trained on levels). (c) Utilizing limited sliding windows instead of recurrent cells shows promising improvements, boosting average AUPRC by 4% over both RNN and MLP models. (d) The addition of temporal convolutions produces incremental AUPRC gains of 9%. (e) Accommodating informative missingness by introducing the dual-stream architecture results in further incremental AUPRC gains of 2%. Compared with the joint modeling baseline, Match-Net achieves average AUPRC improvements of 15% and one-step-ahead improvements of 26%. (f) Furthermore, incorporating the most recent clinician input (Match-Net

) improves AUPRC by an additional 17%. While the distribution of gains skews near-term with our default choices of

, , and (i.e. unweighted by ), alternate distributions may be obtained via appropriate weight functions depending on deployment context. (g) Finally, using Match-Net as an example, we observe individual and cumulative benefits due to miscellaneous design choices applied to all models.

RNN MLP
auroc 0.5 0.949 0.948
1.0 0.930 0.930
1.5 0.891 0.890
2.0 0.901 0.895
2.5 0.883 0.874
auprc 0.5 0.464 0.469
1.0 0.410 0.435
1.5 0.340 0.340
2.0 0.355 0.359
2.5 0.365 0.356
(a) Gain from Covariate History
df-RNN JM
0.913 0.913
0.830 0.917
0.700 0.881
0.712 0.894
0.708 0.883
0.494 0.473
0.386 0.415
0.213 0.319
0.241 0.362
0.217 0.366
(b) Informativeness of Differences
S-MLP RNN
auroc 0.5 0.950 0.949
1.0 0.932 0.930
1.5 0.897 0.891
2.0 0.904 0.901
2.5 0.881 0.883
auprc 0.5 0.500 0.464
1.0 0.447 0.410
1.5 0.354 0.340
2.0 0.364 0.355
2.5 0.371 0.365
(c) Gain from Limited Window
S-TCN S-MLP
0.961 0.950
0.941 0.932
0.902 0.897
0.908 0.904
0.884 0.881
0.580 0.500
0.505 0.447
0.367 0.354
0.380 0.364
0.381 0.371
(d) Gain from Temporal Convolutions
Match-Net S-TCN
auroc 0.5 0.962 0.961
1.0 0.942 0.941
1.5 0.902 0.902
2.0 0.909 0.908
2.5 0.886 0.884
auprc 0.5 0.594 0.580
1.0 0.513 0.505
1.5 0.373 0.367
2.0 0.390 0.380
2.5 0.384 0.381
(e) Gain from Missingness-Awareness
Match-Net Match-Net
0.989 0.962
0.951 0.942
0.897 0.902
0.901 0.909
0.885 0.886
0.862 0.594
0.636 0.513
0.393 0.373
0.380 0.390
0.375 0.384
(f) Gain from Most Recent Diagnosis
AUROC AUPRC
O  L  E 0.5 1.0 1.5 2.0 2.5 0.5 1.0 1.5 2.0 2.5
     0.946 0.923 0.895 0.902 0.888 0.493 0.418 0.348 0.351 0.355
    ✓ 0.956 0.938 0.901 0.906 0.886 0.534 0.477 0.368 0.381 0.384
  ✓   0.944 0.920 0.888 0.896 0.886 0.435 0.413 0.334 0.351 0.357
  ✓  ✓ 0.961 0.941 0.902 0.908 0.887 0.575 0.502 0.370 0.384 0.377
✓     0.946 0.923 0.894 0.902 0.887 0.499 0.425 0.352 0.365 0.365
✓    ✓ 0.956 0.938 0.901 0.907 0.887 0.533 0.473 0.362 0.379 0.385
✓  ✓   0.942 0.918 0.885 0.894 0.881 0.448 0.404 0.332 0.348 0.361
✓  ✓  ✓ 0.962 0.942 0.902 0.909 0.886 0.594 0.513 0.373 0.390 0.384
(g) Gain from oversampling (O), label forwarding (L), and elastic net (E) for Match-Net
Table 5: Source-of-gain accounting. Bold values indicate best performance. Note that AUROC is much less sensitive than AUPRC in the context of highly imbalanced classes like the ADNI data Saito and Rehmsmeier [2015].