Counterfactual Invariance to Spurious Correlations: Why and How to Pass Stress Tests

05/31/2021 ∙ by Victor Veitch, et al. ∙ 0

Informally, a `spurious correlation' is the dependence of a model on some aspect of the input data that an analyst thinks shouldn't matter. In machine learning, these have a know-it-when-you-see-it character; e.g., changing the gender of a sentence's subject changes a sentiment predictor's output. To check for spurious correlations, we can `stress test' models by perturbing irrelevant parts of input data and seeing if model predictions change. In this paper, we study stress testing using the tools of causal inference. We introduce counterfactual invariance as a formalization of the requirement that changing irrelevant parts of the input shouldn't change model predictions. We connect counterfactual invariance to out-of-domain model performance, and provide practical schemes for learning (approximately) counterfactual invariant predictors (without access to counterfactual examples). It turns out that both the means and implications of counterfactual invariance depend fundamentally on the true underlying causal structure of the data. Distinct causal structures require distinct regularization schemes to induce counterfactual invariance. Similarly, counterfactual invariance implies different domain shift guarantees depending on the underlying causal structure. This theory is supported by empirical results on text classification.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Our focus in this paper is the sort of spurious correlations revealed by “poke it and see what happens” testing procedures for machine-learning models. For example, we might test a sentiment analysis tool by changing one proper noun for another (“tasty Mexican food” to “tasty Indian food”), with the expectation that the predicted sentiment should not change. This kind of perturbative stress testing is increasingly popular: it is straightforward to understand and offers a natural way to test the behavior of models against the expectations of practitioners ribeiro2020beyond,Wu:Ribeiro:Heer:Weld:2019,naik2018stress.

Intuitively, models that pass such stress tests are preferable to those that do not. However, fundamental questions about the use and meaning of perturbative stress tests remain open. For instance, what is the connection between passing stress tests and model performance on prediction? Eliminating predictor dependence on a spurious correlation should help with domain shifts that affect the spurious correlation—but how do we make this precise? And, how should we develop models that pass stress tests when our ability to generate perturbed examples is limited? For example, automatically perturbing the sentiment of a document in a general fashion is difficult.

The ad hoc nature of stress testing makes it difficult to give general answers to such questions. In this paper, we will use the tools of causal inference to formalize what it means for models to pass stress tests, and use this formalization to answer the questions above. We will formalize passing stress tests as counterfactual invariance, a condition on how a predictor should behave when given certain (unobserved) counterfactual input data. We will then derive implications of counterfactual invariance that can be measured in the observed data. Regularizing predictors to satisfy these observable implications provides a means for achieving (partial) counterfactual invariance. Then, we will connect counterfactual invariance to robust prediction under certain domain shifts, with the aim of clarifying what counterfactual invariance buys and when it is desirable.

An important insight that emerges from the formalization is that the true underlying causal structure of the data has fundamental implications for both model training and guarantees. Methods for handing ‘spurious correlations’ in data with a given causal structure need not perform well when blindly translated to another causal structure.

Counterfactual Invariance

Consider the problem of learning a predictor that predicts a label from covariates . In this paper, we’re interested in constructing predictors whose predictions are invariant to certain perturbations on . Our first task is to formalize the invariance requirement.

To that end, assume that there is an additional variable that captures information that should not influence predictions. However, may causally influence the covariates . Using the potential outcomes notation, let to denote the counterfactual we would have seen had been set to , leaving all else fixed. Informally, we can understand perturbative stress tests as a way of producing particular realizations of counterfactual pairs that differ by an intervention on . Then, we formalize the requirement that an arbitrary change to does not change predictions:

Definition 1.1.

A predictor is counterfactually invariant to if almost everywhere, for all in the sample space of . When is clear from context, we’ll just say the predictor is counterfactually invariant.

2 Causal Structure

Counterfactual invariance is a condition on how the predicted label behaves under interventions on parts of the input data. However, intuitions about stress testing are based on how the true label behaves under interventions on parts of the input data. We will see that the true causal structure fundamentally affects both the implications of counterfactual invariance, and the techniques we use to achieve it. To study this phenomenon, we’ll use two causal structures that are commonly encountered in applications; see fig. 1.

Prediction in the Causal Direction

(a) Causal direction

