Estimating Counterfactual Treatment Outcomes over Time Through Adversarially Balanced Representations

02/10/2020 ∙ by Ioana Bica, et al. ∙ University of Cambridge University of Oxford 0

Identifying when to give treatments to patients and how to select among multiple treatments over time are important medical problems with a few existing solutions. In this paper, we introduce the Counterfactual Recurrent Network (CRN), a novel sequence-to-sequence model that leverages the increasingly available patient observational data to estimate treatment effects over time and answer such medical questions. To handle the bias from time-varying confounders, covariates affecting the treatment assignment policy in the observational data, CRN uses domain adversarial training to build balancing representations of the patient history. At each timestep, CRN constructs a treatment invariant representation which removes the association between patient history and treatment assignments and thus can be reliably used for making counterfactual predictions. On a simulated model of tumour growth, with varying degree of time-dependent confounding, we show how our model achieves lower error in estimating counterfactuals and in choosing the correct treatment and timing of treatment than current state-of-the-art methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

As clinical decision-makers are often faced with the problem of choosing between treatment alternatives for patients, reliably estimating their effects is paramount. While clinical trials represent the gold standard for causal inference, they are expensive, have a few patients and narrow inclusion criteria (booth2014randomised). Leveraging the increasingly available observational data about patients, such as electronic health records, represents a more viable alternative for estimating treatment effects.

A large number of methods have been proposed for performing causal inference using observational data in the static setting (johansson2016learning; shalit2017estimating; alaa2017bayesian; li2017matching; yoon2018ganite; alaa2018limits; yao2018representation) and only a few methods address the longitudinal setting (xu2016bayesian; roy2016bayesian; soleimani2017treatment; schulam2017reliable; lim2018forecasting). However, estimating the effects of treatments over time poses unique opportunities such as understanding how diseases evolve under different treatment plans, how individual patients respond to medication over time, but also which are optimal timings for assigning treatments, thus providing new tools to improve clinical decision support systems.

Figure 1: Applicability of CRN in cancer treatment planning. We illustrate 3 patients with different covariate and treatment histories . For a current time , CRN can predict counterfactual trajectories (the coloured dashed branches) for planned treatments in the future. Through the counterfactual predictions, we can decide which treatment plan results in the best patient outcome (in this case, the lowest tumour volume). This way, CRN can be used to perform all of the following: choose optimal treatments (a), find timing when treatment is most effective (b) decide when to stop treatment (c).

The biggest challenge when estimating the effects of time-dependent treatments from observational data involves correctly handling the time-dependent confounders: patient covariates that are affected by past treatments which then influence future treatments and outcomes (platt2009time). For instance, consider that treatment A is given when a certain patient covariate (e.g. white blood cell count) has been outside of normal range values for several consecutive timesteps. Suppose also that this patient covariate was itself affected by the past administration of treatment B. If these patients are more likely to die, without adjusting for the time-dependent confounding (e.g. the changes in the white blood cell count over time), we will incorrectly conclude that treatment A is harmful to patients. Moreover, estimating the effect of a different sequence of treatments on the patient outcome would require not only adjusting for the bias at the current step (in treatment A), but also for the bias introduced by the previous application of treatment B.

Existing methods for causal inference in the static setting cannot be applied in this longitudinal setting since they are designed to handle the cross-sectional set-up, where the treatment and outcome depend only on a static value of the patient covariates. If we consider again the above example, these methods would not be able to model how the changes in patient covariates over time affect the assignment of treatments and they would also not be able to estimate the effect of a sequence of treatments on the patient outcome (e.g. sequential application of treatment A followed by treatment B). Different models that can handle these temporal dependencies in the observational data and varying-length patient histories are needed for estimating treatment effects over time.

Time-dependent confounders are present in observational data because doctors follow policies: the history of the patients’ covariates and the patients’ response to past treatments are used to decide future treatments (mansournia2012effect)

. The direct use of supervised learning methods will be biased by the treatment policies present in the observational data and will not be able to correctly estimate counterfactuals for different treatment assignment policies.

Standard methods for adjusting for time-varying confounding and estimating the effects of time-varying exposures are based on ideas from epidemiology. The most widely used among these are Marginal Structural Models (MSMs) (robins2000marginal; mansournia2012effect)

which use the inverse probability of treatment weighting (IPTW) to adjust for the time-dependent confounding bias. Through IPTW, MSMs create a pseudo-population where the probability of treatment does not depend on the time-varying confounders. However, MSMs are not robust to model misspecification in computing the IPTWs. MSMs can also give high-variance estimates due to extreme weights; computing the IPTW involves dividing by probability of assigning a treatment conditional on patient history which can be numerically unstable if the probability is small.

