Moment-Matching Graph-Networks for Causal Inference

07/20/2020 ∙ by Michael Park, et al. ∙ 0

In this note we explore a fully unsupervised deep-learning framework for simulating non-linear structural equation models from observational training data. The main contribution of this note is an architecture for applying moment-matching loss functions to the edges of a causal Bayesian graph, resulting in a generative conditional-moment-matching graph-neural-network. This framework thus enables automated sampling of latent space conditional probability distributions for various graphical interventions, and is capable of generating out-of-sample interventional probabilities that are often faithful to the ground truth distributions well beyond the range contained in the training set. These methods could in principle be used in conjunction with any existing autoencoder that produces a latent space representation containing causal graph structures.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Recently there have been many efforts to imbue deep-learning models with the ability to perform causal inference. This has been motivated primarily by the inability of traditional correlative models to make predictions on interventional and counterfactual questions Spirtes et al. (2000); Pearl (2000), as well as the explainability of causal graphical models. These efforts have largely run in parallel to the developing trend of exploiting the non-local properties of graph neural networks Wang et al. (2017)

to generate powerful and efficient representations of high-dimensional data.

In this note we dichotomize the task of causal inference as a two-step process, illustrated in Figure 1. The first step involves inferring the graphical structure of a causal model associated with a given observational data set as a directed acyclic graph (DAG). Inferring the structure of causal DAG’s from observational data has a long history and there have been many proposed techniques including constraint-based Spirtes et al. (2000); Pearl (2000); Zhang (2008); Meek (1995) and score-based methods Bouckaert (1993); Chickering (2002); Chickering and Heckerman (2013); Heckerman et al. (1995), recently developed masked-gradient methods Zheng et al. (2018, 2019); Yu et al. (2019); Ng et al. (2019a, b); Fang et al. (2020); Ng et al. (2020), as well as hybrid methods Lachapelle et al. (2019)

. Notable novel alternatives also include methods based on reinforcement-learning

Zhu and Chen (2019), adversarial networks Kalainathan et al. (2018)

and restricted Boltzmann machines

Sokolovska et al. (2020). Since the task of causal structural discovery is merely a means to an end for this work, we (rather arbitrarily) adopt the masked-gradient approach due to its parsimonious integration with the neural network based architectures for SEM-learning that are the subject of this note.111codebase:

Figure 1:

The causal inference steps in this note begin with existing DAG structure-learning algorithms to infer causal structures in latent representations of data. Using the learned DAG, neural-networks are used to estimate the response of conditional probabilities under various graphical interventions.

For the second step of causal inference, we develop a novel autoencoding architecture that applies generative moment-matching neural-networks Zhao et al. (2017); Ren et al. (2016) to the edges of the learned causal graph, in order to estimate the functional dependence of the causally related observables as a structural equation model (SEM). Since their inception, generative moment-matching networks have been used for various tasks Bouchacourt (2017); Gao and Huang (2018); Briol et al. (2019); Lotfollahi et al. (2019) related to the estimation of joint and conditional probability distributions, but to our knowledge this is the first use of their applications to an explicit causal graph structure. Our aim is to develop a fully unsupervised formalism that starts from purely observational tabular data, and ends with a robust automated sampling procedure that generates an accurate functional estimate of conditional probability distributions for the associated SEM. Existing techniques for Bayesian sampling on the latent space of generative models are also numerous, including Monte Carlo and gradient-optimization based methods Ahn et al. (2012); Hanson (2001); Park et al. (2018).

Much of this work has been inspired by several recent efforts to develop generative models that encode causal structure. For example, in Kocaoglu et al. (2017)

the authors develop specific conditional adversarial loss functions for learning multi-step causal relations. Their goals are similar to those described in this note with a focus on linear relations within high-dimensional image vectors. In

Yang et al. (2020)

the authors use supervised learning to endow the latent space distributions of a variational autoencoder with a causal graphical structure, with the aim of intervening on this latent space to control specific properties of their feature maps. In this note we perform experiments on simple low-dimensional feature maps, and examine the performance of our autoencoder in generating accurate conditional probability distributions from complex non-linear multi-step causal structures. These causal structures are assumed to exist as relations among dimensions in the latent representation of the data. Thus in principle, the methods described here should also be applicable to more complex feature maps such as those generated by image and language data. However experimentation on these high-dimensional data types are beyond the scope of this note.

