Function Driven Diffusion for Personalized Counterfactual Inference

10/31/2016 ∙ by Alexander Cloninger, et al. ∙ Yale University 0

We consider the problem of constructing diffusion operators high dimensional data X to address counterfactual functions F, such as individualized treatment effectiveness. We propose and construct a new diffusion metric K_F that captures both the local geometry of X and the directions of variance of F. The resulting diffusion metric is then used to define a localized filtration of F and answer counterfactual questions pointwise, particularly in situations such as drug trials where an individual patient's outcomes cannot be studied long term both taking and not taking a medication. We validate the model on synthetic and real world clinical trials, and create individualized notions of benefit from treatment.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

We address the problem of building a metric on a high dimensional data set that is smooth with respect to an external nonlinear function . They types of functions we consider arise from counterfactual questions, such as “would this patient benefit or be hurt from medication, given their history and baseline health?”. In medical studies, these functions adopt the interesting feature that they cannot be evaluated pointwise, since an individual patient’s outcomes cannot be studied long term both taking and not taking a medication.

1.1 Mathematical Approach

In the case of an treatment study, we denote the risk on treatment as and treatment as . The quantity is known as the individual treatment effectiveness. The treatment group is denoted if is in treatment , otherwise. A current method of dealing with treatment effectiveness is the Cox proportional hazard model from [4]. In it, we let be the common baseline hazard function which describes the risk of an outcome at each time step independent of treatment. Within treatment groups, the hazard function for the Cox proportional hazard model takes the form

where is an indicator function for which treatment patient was in. This makes the associated survival distribution

Also assume there is a random censoring model, which means that people leave the trial at random times throughout the process. This means we don’t observe the true outcome time of each patient, but instead we observe the leave time , where and are independent and represents time to censorship. The indicator function of whether is known, as well.

A patient personalized version of this model would be

which now allows the benefit or detriment of the drug to be patient specific.

In this model of personalized risk, and are unknowable pointwise in a drug trial since each only takes one of the drugs in

. So we estimate

in a neighborhood by assuming that locally, patients satisfy a proportional hazard model, with

for some metric , which we discuss further in Section 3. Thus we can assume that everyone in the neighborhood shares a common baseline risk, and the relative benefit of treatment is a constant multiple of that risk. This allows us to run a cox proportional hazard model on by fitting to

Estimate of for each neighborhood can be done in several ways. If we only observe (i.e. whether or not the patient had an outcome before leaving the trial), then

is estimated through method of moments between the two treatment groups. If we observe the actual outcome time

along with , is estimated through partial likelihood maximization. As a note, while partial likelihood maximization uses more information and thus should result in a better estimate, convergence guarantees are more difficult to derive. We present certain guarantees for both approaches in Section 4.

This means reflects the amount a patient is positively or negatively affected by a drug, and can be used to approximate . The problem turns into a metric discovery problem of determining a metric that learns the level sets of . This is akin to finding pockets of people, based only on baseline information , that are at much higher risk (or lower risk) on drug than they are on drug . Discovery of the metric then allows for an analysis of “types” of responders and non-responders.

We view these types of functions as functions that can only be evaluated on large subsets of the data. In other words, for is only computable when . We build an algorithm which constructs a metric such that

for a small constant , where . In other words, the metric does not only consider the geometry of the space , but also the geometry and the properties of the function being studied.

The purpose of computing is two-fold:

  1. this discovers the intrinsic organization of which dictates changes in . This makes any subsequent clustering or analysis done using reflect the level sets of , as well as the intrinsic structure of . The reason for doing this is that may not be smooth with respect to the intrinsic geometry of the space, but has structure that is well described by a subset of the features. Also,

  2. this allows for simple estimation of at a finer scale than it is reliable naively. Using , we are able to construct an estimate of , which we call , which can be evaluated pointwise via

    One can also define a multi-scale decomposition of via

    The key in both these approximations is that, provided is smooth with respect to , the approximating neighborhood will have a large radius about level sets of . This increases the number of points in for a fixed , making the approximations more accurate than those generated by taking an isotropic ball of radius about .

1.2 Main Contributions

