Neural ODEs for Multi-State Survival Analysis

06/08/2020 ∙ by Stefan Groha, et al. ∙ Harvard University 1

Survival models are a popular tool for the analysis of time to event data with applications in medicine, engineering, economics and many more. Advances like the Cox proportional hazard model have enabled researchers to better describe hazard rates for the occurrence of single fatal events, but are limited by modeling assumptions, like proportionality of hazard rates and linear effects. Moreover, common phenomena are often better described through multiple states, for example, the progress of a disease might be modeled as healthy, sick and dead instead of healthy and dead, where the competing nature of death and disease has to be taken into account. Also, individual characteristics can vary significantly between observational units, like patients, resulting in idiosyncratic hazard rates and different disease trajectories. These considerations require flexible modeling assumptions. Current standard models, however, are often ill-suited for such an analysis. To overcome these issues, we propose the use of neural ordinary differential equations as a flexible and general method for estimating multi-state survival models by directly solving the Kolmogorov forward equations. To quantify the uncertainty in the resulting individual cause-specific hazard rates, we further introduce a variational latent variable model. We show that our model exhibits state-of-the-art performance on popular survival data sets and demonstrate its efficacy in a multi-state setting.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Time-to-event analysis is of fundamental importance in many fields where there is interest in the time until the occurrence of particular events, while allowing for time-dependent missing outcomes (i.e. “censored” data). Examples include time-to-death analysis in medicine vigano2000survival, failure of mechanical systems in engineering samaniego2007system and financial risk dirick2017time. If the events are additionally fatal, we speak of survival analysis

. For simplicity most survival models only consider the binary case where one is interested in the transition from one non-fatal to a fatal state. The aim of many such models is to relate the arrival of events with observed characteristic information, e.g. model a patient’s survival probability given a set of individual covariates. To this date, the standard tool for survival analysis is the proportional hazards model, introduced in the seminal paper by cox1972regression, which assumes a proportionality between the hazards for different values of the covariates of the model.

A first generalization of standard survival analysis considers multiple competing events, where all possible state transitions are fatal. For example, in the medical setting a patient can have multiple causes of death. For the incidence of these separate events, treating the other events as censored however leads to a bias due to misspecification of the at-risk population fine1999proportional and the competing nature of the events has to be specifically modeled.

In recent years, with growing data availability and the advent of precision medicine, there has been an increasing interest in a more refined modeling approach in clinical applications, taking into account multiple non-fatal states and more complicated relationships between all states rueda2019dynamics,gerstung2017precision,grinfeld2018classification,duffy1997,nicora2020continuous,Longini1989. For example, in the case of acute myeloid leukemia individualized genetic prediction based on a more complicated multi-stage model was used to tailor personalized treatment within first complete remission gerstung2017precision. In general, knowing which transitions and end-points are most likely to occur enhances the clinician’s ability for decision making.

The inclusion of covariates in common approaches for multi-state models usually requires making strong assumptions regarding the stochastic process and the dependence between model parameters and covariates. We propose a general alternative approach, based on modeling the Kolomogorov forward equation of the underlying process using neural ordinary differential equations chen2018neural. The use of neural networks provides considerably more model flexibility in comparison to previous approaches allowing the learning of expressive covariate relationships without placing any restrictive modeling assumption on the states. A state augmentation akin to a memory process further enables us to move beyond the common Markov assumption in the state transition probabilities. The method presented in this paper is to the knowledge of the authors the first neural network approach designed to explicitly handle multi-state survival models without using simplifying assumptions.

2 Background and related work

2.1 Survival analysis

Survival analysis is one of the simplest approaches for the study of time-to-event data. It categorizes the underlying states of interest as a dichotomous pair of a non-fatal and a fatal event, e.g. alive/dead for patients or functioning/failure for mechanical devices. In those models interest lies in the transition from the non-fatal to the fatal (absorbing) state. Let

denote a random variable describing the time of the arrival of the fatal event.

can be flexibly modeled as the first jump of an inhomogenous Poisson process with density function

where denotes the hazard function and is the survival function. In many cases, e.g. for patient data in clinical trials or for observational data, some of the participants will drop out at an earlier stage than the time of conclusion of the study. This gives an ambiguous meaning to the observed time points, , which is a fatal-event if there is no censoring or a drop-out . In the latter case the only information available is that which has probability . This is a case of right-censoring. Assuming independence of the censoring process the likelihood contribution of an individual is