We introduce the Counterfactual Recurrent Network (CRN), a novel sequence-to-sequence architecture for estimating treatment effects over time. CRN leverages recent advances in representation learning (bengio2012representation) and domain adversarial training (ganin2016domain) to overcome the problems of existing methods for causal inference over time. Our main contributions are as follows.

Treatment invariant representations over time. CRN constructs treatment invariant representations at each timestep in order to break the association between patient history and treatment assignment and thus removes the bias from time-dependent confounders. For this, CRN uses domain adversarial training (ganin2016domain; li2018deep; sebag2019multi) to trade-off between building this balancing representation and predicting patient outcomes. We show that these representations remove the bias from time-varying confounders and can be reliably used for estimating counterfactual outcomes. This represents the first work that introduces ideas from domain adaptation to the area of estimating treatment effects over time. In addition, by building balancing representations, we propose a novel way of removing the bias introduced by time-varying confounders.

Counterfactual estimation of future outcomes. To estimate counterfactual outcomes for treatment plans (and not just single treatments), we integrate the domain adversarial training procedure as part of a sequence-to-sequence architecture. CRN consists of an encoder network which builds treatment invariant representations of the patient history that are used to initialize the decoder. The decoder network estimates outcomes under an intended sequence of future treatments, while also updating the balanced representation. By performing counterfactual estimation of future treatment outcomes, CRN can be used to answer critical medical questions such as deciding when to give treatments to patients, when to start and stop treatment regimes, and also how to select from multiple treatments over time. We illustrate in Figure 1 the applicability of our method in choosing optimal cancer treatments.

In our experiments, we evaluate CRN in a realistic set-up using a model of tumour growth (geng2017prediction). We show that CRN achieves better performance in predicting counterfactual outcomes, but also in choosing the right treatment and timing of treatment than current state-of-the-art methods.

2 Related work

We focus on methods for estimating treatment effects over time and for building balancing representations for causal inference. A more in-depth review of related work is in Appendix A.

Treatment effects over time. Standard methods for estimating the effects of time-varying exposures were first developed in the epidemiology literature and include the g-computation formula, Structural Nested Models and Marginal Structural Models (MSMs) (robins1986new; robins1994correcting; robins2000marginal; robins2008estimation)

. Originally, these methods have used predictors performing logistic/linear regression which makes them unsuitable for handling complex time-dependencies

(hernan2001marginal; mansournia2012effect; mortimer2005application)

. To address these limitations, methods that use Bayesian non-parametrics or recurrent neural networks as part of these frameworks have been proposed.

(xu2016bayesian; roy2016bayesian; lim2018forecasting).

To begin with, xu2016bayesian use Gaussian processes to model discrete patient outcomes as a generalized mixed-effects model and uses the -computation method to handle time-varying confounders. soleimani2017treatment extend the approach in xu2016bayesian to the continuous time-setting and model treatment responses using linear time-invariant dynamical systems. roy2016bayesian use Dirichlet and Gaussian processes to model the observational data and estimate the IPTW in Marginal Structural Models. schulam2017reliable build upon work from lok2008statistical; arjas2004causal and use marked point processes and Gaussian processes to learn causal effects in continuous-time data. These Bayesian non-parametric methods make strong assumptions about model structure and consequently cannot handle well heterogeneous treatment effects arising from baseline variables (soleimani2017treatment; schulam2017reliable) and multiple treatment outcomes (xu2016bayesian; schulam2017reliable).

The work most related to ours is the one of lim2018forecasting which improves on the standard MSMs by using recurrent neural networks to estimate the inverse probability of treatment weights (IPTWs). lim2018forecasting

introduces Recurrent Marginal Structural Networks (RMSNs) which also use a sequence-to-sequence deep learning architecture to forecast treatment responses in a similar fashion to our model. However, RMSNs require training additional RNNs to estimate the propensity weights and does not overcome the fundamental problems with IPTWs, such as the high-variance of the weights. Conversely, CRN takes advantage of the recent advances in machine learning, in particular, representation learning to propose a novel way of handling time-varying confounders.

Balancing representations for treatment effect estimation. Balancing the distribution of control and treated groups has been used for counterfactual estimation in the static setting. The methods proposed in the static setting for balancing representations are based on using discrepancy measures in the representation space between treated and untreated patients, which do not generalize to multiple treatments (johansson2016learning; shalit2017estimating; li2017matching; yao2018representation). Moreover, due to the sequential assignment of treatments in the longitudinal setting, and due to the change of patient covariates over time according to previous treatments, the methods for the static setting are not directly applicable to the time-varying setting (hernan2000marginal; mansournia2012effect).

3 Problem formulation