The study of individualized treatment effects has recently been considered with linear lasso models of [12], and linear logistic models with AdaBoost of [8]. A number of models have been built to predict outcomes from a single treatment, but high risk for an outcome does not necessarily imply treatment benefit, as seen in [7, 6]. While these models provide useful treatment recommendations, they project to a one dimensional function space and interpretability is limited to the non-zero coefficients of the model.

While we are interested in determining a treatment recommendation, we are also interested in the question of characterizing the level sets of a treatment effect. Diffusion embeddings provide a non-linear framework to map out the data into a continuum of varying treatment effectiveness. Using a diffusion metric, one can determine variability of types of patients that similarly benefit from treatment (or lack of treatment).

Function regularized diffusion has been considered when can be evaluated pointwise by [14]. We have also previously examined building non-linear features of functions that cannot be evaluated pointwise and subsequently organize these features in [2], as well as regression of non-linear Cox proportional hazard functions in [9].

The main contributions of this work are:

  • the ability to build a function regularized diffusion metric without the ability to evaluate pointwise,

  • interpolation of a function regularized diffusion metric to new points where is unknown, and

  • the use of function regularized diffusion to define pockets of responders and non-responders to a given treatment.

This paper is organized as follows. Section 2

gives background descriptions of the tools we reference throughout the paper, including diffusion maps, and hierarchical cluster treesSection

3 details the function weighted trees used to generate , as well as the introduces the notion of estimating a data point’s personalized function estimate. Section 4 discusses the guarantees that can be given for personalized treatment effect, as well as convergence rates. Section 5 applies and validates our algorithm on several datasets of synthetic patients, and discovers the original ground truth metric. We also examine the algorithm on real world patient data and examine validation schemes.

2 Background

In this section, we discuss previous research that considers organization of points. This considers as a data matrix of points and features per point. Denote the rows of by (the set of points), and the set of columns by (the set of features or questions). For this section, there is no external function being considered.

2.1 Diffusion Geometry

Diffusion maps is a manifold learning technique based on solving the heat equation on a data graph, as in [3]

. It has been used successfully in a number of signal processing, machine learning, and data organization applications. We will briefly review the diffusion maps construction.

Let be a high dimensional dataset with . A data graph is constructed with each point as a node and edges between two nodes with weights

. The affinity matrix

is required to be symmetric and non-negative. Common choices of kernel are the gaussian

or positive correlation

can be computed using only nearest neighbors of such that .

Let . We normalize kernel

to create a Markov chain probability transition matrix

The eigendecomposition of yields a sequence of eigenpairs such that

The diffusion distance measures the distance between two points as the probability of points transitioning to a common neighborhood in some time . This gives

Retaining only the first eigenvectors creates an embedding such that

Figure 1

shows a two dimensional example dataset and the data graph generated on the points. We also see the low frequency eigenfunctions on the data graph, and the diffusion embedding


(a) Data Graph colored by x-coordinate (b) Data graph colored by
(c) Data graph colored by (d) colored by x-coordinate
Figure 1: Toy example to demonstrate relationship between geometry of dataset and eigenfunctions of graph. Dataset is only 2D for ease of visualization, algorithm is equally valid on high-dimensional dataset.

Remark: The diffusion time is a continuous variable, which can be thought of as the degree to which is a low-pass filter. For small , more of the high-freqncy eigenfunctions are given non-trivial weight. For large , the embedding is mostly concentrated on the low-frequency eigenfunctions that vary slowly across the data.

2.2 Hierarchical Tree From Diffusion Distance

The main idea behind bigeometric organization is to construct a coupled geometry via a partition tree on both the data points and the features. A partition tree is effectively a set of increasingly refined partitions, in which finer child partitions (i.e. lower levels of the tree) are splits of the parent folder which attempt to minimize the inter-folder variability.

Let be a dataset of points, and be a diffusion embedding with corresponding diffusion distance . A partition tree on is a sequence of tree levels , . Each level consists of disjoint sets such that

Also, we define subfolders (or children) of a set to be the indices such that

For notation, and . See Figure 2 for a visual breakdown of .

This tree can be in two ways:

  1. Top-down: Taking the embedded points , the initial split

    divides the data into 2 (or k) clusters via k-means or some clustering algorithm. Each subsequent folder is then split into 2 (or k) clusters in a similar way, until each folder contains a singleton point.

  2. Bottom-up: Taking the embedded points , the bottom folders are determined by choosing a fixed radius and covering with balls of radius . Each subsequent level of the tree is then generated as combinations of the children nodes that are “closest” together under the distance .