(1)

The most widely used tool to obtain the influence of covariates on the survival function is the Cox proportional hazards model. This method is a semi-parametric method for the hazard function , which is modeled as , where are coefficients for the covariates and is a baseline hazard directly estimated from the data. Both the linear nature of the model as well as the proportional hazards assumption are often violated in practice.

Many extensions of the Cox proportional hazards model have been proposed, aiming to relax one or both of those assumptions. This includes models using the Cox model structure, but extending it to non-linear features or non-proportional hazards, e.g. by modeling with being a deep neural network, or

(continuous time models) katzman2018deepsurv,kvamme2019time; approaches using MLPs lee2018deephit or recurrent neural networks giunchiglia2018rnn,ren2019deep for every time step (discrete time models); Gaussian Process models alaa2017deep,fernandez2016gaussian or generative adversarial networks (GANs) chapfuwa2018adversarial.

2.2 A progressive three-state survival model

The aim of multi-state models is a more granular analysis of time-to-event phenomena, where common binary outcomes (i.e. health/death) can not adequately describe real observations. A simple extension of a traditional binary survival model observing the time to a fatal event is the addition of an intermediate state, illness, which could denote the appearance of symptoms or, more generally, some non-fatal disease progression. Such models and their state-space can be described by a directed graph as shown in (b). For processes that are assumed to evolve continuously over time where observational units, like patients, move between states this suggests models using continuous time, finite state space Markov processes. The Markov process is completely characterized by the state transition probabilities for all tuples of states and all tuples of time points

where denotes the state of an individual at time .

1
Health

2
Cause A

3
Cause B

(a) Competing hazards model.

1
Health

2
Illness

3
Death

(b) General multi-state model.
Figure 1: Example graphs corresponding to the 2-state competing hazards model (a) and the illness-death model (b), a popular multi-state model. Competing hazard models are a special case of multi-state models that only have one non-absorbing state, whereas in general multi-state models can have arbitrary, even cyclical connections.

Describing the transition probabilities, and hence the likelihood, with a model that allows for flexible use of covariates, while allowing non-homogeneous state evolution is challenging. The standard tool is a Markov multi-state model, where a Cox proportional hazards model is applied to each transition separately. The transition probabilities are then given by assuming a Markov model for the transition through states mstate. This has the obvious disadvantages of the Cox proportional hazard model at each transition and additionally a Markov assumption for each state, together with the assumption that event times for different events are independent of each other, which is rarely given in practice. In the case of partially observed progressive models there have been extensions to hidden Markov models msm and attentive state space models alaa2018forecasting, with hidden, unsupervised state evolution. While these models are extremely interesting in their own right, they aim to solve a somewhat different problem. However, an application of our approach to model the hidden state dynamics for continuous time processes would be an interesting avenue for further research.

The conceptually more appealing approach of solving the Kolmogorov forward equation was introduced in titman2011kfespline. However, the proposed B-spline basis for the hazard function does not generalize well to inclusion of covariates, as a separate Kolmogorov forward equation has to be fit for every realization of the covariates. Recently, generalizations to the special case of competing hazard models using Gaussian Processes alaa2017deep and deep neural networks lee2018deephit were proposed, however we are not aware of any literature considering an extension of such flexible methods to the setting of general multi-state models.

3 Multi-state survival models

Mathematically, we define a continuous time stochastic process taking values in a finite state space over known time horizon . Such processes are often called (Markov) jump processes. In the following we will describe the likelihood function and its relation to the Kolmogorov forward equations kolmogoroff1931analytischen, feller1949.

3.1 Multi-state likelihood and Kolmogorov equations

For each individual we will observe the process in the form of discrete jumps over the relevant time interval .

In this setting, a single observations consists of a set of time-indexed states . The likelihood is given by

where denotes all free model parameters and the probability to be in the initial state. To avoid confusion, we define . The full likelihood for a set of observations is thus given by

with . Under the Markov assumption the evolution of the transition probabilities in the likelihood is governed by the Kolmogorov forward equation

(2)

where the Markov property is evident by the fact that the instantaneous transition rates are only dependent on the time .

3.2 Right-censoring