(b) Anticausal direction
Figure 1: Causal models for the data generating process. We decompose the observed covariate into latent parts defined by their causal relationships with and . Solid arrows denote causal relationships, while dashed lines denote non-causal associations. The differences between these causal structures will turn out to be key for understanding counterfactual invariance.

We begin with the case where is a cause of .

Example 2.1.

We want to automatically classify the quality of product reviews. Each review has a number of “helpful” votes

(from site users). We predict using the text of the product review . However, we find interventions on the sentiment of the text change our prediction; changing “Great shoes!” to “Bad shoes!” changes the prediction.

In the examples in this paper, the covariate is text data. Usually, the causal relationship between the text and and will be complex—e.g., the relationships may depend on abstract, unlabeled, parts of the text such as topic, writing quality, or tone. In principle, we could enumerate all such latent variables, construct a causal graph capturing the relationships between these variables and , and use this causal structure to study counterfactual invariance. For instance, if we think that topic causally influences the helpfulness , but is not influenced by sentiment , then we could build a counterfactually invariant predictor by extracting the topic and predicting using topic alone. However, exhaustively articulating all possible such variables is a herculean task.

Instead, notice that the only thing that’s relevant about these latent variables is their causal relationship with and . Accordingly, we’ll decompose the observed variable into parts defined by their causal relationships with and . We remain agnostic to the semantic interpretation of these parts. Namely, we define as the part of that is not causally influenced by (but may influence ), as the part that does not causally influence (but may be influenced by ), and is the remaining part that is both influenced by and that influences . The causal structure is shown in fig. 0(a).

We see there are two paths that lead to and being associated. The first is when affects which, in turn, affects . For example, a very enthusiastic reviewer might write a longer, more detailed review, which will in turn be more helpful. The second is when a common cause or selection effect in the data generating process induces an association between and , which we denote with a dashed arrow. For example, if books tend to get more positive reviews, and also people who buy books are more likely to flag reviews as helpful, then the product type would be a common cause of sentiment and helpfulness.

Prediction in the Anti-Causal Direction

We also consider the case where causes .

Example 2.2.

We want to predict the star rating of movie reviews from the text . However, we find that predictions are influenced by the movie genre ; e.g., changing “Adam Sandler” (a comedy actor) to “Hugh Grant” (a romance actor) changes the predictions.

Figure 0(b) shows the causal structure. Here, the observed is influenced by both and . Again, we decompose into parts defined by their causal relationship with and . Here, (and thus ) can be associated with through two paths. First, if is non-trivial, then conditioning on it causes a dependence between and (because is a collider). For example, if Adam Sandler tends to appear in good comedy movies but bad movies of other genres then seeing “Sandler” in the text induces a dependency between sentiment and genre. Second, and may be associated due to a common cause, or due to selection effects in the data collection protocol—this is represented by the dashed line between and . For example, fans of romantic comedies may tend to give higher reviews (to all films) than fans of horror movies.

Non-Causal Associations

Frequently, a predictor trained to predict from will rely on , even though there is no causal connection between and , and therefore will fail counterfactual invariance. The reason is that serves as a proxy for , and is predictive of due to the non-causal (dashed line) association.

There are two mechanisms that can induce such associations. First, and may be confounded: they are both influenced by an unobserved common cause . For example, people who review books may be more upbeat than people who review clothing. This leads to positive sentiments and high helpfulness votes for books, creating an association between sentiment and helpfulness. Second, and may be subject to selection: there is some condition (event) that depends on and , such that a data point from the population is included in the sample only if holds. For example, our training data might only include movies with at least 100 reviews. If only excellent horror movies have so many reviews (but most rom-coms get that many), then this selection would induce an association between genre and score. Formally, the dashed-line causal graphs mean our sample is distributed according to where are caused by and are causes of , and are causally related according to the graph.

In addition to the non-causal dashed-line relationship, there is also dependency induced by between and by . Whether or not each of these dependencies is “spurious” is a problem-specific judgement that must be made by each analyst based on their particular use case. E.g., using genre to predict sentiment may or may not be reasonable, depending on the actual application in mind. However, there is a special case that captures a common intuition for purely spurious association.

Definition 2.3.

We say that the association between and is purely spurious if .

That is, if the dashed-line association did not exist (removed by conditioning on ) then the part of that is not influenced by

would suffice to estimate

.

3 Observable Signatures of Counterfactually Invariant Predictors