Figure 2: Breakdown of into folders.

Remark: It is important to note that whether one chooses a top-down or bottom-up approach, the fact that the clustering occurs on the diffusion embedding makes the resulting tree, by definition, a “bottom-up” geometry. This is because the embedding and diffusion distance is built off of local similarities alone, meaning that the resultant geometry and partition tree are based on properties of the underlying dataset and manifold rather than the ambient dimension and a naive Euclidean distance in .

3 Weighted and Directional Trees Without Pointwise Function Evaluation

Let us denote our data space . In its most general form, we have an external function which cannot be evaluated pointwise. can only be evaluated on large subsets . We define the pointwise estimate of to be , as in (2). This is done by defining a type of locally weighted distance to incorporate estimates of a feature’s power to discriminate in different half spaces. The details on this method are in Section 3.1. An overview of the approach is in Algorithm 1.

Required: Training points
Function to be evaluated on
Result: such that range is large, where
  1. Build a diffusion embedding of the points and a hierarchical tree

  2. Build a tree that determines the local coordinate feature weights (see Section 3.1)

  3. Build a new diffusion embedding of the points and a hierarchical tree based on the kernel in (2)

  4. Iterate between the points and the features until embedding and tree are stable

  5. Define pointwise neighborhood and function estimate

Algorithm 1 Calculate Function Weighted Metric

3.1 Weighted Trees

Let be the data matrix and be the integral operator of interest. Denote the rows of by (the set of points), and the set of columns by (the set of features or questions). We wish to build feature weights on each folder of that maximally separate . The algorithm is as follows:

  1. Assume the tree is known and separates into hierarchical nodes. Fix a node .

  2. For each element , we split into intervals and bin the elements of such that

    Question is then assigned a local weight for its ability to discriminate by


    where is the weighted mean across all bins.

  3. Now that every node of the tree has local feature weights, we calculate the local weights at a point by


    These weights create a diagonal matrix where for a small positive constant .

  4. The kernel function is then


    The normalization in the denominator is needed to guarantee is positive semi-definite.

Theorem 1.

The kernel from (2) is positive semi-definite.

The proof of Theorem 1 is in Appendix A.

Because is positive semi-definite, we can compute the embedding of the data , and induce a new diffusion metric on the data,

This, in turn, allows us to define as an estimate to , where

3.2 Interpolation and Leave Out Validation

The metric and function estimate can easily be extended to new points not in the training data. This is done by building an asymmetric affinity matrix to the training data, which can be thought of as a reference set. The approach is an application of [10], which we briefly outline here.

Let be training data, and defined on subsets of . Let be testing points on which is not defined a priori. Define to be

With the normalization matrices and , we set

and take the eigendecomposition of the small matrix . This gives an embedding of the reference points . Then eigendecompsition of the entire set of points is estimated by . The details of the extension algorithm can be found in Algorithm 2.

Required: Training points
Function to be evaluated on
Testing points
Result: Pointwise function estimate on testing points
  1. Build stable function weighted diffusion embedding of the features and a hierarchical tree via Algorithm 1 using the training points

  2. Build a function weighted diffusion embedding using the weighted embedding via reference set algorithm (see Section 3.2)

  3. For each testing point , define the pointwise training neighborhood and function estimate

Algorithm 2 Nearest Neighbor Function Estimation

Algorithm 2 can be thought of as generating an optimized metric for a k-nearest neighbor search. There could be better mechanisms of classification and regression for predicting

, ranging from support vector machines in


, to various types of linear regression, such as Elastic Net from

[16]. These choices are application and function specific, which is why we remain with a simple nearest neighbor interpolation. The key is that the metric agrees with the intrinsic geometry of the data.

It is also important to run leave out validation of the algorithm to insure no overfitting of the data. As the algorithm is semi-supervised and weights variables according to their discriminatory power, it is possible to give high weight to features which are spuriously correlated with the function. This makes it crucial to run N-fold cross validation of the data to ensure that the predicted are good estimates to the true function.

4 Localized Hazard Ratio Estimation