In Section II we give a brief review of causal graphs and describe a vectorized formulation for structural equation models that is suited for deep-learning applications. In Section III we give the results of our experiments on causal structure learning using existing masked gradient methods. We then describe our algorithm for SEM-learning and provide results on its performance. In Section IV we conclude with a discussion on possible applications and future directions for this work.

Ii Background

ii.1 Causal Graphs

The identification of a causal effect between two variables is equivalent to measuring the response of some endogenous variable with respect to a controlled change in some exogenous variable . If all of the variables are controlled, then the causal effect can be directly inferred via the conditional probability distribution . Inferring causal effects from uncontrolled observational data is challenging due to the existence of confounding variables which generate spurious correlations whose effects on the conditional probability may be statistically indistinguishable from true causal effects. This is illustrated diagramatically in Figure 2. Here we adopt the formalism of Pearl in which the effect of a controlled change in variable is represented on a causal graph by mutilating all of the arrows going into node as shown in Figure 3. The result is referred to as the intervened222For notational simplicity we use slashes to indicate graph mutilated variables in conditional probabilities rather than Pearl’s original notation of conditional probability distribution

Figure 2: Integrating out a confounding common cause variable generates a spurious correlation via a correction to the conditional probability distribution .
Figure 3: Observing a controlled change to some variable requires removing the effects of any possible external influences. This is represented graphically by mutilating all in-going arrows into node .

There exists a rich literature describing the necessary and sufficient conditions for statistical distinguishability between causal and correlative effects, as well as methods for estimating causal responses when these conditions are met Spirtes et al. (2000); Pearl (2000). Although the necessary conditions are beyond the scope of this brief review, the sufficient conditions amount to a requirement that the subset of measured confounding variables must be sufficiently complete so as to provide adequate control over the causal effects. In particular, the requirement of sufficient completeness can be succinctly dichotomized into two cases known as the back-door and front-door criterion. The back-door criteria can be used to estimate the causal response on a pair of nodes , given an observation of a set of confounding variables as shown in Figure 4. The intervened conditional probability can then be computed via the back-door adjustment formula given in Equation 1.


The front-door criteria can be used to estimate the causal response on a pair of nodes in situations where there exists a chain of causal influences as shown in Figure 4. The intervened conditional probability can then be computed via the front-door adjustment formula given in Equation 2.

Figure 4: (Left) Given the sufficiently complete set of measured confounding variables , the back-door adjustment formula estimates the causal effect of on . A measurement of only the set would be insufficient due to the existence of an unblocked “back-door” path between the observables given by . (Right) If there exists a causal chain , the front-door adjustment formula can be used to disentangle the causal effect of on from any measured or unmeasured confounding variables.

ii.2 Structural Equation Models

Structural equation models (SEM’s) are a functional extension of causal graphical models in which the values of each node variable are determined as a function of its parent node variables and noise . Here we adopt a notation where each node in a causal graph with nodes is specified by a spacetime index and Einstein summation is assumed. The set of parent (child) nodes corresponding to is given by () as illustrated in Figure 5. The generic form for an SEM can then be expressed as shown in Equation 3

Figure 5: Given some node in a causal graph , we use to refer to the set of all nodes that are parents of node and to refer to the set of all nodes that are children of node .

If the contribution from noise is assumed to be additive, then each node variable can be expressed simply as a polynomial (or other) expansion in its parent nodes as shown in Equation 4. The leading order term in this expansion describes a linearized SEM, which is typically expressed in terms of a weighted graph adjacency matrix in the form shown in Equation 5.


The linear SEM of Equation 5 has the unique property that its exact solution describes a generative model that predicts each variable from pure noise as shown in Equation 6. The inverse operator can be expressed in closed-form as a degree- polynomial in terms of Cayley-Hamilton coefficients , which describe the propagation of ancestral noise through the causal graph. Thus each node variable can be expressed as a linear combination of its noise and the noise of its ancestors , as shown in Equation 7.