We now consider the question of how to achieve counterfactual invariance in practice. The challenge is that counterfactual invariance is defined by the behavior of the predictor on counterfactual data that is never actually observed. This makes checking counterfactual invariance impossible. Instead, we’ll derive a signature of counterfactual invariance that actually can be measured—and enforced—using ordinary datasets where (or a proxy) is measured. For example, the star rating of a review as a proxy for sentiment, or genre labels in the movie review case.

Intuitively, a predictor is counterfactually invariant if it depends only on , the part of that is not affected by . To formalize this, we need to show that such a is well defined:

Lemma 3.1.

Let be a

-measurable random variable such that, for all measurable functions

, we have that is counterfactually invariant if and only if is -measurable. If is discrete then such a exists.

Accordingly, we’d like to construct a predictor that is a function of only (i.e., is measurable). The key insight is that we can use the causal graphs to read off a set of conditional independence relationships that are satisfied by . Critically, these conditional independence relationships are testable from the observed data. Thus, they provide a signature of counterfactual invariance:

Theorem 3.2.

If is a counterfactually invariant predictor:

  1. Under the anti-causal graph, .

  2. Under the causal-direction graph, if and are not subject to selection (but possibly confounded), .

  3. Under the causal-direction graph, if the association is purely spurious, , and and are not confounded (but possibly selected), .

Causal Regularization

Without access to counterfactual examples, we cannot directly enforce counterfactual invariance. However, we can require a trained model to satisfy the counterfactual invariance signature of Theorem 3.2. The hope is that enforcing the signature will lead the model to be counterfactually invariant. To do this, we regularize the model to satisfy the appropriate conditional independence condition. For simplicity of exposition, we restrict to binary and . The (infinite data) regularization terms are

marginal regularization (3.1)
conditional regularization
(3.2)

Maximum mean discrepancy (MMD) is a metric on probability measures.

111The choice of MMD is for concreteness, any distance on probability spaces would do. The marginal independence condition is equivalent to eq. 3.1 equal , and the conditional independence is equivalent to eq. 3.2 equal . In practice, we can estimate the MMD with finite data samples [Gretton:Borgwardt:Rasch:Scholkopf:Smola:2012]

. When training with stochastic gradient descent, we compute the penalty on each minibatch.

The procedure is then: if the data has causal-direction structure and the association is due to confounding, add the marginal regularization term to the the training objective. If the data has anti-causal structure, or the association is due to selection, add the conditional regularization term instead. In this way, we regularize towards models that satisfy the counterfactual invariance signature.

A key point is that the regularizer we must use depends on the true causal structure. The conditional and marginal independence conditions are generally incompatible. Enforcing the condition that is mismatched to the true underlying causal structure will not in general enforce counterfactual invariance, or may throw away more information than is required.

Gap to Counterfactual Invariance

The conditional independence signature of Theorem 3.2 is necessary but not sufficient for counterfactual invariance. This is for two reasons. First, counterfactual invariance applies to individual datapoint realizations, but the signature is distributional. In particular, the invariance for all would also imply the conditional independence signature. But, this invariance is weaker than counterfactual invariance, since it doesn’t require access to counterfactual realizations. Second, does not imply, in general, that is not a cause of . This (unusual) behavior can happen if, e.g., there are levels of that we do not observe in the training data, or there are variables omitted from the causal graph that are a common cause of and .

Unfortunately, the gap between the signature and counterfactual invariance is a fundamental restriction of using observational data. The conditional independence signature is in some sense the closest proxy for counterfactual invariance we can hope for. In section 5, we’ll see that enforcing the signature does a good job of enforcing counterfactual invariance in practice.

4 Performance Out of Domain

Counterfactual invariance is an intuitively desirable property for a predictor to have. However, it’s not immediately clear how it relates to model performance as measured by, e.g., accuracy. Intuitively, eliminating predictor dependence on a spurious may help with domain shift, where the data distribution in the target domain differs from the distribution of the training data. We now turn to formalizing this idea.

First, we must articulate the set of domain shifts to be considered. In our setting, the natural thing is to hold the causal relationships fixed across domains, but to allow the non-causal (“spurious”) dependence between and to vary. Demanding that the causal relationships stay fixed reflects the requirement that the causal structure describes the dynamics of an underlying real-world process—e.g., the author’s sentiment is always a cause (not an effect) of the text in all domains. On the other hand, the dependency between and induced by either confounding or selection can vary without changing the underlying causal structure. For confounding, the distribution of the confounder may differ between domains—e.g., books are rare in training, but common in deployment. For selection, the selection criterion may differ between domains—e.g., we include only frequently reviewed movies in training, but make predictions for all movies in deployment.