Consider an observational dataset consisting of information about independent patients. For each patient , we observe time-dependent covariates , treatment received and outcomes for discrete timesteps. The patient can also have baseline covariates such as gender and genetic information. Note that the outcome will be part of the observed covariates . For simplicity, the patient superscript will be omitted unless explicitly needed.

We adopt the potential outcomes framework proposed by (neyman1923applications; rubin1978bayesian) and extended by (robins2008estimation) to account for time-varying treatments. Let be the potential outcomes, either factual or counterfactual, for each possible course of treatment . Let represent the history of the patient covariates , treatment assignments and static features . We want to estimate:

(1)

where represents a possible sequence of treatments from timestep just until before the potential outcome is observed. We make the standard assumptions (robins2000marginal; lim2018forecasting) needed to identify the treatment effects: consistency, positivity and no hidden confounders (sequential strong ignorability). See Appendix B for more more details.

4 Counterfactual Recurrent Network

The observational data can be used to train a supervised learning model to forecast: . However, without adjusting for the bias introduced by time-varying confounders, this model cannot be reliably used for making causal predictions (robins2000marginal; robins2008estimation; schulam2017reliable). The Counterfactual Recurrent Network (CRN) removes this bias through domain adversarial training and estimates the counterfactual outcomes , for any intended future treatment assignment .

Balancing representations. The history of the patient contains the time-varying confounders which bias the treatment assignment in the observational dataset. Inverse probability of treatment weighting, as performed by MSMs, creates a pseudo-population where the probability of treatment does not depend on the time-varying confounders (robins2000marginal). In this paper, we propose instead building a representation of the history that is not predictive of the treatment . This way, we remove the association between history, containing the time-varying confounders , and current treatment . robins1999association shows that in this case, the estimation of counterfactual treatment outcomes is unbiased. See Appendix C for details and for an example of a causal graph with time-dependent confounders.

Let be the representation function that maps the patient history to a representation space . To obtain unbiased treatment effects, needs to construct treatment invariant representations such that . To achieve this and to estimate counterfactual outcomes under a planned sequence of treatments, we integrate the domain adversarial training framework proposed by ganin2016domain and extended by sebag2019multi to the multi-domain learning setting, into a sequence-to-sequence architecture. In our case, the different treatments at each timestep are considered the different domains. Note that the novelty here comes from the use of domain adversarial training to handle the bias from the time-dependent confounders, rather than the use of sequence-to-sequence models, which have already been applied to forecast treatment responses (lim2018forecasting). Figure 2 illustrates our model architecture.

Figure 2: CRN architecture. Encoder builds representation

that maximizes loss of treatment classifier

and minimizes loss of outcome predictor . is used to initialize the decoder, which continues to update it to predict counterfactual outcomes of a sequence of future treatments.

Encoder. The encoder network uses an RNN, with LSTM unit (hochreiter1997long), to process the history of treatments , covariates and baseline features to build a treatment invariant representation , but also to predict one-step-ahead outcomes . To achieve this, the encoder network aims to maximize the loss of the treatment classifier and minimize the loss of the outcome predictor network . This way, the balanced representation is not predictive of the assigned treatment , but is discriminative enough to estimate the outcome . To train this model using gradient descent, we use the Gradient Reversal Layer (ganin2016domain).

Decoder. The decoder network uses the balanced representation computed by the encoder to initialize the state of an RNN that predicts the counterfactual outcomes for a sequence of future treatments. During training, the decoder uses as input the outcomes from the observational data , the static patient features and the intended sequence of treatments . The decoder is trained in a similar way to the encoder to update the balanced representation and to estimate the outcomes. During testing, we do not have access to ground-truth outcomes; thus, the outcomes predicted by the decoder are auto-regressively used instead as inputs. By running the decoder with different treatment settings, and by auto-regressively feeding back the outcomes, we can determine when to start and end different treatments, which is the optimal time to give the treatment and which treatments to give over time to obtain the best patient outcomes.

The representation is built by applying a fully connected layer, with Exponential Linear Unit (ELU) activation to the output of the LSTM. The treatment classifier and the predictor network consist of a hidden layer each, also with ELU activation. The output layer of uses softmax activation, while the output layer of uses linear activation for continuous predictions. For categorical outcomes, softmax activation can be used. We follow an approach similar to lim2018forecasting and we split the encoder and decoder training into separate steps. See Appendix E for details.

The encoder and decoder networks use variational dropout (gal2016theoretically) such that the CRN can also give uncertainty intervals for the treatment outcomes. This is particularity important in the estimation of treatment effects, since the model predictions should only be used when they have high confidence. Our model can also be modified to allow for irregular samplings of observations by using a PhasedLSTM (neil2016phased).

5 Adversarially balanced representation over time