The aim of the construction of the metric from Section 3 is to construct a metric that differentiates level sets of the treatment risk. This implies that, in a neighborhood , we can assume that . This implies we can estimate the local treatment effect from the model


for . This approximation is because does not vary more than in .

Now assume we fit the false model to simply estimate treatment effectiveness, which is necessary given no knowledge of the model assumptions for locally. What we can assume is that, given , cannot vary too much within a small neighborhood.

There are two regimes in which we study this question of estimating . In either situation, we have a censoring model and observe an outcome only if . When that’s the case, we denote , with otherwise.

If we only observe (i.e. whether or not the patient had an outcome before leaving the trial), then is estimated through method of moments between the two treatment groups. If we observe the actual outcome time along with , is estimated through partial likelihood maximization. We provide results for both, with the stronger and more concrete results coming for observation only of the binary outcome variable .

In both settings, we assume that the patient risk over time is dictated by (4). We also assume that there is some censoring model

for each patient, which is a random variable independent of (

4) that dictates when a patient decides to leave the trial, if they are still alive.

4.1 Binary Observation of Outcome with Censoring

In this scenario, we only observe for each patient. This means we know whether they had an outcome before they left the trial, but not the time at which the outcome occurred.

To create an estimate , we use a method of moments approach. That is, within the neighborhood , we look at the empirical estimate


We borrow and modify results from [5] about small variation of misspecified models. For notation, let

Theorem 2.

Let the survival model satisfy (4), and the for . Assume we use a method of moments estimation of the misspecified model

Let us further assume we only observe an indicator of outcome .

Then the method of moments estimate converges at a rate of to for some finite constant that depends on , , , the non-treated risk model, and the censoring model. The estimate converges to , which satisfies

Moreover, if is well approximated locally by its first order Taylor expansion for small , then we can reduce the term to further show



is a constant that depends only on the size of and the censoring model. Note also that this implies , , and .

The proof of Theorem 2 is in Appendix B.

4.2 Continuous Time to Outcome with Censoring

In this scenario, we observe for each patient, as well as the actual outcome and/or censoring time . This means we know whether they had an outcome before they left the trial, as well as the time that the outcome occurred. That time is an additional source of information, given that we can now attempt to partially order all patients that had an outcome, and ensure that people who were censored at time are estimated to live at least that long (if not longer).

To create an estimate , we use partial likelihood maximization. That is, we construct the log likelihood function


where , and is the argument of the exponential evaluated for patient . In the case of the misspecified model, the argument used is , and the true model is .

We borrow results from [5] about small variation of misspecified models, and [15] and [13] about convergence rates. For notation, let

Theorem 3 ([5]; restated).

Let the survival model satisfy (4), and the for . Assume we use a partial likelihood estimation of the misspecified model


Let us further assume we observe both an indicator of the outcome and an outcome/censoring time . Also, let be the final time at which patients are observed (i.e. our censoring model censors anyone that lives past time ).

Then the partial likelihood estimate converges at a rate of to for some finite constant that depends on , , , the non-treated risk model, and the censoring model, as shown by [15] and [13]. The estimate converges to , which satisfies

as shown by [5].

Moreover, if is well approximated locally by its first order Taylor expansion for small , then we can reduce the log term further to show

where is the derivative of the log likelihood (6) with respect to .


The only part of Theorem 3 that is not restated from [5] is the convergence rate. [15] shows that, given a correct model for as in (4), the partial likelihood maximization estimate satisfies

where convergence is in distribution. Furthermore, [13] shows that the same rate of convergence applies to a misspecified model (7), with the exception that converges to an a priori unknown value . The rest of the proof focuses on characterizing in terms of known quantities, as done by [5]. ∎

5 Examples

5.1 Synthetic Randomized Drug Trial

We create a model of synthetic patients in a drug trial. The patient baseline model consists of 9 dimensions of correlated information, where

We note that this model choice is arbitrary, and was solely chosen to model a dependence between patient features. The patients are randomly split into treatment A and treatment B.

The baseline hazard function for the patients is a Weibull distribution of the type

for and . If a patient is in treatment A, their outcome time is sampled from the Weibull distribution. If a patient is in treatment B, their outcome time is sampled from the Weibull distribution and then adjusted by . Any patient is censored if for a fixed .