We want to capture spurious domain shifts by considering domain shifts induced by selection or confounding. However, there is an additional nuance. Changes to the marginal distribution of will affect the risk of a predictor, even in the absence of any spurious association between and . Therefore, we restrict to shifts that preserve the marginal distribution of .

Definition 4.1.

We say that distributions are causally compatible if both obey the same causal graph, , and there is a confounder and/or selection conditions such that and for some .

We can now connect counterfactual invariance and robustness to domain shift.

Theorem 4.2.

Let be the set of all counterfactually invariant predictors. Let be either square error or cross entropy loss. And, let be the counterfactually invariant risk minimizer. Suppose that the target distribution is causally compatible with the training distribution . Suppose that any of the following conditions hold:

  1. the data obeys the anti-causal graph

  2. the data obeys the causal-direction graph, there is no confounding (but possibly selection), and the association is purely spurious, , or

  3. the data obeys the causal-direction graph, there is no selection (but possibly confounding), the association is purely spurious and the causal effect of on is additive, i.e., the true data generating process is

    (4.1)

    for some functions .

Then, the training domain counterfactually invariant risk minimizer is also the target domain counterfactually invariant risk minimizer, .

Remark 4.3.

The causal case with confounding requires an additional assumption (additive structure) because, e.g., an interaction between confounder and can yield a case where and have a different relationship in each domain (whence, out-of-domain learning is impossible).

This result gives a recipe for finding a good predictor in the target domain even without access to any target domain examples at training time. Namely, find the counterfactually invariant risk minimizer in the training domain. In practice, we can use the regularization scheme of section 3 to (approximately) achieve this. We’ll see in section 5 that this works well in practice.

Optimality

Theorem 4.2 begs the question: if the only thing we know about the target setting is that it’s causally compatible with the training data, is the best predictor the counterfactually invariant predictor with lowest training risk? A natural way to formalize this question is to study the predictor with the best performance in the worst case target distribution. We define and the -minimax predictor . The question is then: what’s the relationship between the counterfactually invariant risk minimizer and the minimax predictor?

Theorem 4.4.

The counterfactually invariant risk minimizer is not -minimax in general. However, under the conditions of Theorem 4.2, if the association is purely spurious, , and satisfies overlap, then the two predictors are the same. By overlap we mean that is a discrete distribution such that for all , if then there is some such that also .

Conceptually, Theorem 4.4 just says that the counterfactually invariant predictor excludes , even when this information is useful in every domain. In the purely spurious case, carries no useful information, so counterfactual invariance is optimal.

5 Experiments

The main claims of the paper are:

  1. Stress test violations can be reduced by suitable conditional independence regularization.

  2. This reduction will improve out-of-domain prediction performance.

  3. To get the full effect, the imposed penalty must match the causal structure of the data.

Setup

To assess these claims, we’ll examine the behavior of predictors trained with the marginal or conditinal regularization on multiple text datasets that have either causal or anti-causal structure. We expect to see that marginal regularization improves stress test and out-of-domain performance on data with causal-confounded structure, and conditional regularization improves these on data with anti-causal structure.

For each experiment, we use BERT devlin2019bert finetuned to predict a label from the text as our base model. We train multiple causally-regularized models on the each dataset. The training varies by whether we use the conditional or marginal penalty, and by the strength of the regularization term. That is, we train identical architectures using as the objective function, where we vary and take as either the marginal penalty, eq. 3.1, or conditional penalty, eq. 3.2. We compare these models’ predictions on data with causal and anti-causal structure.

See supplement for experimental details.

5.1 Robustness to Stress Tests

First, we examine whether enforcing the causal regularization actually helps to enforce counterfactual invariance. We create counterfactual (stress test) examples by perturbing the input data and compare the prediction on these. We build the experimental datasets using Amazon reviews from the product category “Clothing, Shoes, and Jewelry” [Ni:Li:McAuley:2019].

Synthetic