At each timestep , let the different possible treatments represent our domains. As described in Section 4, to remove the bias from time-dependent confounders, we build a representation of history that is invariant across treatments: .

This requirement can be enforced by minimizing the distance in the distribution of between any two pairs of treatments. kifer2004detecting; ben2007analysis, propose measuring the disparity between distributions based on their separability by a discriminatively-trained classifier. Let the symmetric hypothesis class consist of the set of symmetric multiclass classifiers, such as neural network architectures. The -divergence between all pairs of two distributions is defined in terms of the capacity of the hypothesis class to discriminate between examples from the multiple distributions. Empirically, minimizing the divergence involves building a representation where examples from the multiple domains are as indistinguishable as possible (ben2007analysis; li2018deep; sebag2019multi). ganin2016domain use this idea to propose an adversarial framework for domain adaptation involving building a representation which achieves maximum error on a domain classifier and minimum error on an outcome predictor. Similarly, in our case, we use domain adversarial training to build a representation of the patient history that is both invariant to the treatment given at timestep , and that achieves low error in estimating the outcome .

Figure 3: Training procedure for building balancing representation.

Let be the treatment classifier with parameters and let be the output corresponding to treatment . Let be the predictor network with parameters . The representation function is parameterized by the parameters in the RNN: . Figure 3 shows the adversarial training procedure used.

For timestep and patient , let be the treatment (domain) loss and let the outcome loss, defined as follows:

(2)
(3)

If the outcome is binary, the cross-entropy loss can be used instead for . To build treatment invariant representations and to also estimate patient outcomes, we aim to maximize treatment loss and minimize outcome loss.

Thus, the overall loss at timestep is given by:

(4)

where the hyperparameter

controls this trade-off between domain discrimination and outcome prediction. We use the standard procedure for training domain adversarial networks from ganin2016domain and we start off with an initial value for

and use an exponentially increasing schedule during training. To train the model using backpropagation, we use the Gradient Reversal Layer (GRL)

(ganin2016domain). For more details about the training procedure, see Appendix E.

By using the objective , we reach the saddle point that achieves the equilibrium between domain discrimination and outcome estimation.

(5)

The result stated in Theorem 1 proves that the treatment (domain) loss part of our objective (from equation 2) aims to remove the time-dependent confounding bias.

Theorem 1.

Let . For each , let denote the distribution of conditional on and let denote the distribution of conditional on . Let denote the output of corresponding to treatment . Then the minimax game defined by

(6)

has a global minimum which is attained if and only if , i.e. when the learned representations are invariant across all treatments.

Proof.

This result is a restatement of the one in li2018deep. For details, see the Appendix D. ∎

A good representation allows us to obtain a low error in estimating counterfactuals for all treatments, while at the same time to minimize the -divergence between induced marginal distributions of all the domains. We use an algorithm that directly minimizes a combination of the divergence and the empirical training margin.

6 Experiments

In real datasets, counterfactual outcomes and the degree of time-dependent confounding are not known (schulam2017reliable; lim2018forecasting). To validate the CRN111The implementation of the model can be found at https://bitbucket.org/mvdschaar/mlforhealthlabpub/src/master/alg/counterfactual_recurrent_network/ and at https://github.com/ioanabica/Counterfactual-Recurrent-Network., we evaluate it on a Pharmacokinetic-Pharmacodynamic model of tumour growth (geng2017prediction), which uses a state-of-the-art bio-mathematical model to simulate the combined effects of chemotherapy and radiotherapy in lung cancer patients. The same model was used by lim2018forecasting to evaluate RMSNs.

Model of tumour growth The volume of tumour days after diagnosis is modelled as follows:

(7)

where are sampled as described in geng2017prediction. To incorporate heterogeneity in patient responses, the prior means for and are adjusted to create patient subgroups, which are used as baseline features. The chemotherapy concentration and radiotherapy dose are modelled as described in Appendix F

. Time-varying confounding is introduced by modelling chemotherapy and radiotherapy assignment as Bernoulli random variables, with probabilities

and depending on the tumour diameter: and where is the average diameter over the last days, , is the sigmoid and . The amount of time-dependent confounding is controlled through ; the higher is, the more important the history is in assigning treatments. At each timestep, there are four treatment options: no treatment, chemotherapy, radiotherapy, combined chemotherapy and radiotherapy. For details about data simulation, see Appendix F.

Benchmarks We used the following benchmarks for performance comparison: Marginal Structural Models (MSMs) (robins2000marginal)