The key behind this model is that is patient specific, and depends only on a subset of the patient’s baseline information. Specifically,

where is displayed in Figure 3. Setting and , about of patients have an outcome.

Figure 3:

A linear Cox proportional hazard model, by definition, is unable to recover the full spread of . But beyond that, in this case a linear model fails to even recover the treatment group as a significant factor in risk. See Table 1 for the regression coefficients.

Variable Name Treatment
Coefficeint 0.0597 -1.8517 -1.6500 -1.8440 -0.1797 -0.3245 -0.1619 -0.1127 0.0566 0.0954
p-value 0.3155 0.0001 0.0001 0.0001 0.3766 0.1087 0.4295 0.5826 0.7854 0.6450
Table 1:

Our algorithm on the other hand is able to recover both the geometry of the patients and an estimate of the personalized hazard ratio . Figure 4 shows the recovered embedding of the patients, and is colored by the estimate of .

(a) Predicted HR (b) Ground Truth HR
Figure 4:

The model works for predictive personalized hazards on new testing data, as well. We run repeated random sub-sampling validation on the toy data by retaining of the patients for training, and testing on the remaining . This was iterated times. The results are shown in Figure 5.

Figure 5:

5.2 Models with Treatment Propensity

Let , with

The baseline hazard function is the same as in Section 5.1, and the personal hazard ratio is

However, unlike in the previous examples, is not randomized across the population. Instead,

where , , and .

Fraction Treated in Neighborhood Ground Truth Estimated
Figure 6: Black corresponds to points where of the points were from the same treatment group, and thus removed due to lack of estimate precision.

We also consider a random model, in which is a random symmetric positive definite matrix with condition number less than 10, and the personal hazard ratio follows the form

where and are sparse standard normal random variables which are non-zero with probability . Also, (resp. ) are standard normal random variables which are non-zero if and only if and (resp. and ) are non-zero. Also, the probability that a patient is treated is determined by

We run this model across 100 iterations, where we generate patients who’s baseline hazard function is drawn from a Weibull distribution as in previous examples. The patients are then censored such that fraction of the patients have an outcome, where is a uniform random variable drawn from . We calculate the correlation between the predicted personalized treatment effect and the ground truth treatment effect . Because of the propensity for treatment, we only estimate in neighborhoods such that of the patients are in the same treatment group. The histogram of correlations is in Figure 7.

Figure 7: Histogram of Correlations between predicted personalized treatment effect and ground truth across 100 iterations.

5.3 Real World Data

We examine our algorithm on breast cancer data from the Rotterdam Tumor Bank . The Rotterdam tumor bank dataset contains records for 1,546 patients with node-positive breast cancer, and nearly 90 percent of the patients have an observed outcome. Because this data has no ground truth, we must use leave out cross-validation to validate the recommendations. We train the model on a random of the patients, and test on , and we iterate this process 100 times. We then split the testing data into three groups, where

is the standard deviation of treatment recommendations:

  • Recommended: People with recommendation such that and or and . These are people whose actions followed the recommendation.

  • Neutral: People with recommendation . These are people without a strong recommendation.

  • Anti-Recommended: People with recommendation such that and or and . These are people whose actions went against the recommendation.

We then plot the survival curves of the Recommended and Anti-Recommended groups in Figure 8. Again, these were all testing samples in order to avoid over-fitting to the outcomes. The group of testing data patients that followed the recommendations lived significantly longer than those that did not follow the recommendation, with a p-value of for the log-rank test of whether these curves are significantly different.

Treatment survival curves
Recommendation survival curves
Figure 8: Treatment recommendation on Rotterdam breast cancer testing data for .

6 Conclusions and Future Work

This paper develops a method for building a data dependent metric that is simultaneously learns the level sets of a function . The method only needs to evaluate on various half spaces of the data, making it useful when cannot be evaluated pointwise. We develop a weighted tree distance to accomplish this, and develop feature weights at multiple scales and locations in the data. Once has been discovered, we can do k nearest neighbor prediction for new points added to the data without knowledge of at the point.

This algorithm was designed with medical applications in mind, specifically building a local cox proportional hazard model for patients in a dataset. The embedding created by can be used to characterize types of people that are hurt or helped by a drug, and even assign a personalized treatment hazard score to new patients whose outcomes are unknown.


The author would like to thank Raphy Coifman and Jonathan Bates for many discussions about the problem, and Harlan Krumholz, and Shu-Xia Li for introducing the issues associated with treatment effectiveness and Cox proportional hazard models. The author is supported by NSF Award No. DMS-1402254.


  • [1] Bernhard E. Boser, Isabelle M. Guyon, and Vladimir N. Vapnik.

    A training algorithm for optimal margin classifiers.


    Proceedings of the Fifth Annual Workshop on Computational Learning Theory

    , COLT ’92, pages 144–152, New York, NY, USA, 1992. ACM.
  • [2] Alexander Cloninger, Ronald R Coifman, Nicholas Downing, and Harlan M Krumholz. Bigeometric organization of deep nets. Applied and Computational Harmonic Analysis, 2016.
  • [3] Ronald R Coifman and Stéphane Lafon. Diffusion maps. Applied and computational harmonic analysis, 21(1):5–30, 2006.
  • [4] D.R. Cox. Regression models and life-tables. Journal of the Royal Statistical Society. Series B, 1972.
  • [5] Mitchell H Gail, S Wieand, and Steven Piantadosi. Biased estimates of treatment effect in randomized experiments with nonlinear regressions and omitted covariates. Biometrika, pages 431–444, 1984.
  • [6] Holly Janes, Marshall D Brown, Ying Huang, and Margaret S Pepe. An approach to evaluating and comparing biomarkers for patient treatment selection. The international journal of biostatistics, 10(1):99–121, 2014.
  • [7] Holly Janes, Margaret S Pepe, Patrick M Bossuyt, and William E Barlow. Measuring the performance of markers for guiding treatment decisions. Annals of internal medicine, 154(4):253–259, 2011.
  • [8] Chaeryon Kang, Holly Janes, and Ying Huang. Combining biomarkers to optimize patient treatment recommendations. Biometrics, 70(3):695–707, 2014.
  • [9] Jared Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. Deep survival: A deep cox proportional hazards network., 2016.
  • [10] Dan Kushnir, Ali Haddad, and Ronald R Coifman. Anisotropic diffusion on sub-manifolds with application to earth structure classification. Applied and Computational Harmonic Analysis, 32(2):280–294, 2012.
  • [11] Kaare Brandt Petersen, Michael Syskind Pedersen, et al. The matrix cookbook. Technical University of Denmark, 7:15, 2008.
  • [12] Min Qian and Susan A Murphy. Performance guarantees for individualized treatment rules. Annals of statistics, 39(2):1180, 2011.
  • [13] Cyntha A Struthers and John D Kalbfleisch. Misspecified proportional hazard models. Biometrika, 73(2):363–369, 1986.
  • [14] Arthur D Szlam, Mauro Maggioni, and Ronald R Coifman. Regularization on graphs with function-adapted diffusion processes. The Journal of Machine Learning Research, 9:1711–1739, 2008.
  • [15] Anastasios A Tsiatis. A large sample study of cox’s regression model. The Annals of Statistics, pages 93–108, 1981.
  • [16] Hui Zou and Trevor Hastie. Regularization and variable selection via the elastic net. J. R. Statist. Soc. B, 2005.

Appendix A Proof of Theorem 1

We show can be expressed as an integral over all ambient space via

We then use identities 8.1.7 and 8.1.8 from [11], which show the product of two gaussians gives

where and are combinations of , , and . Their exact forms are irrelevant, as the right hand term is simply a normalized guassian that can be integrated out with respect to . Thus, after evaluating the integral, we are left with

Now, for any we can compute

Appendix B Proof of Theorem 2

Let the survival model satisfy

and the for . Assume we use likelihood estimation of the misspecified model

Also, assume the trial is observed until time .

The rate of convergence of for the method of moments calculation (5

) is a simple application of the central limit theorem, as

is a Bernoulli random variable whose probability is a function of the number of samples , and the rate of outcomes prior to censoring, which is dictated by , , and the censoring model . The rest of the proof focuses on characterizing in terms of known quantities.

[5] show that the method of moments limit point satisfy

under the false model. Rearranging these equations and noting that

we arrive at the equations