To study the relationship between counterfactual invariance and the distributional signature of Theorem 3.2, we construct a synthetic confound. For each review, we draw a Bernoulli random , and then perturb the text so that the common words “the” and “a” carry information about : for example, we replace “the” with the token “thexxxxx” when . We take to be the review score, and subsample so is balanced. This data has anti-causal structure: the text is written to explain the score . Further, we expect that the association is purely spurious, because “the” and “a” carry little information about the label.

We train the models on data where . We then create perturbed stress-test datasets by changing each example to the counterfactual (using the synthetic model). By measuring the performance of each model on the perturbed data, we can test whether the distributional properties enforced by the regularizers result in counterfactual invariance at the instance level. Figure 2 shows that conditional regularization (matching the anti-causal structure) reduces checklist failures, as measured by the frequency that the predicted label changes due to perturbation as well as the mean absolute difference in predictive probabilities that is induced by perturbation.

Figure 2: Regularizing conditional MMD improves counterfactual invariance on synthetic anti-causal data. Sufficiently high regularization of marginal MMD also improves invariance, but impairs accuracy. Dashed lines show baseline performance of an unregularized predictor. Left: lower conditional MMD implies that predictive probabilities are invariant to perturbation. Although marginal MMD penalization can result in low conditional MMD and good stress test performance, this comes at the cost of very low in-domain accuracy. Right: MMD regularization reduces the rate of predicted label flips on perturbed data, with little affect on in-domain accuracy. Conditional MMD regularization reduces predicted label flips to , while the best result for marginal MMD is .
Natural

To study the relationship in real data, we use the review data in a different way. We now take

to be the score, binarized as

. We use this as a proxy for sentiment, and consider problems where sentiment should (plausibly) not have a causal effect on . For the causal prediction problem, we take to be the helpfulness score of the review (binarized as described below). This is causal because readers decide whether the review is helpful based on the text. For the anti-causal prediction problem, we take to be whether “Clothing” is included as a category tag for the product under review (e.g., boots typically do not have this tag). This is anti-causal because the product category affects the text.

We control the strength of the spurious association between and . In the anti-causal case, this is done by selection: we randomly subset the data to enforce a target level of dependence between and . The causal-direction case with confounding is more complicated.

Figure 3: Penalizing the MMD matching the causal structure improves stress test performance on natural product review data. Note that penalizing the wrong MMD may not help: the marginal MMD hurts on the anticausal dataset. Perturbations are generated by swapping positive and negative sentiment adjectives in examples.

To manipulate confounding strength, we binarize the number of helpfulness votes in a manner determined by the target level of association. We take where is a -dependent threshold, chosen to induce a target association. We choose . We balance by subsampling, which also balances .

Now, we create stress test perturbations of these datasets by randomly changing adjectives in the examples. Using predefined lists of postive sentiment adjectives and negative sentiment adjectives, we swap any adjective that shows up on a list with a randomly sampled adjective from the other list. This preserves basic sentence structure, and thus creates a limited set of counterfactual pairs that differ on sentiment.

Results for differences in predicted probabilities between original and perturbed data are shown in fig. 3. Each point is a trained model, which vary in measured MMD on the test data and on sensitivity to perturbations. Recall that the conditional independence signature of Theorem 3.2 are necessary but not sufficient for counterfactual invariance, so it’s not certain that regularizing to reduce the MMD will reduce perturbation sensitivity. Happily, we see that regularizing to reduce the MMD that matches the causal structure does indeed reduce sensitivity to perturbations.

Notice that regularizing the causally mismatched MMD can have strange effects. Regularizing marginal MMD in the anti-causal case actually makes the model more sensitive to perturbations!

5.2 Domain Shift

Next, we study the effect of causal regularization on model performance under domain shift.

Natural Product Review

We again use the natural Amazon review data described above. For both the causal and anti-causal data, we create multiple test sets with variable spurious correlation strength. This is done in the manner described above, varying . Here, is the strength of spurious association. The test sets are out-of-domain samples. By design, is balanced in each dataset, so these samples are causally compatible with the training data. For both the causal and anti-causal datasets, the training data has . We train a classifier for each regularization type and regularization strength, and measure the accuracy on each test domain. The results are shown in fig. 4.

First, the unregularized predictors do indeed learn to rely on the spurious association between sentiment and the label. The accuracy of these predictors decays dramatically as