, which use logistic regression for estimating the IPTWs and linear regression for prediction (see Appendix

G for details). We also compare against the Recurrent Marginal Structural Networks (RMSNs) lim2018forecasting, which is the current state-of-the-art model in estimating treatment responses. RMSNs use RNNs to estimate the IPTWs and the patient outcomes (details in Appendix H). To show that standard supervised learning models do not handle the time-varying confounders we compare against an RNN and a linear regression model, which receive as input treatments and covariates to predict the outcome (see Appendix I for details). Our model architecture follows the description in Sections 4 and 5, with full training details and hyperparameter optimization in Appendix J. To show the importance of adversarial training, we also benchmark against CRN () a model with the same architecture, but with , i.e our model architecture without adversarial training.

6.1 Evaluate models on counterfactual predictions

Previous methods focused on evaluating the error only for factual outcomes (observed patient outcomes) (lim2018forecasting). However, to build decision support systems, we need to evaluate how well the models estimate the counterfactual outcomes, i.e patient outcomes under alternative treatment options. The parameters and control the treatment assignment policy, i.e. the degree of time-dependent confounding present in the data. We evaluate the benchmarks under different degrees of time-dependent confounding by setting . For each we simulate a 10000 patients for training, 1000 for validation (hyperparameter tuning) and 1000 for out-of-sample testing. For the patients in the test set, for each time , we also simulate counterfactuals , represented by tumour volume , under all possible treatment options.

(a) One-step ahead prediction
(b) Five-step ahead prediction
Figure 4: Results for prediction of patient counterfactuals.

Figure 4 (a) shows the normalized root mean squared error (RMSE) for one-step ahead estimation of counterfactuals with varying degree of time-dependent confounding . The RMSE is normalized by the maximum tumour volume: . The linear and MSM models provide a baseline for performance as they achieve the highest RMSE. While the use of IPTW in MSMs helps when increases, using linear modelling has severe limitations. When there is no time-dependent confounding, the machine learning methods achieve similar performance, close to 0.6% RMSE. As the bias in the dataset increases, the harder it becomes for the RNN and the CRN ( = 0) to generalize to estimate outcomes of treatments not matching the training policy. When , CRN improves by on the same model architecture without domain adversarial training CRN ( = 0).

Our proposed model achieves the lowest RMSE across all values of . Compared to RMSNs, CRN improves by when . To highlight the gains of our method even for smaller , Figure 4 (b) shows the RMSE for five-step ahead prediction (with counterfactuals generated as described in Section 6.2 and Appendix L). RMSNs also use a decoder for sequence prediction. However, RMSNs require training additional RNNs to estimate the IPTW, which are used to weight each sample during the decoder training. For -step ahead prediction, IPTW involves multiplying weights which can result in high variance. The results in Figure 4 (b) show the problems with using IPTW to handle the time-dependent confounding bias. See Appendix K for more results on multi-step ahead prediction.

Balancing representation: To evaluate whether the CRN has indeed learnt treatment invariant represenations, for , we illustrate in Figure 5 the T-SNE embeddings of the balancing representations built by the CRN encoder for test patients. We color each point by the treatment received at timestep to highlight the invariance of across the different treatments. In Figure 5(b), we show only for chemotherapy and radiotherapy for better understanding.

(a)
(b)
Figure 5: TSNE embedding of the balancing representation learnt by the CRN encoder at different timesteps . Notice that is not predictive of the treatment given at timestep .

6.2 Evaluate recommending the right treatment and timing of treatment

Evaluating the models just in terms of the RMSE on counterfactual estimation is also not enough for assessing their reliability when used as part of decision support systems. In this section we assess how well the models can select the correct treatment and timing of treatment for several forecasting horizons . We generate test sets consisting of 1000 patients where for each horizon and for each time in a patient’s trajectory, there are options for giving chemotherapy at one of and options for giving radiotherapy at one of . At the rest of the future timesteps, no treatment is applied. These treatment plans are assessed in terms of the tumour volume outcome . We select the treatment (chemotherapy or radiotherapy) that achieves lowest , and within the correct treatment the timing with lowest . We also compute the normalized RMSE for predicting . See Appendix L for more details about the test set. The models are evaluated for 3 settings of and .