The weighted adjacency matrix serves the dual purpose of masking each node variable from its non-parent nodes through its zero-entries, while the non-zero entries define the strength of linear correlations between each pair of nodes in the causal graph. Unfortunately there is no standardized generalization to non-linear SEM’s. One natural possibility is to define a separate weighted adjacency matrix for each order in a functional expansion like the polynomial example in Equation 4. While this interpretation nicely generalizes the linear approximation, its computational complexity is unbounded, and there have been various other suggested interpretations for the adjacency matrix weights, related to the mutual information between parent-child node variables Fang et al. (2020).

In this note we develop an alternative formalism for describing non-linear SEM’s that is agnostic to the interpretation of the weights in the adjacency matrix. We thus define a causal mask matrix which is just the unweighted adjacency matrix as shown in Equation 8, where refers to an element-wise multiplication.


We then define a procedure for extracting the data for the parents of each node in the following way. We first lift each node variable into an auxiliary dimension . Index contraction of the spacetime index with the mask matrix then produces a vector for each node whose index in the auxiliary dimension contains its parent-node data as shown in Equation 9. This vectorized parental masking procedure is suitable for expressing functions of sets of parent-nodes in a generalized SEM as .


Iii Experiments

iii.1 Causal Structure Learning

The algorithms for SEM-learning described in this note rely on first inferring the correct causal graph structure for a given data set. Fortunately the last two years have seen exciting progress in applications of neural networks to the problem of causal graph structure-learning, particularly in the area of masked-gradient methods Zheng et al. (2018); Yu et al. (2019); Ng et al. (2019a); Fang et al. (2020); Ng et al. (2020). These methods center around an identity for acyclic weighted adjacency matrices, which was first derived in Zheng et al. (2018) and is shown in Equation 10. This identity enables a re-formulation of acyclic graph-learning as a continuous optimization problem. Here again denotes element-wise multiplication.


The graph-learning network can then be constructed using an encoder/decoder framework with an objective function that attempts to minimize some reconstruction loss, subject to an acyclicity constraint , where is a function of the weighted adjacency matrix given in Equation 11.


The original formulation for this continuous optimization, referred to as NO-TEARS Zheng et al. (2018), uses a reconstruction loss inspired directly by the form of the linear SEM in Equation 5. As illustrated in in the first line of Table 1, the encoder is just the identity function while the decoder is an MLP that takes as input a weighted masked latent space vector .

Encoder Decoder
Table 1: A comparison of functional structures for three well known masked-gradient-based algorithsm for causal structure learning.

In this note we focus our tests on two non-linear generalizations of the NO-TEARS algorithm, referred to as GNN and GAE. The encoder/decoder architectures are given in Table 1, where and refer to generic MLP based function-learners. Both of the GNN and GAE

frameworks generalize the well known closed-form solution for linear SEM’s. However the salient difference between them is the presence of a residual connection in

GNN represented by the identity term in the second line of Table 1. The reconstruction loss function for GNN is given by the usual evidence lower-bound (ELBO) for variational autoencoders while the reconstruction loss for GAE is simply the mean-squared-error (MSE). The above optimization can be implemented using the method of Lagrange multipliers with the Lagrangian defined in Equation 12.


Following the work in Yu et al. (2019); Ng et al. (2019a) we perform tests on four different toy data sets generated by structural equation models of increasing non-linear complexity, as shown in Equations 13-16.

linear: (13)
non-linear 1: (14)
non-linear 2: (15)
non-linear 3: (16)

In the original papers, both GNN and GAE were tested using randomly generated Erdős-Rényi graphs. For graphs with nodes, the authors of GNN reported structural hamming distance (SHD) errors ranging from (for nonlinear 2) and for (nonlinear 1). Impressively, the performance of the GAE algorithm exhibits a scaling that is roughly independent of the number of nodes in the graph for the Erdős-Rényi case, which we have verified in our own experiments. The primary reason for the difference in performance on large graphs is due to the presence of the residual connection in GNN, which enables an extremely accurate reconstruction of the data despite an incorrect causal graph structure.

