Counterfactual confounding adjustment for feature representations learned by deep models: with an application to image classification tasks

04/20/2020
by   Elias Chaibub Neto, et al.
0

Causal modeling has been recognized as a potential solution to many challenging problems in machine learning (ML). While counterfactual thinking has been leveraged in ML tasks that aim to predict the consequences of actions/interventions, it has not yet been applied to more traditional/static supervised learning tasks, such as the prediction of labels in image classification tasks. Here, we propose a counterfactual approach to remove/reduce the influence of confounders from the predictions generated a deep neural network (DNN). The idea is to remove confounding from the feature representations learned by DNNs in anticausal prediction tasks. By training an accurate DNN using softmax activation at the classification layer, and then adopting the representation learned by the last layer prior to the output layer as our features, we have that, by construction, the learned features will fit well a (multi-class) logistic regression model, and will be linearly associated with the labels. Then, in order to generate classifiers that are free from the influence of the observed confounders we: (i) use linear models to regress each learned feature on the labels and on the confounders and estimate the respective regression coefficients and model residuals; (ii) generate new counterfactual features by adding back to the estimated residuals to a linear predictor which no longer includes the confounder variables; and (iii) train and evaluate a logistic classifier using the counterfactual features as inputs. We validate the proposed methodology using colored versions of the MNIST and fashion-MNIST datasets, and show how the approach can effectively combat confounding and improve generalization in the context of dataset shift. Finally, we also describe how to use conditional independence tests to evaluate if the counterfactual approach has effectively removed the confounder signals from the predictions.

READ FULL TEXT
research
01/12/2020

Towards causality-aware predictions in static machine learning tasks: the linear structural causal model case

While counterfactual thinking has been used in ML tasks that aim to pred...
research
10/22/2022

Counterfactual Generation Under Confounding

A machine learning model, under the influence of observed or unobserved ...
research
12/18/2020

Robustness to Spurious Correlations in Text Classification via Automatically Generated Counterfactuals

Spurious correlations threaten the validity of statistical classifiers. ...
research
08/04/2023

Adapting to Change: Robust Counterfactual Explanations in Dynamic Data Landscapes

We introduce a novel semi-supervised Graph Counterfactual Explainer (GCE...
research
06/02/2021

Towards Robust Classification Model by Counterfactual and Invariant Data Generation

Despite the success of machine learning applications in science, industr...
research
08/09/2018

Counterfactual Normalization: Proactively Addressing Dataset Shift and Improving Reliability Using Causal Mechanisms

Predictive models can fail to generalize from training to deployment env...

Please sign up or login with your details

Forgot password? Click here to reset