As alluded to earlier, censoring is common in survival models and requires an adjustment of the likelihood function. Assuming independence of the censoring process, we observe , where are individual covariates or regressors, is the number of transitions the individual is going through and are as above or the state at time of last contact (censoring time). Censoring is indicated by whereas we write if the event is observed. The corresponding likelihood can then be written as

Remark (Left- and interval-censoring).

Left-censoring, where the participants are unobserved at the beginning or interval censoring which is a combination of left- and right-censoring can be dealt with similarly, see e.g. van2016multi therefore enabling the modeling of partially observed Markov processes.

We note that the above likelihood already includes possible left-truncation, where a patient is added at a later time, but known to be in a certain state up until this point, for example to control for immortal time bias.

4 survNode: neural ODEs for multi-state modelling

We model the instantaneous transition rate matrix with a neural network. Conservation of probability means the elements of the transition rate matrix need to fulfill

This restriction can be implemented by modelling through the neural network and set . As we need the transition rates to be larger than , we use a softplus activation on the last layer of the network.

To incorporate the covariates we use the following approach. Instead of only modeling the Markovian transition rate , we incorporate the history of the evolution and the covariate state of individual as . For this we introduce auxiliary memory states , governed by the differential equation

The initial conditions are encoded by the covariates of the patient , where is given by a neural net. We can then obtain the system of coupled ODEs

where the second line is the Kolmogorov backward equation. Using that and therefore , we obtain at any and .

We model both and with one neural network , where the first (number of non-zero off-diagonal elements of ) outputs of the last layer are passed through a softplus non-linearity. The neural network can also explicitly depend on the covariates , however we did not find this to improve the results. This generalizes the approach in chen2018neural and shares some conceptual ideas with jia2019neural. Another interpretation of the memory states is the augmentation of the neural ODE with additional states as seen in dupont2019augmented. The algorithm is shown in Algorithm 1.

Covariates , time interval .
, Get initial values.
def KFE_KBE(): Kolmogorov forward and backward equation.
      , from NN with softplus for .
      Enforce constraints.
      Calculate gradient for Kolmogorov forward equation.
      Calculate gradient for Kolmogorov backward equation.
      Calculate gradient for augmented evolution.
     return return derivatives
Get the instantaneous transition rate.
Use the composability to get
,
Algorithm 1 Obtain and in survNode

Following massaroli2020dissecting, we furthermore add an loss term for the time evolved memory states at the maximum time of the training batch, which can be seen as some modification of minimizing a Lyapunov exponent such that comparable initial values produce comparable survival.

With this model we also have direct access to the hazard rate over time. By predicting the hazard rates for the possible realizations of e.g. a binary feature over time and taking the ratio, we can therefore get a personalized predictive score for the influence of that feature on the transition rates between states, which is important for example for predicting treatments or identifying biomarkers in a clinical setting.

Due to the encoding of the covariates into the initial values of the memory states this model can naturally extend to include features based on longitudinal data, text data or imaging data by encoding the initial values with recurrent neural network layers, natural language processing layers or convolutional layers and training those at the same time. The model is implemented in PyTorch pytorch using the torchdiffeq chen2018neural package.

5 Variational survNode: modeling uncertainty

To obtain a quantification of model uncertainty, we further extend the model to a variational setting by introducing latent variables. Instead of maximum likelihood estimation, the objective will be the variational free energy or evidence lower bound elbo. The variational model assumes the existence of a latent state , which replaces the role of the memory state above, such that does not depend on the covariates given . The objective is then

where we model the variational distribution and the prior as

with neural networks for , , , and , encoding the covariates into the latent space.

For prediction, we obtain realizations of the transition matrix by repeated sampling from the prior and taking the mean as well as the credible interval

with the ODEsolve term specified in Algorithm 1.

6 Experiments

6.1 Survival: benchmark of model

To benchmark our proposed model against various survival frameworks, we examine the performance of survNode on the metabric breast cancer data set curtis2012genomic, pereira2016somatic, as well as the support data set knaus1995support. We present the benchmark concordance (c)antolini2005time, integrated Brier score (ibs)brier1951verification (which is an estimation of the deviation from true probability) as well as the integrated binomial log-likelihood estimator (ibll) with five-fold cross validation in Table 1.