(a) Anti-Causal Data: conditional regularization improves domain-shift robustness.
(b) Causal-Direction Data: marginal regularization improves domain-shift robustness.
Figure 4: The best domain-shift robustness is obtained by using the regularizer that matches the underlying causal structure of the data. The plots show out-of-domain accuracy for models trained on the (natural) review data. In each row, the left figure shows out-of-domain accuracies (lines are models), with the -axis showing the level of spurious correlation in the test data ( is the training condition); the right figure shows worst out-of-domain accuracy versus in-domain test accuracy (dots are models).

the spurious assocation moves from negative (0.3) to positive—in the causal case, the unregularized predictor is worse than chance in the 0.8 domain.

Following section 3, the regularization that matches the underlying causal structure should yield a predictor that is (approximately) counterfactually invariant. Following Theorem 4.2, we expect that good performance of a counterfactually-invariant predictor in the training domain should imply good performance in each of the other domains. Indeed, we see that this is so. Models that are regularized to have small values of the appropriate MMD do indeed have better out-of-domain performance. Such models have somewhat worse in-domain performance, because they no longer exploit the spurious correlation.

MNLI Data
Figure 5: Conditional MMD penalization improves robustness in anti-causal MNLI data. Marginal regularization does not improve over the baseline unregularized model, shown with dashed lines. Left: Conditional regularization improves minimum accuracy across groups. When overregularized, the predictor returns the same for all inputs, yielding a worst-group accuracy of . Right: Conditional MMD regularization significantly improves worst group accuracy (-axis) while only mildly reducing overall accuracy (-axis).

For an additional test on naturally-occurring confounds, we use the multi-genre natural language inference (MNLI) dataset [williams-etal-2018-broad]. Instances are concatenations of two sentences, and the label describes the semantic relationship between them, . There is a well-known confound in this dataset: examples where the second sentence contain a negation word (e.g., “not”) are much more likely to be labeled as contradictions [gururangan2018annotation]. Following sagawa2019distributionally, we set to indicate whether one of a small set of negation words is present. Although is derived from the text , it can be viewed as a proxy for a latent variable indicating whether the author intended to use negation in the text. This is an anti-causal prediction problem: the annotators were instructed to write text to reflect the desired label [williams-etal-2018-broad].

Following sagawa2019distributionally, we divide the MNLI data into groups by and compute the “worst group accuracy” across all such groups. Because this is an anti-causal problem, we predict that the conditional MMD is a more appropriate penalty than the marginal MMD. As shown in Figure 5, this prediction holds: conditional MMD regularization dramatically improves performance on the worst group, while only lightly impacting the overall accuracy across groups.

6 Related work

Several papers draw a connection between causality and domain shifts [Subbaswamy:Saria:2018, Subbaswamy:Chen:Saria:2019, Arjovsky:Bottou:Gulrajani:LopezPaz:2020, Meinshausen:2018, Peters:Buhlmann:Meinshausen:2016, RojasCarulla:Scholkopf:Turner:Peters:2018, Zhang:Scholkopf:Muandet:Wang:2013]. Typically, this work considers a prediction setting where the covariates include both causes and effects of , and it is unknown which is which. The goal is to learn to predict using only its causal parents. Zhang:Scholkopf:Muandet:Wang:2013 considers anti-causal domain shift induced by changing and proposes a data reweighting scheme. Counterfactual invariance is not generally the same as invariance to the domain shifts previously considered.

A related body of work focuses on “causal representation learning” [Besserve:Mehrjou:Sun:Scholkopf:2019, Locatello:Poole:Raetsch:Scholkopf:Bachem:Tschannen:2020, Scholkopf:Locatello:Bauer:Ke:Kalchbrenner:Goyal:Bengio:2021, Arjovsky:Bottou:Gulrajani:LopezPaz:2020]. Our approach follows this tradition, but focuses on splitting into components defined by their causal relationships with the label and an additional covariate . Rather than attempting to infer the causal relationship between and

, we show that domain knowledge of this relationship is essential for obtaining counterfactually-invariant predictors. The role of causal vs anti-causal data generation in semi-supervised learning has also been studied 

[Scholkopf:Locatello:Bauer:Ke:Kalchbrenner:Goyal:Bengio:2021, Scholkopf:Janzing:Peters:Sgouritsa:Zhang:Mooij:2012]. In this paper we focus on a different implication of the causal vs anti-causal distinction.

