1 Introduction
The demand for methods evaluating the effect of treatments, policies and interventions on individuals is rising as interest moves from estimating population effects to understanding effect heterogeneity in fields ranging from economics to medicine. Motivated by this, the literature proposing machine learning (ML) methods for estimating the effects of treatments on continuous (or binary) endpoints has grown rapidly, most prominently using treebased methods hill2011bayesian; athey2016recursive; wager2018estimation; athey2019generalized; hahn2017bayesian, Gaussian processes alaa2017bayesian; alaa2018limits
, and, in particular, neural networks (NNs)
johansson2016learning; Shalit:16; johansson2018learning; shi2019adapting; hassanpour2019counterfactual; hassanpour2020learning; assaad2021counterfactual; curth2020. In comparison, the ML literature on heterogeneous treatment effect (HTE) estimation with timetoevent outcomes is rather sparse. This is despite the immense practical relevance of this problem – e.g. many clinical studies consider timetoevent outcomes; this could be the time to onset or progression of disease, the time to occurrence of an adverse event such as a stroke or heart attack, or the time until death of a patient.In part, the scarcity of HTE methods may be due to timetoevent outcomes being inherently more challenging to model, which is attributable to two factors tutz2016modeling: (i) timetoevent outcomes differ from standard regression targets as the main objects of interest are usually not only expected survival times but the dynamics of the underlying stochastic process, captured by hazard and survival functions, and (ii) the presence of censoring. This has led to the development of a rich literature on survival analysis particularly in (bio)statistics, see e.g. tutz2016modeling; klein2003survival
. Classically, the effects of treatments in clinical studies with timetoevent outcomes are assessed by examining the coefficient of a treatment indicator in a (semi)parametric model, e.g. Cox proportional hazards model
Cox:72, which relies on the often unrealistic assumption that models are correctly specified. Instead, we therefore adopt the nonparametric viewpoint of van der Laan and colleagues van2011targeted; stitelman2010collaborative; stitelman2011targeted; cai2019targeted who have developed tools to incorporate ML methods into the estimation of treatmentspecific population average parameters. Nonparametrically investigating treatment effect heterogeneity, however, has been studied in much less detail in the survival context. While a number of treebased methods have been proposed recently tabib2020non; henderson2020individualized; zhang2017mining; cui2020estimating, NNbased methods lack extensions to the timetoevent setting despite their successful adoption for estimating the effects of treatments on other outcomes – the only exception being chapfuwa2021enabling, who directly model event times under different treatments with generative models.Instead of modeling event times directly like in chapfuwa2021enabling, we consider adapting machine learning methods, with special focus on NNs, for estimation of (discretetime) treatmentspecific hazard functions. We do so because many target parameters of interest in studies with timetoevent outcomes are functions of the underlying temporal dynamics; that is, hazard functions can be used to directly compute (differences in) survival functions, (restricted) mean survival time, and hazard ratios. We begin by exploring and characterising the unique features of the survival treatment effect problem within the context of empirical risk minimization (ERM); to the best of our knowledge, such an investigation is lacking in previous work. In particular, we show that learning treatmentspecific hazard functions is a challenging problem due to the potential presence of multiple sources of covariate shift: (i) nonrandomized treatment assignment (confounding), (ii) informative censoring and (iii) a form of shift we term eventinduced covariate shift, all of which can impact the quality of hazard function estimates. We then theoretically analyze the effects of said shifts on ERM, and use our insights to propose a new NNbased model for treatment effect estimation in the survival context.
Contributions (i) We identify and formalize key challenges of heterogeneous treatment effect estimation in timetoevent data within the framework of ERM. In particular, as discussed above, we show that when estimating treatmentspecific hazard functions, multiple sources of covariate shift arise. (ii) We theoretically analyse their effects by adapting recent generalization bounds from domain adaptation and treatment effect estimation to our setting and discuss implications for model design. This analysis provides new insights that are of independent interest also in the context of hazard estimation in the absence of treatments. (iii) Based on these insights, we propose a new model (SurvITE) relying on balanced representations that allows for estimation of treatmentspecific target parameters (hazard and survival functions) in the survival context, as well as a sister model (SurvIHE), which can be used for individualized hazard estimation in standard survival settings (without treatments). We investigate performance across a range of experimental settings and empirically confirm that SurvITE outperforms a range of natural baselines by addressing covariate shifts from various sources.
2 Problem Definition
In this section, we discuss the problem setup of heterogeneous treatment effect estimation from timetoevent data, our target parameters and assumptions. In Appendix A, we present a selfcontained introduction to and comparison with heterogeneous treatment effect estimation with standard (binary/continuous) outcomes.
Problem setup. Assume we observe a timetoevent dataset comprising realizations of the tuple for patients. Here, and
are random variables for a covariate vector describing patient characteristics and an indicator whether a binary treatment was administered at baseline, respectively. Let
and denote random variables for the timetoevent and the timetocensoring; here, events are usually adverse, e.g. progression/onset of disease or even death, and censoring indicates loss of followup for a patient. Then, the observed timetoevent outcomes of each patient are described by and , which indicate the time elapsed until either an event or censoring occurs and whether the event was observed or not, respectively. Throughout, we treat survival time as discrete^{1}^{1}1Where necessary, discretization can be performed by transforming continuousvalued times into a set of contiguous time intervals, i.e., implies where implies the temporal resolution. and the time horizon as finite with predefined maximum , so that the set of possible survival times is .We transform the short data structure outlined above to a socalled long data structure which can be used to directly estimate conditional hazard functions using standard machine learning methods stitelman2010collaborative. We define two counting processes and which track events and censoring, i.e. and for ; both are zero until either an event or censoring occurs. By convention, we let . Further, let be the indicator for an event occuring exactly at time ; thus, for an individual with and , for all , and at the event time . The conditional hazard is the probability that an event occurs at time given that it does not occur before time , hence it can be defined as cai2019targeted
(1) 
It is easy to see from (1) that given data in long format, can be estimated for any by solving a standard classification problem with as target variable, considering only the samples atrisk at time in each treatment arm (individuals for which neither event nor censoring has occurred until that time point; i.e. the set ). Finally, given the hazard, the associated survival function can then be computed as . The censoring hazard and survival function can be defined analogously.
Target parameters. While the main interest in the standard treatment effect estimation setup with continuous outcomes usually lies in estimating only the (difference between) conditional outcome means under different treatments, there is a broader range of target parameters of interest in the timetoevent context, including both treatmentspecific target functions and contrasts that represent some form of heterogeneous treatment effect (HTE). We define the treatmentspecific (conditional) hazard and survival functions as
(2) 
Here, denotes Pearl:09’s dooperator which indicates an intervention in which every individual is assigned treatment ; below we discuss assumptions that are necessary to identify such interventional quantities from observational datasets in the presence of censoring.
Given and , possible HTEs of interest^{2}^{2}2Note: All parameters of interest to us are heterogeneous (also sometimes referred to as individualized), i.e. a function of the covariates , while the majority of existing literature in (bio)statistics considers population average parameters that are functions of quantities such as , which average over all . include the difference in treatmentspecific survival times at time , i.e.
, the difference in restricted mean survival time (RMST) up to time , i.e. , and hazard ratios. In the following, we will focus on estimation of the treatment specific hazard functions as this can be used to compute survival functions and causal contrasts.
Assumptions. (1. Identification) To identify interventional quantities from observational data, it is necessary to make a number of untestable assumptions on the underlying datagenerating process (DGP) Pearl:09 – this generally limits the ability to make causal claims to settings where sufficient domain knowledge is available. Here, as stitelman2010collaborative; stitelman2011targeted, we assume the data was generated from the fairly general directed acyclic graph (DAG) presented in Fig. 1. As there are no arrows originating in hidden nodes entering treatment or censoring nodes, this graph formalizes (1.a) The ‘No Hidden Confounders’ Assumption in static treatment effect estimation and (1.b) The ‘Censoring At Random’ Assumption in survival analysis stitelman2010collaborative. The latter is necessary here as estimating an effect of treatment on event time implicitly requires that censoring can be ‘switched off’ – i.e. intervened on. This graph implicitly also formalizes (1.c) The Consistency Assumption, i.e. that observed outcomes are ‘potential’ outcomes under the observed intervention, as each node in a DAG is defined as a function of its ancestors and exogenous noise stitelman2010collaborative. Under these assumptions, .
(2. Estimation) To enable nonparametric estimation of for some , we additionally need to assume that the interventions of interest are observed with nonzero probability; within different literatures these assumptions are known under the label of ‘overlap’ or ‘positivity’ Shalit:16; van2011targeted. In particular, for we need that (2.a) , i.e. treatment assignment is nondeterministic, and that (2.b) for all , i.e. no individual will be deterministically censored before . Finally, because is a probability defined conditional on survival up to time , we need to assume that (2.c) for it to be welldefined. We formally state and discuss all assumptions in more detail in Appendix C.
3 Challenges in Learning TreatmentSpecific Hazard Functions using ERM
Preliminaries: ERM under Covariate Shift.
Recall that in problems with covariate shift, the training distribution used for ERM and target distribution are mismatched: One assumes that the marginals do not match, i.e. , while the conditionals remain the same, i.e. kouw2018introduction. If the hypothesis class used in ERM does not contain the truth (or in the presence of heavy regularization), this can lead to suboptimal hypothesis choice as in general.
3.1 Sources of Covariate Shift in Learning TreatmentSpecific Hazard Functions
We now consider how to learn a treatmentspecific hazard function from observational data using ERM. As detailed in Section 2, we exploit the long data format by realizing that can be estimated by solving a standard classification problem with as dependent variable and as covariates, using only the samples at risk with treatment status , i.e. , which corresponds to solving the empirical analogue of the problem
(3) 
where we use to refer to the observational (atrisk) distribution with
. If the loss function
is chosen to be the logloss, this corresponds to optimizing the likelihood of the hazard.The observational (atrisk) covariate distribution , however, is not our target distribution: instead, to obtain reliable treatment effect estimates for the whole population, we wish to optimize the fit over the population at baseline, i.e. the marginal distribution which we will refer to as below to emphasize it being the baseline atrisk distribution.^{3}^{3}3With slight abuse of notation, we will use and also to refer to densities of continuous . Here, differences between and the population atrisk can arise due to three distinct sources of covariate shift:

[leftmargin=5.5mm]

(Shift 1) Confounding/treatment selection bias: if treatment is not assigned completely at random, then and the distribution of characteristics across the treatment arms differs already at baseline, thus in general.

(Shift 2) Censoring bias: regardless of the presence of confounding, if the censoring hazard is not independent of covariates, i.e. , then the population atrisk changes over time such that in general. If, in addition, there are differences between the treatmentspecific censoring hazards, then the atrisk distribution will also differ across treatment arms at any given timepoint, i.e. for in general.

(Shift 3) Eventinduced shifts: Counterintuitively, even in the absence of both confounding and censoring, there will be covariate shift in the atrisk population if the eventhazard depends on covariates, i.e. if then in general. Further, if there are heterogenous treatment effects, then for in general.
What makes the survival treatment effect estimation problem unique? While Shift 1 arises also in the standard treatment effect estimation setting, Shift 2 and Shift 3 arise uniquely due to the nature of timetoevent data. Thus, estimating treatment effects from timetoevent data is inherently more involved than estimating treatment effects in the standard static setup, as covariate shift at time horizon can arise even in a randomized control trial (RCT). Thus, in addition to the overall atrisk population changing over time, both treatment effect heterogeneity and treatmentdependent censoring can lead to differences in the composition of the population atrisk in each treatment arm. Further, Shifts 1, 2 and 3 can also interact to create more extreme shifts; e.g. if treatment selection is based on the same covariates as the event process (i.e. in Fig. 1) then eventinduced shift can amplify the selection effect over time (refer to Appendix E for a synthetic example of this).
Interestingly, changes of the atrisk population over time arise also in standard survival problems (without treatments); yet in the context of prediction these do not matter: as the atrisk population at any timestep is also the population that will be encountered at testtime, this shift in population over time is not problematic, unless it is caused by censoring. If, however, our goal is estimation of a target parameter over the whole population, this corresponds to a setting where the ideal evaluation is performed on a ‘counterfactual’ population (i.e. the population resulting if all individuals had survived until time ) which is never encountered in test sets – and hence requires careful consideration of the consequences of the covariate shifts discussed above. To see why fixing one target population is necessary, note that when the goal is estimation of the difference in survival curves, i.e. , this requires estimation of hazard functions; if each of them was optimized for a different target population, this would make the final survival curves and their differences hard to interpret.
Finally, we note that Shifts 2 and (particularly) 3 seemingly appear only because we chose to represent the data in long format. However, as we show in Appendix B, many MLbased discretetime models targeting hazard or survival function directly implicitly rely on the long dataformat (or similar transformations), making these shifts problematic for them too. Thus, representation in long format and the use of the classification approach only helps to make these shifts explicit. Survival models which model (log) time as a regression target do not suffer from Shift 3; however, as we show in the experiments using the model of chapfuwa2021enabling, their performance on estimating survival functions can be poor.
3.2 Possible Remedies and Theoretical Analysis
A natural solution to tackle bias in ERM caused by covariate shift is to use importance weighting shimodaira2000improving; i.e. to reweight the empirical risk by the density ratio of target and observed distribution . In our context, for any , optimal importance weights are given by with , the propensity score, and the probability to be at risk, i.e. the probability that neither event nor censoring occurred before time . These weights are welldefined due to the overlap assumptions detailed in Sec. 2; however, they are in general unknown as they depend on the unknown target parameters through . Further, especially for large , these weights might be very extreme even if known, which can lead to highly unstable results cortes2010learning – making biased yet stabilized weighting schemes, e.g. truncation, a good alternative. Therefore, we only assume access to some (possibly imperfect) weights s.t. , so that we can create a weighted distribution . (Note: can be recovered by using .)
Either instead of johansson2016learning; Shalit:16 or in addition to weighting johansson2018learning; hassanpour2019counterfactual; assaad2021counterfactual; johansson2020generalization, the literature on learning balanced representations for static treatment effect estimation has focused on finding a different remedy for distributional differences between treatment arms: creating representations which have similar (weighted) distributions across arms as measured by an integral probability metric (IPM), motivated by generalization bounds. As we show below, we can exploit a similar feature in our context by finding a representation that minimizes the IPM term not between treatment arms, but between covariate distribution at baseline and . The proposition below bounds the target risk of a hazard estimator relying on any representation. The proof, which relies on the concept of excess target information loss, proposed recently to analyze domainadversarial training johansson2019support, and the standard IPM arguments made in e.g. johansson2020generalization, is stated in Appendix C.
Proposition 1.
For fixed and representation , let , and denote the target, observational, and weighted observational distribution of the representation . Define the pointwise losses
(4) 
of (hazard) hypothesis w.r.t. distributions in covariate and representation space, respectively. Assume there exists a constant s.t. for some family of functions . Then we have that
(5) 
where and we define the excess target information loss analogously to johansson2019support as with . For invertible , .
Unlike the bounds provided in Shalit:16; johansson2018learning; johansson2020generalization; assaad2021counterfactual; chapfuwa2021enabling, this bound does not rely on representations to be invertible; we consider this feature important as none of the works listed actually enforced invertibility in their proposed algorithms. Given bound (5), it is easy to see why noninvertibilty can be useful: for any (possibly noninvertible) representation for which it holds that , it also holds that and the causally identifying restrictions continue to hold. A simple representation for which this property holds is a selection mechanism that chooses only the causal parents of from within ; if can be partitioned into variables affecting the instantaneous risk ( in Fig. 1), and variables affecting only treatment assignment () and/or censoring mechanism (), then the IPM term can be reduced by a representation which drops the latter two sets of variables – or irrelevant variables correlated with any such variables – without affecting . As a consequence, eventinduced covariate shift can generally not be fully corrected for using noninvertible representations (unless the variables affecting event time are different at every timestep). Further, given perfect importance weights , both and IPM term are zero.
Except for the dependence on , this bound differs from the regressionbased bound for survival treatment effects stated in chapfuwa2021enabling (which is identical to the original treatment effect bound in Shalit:16) in that we have dependence on in the IPM term, which, among other things, explicitly captures the effect of censoring. Our bound motivates that, instead of finding representations that balance treatment and control group at baseline (or at each time step) we should find representations that balance towards the baseline distribution for each time step, which motivates our method detailed below. If, instead, we would apply the IPMterm to encourage only the armspecific atrisk distributions at each timestep to be similar, this would correct only for shifts due to (i) confounding at baseline, (ii) treatmentinduced differences in censoring and (iii) treatmentinduced differences in events. It will, however, not allow handling the event and censoringinduced shifts that occur regardless of treatment status. Note that this bound therefore also motivates the use of balanced representations for modeling timetoevent outcomes in the presence of informative censoring even in the standard prediction setting, which is a finding that could be of independent interest for the ML survival analysis literature.
4 SurvITE: Estimating HTEs from TimetoEvent Data
Based on the theoretical analysis above, we propose a novel deep learning approach to HTE estimation from observed timetoevent data, which we refer to as SurvITE (Individualized Treatment Effect estimator for Survival analysis).^{4}^{4}4The source code for SurvITE is available in https://github.com/chl8856/survITE. The network architecture is illustrated in Figure 2. Note that even in the absence of treatments we can use this architecture for estimation of hazards and survival functions by using only one treatment . As we show in the experiments, this version of our method – SurvIHE (Individualized Hazard Estimator for Survival analysis) – is of independent interest in the standard survival setting, as it corrects for Shifts 1 & 2. Below, we describe the empirical loss functions we use to find representation and hypotheses .
Let denote the representation (parameterized by ) and the hazard estimator for treatment and time (parameterized by ), each implemented as a fullyconnected neural network. While the output heads are thus unique to each treatmentgroup timestep combination, we allow hazard estimators to share information by using one shared representation for all hazard functions. This allows for both borrowing of information across different and significantly reduces the number of parameters of the network. Then, given the timetoevent data , we use the following empirical loss functions for the observational risk and the IPM term:
where is the finitesample Wasserstein distance cuturi2014wass; refer to Appendix D for further detail. Note that , which penalizes the discrepancy between the baseline distribution and each atrisk distribution , simultaneously tackles all three sources of shifts. Further, is the number of samples atrisk in each treatment arm; its presence ensures that each combination contributes equally to the loss. Overall, we can find and ’s that optimally trade off balance and predictive power as suggested by the generalization bound (5) by minimizing the following loss:
(6) 
where , and is a hyperparameter. The pseudocode of SurvITE, the details of how to obtain and how we set can be found in Appendix D.
Uniform vs. nonuniform weighting. In (6), all samples are weighted uniformly (within each combination). We tested nonuniform, estimated importance weights , and, in synthetic experiments, even considered ‘oracle’ weights. Across both strategies for weighting and different truncation thresholds, we found that nonuniform weighting did not improve the performance of SurvITE. This is in line with recent empirical byrd2019effect and theoretical xu2021understanding
findings indicating that weighting may have little impact in deep learning – overparametrized NNs have sufficient capacity to not have to tradeoff between classifying different training points
byrd2019effect, which is the problem of lowcapacity misspecified models (e.g. linear models) in this context. We conjecture that the IPMterm, on the other hand, does help as it fulfils a slightly different purpose than weighting; it forces to act similarly to a variable selection mechanism (making the subsequent learning problem easier) and encourages ‘shiftinvariant’ representations that generalize better in the presence of different shifts.5 Related work
Heterogeneous treatment effect estimation (nonsurvival) has been studied in great detail in the recent ML literature. While early work built mainly on treebased methods hill2011bayesian; athey2016recursive; wager2018estimation; athey2019generalized, many other methods, such as Gaussian processes alaa2017bayesian; alaa2018limits and GANS yoon2018ganite, have been adapted to estimate HTEs. Arguably the largest stream of work johansson2016learning; Shalit:16; johansson2018learning; shi2019adapting; hassanpour2019counterfactual; hassanpour2020learning; assaad2021counterfactual; curth2020 built on NNs, due to their flexibility and ease of manipulating loss functions, which allows for easy incorporation of balanced representation learning as proposed in johansson2016learning; Shalit:16 and motivated also the approach taken in this paper. Another popular approach has been to consider modelagnostic (or ‘metalearner’ kunzel2019metalearners) strategies, which provide a ‘recipe’ for estimating HTEs using any predictive ML method kunzel2019metalearners; kennedy2020optimal; nie2021quasi; curth2020. Because of their simplicity, the single model (Slearner) – which uses the treatment indicator as an additional covariate in otherwise standard modelfitting – and two model (Tlearner) – which splits the sample by treatment status and fit two separate models – strategies kunzel2019metalearners, can be directly applied to the survival setting by relying on a standard survival (prediction) method as baselearner.
ML methods for survival prediction continue to multiply; here we focus on the most related class of methods – namely on those nonparametrically modeling conditional hazard or survival functions – and not on those relying on flexible implementations of the Cox proportional hazards model (e.g. faraggi1995neural; Katzman:16; Luck:17) or modeling (log)time as a regression problem (e.g. Hothorn:06; ranganath:16; chapfuwa2018adversarial; steingrimsson2019censoring; steingrimsson2020deep; avati2020countdown). One popular nonparametric estimator of survival functions is Ishwaran:08’s random survival forest, which relies on the NelsonAalen estimator to nonparametrically estimate the cumulative hazard within treeleaves. The idea of modeling discretetime hazards directly using any arbitrary classifier and long datastructures goes back to at least brown1975use, with implementations using NNbased methods presented in e.g. biganzoli1998feed; gensheimer2019scalable; ren2017dsra; kvamme2019continuous. Changhee:AAAI18 models the probability mass function instead of the hazard, and yu2011learning use labels
to estimate the survival function directly using multitask logistic regression. For a more detailed overview of different strategies for estimating survival functions, refer to Appendix B.
Estimating HTEs from timetoevent data has been studied in much less detail. tabib2020non; zhang2017mining use treebased nearestneighbor estimates to estimate expected differences in survival time directly, and henderson2020individualized use a BARTbased Slearner to output expected differences in logsurvival time. hu2020estimating performed a simulation study using different survival prediction models as baselearners for a twomodel approach to estimating the difference in median survival time. Based on ideas from the semiparametric efficiency literature, cui2020estimating and diaz2018targeted propose estimators that target the (restricted) mean survival time directly and consequently do not output estimates of the treatmentspecific hazard or survival functions. We consider the ability to output treatmentspecific predictions an important feature of a model if the goal is to use model output to give decision support, given that it allows the decisionmaker to tradeoff relative improvement with the baseline risk of a patient. Finally, chapfuwa2021enabling recently proposed a generative model for treatmentspecific event times which relies on balancing representations to balance only the treatment groups at baseline. This model does not output hazard or survival functions, but can provide approximations by performing MonteCarlo sampling.
6 Experiments
Unfortunately, when the goal is estimating (differences of) survival functions (instead of predicting survival), evaluation on real data will not reflect performance w.r.t. the intended baseline population. Therefore, we conduct a range of synthetic experiments with known ground truth. We evaluate the effects of different shifts separately by starting with survival estimation without treatments, and then introduce treatments. Finally, we use the realworld dataset Twins louizos2017causal which has uncensored survival outcomes for twins (where the treatment is ‘being born heavier’), and is hence free of Shifts 1 & 2.
Baselines.
Throughout, we use Cox regression (Cox), a model using a separate logistic regression to solve the hazard classification problem at each timestep (LRsep), random survival forest (RSF), and a deep learningbased timetoevent method Changhee:AAAI18 (DeepHit) as natural baselines; when there are treatments, we use them in a twomodel (Tlearner) approach. In settings with treatments, we additionally use the CSAINFO model of chapfuwa2021enabling (CSA), where we use its generative capabilities to approximate target quantities via montecarlo sampling. Finally, we consider ablations of SurvITE (and SurvIHE); in addition to removing the IPM term (SurvITE (no IPM)), we consider two variants of SurvITE based on Shalit:16’s CFRNet balancing term: SurvITE (CFR1) creates a representation balancing treatment groups at baseline only, and SurvITE (CFR2) creates a representation optimizing for balance of treatment groups at each time step (i.e. no balancing towards ). We discuss implementation in Appendix D.
Synthetic Experiments.
We consider a range of synthetic simulation setups (S1S4) to highlight and isolate the effects of the different types of covariate shift. As event and censoring processes, we use
with treatment assignment mechanism , with
the sigmoid function. Additionally, we assume administrative censoring at
throughout, i.e.,, marking e.g. the end of a hypothetical clinical study. Covariates are generated from a 10dimensional multivariate normal distribution with correlations, i.e.
where with . We use 5000 independently generated samples each for training and testing.In S1, we begin with the simplest case – no treatments and no censoring – using only to generate events, considering only eventinduced shift (Shift 3). In S2, we introduce informative censoring using (Shift 2+3). In S3, we use treatments and consider biased treatment assignment (without censoring) (Shift 1+3). In S4, we consider the most difficult case with all three types of shift (Shift 1+2+3). In the latter two settings, we vary treatment selection by changing (i) whether the covariate set overlaps with the eventinducing covariates () or not () and (ii) the selection strength . We present exploratory plots of these DGPs in Appendix E.
Fig. 3 (top) shows performance on estimating for all scenarios and methods, while Fig. 3 (bottom) shows performance on estimating the difference in survival functions () for a selection of methods (for readability, full results in Appendix F). In Table 1, we further evaluate the estimation of differences in RMST (). Additional performance metrics for survival prediction are reported in Appendix F. We observe that SurvITE (/SurvIHE) performs best throughout, and that introduction of the IPM term leads to substantial improvements across all scenarios. In S1 with only eventinduced covariate shift and in S3/4 when treatment selection and eventinducing covariates overlap (), balancing cannot remove all shift as the shiftinducing covariates are predictive of outcome; however, even here the IPMterm helps as it encourages dropping other covariates (which appear imbalanced due to correlations in ). As expected, both Cox and LRsep do not perform well as they are misspecified, while the nonparametric RSF is sufficiently flexible to capture the underlying DGP and usually performs similarly to SurvITE (architecture only), but is outperformed once the IPM term is added. For readability, we did not include DeepHit in Fig. 3; using table F.1 presented in Appendix F, we observe that DeepHit performs worse than the SurvITE architecture without IPM term, indicating that our model architecture alone is better suited for estimation of treatmentspecific survival functions (note that Changhee:AAAI18 focused mainly on discriminative (predictive) performance, and not on the estimation of the survival function itself). Therefore, upon addition of the IPMterms, the performance gap between SurvITE and DeepHit only becomes larger.
A comparison with ablated versions highlights the effect of using the appropriate baseline population to define balance; naive balancing across treatment arms (either at baseline – SurvITE(CFR1), or over time – SurvITE(CFR2)) is not as effective as using the baseline population as a target, especially at the later time steps where the effects of bias worsen. While SurvITE(CFR2) almost matches the performance of the full SurvITE in S3, it performs considerably worse in S4, indicating that this form of balancing suffers mainly due to its ignorance of censoring. Finally, a comparison with CSA highlights the value of modeling hazard functions directly: we found that MonteCarlo approximation of the survival function using the generated event times gives very badly calibrated survival curves as event times generated by CSA were concentrated in a very narrow interval, leading to survival estimates of 0 and 1 elsewhere. Its performance on estimation of RMST was likewise poor; we conjecture that this is due to (i) CSA modeling continuous time, while the outcomes were generated using a coarse discrete time model, and (ii) the significant presence of administrative censoring.
Methods  S3 (, no overlap)  S4 (, no overlap)  Twins (no censoring)  Twins (censoring)  
Cox  0.4340.03  1.0730.05  0.4240.02  1.0470.04  2.850.10  20.330.50  2.880.09  20.600.50 
RSF  0.3280.02  1.0270.03  0.3320.02  1.0580.03  3.150.07  22.420.36  3.180.08  22.620.46 
LRsep  0.4120.02  1.1110.07  0.4180.02  1.1490.04  2.940.10  20.600.53  2.940.10  20.660.52 
DeepHit  0.3470.03  0.8210.07  0.3610.08  0.8300.15  2.950.28  20.891.91  2.860.09  20.690.52 
CSA  0.4210.01  2.0980.26  0.4060.01  1.9320.12  3.420.12  26.201.21  4.410.54  47.791.55 
SurvITE (no IPM)  0.2750.04  0.8430.11  0.3100.05  0.9300.11  2.800.10  19.801.01  2.850.22  20.001.07 
SurvITE (CFR1)  0.2690.04  0.8250.09  0.3410.02  1.0160.10  2.680.06  19.160.37  2.670.15  19.100.85 
SurvITE (CFR2)  0.2360.04  0.6910.08  0.2940.07  0.8150.15  2.610.12  18.690.64  2.690.22  19.201.44 
SurvITE  0.2250.03  0.6870.08  0.2370.03  0.7030.06  2.530.09  18.340.70  2.630.10  18.760.56 
Real data: Twins.
Finally, we consider the Twins benchmark dataset, containing survival times (in days, administratively censored at t=365) of 11400 pairs of twins, which is used in louizos2017causal; yoon2018ganite to measure HTEs of birthweight on infant mortality. We split the data 50/50 for training and testing (by twin pairs), and similar to yoon2018ganite, use a covariatebased sampling mechanism to select only one twin for training to emulate selection bias. Further, we consider a second setting where we additionally introduce covariatedependent censoring. For all discretetime models, we use a nonuniform discretization to construct classification tasks because most events are concentrated in the first weeks. A more detailed description of the data and experimental setup can be found in Appendix E. As the data is real and ground truth probabilities are unknown, is suited best to evaluate performance on estimating effect heterogeneity. The results presented in Table 1 largely confirm our findings on relative performance in the synthetic experiments; only RSF performs relatively worse on this dataset.
7 Conclusion
We studied the problem of inferring heterogeneous treatment effects from timetoevent data by focusing on the challenges inherent to treatmentspecific hazard estimation. We found that a variety of covariate shifts play a role in this context, theoretically analysed their impact, and demonstrated across a range of experiments that our proposed method SurvITE successfully mitigates them.
Limitations. Like all methods for inferring causal effects from observational data, SurvITE relies on a set of strong assumptions which should be evaluated by a domain expert prior to deployment in practice. Here, the timetoevent nature of our problem adds an additional assumption (‘random censoring’) to the standard ‘no hidden confounders’ assumption in classical treatment effect estimation. If such assumptions are not properly assessed in practice, any causal conclusions may be misleading.
We thank anonymous reviewers as well as members of the vanderschaarlab for many insightful comments and suggestions. AC gratefully acknowledges funding from AstraZeneca. CL was supported through the IITP grant funded by the Korea government(MSIT) (No. 2021001341, AI Graduate School Program, CAU). Additionally, MvdS received funding from the Office of Naval Research (ONR) and the National Science Foundation (NSF, grant number 1722516).
References
Appendix
This appendix is organized as follows: We first present an extended overview of the standard treatment effect estimation setup and discuss differences with the timetoevent setting (Appendix A). Then, we give an extended review of strategies for nonparametric estimation of survival dynamics (Appendix B). In Appendix C we discuss technical details – assumptions and proofs – and Appendix D we discuss implementation. Appendix E contains additional descriptions of datasets and experimental setup and Appendix F presents additional results.
Appendix A Preliminaries on treatment effect estimation
In the standard treatment effect estimation setup with binary or continuous outcomes (see e.g. [Shalit:16, alaa2018limits, curth2020]), one usually observes a dataset comprising realizations of the tuple . and represent patient characteristics and treatment assignment as in the main text. is usually a binary () or continuous () outcome. The target parameter of interest is often the conditional average treatment effect (CATE)
(7) 
which is impossible to estimate from observational data without further assumptions, as – due to the fundamental problem of causal inference [holland1986statistics] – every individual is only ever observed under one of the two possible interventions. CATE can therefore only be nonparametrically estimated under the imposition of untestable assumptions; here we rely on the standard ignorability assumptions [rosenbaum1983central] of No hidden confounders (1.a), Consistency (1.c) and Positivity/Overlap in treatment assignment (2.a).
a.1 Comparison with the timetoevent treatment effects setup
The timetoevent setting is made more involved by (i) the presence of censoring and (ii) the interest in the dynamics of the underlying survival process.
Censoring – the removal of some individuals from the sample before having observed their event time – further complicates the treatment effect estimation problem, because every individual’s outcome (timetoevent) is now observed under at most one intervention. The presence of censoring adds an additional source of covariate shift, and the need to rely on the assumptions of Censoring at random (1.b) and Positivity in censoring (2.b). Censoring is, however, different from complete missingness of the outcome as the censoring time provides some information on the outcome – an individual has survived at least until the censoring time.
While the difference in expected survival time (the timetoevent equivalent in CATE) can be the treatment effect of interest in a study, many survival analysis problems are concerned with target parameters that capture differences in the dynamics of the underlying survival process across treatments, e.g. hazard ratios or differences in survival functions – which substantially increases the number of possible target parameters to model (beyond ‘only’ CATE). Instead of only modeling expected outcomes (as would be the case in the standard setup as discussed above), modeling survival dynamics through e.g. the treatmentspecific hazard function can therefore often be of interest. Nonparametrically modeling hazard functions introduces the additional assumption on Positivity of events (2.c).
Appendix B Strategies for lossbased discretetime hazard and survival function estimation
In this section, we review strategies for nonparametric (or machinelearning based) estimation of the dynamics underlying discretetime event processes. Here, we consider on the standard case without treatments to highlight how a dependence on different populations arises in different modeling strategies, and follow closely the exposition of different strategies in [kvamme2019continuous]. We focus on loss functions that can be used for implementation to highlight that these approaches are valid for use of any classifier, and then briefly mention specific instantiations of such approaches from related work.
Preliminaries. In addition to hazard and survival function defined in the main text, define the probability mass functions (PMF) as
(B.1) 
Note that a hazard can then also be defined as . Further, recall that the survival function , so that the PMF can be rewritten as .
b.1 Likelihoodbased hazard estimation
Under the assumption of random censoring (which is discussed further in Appendix C.1), the likelihood function of the observed (short) data factorizes; i.e.
By the likelihood principle, the parts pertaining to censoring are ignorable, hence we can consider censoring and event likelihoods separately [rubin1976inference]. The likelihood contribution of observation to the negative timetoevent likelihood can then be written as:
(B.2) 
so that, after taking the logarithm and summing over all we have that
(B.3) 
with as in the main text. Thus, the classification approach with logloss is equivalent to optimizing for the likelihood of the hazard. Optimizing the likelihood of the hazard thus suffers from the exact same shifts as the classification approach, namely the shifts induced by focusing on the ‘atrisk’ population at any timestep: the logloss also has dependence on . Note that, as we illustrate in section B.1.1, under such shifts, optimizing the likelihood is only problematic if the model for is misspecified – a wellestablished fact in the literature on covariate shift [shimodaira2000improving].
Depending how is parameterized, different models proposed in related work arise. The idea to use a classification approach dates back to at least the logistichazard model in [brown1975use], and is reviewed in more detail in [tutz2016modeling]. The first NNbased implementation that we are aware of is [biganzoli1998feed], which parameterizes by using one shared network for all where the timeindicator is passed as an additional covariate. [gensheimer2019scalable] instead propose a network with some shared layers and specific output layers (resulting in a model similar to the SurvIHE basemodel). Finally, [ren2017dsra]’s DSRA parameterizes using a recurrent network which encodes the structure shared across time.
b.1.1 Illustration: Why (mis)specification matters
To briefly illustrate when eventinduced atrisk population shift matters, we consider two simple toy examples: we rely on eventprocesses with covariatedependent but timeconstant hazards, i.e. , and there are 5 multivariate normal correlated covariates, of which only determines the hazard. We parameterize hazard estimators using a separate logistic regression at each time step . We consider one process where this logistic regression is correctly specified for the underlying hazard function, as . We consider another process where this logistic regression is misspecified, as (i.e. there is a nonlinearity that cannot be perfectly captured by a simple logistic regression).
As can be seen in Fig. B.1, both processes lead to eventinduced covariate shift. However, this shift has no effect on hazard estimator performance over time when the model is correctly specified. Yet, when the model is incorrectly specified, the estimator has to trade off making errors in different regions of the covariate space. The optimal tradeoff w.r.t. the baseline distribution is made by the hazard classifier at where the atrisk distribution corresponds to the marginal distribution of covariates. Due to eventinduced covariate shift, hazard estimates become increasingly biased towards the survivor population at later timesteps.
b.2 Survivalbased estimation
An alternative approach to targeting the likelihood of the hazard would be to target the survival function directly by realizing that , so that the survival function can be estimated directly by solving classification problems with targets . This considers a loss function
(B.4) 
which suffers from censoringinduced covariate shift due to the interaction of ; i.e. only noncensored individuals contribute to the ‘negative class’ , an effect that gets larger for large. The multitask logistic regression approach proposed in [yu2011learning] is a variant of the more general approach described above; it uses a modeling approach based on conditional random fields [lafferty2001conditional] and jointly models all survival functions by accounting for the sequential nature of targets and the existance of a restricted set of ‘legal’ values.
b.3 PMFbased estimation
Finally, instead of focussing on hazard or survival function, one could also estimate the PMF function; the PMF can be transformed to hazard or survival functions by realizing that and . This can be done by treating the survival problem as a
class classification problem with onehot encoded labels
leading to the loss(B.5) 
so that each uncensored observation contributes mainly to the estimate of at its eventtime step [ren2017dsra] (instead of multiple timesteps as in the previous two subsections). Due to the presence of censoring indicator , this suffers from censoringinduced covariate shift. As in [Changhee:AAAI18]’s DeepHit, a likelihood contribution marginalizing over possible outcomes for all censored observations can be added, such that they contribute to by signalling that their event times are larger. For correctly specified models this corresponds to optimizing the likelihood of the PMF and is hence sufficient to correct for censoring, however, otherwise this does not exactly correct for censoringinduced covariate shift.
Appendix C Technical details: Assumptions and Proofs
c.1 Assumptions
In this section, we discuss and formally state the assumptions made in Section 2. As e.g. [stitelman2010collaborative, stitelman2011targeted, cai2019targeted], we assume the fairly general causal structure encoded in the DAG in Figure 1. By assuming that observed data was generated from this DAG, the classical identifying assumptions (No Hidden Confounders, Censoring At Random, and Consistency) are implicitly formalized [stitelman2010collaborative].
Equivalently, we can restate the assumptions using potential outcomes [rubin2005causal] notation. As in e.g. [diaz2018targeted], we let denote the potential event time that would have been observed had treatment a been assigned, and been externally set. Then, the following assumptions are implied by the DAG:
Assumption 1 (1.a No Hidden Confounders/ Unconfoundedness).
Treatment assignment is random conditional on covariates, i.e. .
Assumption 2 (1.b Censoring at random).
Censoring and outcome are conditionally independent, i.e. .
Assumption 3 (1.c Consistency).
The observed outcomes are the potential outcomes under the observed intervention, i.e. if then .
Then, we can write
Here, the equalities in line one and two follow by definition, line three follows by assumption 1.a, line four follows by assumption 1.b, the equality in line five follows by assumption 1.c, and the final line follows by definition.
To enable nonparametric estimation of for some fixed , we additionally consider a number of conditions on the likelihood of observing certain events.
Assumption 4 (2.a Overlap/positivity (treatment assignment)).
Treatment assignment is nondeterministic, i.e. for some , we have that
Assumption 5 (2.b Positivity (censoring)).
Censoring is nondeterministic, i.e. for some , we have that for all .
Assumption 6 (2.c Positivity (events)).
Not all events deterministically occur before time , i.e.
Assumptions 1.a, 1.c and 2.a are standard within the treatment effect estimation literature [alaa2018limits, Shalit:16]; assumptions 1.b and 2.b are standard within the literature with survival outcomes [diaz2018targeted, cui2020estimating]. Assumption 2.c is needed only if we aim to estimate for all , otherwise it would suffice to follow a convention such as setting whenever .
c.2 Proof of proposition 1
In this section we state the proof of proposition 1 and restate two lemmas from [johansson2019support] which we use within the proof.
Notation and definitions (restated)
For fixed and representation , let , and denote the baseline, observational and weighted observational distribution w.r.t. the representation . Define the pointwise losses
Comments
There are no comments yet.