Model metabric metabric metabric support support support
c ibs ibll c ibs ibll
Cox-PHcox1972regression 0.628 0.183 -0.538 0.598 0.217 -0.623
DeepSurv katzman2018deepsurv 0.636 0.176 -0.532 0.611 0.214 -0.619
Cox-Time kvamme2019time 0.662 0.172 -0.515 0.629 0.212 -0.613
DeepHit lee2018deephit 0.675 0.184 -0.539 0.642 0.223 -0.637
RSF ishwaran2008random 0.649 0.175 -0.515 0.634 0.212 -0.610
survNode 0.667 0.157 -0.477 0.622 0.198 -0.580
Table 1: Benchmark of survNode

. The results for the other models are taken from kvamme2019time. Note that the results of competing methods are more hyperparameter optimized.

While the concordance solely evaluates a method’s discriminative performance, the Brier score and binomial log-likelihood also evaluate the calibration of the survival estimates. In the example of a diagnostic test in a clinical setting a low integrated Brier score corresponds to a better predictive value of the diagnosis, meaning the probabilities of a positive or negative diagnosis are closer to the real underlying probabilities. A higher concordance, on the other hand, will give a better classification into positive diagnosis or negative diagnosis graf1999ibsvsconc. As can be seen in Table 1, our method outperforms all competitors in terms of calibration (ibs, ibll) while attaining state-of-the-art discriminative performance as measured by the concordance index. An illustration of this is shown in the supplemental material. Although concordance is a common figure of merit, for clinical applications of predictive models for precision medicine, it is arguably more important to have a well calibrated probability for the event to provide the clinician with unbiased decision support graf1999ibsvsconc,hand1997construction,hilden1978probs.

6.2 Multi-state survival: competing hazards and the illness death model

To show the efficacy of our model in the multi-state setting we first turn to the simplest extension of survival models, the competing hazards model ((a)). In this setting all possible states the model can transition to are absorbing, and hence there are no intermediate states. In this specific multi-state case we can benchmark our model against the standard tools for competing hazards analysis: The cause-specific Cox models, where a Cox proportional hazards model is fit for each transition taking all other transitions as censored; the Fine–Gray model fine1999proportional; as well as DeepHit lee2018deephit and DeepHit with an additional loss term to specifically improve concordance, at the cost of worse calibration kvamme2019time. We benchmark on the synthetic data set provided by lee2018deephit with two possible competing outcomes. Our results for cause-specific concordance are compared with those from lee2018deephit in Table 2.

Model concordance cause 1 concordance cause 2
Cause-specific Cox model cox1972regression
Fine-Gray katzman2018deepsurv
DeepHit lee2018deephit
DeepHit (with ranking loss) lee2018deephit
survNode (this paper)
Table 2: Benchmark in the competing hazards case with five-fold cross validation.

We note that we use an ad-hoc hyperparameter setting and do not perform hyperparameter tuning using the validation set, in contrast to the analyses in lee2018deephit. Nevertheless, our model outperforms the conventional Cox models and is competitive with DeepHit without the ranking loss, thus demonstrating the robustness of our method.

To visualize the advantage of our model over the standard tools in the multi-state setting, we seek to compare with a non-parametric estimator for an illness-death model ((b)). In the multi-state setting a population mean for the probabilities in each state can be obtained with the Aalen–Johansen estimator aalen1978empirical. We simulate a data set with proportional hazards violation using the coxed R package harden_kropko_2019(see supplementary material). We compare our model with the standard tool in the multi-state survival literature, which is fitting a Cox proportional hazard model to each transition, treating the other events as censored. Importantly, this assumes independence between the occurring events, which while true for the simulated data set is often not the case in real world scenarios. We use the R package mstate mstate to get the state probabilities at each time. The comparison can be seen in (a) and we see a clear advantage of our model over the cause-specific Cox model.

(a) survNode
(b) Variational survNode
Figure 2: Plot of probabilities for being in the different states of the illness-death model, indicated by the different colors. Blue corresponds to “Health”, red to “Illness” and green to “Death”. We compare to the cause-specific multi-state Cox model in addition to the non-parametric Aalen–Johansen estimator of the test data.

6.3 Latent multi state survival model

Benchmark

As a first step of analysis of the latent multi-state survival model, we benchmark the model against other models for the special case of survival analysis on the metabric data set. Without any hyper-parameter tuning and ad-hoc parameter choice, we obtain a concordance of , integrated Brier score of and integrated binomial log-likelihood of . As such the variational survNode is better calibrated than all other available models with competitive discrimination performance (not including plain survNode

). We can visualize the prediction and confidence interval by again comparing to the Aalen–Johansen estimator in the simulated illness-death model in

(b). For this we have trained the variational survNode on a training set with early stopping on a validation set and compare the prediction for the possible covariates with the Aalen–Johansen estimators obtained on a test set.

Calibration of the credible intervals

While our model captures the non-parametric estimator by visual inspection, we seek to quantify the calibration performance. Again using the R package coxed, we simulate a survival data set with three covariates. From the coxed package we also extract the underlying individual survival probabilities. To estimate calibration of the error intervals, we therefore calculate the average of fraction of times the true survival probabilities we sample from lie within the 95% credible interval. We compare the calibration of our model to the prediction from a Cox proportional hazards model using the R survival

package survival-package, which implements the calculation of standard errors. For one random realization of the simulated data we perform a five fold cross validation in

Table 3.

Model calibration concordance
Cox proportional hazards model cox1972regression
survNode (this paper)
Table 3: Comparison of survNodeand the Cox proportional hazard model in terms of calibration and concordance.

We find our model to be more conservative for the error intervals and to be better calibrated.

Clustering of the latent space

Figure 3: The latent space of the variational survNode model exhibits unsupervised clustering into meaningful clusters. The subset of patients in the clusters on the left are used in a non-parametric Aalen–Johansen estimator to show clear significant differences in the probabilities of the different states depending on the cluster. State 1 is hereby the initial "Health" state, 2 corresponds to "Illness" and 3 to the absorbing "Death" state.

On top of the estimation of credible intervals, an additional useful feature of the latent variable model can be found by inspection of the latent space of the model. We again simulate an illness-death model data-set with coxed

, using three covariates, which we bin into three quantiles each. We again run the variational

survNode model with early stopping using a validation set and then inspect the latent space for the validation data. Using UMAP mcinnes2018umap-software we identify five clusters (Figure 3). We examine the probabilities to be in each of the "Health", "Illness" and "Death" states for each cluster in the validation data set using the non-parametric Aalen–Johansen estimator. As can be seen in Figure 3

, the clusters are a meaningful unsupervised differentiation between patients and capture survival differences as well as differences in transitioning to the "Illness" state well. We can additionally obtain covariate effects associated with each cluster by using logistic regression. This feature has useful applications in a clinical setting, where identification of extreme survivors to a treatment while modeling other state transitions is of particular interest.

Our approach is directly applicable to survival analysis, where methods for example based on LDA Chapfuwa_2020 were recently proposed to cluster the latent space, but generalizes those to the multi-state setting.

7 Conclusion

We have introduced a general and flexible method for multi-state survival analysis based on neural ODEs and shown state-of-the art performance in the special cases of survival and competing risk analysis. We can get an estimation of credible intervals for the model as well as an interpretability aspect by introducing latent states.

Broader Impact

We were motivated by the explicit desire to solve a problem in the field of cancer immunotherapy, where treatment outcomes are highly heterogeneous and the proportional hazard assumption is often violated checkmate. We are interested in multi-state modeling to predict the onset of immune related adverse events in competition with mortality, survival after developing adverse events, and decision support for further re-challenge with immunotherapy. It is our hope that the prediction aspect of the model presented here can be used to support clinical decision making by anticipating potential interventions prior to the onset of harm. Reliable modeling of uncertainty in the prediction is thus crucial for this practical application. We further plan to use the latent space clustering, in conjunction with biological covariates, for sub-phenotyping, and biomarker discovery for adverse events. We believe our model is broadly applicable in clinical settings with dynamic events and latent sub-populations, and hope to have a meaningful impact on AI-based precision medicine.

Clinical decision support also comes with risk, as model failure can be hazardous for the patient either by ending a successful treatment needlessly early, or by continuing a treatment when it has failed and leads to harmful side effects. Actions based on model predictions must thus consider uncertainty in the prediction itself as well as the broader harm/benefit trade-offs. In the specific case of immunotherapy-related adverse events, interventions such as increased patient monitoring/contact may pose limited risk for harm. However, much uncertainty remains over the ethical use of survival prediction, as immunotherapy offers the potential of permanent elimination of disease (in contrast to many treatments of advanced cancers), patients may be more inclined to continue treatment against the guidance of a predictive model, in the hopes of a cure.

As with any statistical method, the output relies on the quality of the underlying data, e.g. electronic health records (EHR). Due to systemic biases there is often a difference in the availability and accuracy of EHR between different communities, and special care has to be taken that computational models are not themselves biased by data quality, and thus further exacerbating existing disparities. The algorithm may exploit biases in the data leading to decisions that would be considered discriminatory behavior by a human counterpart. Identifying such biases necessitates model interpretability through clustering and feature importance kovalev2020survlime, as well as assessment in diverse populations.

S.G. and A.G. are supported by the National Cancer Institute and the Louis B. Mayer Foundation. S.M.S. is partially supported by the Engineering and Physical Sciences Research Council (EPSRC) grant EP/K503113/1. This material is based upon work supported by Google Cloud.

Supplementary Material: Neural ODEs for Multi-State Survival Analysis

Appendix A Implementation details

All models are implemented in PyTorch v1.4 pytorch using the torchdiffeq

package chen2018neural. As our example networks are sufficiently small, we use backpropagation through the ODE solver to obtain gradients, however, using the adjoint method is possible. We use the

dopri5 method for the ODE solver with an absolute and relative tolerance of in the ODE solver. To include the accuracy of the solution as a hyperparameter, we scale the event times to have the maximum value , which we choose to be of . To specify the non-zero elements of the transition rate matrix, a matrix with 1 indicators for non-zero off-diagonal elements and NaN indicators for all other elements are needed.

For training the model minimizing the negative log likelihood, the hyperparameters are:

  • [noitemsep]

  • Number of layers

    and number of neurons per layer

    with dropoutsrivastava2014dropout

    for multilayer perceptron encoding the covariates into memory states;

  • Number of layers and number of neurons per layer for multilayer perceptron modeling ;

  • Number of memory states ;

  • Coefficient of Lyapunov style loss term ;

  • Scaling coefficient for event times ;

  • Learning rate of the Adam optimizer kingma2014adam;

  • Weight decay .

For the variational approach minimizing the ELBO, we have the hyperparameters:

  • [noitemsep]

  • Number of layers and number of neurons per layer with dropout for multilayer perceptron for prior ;

  • Number of layers and number of neurons per layer with dropout for multilayer perceptron for variational postierior ;

  • Number of layers and number of neurons per layer for multilayer perceptron modeling ;

  • Number of latent states ;

  • Coefficient of Lyapunov style loss term ;

  • ELBO parameter

  • Scaling coefficient for event times ;

  • Learning rate of the Adam optimizer;

  • Weight decay ,

where the ELBO parameter

characterizes the relative weight between log likelihood and Kullback-Leibler divergence, which we set to be

throughout the paper. Closer investigation of the clustering property with respect to this parameter would be of interest.

Due to the implementation of the algorithm, solving the Kolmogorov backward equation at the same time as the Kolmogorov forward equation, we only need to solve the ODE once per batch from to to get for all patients in the batch, as we can store the solution at the union of all and and then compute

We note that due to the implementation of the differential equation solver in native python code, the computational speed of the method could be increased. As shown in benchmarks rackauckas, the julia programming language bezanson2017julia and the DiffEqFlux diffeqflux package might offer a great alternative and we are currently working on a julia implementation.

Appendix B Experiments

b.1 Simulation of data

The simulated data in the publication is generated with the R package coxed. In the survival cases, we choose three covariates, where one of the covariates has time varying coefficients to model a proportional hazards violation. We choose all coefficients to be of , with a saw-tooth time dependence for the time dependent covariate. We sample patients for the training set and patients for the validation and test set respectively with event times between and .

In the case of the illness death model, we sample using the coxed package for every transition, assuming independence of each transition. We extract the covariates from the first sampled model and use them for the other two survival realizations, however choosing different coefficients. Due to a limitation of the coxed package, only the first sampled model can have time varying coefficients, with the other transitions then effectively being sampled from a Cox-model. In the competing case between "Illness" and "Death" from the "Health" state, we choose the first ocurring time of the two sampled survival data realizations, no matter if there is censoring or not. The maximum time for the generated data in the competing case is , whereas we choose for the transition from "Illness" to "Death".

b.2 Data sets and hyperparameters

The metabric and support data sets are standard survival data sets for benchmarking. The characteristics are shown in Table 4 kvamme2019time and are obtained from the pycox python package kvamme2019time.

Data set Size Covariates Unique Durations Prop. Censored
support 8873 14 1714 0.32
metabric 1904 9 1686 0.42
Table 4: Characteristics of the metabric and support data sets.

The synthetic data set in the competing hazards case is taken from lee2018deephit and available on github with patients and two outcomes, where of patients experience any event, whereas the other are censored.

For all benchmark experiments we do a five-fold cross validation where we split the data in an split into test-data and the remaining data again in an split into training data and validation data.

The hyperparameter space used in the benchmarks on metabric and support are

  • [noitemsep]

  • with and ;

  • with

  • ;

  • ;

  • ;

  • ;

  • .

We use random sampling from the hyperparameter space to get realizations of the hyperparameters. The batch size is taken to be either or the length of the data set, whichever is smaller.

For the competing hazards experiment we use the hyperparameters

  • [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]

  • with and ;

  • with

  • ;

  • ;

  • ;

  • ;

  • .

For the comparison with the non-parametric Aalen Johansen estimator the hyperparameters used for the model minimizing the negative log likelihood were

  • [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]

  • with and ;

  • with

  • ;

  • ;

  • ;

  • ;

  • .

In the case of the latent model minimizing the ELBO we used

  • [noitemsep]

  • with and ;

  • with and ;

  • with

  • ;

  • ;

  • ;

  • ;

  • ;

  • ,

and finally for clustering the latent space the hyperparameter setting we use is

  • [noitemsep]

  • with and ;

  • with and ;

  • with

  • ;

  • ;

  • ;

  • ;

  • ;

  • .

Appendix C Visualisation of calibration in the survival setting

We can examine the calibration of the model in the simple case of one binary covariate. In this case we can use the population level non-parametric Kaplan–Meier estimator kaplan1958nonparametric to obtain the survival function . We use the R package coxed

harden_kropko_2019 to simulate survival data with proportional hazards violation and one binary variable

var. We split the data set into training, validation and test set and obtain the Kaplan–Meier estimator for both variable and on the test data. The survival model is trained on the training data with early stopping using the validation data and predicted for and . This prediction is compared to the Kaplan–Meier estimator on the test data. We compare our model (survNode

) with a Cox proportional hazards model, a fully parametric accelerated failure time model based on the Weibull distribution collett2003modelling, as well as DeepHit lee2018deephit and Cox-Time kvamme2019time, a discrete and continuous time machine learning model, respectively. The visual comparison can be seen in

Figure 4.

Figure 4: Comparison of different tools to fit survival distributions in the population level case of one binary covariate. On the left our model is compared with the standard tools of survival analysis and the non-parametric Kaplan–Meier estimator of the test set. On the right we compare our model on the same data with two state of the art machine learning approaches, again with the Kaplan–Meier as a non-parametric estimator.

We see that due to the proportional hazard violation, the Cox model as well as the model based on the parametric Weibull distribution do not capture the survival function well, whereas the SurvNODE model does. Comparing to the other machine learning based frameworks, we see that DeepHit does not reproduce the survival function well.

Appendix D Clustering: Covariates and survival strata

To further examine the clustering of the latent space, we can superimpose the nine binary covariates in the model on the UMAP projection. This can be seen in Figure 5. We see that some of the clusters clearly reflect the covariates, for example in the case of covariate one, which is the lowest third of the covariate with the largest effect size for one of the transitions in the simulation, we see that almost all the values are in one of the clusters. By characterizing the effect of the covariates on these clusters with specific survival properties, we can obtain the influence of the covariate on survival.

Figure 5: Possible values of the nine binary covariates in the model. We see that some of the clusters clearly reflect the covariates.
Figure 6: Predicted survival function vs real underlying survival function from the simulation. We see that the credible intervals cover the underlying survival function well.

Appendix E Calibration of the credible intervals

A visual way to show the calibration of the credible intervals is to predict individual survival over time and plot together with the true underlying survival function obtained from the coxed R package. This can be seen in Figure 6. We see that the credible intervals contain the survival function in most of the cases.