, , ,
CRN RMSN MSM CRN RMSN MSM CRN RMSN MSM
Normalized 3 2.43% 3.16% 6.75% 1.08% 1.35% 3.68% 1.54% 1.59% 3.23%
RMSE 4 2.83% 3.95% 7.65% 1.21% 1.81% 3.84% 1.81% 2.25% 3.52%
5 3.18% 4.37% 7.95% 1.33% 2.13% 3.91% 2.03% 2.71% 3.63%
6 3.51% 5.61% 8.19% 1.42% 2.41% 3.97% 2.23% 2.73% 3.71%
7 3.93% 6.21% 8.52% 1.53% 2.43% 4.04% 2.43% 2.88% 3.79%
Treatment 3 83.1% 75.3% 73.9% 83.2% 78.6% 77.1% 92.9% 87.3% 74.9%
Accuracy 4 82.5% 74.1% 68.5% 81.3% 77.7% 73.9% 85.7% 83.8% 74.1%
5 73.5% 72.7% 63.2% 78.3% 77.2% 72.3% 83.8% 82.1% 72.8%
6 69.4% 66.7% 62.7% 79.5% 76.3% 71.8% 78.6% 69.7% 64.5%
7 71.2% 68.8% 62.4% 72.7% 71.8% 71.6% 71.9% 69.3% 61.2%
Treatment 3 79.6% 78.1% 67.6% 80.5% 76.8% 77.5% 79.8% 75.7% 60.6%
Timing 4 73.9% 70.3% 63.1% 79.0% 77.2% 73.4% 75.4% 71.4% 58.2%
Accuracy 5 69.8% 68.6% 62.4% 78.3% 73.3% 63.6% 66.9% 31.3% 29.5%
6 66.9% 66.2% 62.6% 73.5% 72.1% 63.9% 65.8% 24.2% 15.5%
7 64.5% 63.6% 62.2% 70.6% 57.4% 44.2% 63.9% 25.6% 12.5%
Table 1: Results for recommending the correct treatment and timing of treatment.

Table 1 shows the results for this evaluation set-up. The treatment accuracy denotes the percentage of patients for which the correct treatment was selected, while the treatment timing accuracy is the percentage for which the correct timing was selected. Note that when and , RMSN and MSM select the wrong treatment timing for projection horizons . CRN performs similarly among the different policies present in the observational data and achieve the lowest RMSE and highest accuracy in selecting the correct treatment and timing of treatment.

In Appendix M we also show the applicability of the CRN in more complex medical scenarios involving real data. We provide experimental results based on the Medical Information Mart for Intensive Care (MIMIC III) database (johnson2016mimic) consisting of electronic health records from patients in the ICU.

7 Conclusion

Despite its wide applicability, the problem of causal inference for time-dependent treatments has been relatively less studied compared to problem of causal inference in the static setting. Both new methods and theory are necessary to be able to harness the full potential of observational data for learning individualized effects of complex treatment scenarios. Further work in this direction is needed for proposing alternative methods for handling time-dependent confounders, for modelling combinations of treatments assigned over time or for estimating the individualized effects of time-dependent treatments with associated dosage.

In this paper, we introduced the Counterfactual Recurrent Network (CRN), a model that estimates individualized effects of treatments over time using a novel way of handling the bias from time-dependent confounders through adversarial training. Using a model of tumour growth, we validated CRN in realistic medical scenarios and we showed improvements over existing state-of-the-art methods. We also showed the applicability of the CRN a real dataset consiting of patient electronic health records. The counterfactual predictions of CRN have the potential to be used as part of clinical decision support systems to address relevant medical challenges involving selecting the best treatments for patients over time, identify optimal treatment timings but also when the treatment is no longer needed. In future work, we will aim to build better balancing representations and to provide theoretical guarantees for the expected error on the counterfactuals.

Acknowledgments

We would like to thank the reviewers for their valuable feedback. The research presented in this paper was supported by The Alan Turing Institute, under the EPSRC grant EP/N510129/1 and by the US Office of Naval Research (ONR).

References

Appendix A Extended related work

Causal inference in the static setting: A large number of methods have been proposed to learn treatment effects from observational data in the static setting. In this case, it is needed to adjust for the selection bias; bias caused by the fact that, in the observational dataset, the treatment assignments depend on the patient features. Several ways of handling the selection bias involve using propensity matching (austin2011introduction; imai2014covariate; abadie2016matching), building representations where treated and un-treated populations had similar distributions (johansson2016learning; shalit2017estimating; li2017matching; yao2018representation) or performing propensity-aware hyperparameter tuning (alaa2017bayesian; alaa2018limits). However, these methods for the static setting cannot be extended directly to time-varying treatments (hernan2000marginal; schisterman2009overadjustment).

Learning optimal policies: A related problem to ours involves learning the optimal treatment policies from logged data (swaminathan2015batch; swaminathan2015self; atan2018learning). That is, learning the treatment option that would give the best reward. Note the difference to the causal inference setting considered in this paper, where the aim is to learn the counterfactual patient outcomes under all possible treatment options. Learning all of the counterfactual outcomes is a harder problem and can also be used for finding the optimal treatment.

A method for learning optimal policies, proposed by atan2018learning uses domain adversarial training to build a representation that is invariant to the following two domains: observational data and simulated randomized clinical trial data, where the treatments have equal probabilities. atan2018learning only considers the static setting and aims to choose the optimal treatment instead of estimating all of the counterfactual outcomes. In our paper the aim is to eliminate the bias from the time-dependent confounders and reliably estimate all of the potential outcomes; thus, at each timestep we build a representation that is invariant to the treatment.

