Causally-motivated Shortcut Removal Using Auxiliary Labels

by   Maggie Makar, et al.

Robustness to certain distribution shifts is a key requirement in many ML applications. Often, relevant distribution shifts can be formulated in terms of interventions on the process that generates the input data. Here, we consider the problem of learning a predictor whose risk across such shifts is invariant. A key challenge to learning such risk-invariant predictors is shortcut learning, or the tendency for models to rely on spurious correlations in practice, even when a predictor based on shift-invariant features could achieve optimal i.i.d generalization in principle. We propose a flexible, causally-motivated approach to address this challenge. Specifically, we propose a regularization scheme that makes use of auxiliary labels for potential shortcut features, which are often available at training time. Drawing on the causal structure of the problem, we enforce a conditional independence between the representation used to predict the main label and the auxiliary labels. We show both theoretically and empirically that this causally-motivated regularization scheme yields robust predictors that generalize well both in-distribution and under distribution shifts, and does so with better sample efficiency than standard regularization or weighting approaches.


Improving Out-of-Distribution Robustness via Selective Augmentation

Machine learning algorithms typically assume that training and test exam...

Learning Representations that Support Robust Transfer of Predictors

Ensuring generalization to unseen environments remains a challenge. Doma...

When Does Group Invariant Learning Survive Spurious Correlations?

By inferring latent groups in the training data, recent works introduce ...

Estimating Generalization under Distribution Shifts via Domain-Invariant Representations

When machine learning models are deployed on a test distribution differe...

Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization

Real-world data collected from multiple domains can have multiple, disti...

Invariant Rationalization

Selective rationalization improves neural network interpretability by id...

Learning Counterfactually Invariant Predictors

We propose a method to learn predictors that are invariant under counter...