Figure 6: Two graph structures used for the experiments in this note, which we refer to as Graph A (left) and Graph B (right). Causal estimation for Graph A requires mutilating two edges independent on the number of confounders, while causal estimation for Graph B requires mutilating a number of edges equal to the number of confounders.

In this note we perform tests on the GNN and GAE algorithms using the two graph structures shown in Figure 6, referred to as Graph A and Graph B. These two graph structures form the baseline cases for our structural equation model tests described in the next section, and represent different configurations of confounding variables increasing in number. The results of our structure-learning experiments, shown in Figure 7, indicate that the explicit presence of numerous confounding variables presents a significant obstacle to the recovery of correct causal structures relative to the Erdős-Rényi case, even for simple graphs with nodes as few as .

Figure 7: Structural hamming distances (SHD) for GNN and GAE as a function of the total number of nodes. Results are shown for Graph A (top row) and Graph B (bottom row) as defined in 6. For each # nodes we generate two graphs with different weights from different random seeds and perform 3 runs for each graph. The error bars indicate variations between the 3 runs on each seed.

iii.2 Structural Equation Modeling

The network architecture for SEM-learning proposed in this note is illustrated in Figure 8, and can be factorized into two components. The first component is just a generic variational autoencoder that encodes each node feature into its latent representation before decoding it back to the target representation . The second component introduces a “causal block” that performs ancestral sampling on the latent representation and produces a latent representation for each child-node that is a function of only its parent-nodes .

Figure 8: The proposed network architecture is an extension of a generic variational autoencoder (blue). The generator for the latent space is augmented with an additional causal network block (orange) that uses a causal mask as defined in 8 to generate a latent space distribution for each child node that is a function of only it’s parent nodes . The child node of a latent variable can thus be generated by cycling the inputs times through .

For SEM-learning on a graph with nodes, the causal block is correspondingly composed of neural-networks as illustrated diagramatically in Figure 9. A restriction on the functional dependence of each node to only its parent nodes is crucial for the automated generation of intervened conditional probability distributions. This is achieved simply through the use of the causal mask in the causal block , as well as the absence of any residual connection except for those nodes which have no parents. This includes those nodes which are chosen for intervention, as well as those nodes with no parents since they can be viewed as being intervened on by the environment. Ancestral sampling of an intervened distribution can then be performed simply by generating data for the intervened node

from a random-normal distribution, and cycling the data through the causal block

times in order to obtain the data for its child node , as illustrated in 8.

Figure 9: The causal block takes inputs from the latent node variables

. A single neural network for each latent dimension generates means and variances for the child nodes

. Nodes with no parents, including the intervened node , contain a residual connection, and all nodes with parents are functions of only their parents.

A functional expression for the causal block can be expressed as a sum of three terms as shown in Equation 17. The first term describes the contribution from noise and is computed via the usual reparameterization trick Kingma and Welling (2013) from neural-network-generated variances. The second term provides a residual connection only for node variables that have no parents. We thus define a delta function whose argument given a specified node is the number of parents belonging to that node, and normalized as shown in equation 18.


The third and final term is generated by the set of neural networks whose input is the vector containing the latent representation of ’s parent node data , as constructed according to Equation 9. The loss function used is a combination of the joint Zhao et al. (2017) and conditional Ren et al. (2016) maximum-mean-discrepancies (MMD and CMMD) as shown in Equation 19, with . The set of networks thus together form a generative conditional moment-matching graph-neural-network.


To measure the performance of interventional sampling we perform tests using an MLP-based encoder and decoder /

each consisting of a single hidden layer with 16 neurons. The causal block

is composed of neural networks, each with input dimension and output dimension , and each consisting of a single hidden-layer containing 64 neurons. For the loss function we choose (rather arbitrarily) and , and each trial is run on 8000 data points. The performance metric used is the relative entropy (KL divergence) between the conditional probability distributions generated by the intervened and unintervened ground truth SEM’s . We then compare it with the relative entropy between the intervened SEM and the one predicted by the causal autoencoder

at different standard deviations away from the distribution means, as illustrated in Figure


. The autoencoder predictions for these results have been smoothened using a kernel density estimator with a normal reference bandwidth.

Figure 10: The performance metric adopted in this note is the relative entropy between the conditional probability distribution for the predicted intervened SEM (top right) and the ground truth SEM (top middle). The is computed along slices corresponding to points at various standard deviations away from the mean (bottom right). As a baseline we compare this against the with respect to the unintervened conditional probability distribution (bottom left).
Figure 11: Performance metrics for experiments on Graph A. ’s are shown along contours of varying standard deviation for the probability distributions (top row) and (bottom row). The solid and dashed lines represent averages for 4 randomly generated adjacency matrices.
Figure 12: Performance metrics for Graph B along contours of varying standard deviation . Results are shown for the probability distributions (top row) and (bottom row). The solid and dashed lines represent averages for 4 randomly generated adjacency matrices.

Iv Discussion

The results of our experiments indicate that the proposed framework for simulating structural equation models is capable of capturing complex non-linear relationships among variables in way that is amenable to multi-step counterfactual interventions. Importantly, the generated probability distributions appear faithful to the ground truth intervened SEM’s, even when the intervened variables are fixed to values that are outside the range of values contained in the training data distributions. This capability implies a predictive ability that is manifestly beyond what is possible through analytical calculations via the back-door and front-door adjustment formulas, which can only be applied to intervened variables that take on values for which observable data exists.

With 8000 data points in each of the training sets, the maximum and minimum values for the node variable typically fall within the range of from the distribution mean, never exceeding . From Figure 11 and 12, we can observe that the linearly correlated data sets are faithful to the ground truth well beyond the mark. On the other hand, those data sets with strong non-linear components vary in their predictive performance beyond , but are reliably closer to the ground truth relative to the un-intervened distributions. This is unsurprising upon closer inspection of the predicted conditional (intervened) probabilities, which demonstrate a clear tendency for our generative model to perform simple linear extrapolations of the distributions in regimes outside those contained in the training data.

Although the experiments performed in this note were restricted to the case of scalar-valued node variables, we expect that a very simple extension of these methods could make them applicable to complex high dimensional image and language data. For example in CausalVAE Yang et al. (2020), the authors use supervised learning to encode specific image labels into a single dimension of the latent space . In one example, they use the CelebA data set of facial images to encode causal relationships between features like , thus allowing them to intervene on the latent space to produce images of unnaturally young bearded faces. Augmenting this procedure with the causal block described in this note would in principle enable synthetic generation of image populations with features that accurately represent conditional probabilities under multiple steps of causal influence. For example, an accurate distribution of hair colors if the graph structure contained . Unfortunately a detailed exploration on these high dimensional data types is beyond the scope of this note.

Another potential application of these methods could be for use with model-based reinforcement learning. In Dasgupta et al. (2019) the authors performed several experiments in a model-free RL framework in which they trained agents to make causal predictions in simple one-step-querying scenarios. In these experiments, the agents were directed to sample points from joint and conditional probability distributions of SEM-generated data, as well as the corresponding distributions from arbitrarily mutilated SEM graphs. These experiments showed evidence that their agents learned to exploit interventional and counterfactual reasoning to accumulate significantly higher rewards compared to the relevant baselines.

In Nair et al. (2019) the authors expand on the previous work by successfully training RL agents to perform causal reasoning in a more complex multi-step relational scenario with the ability to generalize to unseen causal structures that were held-out during training. Their experiments involved two separate RL agents. One which used supervised learning to generate a causal graph model off ground truth graphs, and another which was directed to take “goal-oriented” actions based on models learned by the first agent. The authors strongly hypothesized that the impressive level of generalizability displayed by their algorithm was a direct result of the explicit model-based approach. We find the possibility of performing such experiments using graphical models learned via the fully unsupervised approach described in this note to be both very intriguing and plausibly practical as a future area of exploration.

V Acknowledgements

We thank Vincent Tang, Jiheum Park, Ignavier Ng, Jungwoo Lee, and Tim Lou for useful discussions.