Off-policy evaluation in reinforcement learning:

In reinforcement learning, a similar problem to ours is off-policy evaluation, which uses retrospective observational data, also known as logged bandit feedback

(hoiles2016non; puaduraru2013empirical; doroudi2017importance). In this case, the retrospective observational data consists of sequences of states, actions and rewards which were generated by an agent operating under an unknown policy. The off-policy evaluation methods aim to use this data to estimate the expected reward of a target policy. These methods use algorithms based on importance sampling (precup2000eligibility; thomas2015high; guo2017using), action-value function approximation (model based) (hallak2015off) or doubly robust combination of both approaches (jiang2015doubly). Nevertheless, these methods focus on obtaining average rewards of policies, while in our case the aim is to estimate individualized patient outcomes for future treatments.

Appendix B Assumptions

The standard assumptions needed for identifying the treatment effects are (robins2008estimation; lim2018forecasting; schulam2017reliable):

Assumption 1: Consistency. If for a given patient, then the potential outcome for treatment is the same as the observed (factual) outcome: .

Assumption 2: Positivity (Overlap) (imai2004causal): If then for all .

Assumption 3: Sequential strong ignorability.

Assumption 2 means that, for each timestep, each treatment has non-zero probability of being assigned. Assumption 3 means that there are no hidden confounders, that is, all of covariates affecting both the treatment assignment and the outcomes are present in the the observational dataset. Note that while assumption 3 is standard across all methods for estimating treatment effects, it is not testable in practice (robins2000marginal; pearl2009causal).

Appendix C Time-dependent Confounding

Figure 6 illustrates the causal graphs for a time-varying exposures with 2-steps (robins2000marginal). In Figure 6 (a), the covariate is a time-dependent confounder because it affects the treatment assignments and at the same time, its value is changed by past treatments (mansournia2017handling), as illustrated by the red arrows. Thus, the treatment probabilities at each time depend on the history of covariate and past treatments. Note that and are hidden variables which only affect the covariates, i.e. they do not have arrows into the treatments. Thus, the no hidden confounders assumption (Assumption 3) is satisfied.

Figure 6 (a) and (b) illustrate the two cases when there is no bias from time-dependent confounding. In Figure 6 (a) the treatment probabilities are independent, while in Figure 6 (b) they depend on past treatments.

Figure 6: Causal graphs for 2-step time-varying exposures (robins2000marginal). are patient covariates, are treatments, , are unobserved variable and is the outcome.

Marginal Structural Models robins2000marginal. To remove the association between time-dependent confounders and time-varying treatments, Marginal Structural Models propose using inverse probability of treatment weighting (IPTW). Without loss of generality, consider the use of MSMs with univariate treatments, baseline variables and outcomes. The outcome after timesteps is parametrized as follows: , where is usually a linear function with parameters . To remove the bias from the time-dependent confounders present in the observational dataset, in the regression model MSMs weights each patients using either stabilized weights:

(8)

or unstabilized weights:

(9)

where represents the conditional probability mass function for discrete treatments.

Inverse probability of treatment weighting (IPTW) creates a pseudo-population where each member consists of themselves and (or ) copies added though weighting. In this pseudo-population, Robins robins1999association shows that does not predict treatment , thus removing the bias from time-dependent confounders.

When using unstabilized weights , the causal graph in the pseudo-population is the one in Figure 6 (a) where . On the other hand, when using stabilized weights , causal graph in the pseudo-population is the one in Figure 6 (b) where .

Counterfactual Recurrent Networks. Instead of using IPTW, we proposed building a representation of that is not predictive of treatment . At timestep , we have different possible treatments . We build a representation of the history and covariates and treatments that has the same distribution across the different possible treatments: . By breaking the association between past exposure and current treatments , we satisfy the causal graph in Figure 6 (a) and thus we remove the bias from time-dependent confounders.

Appendix D Proof of Theorem 1

We first prove the following proposition.

Proposition 1.

For fixed , let . Then the optimal prediction probabilities of are given by

(10)
Proof.

For fixed , the optimal prediction probabilities are given by

(11)

Maximising the value function pointwise and applying Lagrange multiplies, we get

(12)

Setting the derivative (w.r.t. ) to and solving for we get

(13)

where can now be solved for using the constraint to be . This gives the result. ∎

Proof.

(of Theorem 1) By substituting the expression from Proposition 1 into the minimax game defined in Eq. 6, the objective for becomes

(14)

We then note that

(15)
(16)
(17)
(18)

where

is the Kullback-Leibler divergence and