Another line of work considers the case where the counterfactuals , are observed for at least some data points [wu2021polyjuice, Garg:Perot:Limtiaco:Taly:Chi:Beutel:2019, Mitrovic:McWilliams:Walker:Buesing:Blundell:2020, Wei:Zou:2019, Varun:Choudhary:Cho:2020, Kaushik:Hovy:Lipton:2020, Teney:Abbasnejad:vandenHengel:2020]

. Kusner:Loftus:Russell:Silva:2017,Garg:Perot:Limtiaco:Taly:Chi:Beutel:2019 in particular examine a notion of counterfactual fairness that can be seen as equivalent to counterfactual invariance. In these papers, approximate counterfactuals are produced by direct manipulation of the text (change male to female names), generative language models, or crowdsourcing. Then, these counterfactuals can either be used as additional training data or the predictor can be regularized such that it cannot distinguish between

and . This strategy can be viewed as enforcing counterfactual invariance directly; an advantage is that it avoids the necessary-but-not-sufficient nuance of Theorem 3.2. However, counterfactual examples can be difficult to obtain for language data in many realistic problem domains, and it may be difficult to learn to generalize from such examples [Huang:Liu:Bowman:2020].

Finally, the marginal and conditional independencies of Theorem 3.2 have appeared in other contexts. If we think of as a protected attribute and

as a ‘fair’ classifier, then the marginal independence is demographic parity, and the conditional independence is equalized odds

[Mehrabi:Morstatter:Saxena:Lerman:Galstyan:2019]. We can now understand these conditions as consequences of a single desideratum: the prediction should not change under intervention on a protected attribute. Similarly, an approach to domain adaptation is to seek representations such that either  [e.g.,][]muandet2013domain,Baktashmotlagh:Harandi:Lovell:Salzmann:2013,Ganin:Ustinova:Ajakan:Germain:Larochelle:Laviolette:Marchand:Lempitsky:2016,Tzeng:Hoffman:Zhang:Saenko:Darrell:2014 or  [e.g.,][]Manders:vanLaarhoven:Marchiori:2019,Yan:Ding:Li:Wang:Xu:Zuo:2017 are distributionally invariant over domains. If we take to be a domain label, these are the marginal and conditional independencies, and can be understood as consequences of the desideratum that the prediction shouldn’t change under domain shift.

7 Discussion

We used the tools of causal inference to formalize and study perturbative stress tests. A main insight of the paper is that counterfactual desiderata can be linked to observationally-testable conditional independence criteria. This requires consideration of the true underlying causal structure of the data. Done correctly, the link yields a simple procedure for enforcing the counterfactual desiderata, and mitigating the effects of domain shift.

The main limitation of the paper is the restrictive causal structures we consider. In particular, we require that , the part of not causally affected by , is also statistically independent of in the observed data. However, in practice these may be dependent due to a common cause. In this case, the procedure here will be overly conservative, throwing away more information than required. Additionally, it is not obvious how to apply the ideas described here to more complicated causal situations, which can occur in structured prediction (e.g., question answering). Extending the ideas to handle richer causal structures is an important direction for future work. The work described here can provide a template for this research program.

Appendix A Proofs

See 3.1

Proof.

Write for the potential outcomes. First notice that if is -measurable then is counterfactually invariant. This is essentially by definition—intervention on doesn’t change the potential outcomes, so it doesn’t change the value of . Conversely, if is counterfactually invariant, then is -measurable. To see this, notice that is determined by and , so for . Now, if depends only on we’re done. So suppose that there is such that (almost everywhere). But then , contradicting counterfactual invariance.

Now, define as the intersection of sigma algebra of and the sigma algebra of the potential outcomes . Because is the intersection of sigma algebras, it is itself a sigma algebra. Because every -measurable random variable is -measurable, we have that is not a cause of any -measurable random variable (i.e., there is no arrow from to ). Because, for counterfactually invariant, is both -measurable and -measurable, it is also -measurable. is countably generated, as and are both Borel measurable. Therefore, we can take to be any random variable such that . ∎

See 3.2

Proof.

Reading -separation from the causal graphs, we have in the causal-direction graph when and are not selected on, and for the other cases. By assumption, is a counterfactually-invariant predictor, which means that is -measurable.

To see that interventional invariance suffices for the conditional independencies, notice that they only the distribution of . It is not possible to distinguish interventional and counterfactual invariance based only on the distribution, so the condition must also hold ∎

See 4.2

Proof.

First, since counterfactual invariance implies -measurable,

(A.1)

It is well-known that under squared error or cross entropy loss the minimizer is . By the same argument, the counterfactually invariant risk minimizer in the target domain is . Thus, our task is to show .

We begin with the anti-causal case. We have that . By assumption, . So, it suffices to show that . To that end, from the anti-causal direction graph we have that . Then,

(A.2)
(A.3)
(A.4)

where the first and third lines are causal compatibility, and the second line is from .

The causal-direction case with no confounding follows essentially the same argument.

For the causal-direction case without selection,

(A.5)
(A.6)

The first line is the assumed additivity. The second line follows because for all causally compatible distributions ( doesn’t change), and . Taking an expectation over , we have . By the same token, . But, , since changes to the confounder don’t change the distribution of (that is, ). And, by assumption, . Together, these imply that . Whence, from eq. A.6, we have , as required. ∎

See 4.4

Proof.

The reason that the predictors are not the same in general is that the counterfactually invariant predictor will always exclude information in , even when this information is helpful for predicting in all target settings. For example, consider the case where are binary, and, in the anti-causal direction, . With cross-entropy loss, the counterfactually invariant predictor is just the constant , but the decision rule that uses if is always better. In the causal case, consider and .

Informally, the second claim follows because—in the absence of information—any predictor that’s better than the counterfactually invariant predictor when and are positively correlated will be worse when and are negatively correlated.

To formalize this, we begin by considering the case where is binary and . So, in particular, the counterfactually invariant predictor is just some constant . Let be any predictor that uses the information in . Our goal is to show that for at least one test distribution (so that is not minimax). To that end, let be any distribution where has lower risk than (this must exist, or we’re done). Then, define . In words: is the collection of points where did better than the constant predictor. Since is better than the constant predictor overall, we must have . Now, define . That is, the set constructed by flipping the label for every instance where did better. By the overlap assumption, . By construction, is worse than on . Further, is a random variable that has the causal structure required by a selection variable (it’s a child of and and nothing else). So, the distribution defined by selection on is causally compatible with and satisfies , as required.

To relax the requirement that , just repeat the same argument conditional on each value of . To relax the condition that is binary, swap the flipped label for any label with worse risk. ∎

Appendix B Experimental Details

b.1 Model

All experiments use BERT as the base predictor. We use bert_en_uncased_L-12_H-768_A-12

from TensorFlow Hub and do not modify any parameters. Following standard practice, predictions are made using a linear map from the representation layer. We use CrossEntropy loss as the training objective. We train with vanilla stochastic gradient descent, batch size 1024, and learning rate

. We use patience 10 early stopping on validation risk. Each model was trained using 2 Tensor Processing Units.

For the MMD regularizer, we use the estimator of Gretton:Borgwardt:Rasch:Scholkopf:Smola:2012 with the Gaussian RBF kernel. We set kernel bandwidth to . We compute the MMD on , where is the model estimate of . (Note: this is , not

—the later has an extra, irrelevant, degree of freedom). We use log-spaced regularization coefficients between

and .

b.2 Data

We don’t do any pre-processing on the MNLI data.

The Amazon review data is from [Ni:Li:McAuley:2019].

b.2.1 Inducing Dependence Between and in Amazon Product Reviews

To produce the causal data with

  1. Randomly drop reviews with helpful votes , until both and .

  2. Find the smallest such that and .

  3. Set for each example and for each example.

  4. Randomly flip to in examples where or , until and .

After data splitting, we have training examples, test examples, and validation examples.

To produce the anti-causal data with , choose a random subset with the target association. After data splitting, we have training examples, test examples, and validation examples.

b.2.2 Synthetic Counterfactuals in Product Review Data

We select product reviews from the Amazon “clothing, shoes, and jewelery” dataset, and assign if the review is 4 or 5 stars, and otherwise. For each review, we use only the first twenty tokens of text. We then assign as a Bernoulli random variable with . When , we replace the tokens “and” and “the” with “andxxxxx” and “thexxxxx” respectively; for we use the suffix “yyyyy” instead. Counterfactuals can then be produced by swapping the suffixes. To induce a dependency between and , we randomly resample so as to achieve and , using the same procedure that was used on the anti-causal model of “natural” product reviews. After selection there are training instances and test instances.