is the multi-distribution Jensen-Shannon Divergence (li2018deep). Since is a constant and the multi-distribution JSD is non-negative and if and only if all distributions are equal, we have that . ∎

Appendix E Training procedure for CRN

Let be an observational dataset consisting of information about independent patients that we use to train CRN. The encoder and decoder networks part of CRN are trained into two separate steps.

To begin with, the encoder is trained to built treatment invariant representations of the patient history and to perform one-step ahead prediction. After the encoder is optimized, we use it to compute the balancing representation for each timestep in the trajectory of patient . To train the decoder, we modify the training dataset as follows. For each patient , we split their trajectory into shorter sequences of the timesteps of the form:

(19)

for . Thus, each patients contributes with examples in the dataset for training the decoder. The different sequences obtained for all patents are randomly grouped into minibatches and used for training.

The pseudocode in Algorithm 1

shows the training procedure used for the encoder and decoder networks part of CRN. The model was implemented in TensorFlow and trained on an NVIDIA Tesla K80 GPU. The Adam optimizer

(kingma2014adam)

was used for training and both the encoder and the decoder are trained for 100 epochs.

  Input: Training data:
  
  (1) Encoder optimization: parameters .
  Learning rate:
  for  do
     
     for Batch in epoch do
        Compute
        Compute
        
        
        
     end for
  end for
  
  (2) Compute the encoder balanced representation and use it to initialize the decoder hidden state.
  for  do
     for  do
        
     end for
  end for
  
  (3) Split dataset in sequences of timesteps:
  
  (4) Optimize decoder: parameters
  Learning rate:
  for p = 1, …, max epochs do
     
     for Batch in epoch do
        Compute
        Compute
        
        
        
     end for
  end for
  
  Output: Trained CRN encoder (parameters ) and trained CRN decoder (parameters . )
Algorithm 1 Pseudo-code for training CRN

Appendix F Pharmacokinetic-Pharmacodynamic Model of tumour growth

To evaluate the CRN on counterfactual estimation, we need access to the data generation mechanism to build a test set that consists of patient outcomes under all possible treatment options. For this purpose, we use the state-of-the-art pharmacokinetic-pharmacodynamic (PK-PD) model of tumour growth proposed by geng2017prediction and also used by lim2018forecasting for evaluating RMSMs. The PK-PD model characterizes patients suffering from non-small cell lung cancer and models the evolution of their tumour under the combined effects of chemotherapy and radiotherapy. In addition, the model includes different distributions of tumour sizes based on the cancer stage at diagnosis.

Model of tumour growth The volume of tumour days after diagnosis is modelled as follows:

(20)

where the parameters are sampled from the prior distributions described in (geng2017prediction) and is a noise term that accounts for randomness in the tumour growth.

To incorporate heterogeneity among patient responses, due to, for instance, gender or genetic factors bartsch2007genetic, the prior means for and are adjusted to create three patient subgroups as described in lim2018forecasting. This way, we incorporate in the model of tumour growth specific characteristics that affect the patient’s individualized response to treatments. Thus, the prior mean of and the prior mean of are augmented as follows.

(21)

where and are the mean parameters from geng2017prediction and and are the parameters used in the data simulation. The patient subgroup is used as baseline features.

The chemotherapy drug concentration follows an exponential decay with half life of 1 day:

(22)

where of Vinblastine if chemotherapy is given at time . fractions of radiotherapy if the radiotherapy treatment is applied at timestep .

Time-varying confounding is introduced by modelling chemotherapy and radiotherapy assignment as Bernoulli random variables, with probabilities and depending on the tumour diameter:

(23)

where is the average tumour diameter over the last days, is the maximum tumour diameter and

is the sigmoid activation function. The parameters

and are set to such that there is probability of receiving treatment when tumour is half of its maximum size. control the amount of time-dependent confounding; the higher is, the more important the history of tumour diameter is in assigning treatments. Thus, at each timestep, there are four treatment options options: no treatment (), chemotherpy (), radiotherapy (), combined chemotherapy and radiotherapy ().

Since the work most relevant to ours is the one of lim2018forecasting we used the same data simulation and same settings for as in their case. When , there is no time-dependent confounding and the treatments are randomly assigned. By increasing we increase the influence of the volume size history (encoded in ) on the treatment probability. For example, assume . From equation (7), the probability of chemotherapy in this case is , where

is the sigmoid function. When

, , when , and when , in this example. can be increased further to increase the bias. However, the values used in the experiments evaluate the model on a wide range of settings for the time-dependent confounding bias.

Appendix G Marginal Structural Models

Marginal Structural Models (robins2000marginal; hernan2001marginal) have been widely used in epidemiology and as part of follow up studies. In our case, we would like to estimate the effects of a sequence of treatments in the future given the current patient history: