Log In Sign Up

Debiasing Graph Neural Networks via Learning Disentangled Causal Substructure

by   Shaohua Fan, et al.
HEC Montréal

Most Graph Neural Networks (GNNs) predict the labels of unseen graphs by learning the correlation between the input graphs and labels. However, by presenting a graph classification investigation on the training graphs with severe bias, surprisingly, we discover that GNNs always tend to explore the spurious correlations to make decision, even if the causal correlation always exists. This implies that existing GNNs trained on such biased datasets will suffer from poor generalization capability. By analyzing this problem in a causal view, we find that disentangling and decorrelating the causal and bias latent variables from the biased graphs are both crucial for debiasing. Inspiring by this, we propose a general disentangled GNN framework to learn the causal substructure and bias substructure, respectively. Particularly, we design a parameterized edge mask generator to explicitly split the input graph into causal and bias subgraphs. Then two GNN modules supervised by causal/bias-aware loss functions respectively are trained to encode causal and bias subgraphs into their corresponding representations. With the disentangled representations, we synthesize the counterfactual unbiased training samples to further decorrelate causal and bias variables. Moreover, to better benchmark the severe bias problem, we construct three new graph datasets, which have controllable bias degrees and are easier to visualize and explain. Experimental results well demonstrate that our approach achieves superior generalization performance over existing baselines. Furthermore, owing to the learned edge mask, the proposed model has appealing interpretability and transferability. Code and data are available at:


page 1

page 2

page 3

page 4


Debiased Graph Neural Networks with Agnostic Label Selection Bias

Most existing Graph Neural Networks (GNNs) are proposed without consider...

Generalizing Graph Neural Networks on Out-Of-Distribution Graphs

Graph Neural Networks (GNNs) are proposed without considering the agnost...

On Structural Explanation of Bias in Graph Neural Networks

Graph Neural Networks (GNNs) have shown satisfying performance in variou...

Discovering Invariant Rationales for Graph Neural Networks

Intrinsic interpretability of graph neural networks (GNNs) is to find a ...

Learning Debiased Representation via Disentangled Feature Augmentation

Image classification models tend to make decisions based on peripheral a...

OrphicX: A Causality-Inspired Latent Variable Model for Interpreting Graph Neural Networks

This paper proposes a new eXplanation framework, called OrphicX, for gen...

A Causal View on Robustness of Neural Networks

We present a causal view on the robustness of neural networks against in...

1 Introduction

Graph Neural Networks (GNNs) have exhibited powerful performance on graph data with various applications kipf2016semi; velivckovic2017graph; hamilton2017inductive; fan2019metapath; fan2020one2multi. One major category of applications are the graph classification task, such as molecular graph property prediction hu2020open; lee2018graph; ying2018hierarchical, superpixel graph classification  hendrycks2019benchmarking, and social network category classification zhang2018end; ying2018hierarchical. It is well known that graph classification is usually determined by a relevant substructure, but not the whole graph structure ying2019gnnexplainer; lucic2021cf; yuan2021explainability

. For example, for MNIST superpixel graph classification task, the digit subgraphs are causal (

deterministic) for labels vu2020pgm. The mutagenic property of a molecular graph depends on the functional groups ( nitrogen dioxide ()), rather than the irrelevant patterns ( carbon rings) luo2020parameterized. Therefore, it is a fundamental requirement for GNNs to identify causal substructures, so as to make correct prediction.

Ideally, when the graphs are unbiased, , only the causal substructures are related with the graph labels, the GNNs are able to utilize such substructure to predict the labels. However, due to the uncontrollable data collection process, the graphs are inevitably biased, , existing meaningless substructures spuriously correlates with labels. Taking a colored MNIST superpixel graph dataset in Sec. 3.1 as an example (illustrated in Fig. 1(a)), each category of digit subgraphs mainly correspond to one kind of color background subgraphs, , digit 0 subgraph is related with red background subgraph. Therefore, the color background subgraph will be treated as bias information, which highly correlates with labels but does not determines them in the training set. Under this situation, will GNNs still stably utilize the causal substructure to make decision?

To investigate the impact of bias on GNNs, we conduct an experimental investigation to demonstrate the impact of bias (especially in the severe bias scenarios) on the generalization capability of GNNs (Sec. 3.1). We find that GNNs actually utilize both bias and causal substructures to make prediction. However, with severer bias correlation, even bias substructure still could not exactly determine labels like causal substructure, GNNs majorly utilize bias substructure as shortcuts to make prediction, causing a large generalization performance degradation. Why this happens? We analyze the data-generating process and model prediction mechanism behind the graph classification using a causal graph (Sec. 3.2). The casual graph illustrates that the observed graphs are generated by the causal and bias latent variables and existing GNNs could not distinguish the causal substructure from entangled graphs. How can we disentangle the causal and bias substructures from observed graphs, so that GNNs can only utilize the causal substructures to make stable prediction when severe bias appears?

To address the question, two challenges need to be faced. 1) How to identify the causal substructure and bias substructure in the severe biased graphs? In the severe bias scenarios, bias substructure will be “easier to learn” for GNNs and finally dominate the prediction. Using the normal cross-entropy loss, like DIR wu2022discovering, could not fully capture such aggressive property of bias. 2) How to extract the causal substructure from an entangled graph? The statistically causal substructure is usually determined by the global property of the entire graph population, rather than a single graph. When extracting causal substructure from a graph, we need to establish the relations among all the graphs.

In this paper, we propose a novel debiasing framework for GNNs via learning Disentangled Causal substructure, called DisC

. Given an input biased graph, we propose to explicitly filter edges into causal and bias subgraphs by a parameterized edge mask generator, whose parameters are shared across entire graph population. As a result, the edge masker is naturally capable to specify the importance for each edge and extract causal and bias subgraphs from a global view of the entire observations. Then, a “casual”-aware (weighted cross-entropy) loss and a “bias”-aware (generalized cross-entropy) loss are respectively utilized to supervise two functional GNN modules. Based on the supervision, the edge mask generator could generate corresponding subgraphs and the GNNs could encode corresponding subgraphs into their disentangled embeddings. With the disentangled embeddings, we randomly permute the latent vectors extracted from different graphs to generate more unbiased counterfactual samples in embedding space. The new generated samples still contain both causal and bias information, while their correlation has been decorrelated. In this time, there is only correlation between causal variables with labels, so that the model could concentrate on the true correlation between the causal subgraphs and labels. Our major contributions are as follows:

  • [leftmargin = 15 pt]

  • To our knowledge, we first study the generalization problem of GNNs in a more challenging yet practical scenario, i.e., the graphs are with severe bias. We systematically analyze the bias impact on GNNs from both experimental study and causal analysis. We find that the bias substructure, compared with causal substructure, is much easier to dominate the training of GNNs.

  • To debias GNNs, we develop a novel GNN framework for disentangling causal substructure, which is flexible to build upon various GNNs for improving generalization ability while enjoying inherent interpretability, robustness and transferability.

  • We construct three new datasets with various properties and controllable bias degrees, which can better benchmark the new problem. Our model outperforms the corresponding base models with a large margin (from 4.47% to 169.17% average improvements). Various investigation studies demonstrate that our model could discover and leverage causal substructure for prediction.

2 Related Works

Generalization of GNNs in wild environments. Most existing GNN methods are proposed under the IID hypothesis, training and testing set are independently sampled from the identical distribution scarselli2008graph; kipf2016semi; velivckovic2017graph; hamilton2017inductive; liao2020pac. However, in reality, thus ideal assumption is hard to be satisfied. Recently, several methods have been proposed to improve the generalization ability of GNNs in wild OOD environments. Several works ma2021subgroup; fan2022debiased; wu2022handling study the OOD problem of node classification. For OOD graph classification task, StableGNN fan2021generalizing propose to learn the stable causal relationship in graphs. OOD-GNN li2021ood propose to constrain each dimension of learned embedding independent. DIR wu2022discovering discovers the invariant rationales for generalizing GNNs. Although they have achieved better OOD performance, they are not designed for the datasets with severe bias, which is more challenging for guaranteeing the generalization ability of GNNs.

Disentangled graph neural networks. Recently, there are a couple of methods that study the disentangled GNNs. DisenGCN ma2019disentangled utilizes neighbourhood routing mechanism to divide the neighbours of the node into several mutually exclusive parts. IPGDN liu2020independence promotes DisenGCN by constraining the different parts of the embedding feature independent. DisenGCN and IPGDN are node-level disentanglement, thus FactorGCN yang2020factorizable considers the whole graph information and disentangles the target graph into several factorized graphs. Despite results of the previous works, they do not consider disentangling the causal and bias information for graphs.

General debiasing methods.

Recently, debiasing problem has drawn much attention in machine learning community 

kim2019learning; li2019repair; sagawa2019distributionally; bahng2020learning; cadene2019rubi; geirhos2018imagenet. One category of these methods is pre-defining a certain bias type explicitly to mitigate kim2019learning; li2019repair; sagawa2019distributionally; bahng2020learning; wang2019learning. For example, Wang et al. wang2019learning and Bahng et al. bahng2020learning design a texture- and color-guided model to adversarially train a debiased neural network against the biased one. Instead of defining certain types of bias, recent approaches nam2020learning; darlow2020latent; lee2021learning rely on the straightforward assumption that models are prone to exploit the bias as shortcuts to make prediction geirhos2020shortcut. In the line with the recent studies, our study belongs to the second category. However, most of existing methods are designed for image datasets and could not effectively extract causal substructure from graph data. Distinctly, we first study the severe bias problem on graph data, and our method could effectively extract causal substructure from graph data.

(a) Examples of graphs in CMNIST-75sp.
(b) Performance of GNNs.
Figure 1: Example graphs of CMNIST-75sp and the performance of GNNs on this dataset.

3 Preliminary Study and Analysis

In this section, we first illustrate the existing GNNs tend to exploit the bias substructure as shortcuts for prediction through a motivating experiment. Then we analyze the prediction process of GNNs in causal view. Based on this causal view, it motivates our solution to relieve the impact of bias.

3.1 Motivating Example

To measure the generalization ability of GNNs with the effect of bias, we construct a graph classification dataset with controllable bias degrees, called CMNIST-75sp. We first construct a biased MNIST image dataset like bahng2020learning, where each category of digit highly correlates with a pre-defined color in their background. For example, in the training set, 90% of 0 digits are with red background ( biased samples), and remaining 10% images are with random background color ( unbiased samples), whose the bias degree is 0.9 in this situation. We consider four bias degrees . For the testing set, we construct both biased testing set and unbiased testing set. The biased testing set has the same bias degree with training set, aiming to measure the extent of models relying on bias. The unbiased testing set, where the digit labels uncorrelate with the background colors, aims to test whether the model could utilize the inherent digit signals for prediction. Note that training set and testing set have the same pre-defined color set. Then, we convert the biased MNIST images into superpixel graphs with at most 75 nodes each graph using knyazev2019understanding

, where the edges are constructed by the KNN method based on the 2D coordinates of superpixels and node features are the concatenation of coordinates and average color of superpixels. Each graph is labeled by its digit class, so that its digital subgraph is deterministic for label and background subgraph is spuriously correlated with labels but not deterministic. The examples of graphs are illustrated in Fig. 


We perform three popular GNN methods: GCN kipf2016semi, GIN xu2018how, and GCNII chen2020simple on CMNIST-75sp and the results are shown in Fig. 1(b). The same color of dashed line and solid line represent the results of the corresponding methods on the biased testing set and the unbiased testing set respectively. Overall, the GNNs achieve much better performance on biased testing set than unbiased testing set. The phenomenon indicates that although GNNs could still learn some causal signals for prediction, the unexpected bias information is also being utilized for prediction. More specifically, with bias degree becoming larger, the performance of GNNs on biased testing set is increased and the value of accuracy is nearly in line with the bias degree, while the performance on unbiased testing drops dramatically. Hence, although causal substructure could determine labels perfectly, in severe bias scenarios, the GNNs lean to utilize the easier to learn bias information to make prediction rather than the inherent causal signals, and bias substructure will finally dominate the prediction.

(a) SCM of the union of the data generation and the existing GNNs’ prediction process.
(b) SCM of our debiasing GNN method.
Figure 2: SCMs. Grey and white variables represent unobserved and observed variables, respectively.

3.2 Problem Analysis

Debiasing GNNs for unbiased prediction requires understanding the natural mechanisms of graph classification task. We present a causal view of the union of the data-generating process and the model prediction process behind the task. Here we formalize the causal view as a Structure Causal Model (SCM) or causal graph glymour2016causal; pearl2009causality by inspecting on the causalities among five variables: unobserved causal variable , unobserved bias variable , observed graph , graph embedding , and ground truth label / prediction 222We use variable Y for both the ground-truth labels and prediction, as they are optimized to be the same.. Fig. 2(a) illustrates the SCM, where each link denotes a causal relationship.

  • [leftmargin = 15 pt]

  • . The observed graph data is generated by two unobserved latent variables: the causal variable and the bias variable , such as digit subgraphs and background subgraphs in the CMNIST-75sp dataset. And all bellow relations are illustrated by CMNIST-75sp.

  • . This link means that the causal variable is the only endogenous parent to determine the generation of ground-truth label . For example, is the oracle digit subgraph, which exactly explains why the label is labeled as .

  • . This link indicates the spurious correlation between and . Such probabilistic dependencies is usually caused by the direct cause or unobserved confounder reichenbach1991direction. Here we do not distinguish these scenarios and only observe the spurious correlation between and , such as the spurious correlation between the color background subgraphs and digit subgraphs.

  • . Existing GNNs usually learn the graph embedding based on the observed graph and make the prediction based on the learned embedding .

According to the SCM, GNNs will utilize both information to make prediction. As bias substructure ( background subgraph) usually has simpler structure than meaningful causal substructure ( digit subgraph), if GNN utilizes such simple substructure, it could achieve low loss very fast. Hence, GNN inclines to utilizes bias information when most graphs are biased. Based on the SCM in Fig. 2(a), according to -connection theory pearl2009causality (see App. A): two variables are dependent if they are connected by at least one unblocked path, we could find two paths that would induce the spurious correlation between the bias variable and label : (1) and (2) . To make the prediction being uncorrelated with the bias , we need to intercept the two unblocked paths. For this purpose, we propose to debias GNNs in causal view, as in Fig. 2(b).

  • [leftmargin = 15 pt]

  • and . To intercept the path (1), we should disentangle the latent variables and from the observed graph and make prediction only based on the causal variable .

  • . To intercept the path (2), as we cannot change the link between and , one possible solution is to make and uncorrelated.

4 Methodology

Motivated by the above causal analysis, in this section, we present our proposed debiasing GNN framework DisC, to remove the spurious correlation. The overall framework is shown in Fig. 3. First, an edge mask generator is learnt to mask the edges of original input graphs into causal subgraphs and bias subgraphs. Second, two separate GNN modules with their corresponding masked subgraphs are trained to encode corresponding causal substructure and bias substructure into disentangled representations, respectively. Last, after the disentangled representations are well-trained, we permute the bias representations among the training graphs to generate counterfactual unbiased samples, so that the correlation between causal representations and bias representations is removed.

Figure 3: The overall framework of DisC.

4.1 Causal and Bias Substructure Generator

Given a mini-batch of biased graphs , our idea is that: we take a collection of graph instances and design a generative probabilistic model to learn to mask the edges into causal subgraph or bias subgraph. Particularly, given a graph , where is the adjacency matrix and

is the node feature matrix, we utilize a multi-layer perceptron (MLP) upon the concatenation of node features

of node and of node to measure the importance of edge for causal subgraph:


Then a sigmoid function

is employed to project

into the range of (0,1), which indicates the probability of edge

being the edge in the causal subgraph as follows:


Naturally, we could get the probability of edge being the edge in the bias subgraph by: Now we could construct the causal edge mask and bias edge mask . Finally, we decompose the original graph into causal subgraph and bias subgraph . Intuitively, the edge mask could highlight different part of structure information of original graphs, thus GNNs built on the different subgraphs could encode different parts of graph information. Moreover, the mask generator has two advantages. (1) Global view: In individual graph level, the mask generator (, MLP), whose parameters are shared by all the edges in a graph, take a global view of all the edges in a graph, which enables us to identify community in graph. It is well known that the effect of an edge cannot be judged independently, because edges usually collaborate with each other, forming a community, to make prediction. Thus, it is critical to evaluate an edge in a global view. In whole graph population level, the mask generator takes a global view of all the graphs in the training set, which enables us to identify causal/bias subgraph. Particularly, as the causal/bias is the statistical information in the population level, it is necessary to view all the graphs to identify the causal/bias substructure. Considering both such coalition effects and population-level statistical information, the generator is able to measure the importance of edges more accurately. (2) Generalization: The mask generator can generalize the mechanism of mask generation to new graphs without retraining, so it is capable and efficient to prune unseen graphs.

4.2 Learning Disentangled Graph Representations

Given and , how to ensure they are causal subgraph and bias subgraph, respectively? Inspired by lee2021learning, our approach simultaneously trains a pair of GNNs

with linear classifiers

as follows: (1) Motivated by the observation in Sec. 3.1 that bias substructure is easier to learn, we utilize a bias-aware loss to train a bias GNN and a bias classifier and (2) in contrast, we train a causal GNN and a causal classifier on the training graphs that the bias GNN struggles to learn. Next, we would present each component in detail.

As shown in Fig. 3, GNN and embed the corresponding subgraphs into causal embedding and bias embedding , respectively, where is the parameters of GNNs. Subsequently, concatenated vector is fed into linear classifiers and to predict the target label . To train and as bias extractor, we utilize the generalized cross entropy (GCE) zhang2018generalized loss to amplify the bias of the bias GNN and classifier:


where and are softmax output of the bias classifier and its probability belonging to the target category , respectively, and is the parameters of classifier. Here

is a hyperparameter that controls the degree of amplifying bias. Given

, the gradient of the GCE loss up-weights the gradient of the standard cross entropy (CE) loss for the samples with a high confidence of predicting the correct target category as follows:


Therefore, compared with CE loss, GCE loss will amplify the gradients of on samples by the confidence score . Based on our observation that the bias information is usually easier to be learned, so the biased graphs will have higher than unbiased graphs. Therefore, the model and trained by GCE loss will focus on bias information and finally get the bias subgraph. Note that, to ensure that predicts target labels mainly based on this , the loss from

is not backpropagated to

, only update in Eq. (4), and vice versa.

Meanwhile, we also train a causal GNN simultaneously with the weighted CE loss. The graphs with high CE loss from can be regarded as the unbiased samples compared with the samples with low CE loss. In this regard, we could obtain the unbias score of each graph as


Large value of implies the graph is an unbiased sample, hence we could use these weights to reweight the loss of these graphs to train and , enforcing them to learn the unbiased information. Thus, the objective function for learning disentangled representation is:


4.3 Counterfactual Unbiased Sample Generation

Until now, we have achieved the first goal analyzed in Sec. 3.2 that is the disentanglement of causal and bias substructures. Next, we will show how to achieve the second goal that makes the causal variable and bias variable uncorrelated. Although we have disentangled causal and bias information, they are disentangled from the biased observed graphs. Hence, there will exist statistical correlation between causal and bias variables inheriting from the biased observed graphs. To further decorrelate and , according to the causal relation of data-generating process: , we propose to generate the counterfactual unbiased samples in embedding space by swapping

. More specifically, we randomly permute bias vectors in each mini-batch and obtain

, where represents the randomly permuted bias vectors of . As and in are randomly combined from different graphs, they will have much less correlation than where both are from the same graph. To make and still focus on the bias information, we also swap label as along with , so that the spurious correlation between and still exists. With the generated unbiased samples, we utilize the following loss function to train two GNN modules:


Together with the disentanglement loss, total loss function is defined as:


where is a hyperparameter for weighting the importance of generation component. Moreover, training with more diverse samples would also benefit with better generalization on unseen testing scenarios. Our approach is summarized in App. B. Note that, as we need well-disentangled representations to generate the high-quality unbiased samples, in the early stage of training, we only train the model with

. After certain epochs, we train the model with


5 Experiment

Datasets. We construct three datasets with various properties and bias ratios to benchmark this new problem, where the datasets have clear causal subgraphs making the results explainable. Following CMNIST-75sp introduced in Sec. 3.1, we use the similar way to construct CFashion-75sp and CKuzushiji-75sp datasets based on the Fashion-MNIST xiao2017fashion and Kuzushiji-MNIST clanuwat2018deep datasets. As the causal subgraphs of these two datasets are more complicated (fashion product and hiragana characters), they are more challenging. Due to the page limits, here we set bias degrees as . We report the main results on unbiased test sets. Details are in App. C.1.

Baselines and experimental setup. As DisC is a general framework which could be built on various base GNN models, we select three popular GNNs: GCN kipf2016semi, GIN xu2018how, and GCNII chen2020simple. The corresponding models are termed as , and , respectively. Hence, base models are the most straight baselines. Another kind fo baselines are the causal-inspired GNN method DIR wu2022discovering and StableGNN fan2021generalizing. We also compare against a general debiasing method LDD lee2021learning by replacing its encoder with GNNs. Graph Pooling method DiffPool ying2018hierarchical and graph disentangling method FactorGCN yang2020factorizable

are also compared. To keep fair comparison, our model uses the same GNN architecture and hyperparameters with the corresponding base model. All the experiments are run 4 times with different random seeds and we report the accuracy and the standard error. More details are in App. 


5.1 Quantitative Evaluation

Main results. The overall results are summarized in Table 1, and we have following observations:

  • [leftmargin = 15 pt]

  • DisC has much better generalization ability than base models. DisC outperforms the corresponding base model consistently with a large margin. With heavier biases, our model achieves larger improvements over base models. Specifically, for CMNIST-75sp, CFashion-75sp and CKuzushiji-75sp with smaller bias degree ( 0.8), our model achieves 40.02%, 4.47% and 29.82% average improvements over corresponding base models, respectively. Surprisingly, with severer biases (0.9 and 0.95), DisC achieves 169.17%, 14.67% and 49.35% average improvements over base models on three datasets, respectively. It indicates that the proposed method is a general framework helping existing GNNs against the negative impact of bias.

  • DisC significantly outperforms existing debiasing methods. We notice that DIR could not achieve satisfying results. The reason is that DIR utilizes CE loss to extract bias information, which could not fully capture the property of bias in severe bias scenarios. And DIR sets one fixed threshold to spilt subgraphs, which is suboptimal. StableGNN outperforms their base model DiffPool and achieve competitive results, indicating the effectiveness of their proposed causal variable distinguishing regularizer. However, their framework adjusts data distribution based on the original dataset, it is hard to generate unbiased distribution when the unbiased samples are scarce. DisC could generate more unbiased samples based on the disentangled representations. Moreover, LDD is a general debiasing method which is not designed for graph data. DisC outperforms corresponding LDD variants with average 23.15%, indicating that the seamless joint of global-population-aware edge masker with debiasing disentangle framework is very effective.

Dataset CMNIST-75sp CFashion-75sp CKuzushiji-75sp
Bias 0.8 0.9 0.95 0.8 0.9 0.95 0.8 0.9 0.95
FactorGCN yang2020factorizable
DiffPool ying2018hierarchical
DIR wu2022discovering
StableGNN fan2021generalizing
GCN kipf2016semi
GIN xu2018how
GCNII chen2020simple
Table 1: Graph classification accuracy evaluated on unbiased testing sets, which have same color (bias) set with training set. The best performance within each base model variant is in bold.

Ablation studies. To validate the importance of each module in our method, in Fig. 4, we conduct ablation studies on our variants (w.o. G means without the sample generation module) and the related variants of LDD. The major difference between DisC/w.o. G with LDD /w.o. G is the edge mask module. In most cases, DisC/w.o. G significantly outperforms LDD /w.o. G, indicating the necessity of learning edge mask for graph data. And DisC which has counterfactual sample generation module could further boost the performances based on the disentangled embeddings of DisC/w.o. G. However, LDD seldomly outperforms LDD /w.o. G or even achieves worse performances. That is, generating high-quality counterfactual samples needs well-disentangled causal and bias embeddings. If embeddings are not well-disentangled, counterfactual samples may act as noisy samples, which would prevent models from achieving further improvement. The edge masker could help the model generate well-disentangled embeddings, which is crucial for overall performance.

(a) CMNIST-75sp
(b) CFashion-75sp
(c) CKuzushiji-75sp
Figure 4: Ablation studies of the DisC vs. LDD average over three bias degrees of each dataset.

Robustness on unseen bias. Table 2 reports the results of DisC compared with its corresponding base models on testing set with unseen bias, the pre-defined color (bias) sets of training set and testing set are disjoint. The performances of base models further drop compared with the results on seen bias scenario in Table 1. However, our model still achieves very stable performances, fully demonstrating the generalization ability of our model on agnostic bias scenario.

Dataset CMNIST-75sp CFashion-75sp CKuzushiji-75sp
Bias 0.8 0.9 0.95 0.8 0.9 0.95 0.8 0.9 0.95
Table 2: The results on unseen unbiased testing sets, , the color has not been seen in training set.

Hyperparameter experiments Fig. 5 is the hyperparameter experiments of the degree of amplifying bias in GCE loss and the importance of generation component . For , we fix and vary from . For , we fix and vary from . From the results, we can see that our model achieves stable performance across different values of and . When , it means the GCE loss will nearly reduce to normal CE loss. We can see the performance of is worse than other scenarios, demonstrating the effectiveness of utilizing GCE loss.

(a) q
Figure 5: The hyperparameter experiments of and

5.2 Qualitative Evaluation

Visualization of edge mask. To better illustrate the significant causal and bias subgraphs extracted by DisC, we visualize the original images, original graph, and corresponding causal subgraph and bias subgraph of CMNIST-75sp with 0.9 bias degree in Fig. 6, where the width of edge represents the value of learned weight or . Fig. 6(a) shows the visualization results of testing graphs with the bias (color) that has been seen in the training set. As we can see, our model could discover the causal subgraphs where the most salient edges are in the digital subgraphs. With these causal subgraphs that highlight the structure information of digital, the GNNs will more easily extract this causal information. Fig. 6(b) shows the visualization results of testing graphs with unseen bias. According to the visualization, our model could still discover the causal subgraph outline, indicating our model could recognize causal subgraphs, whether the bias is seen or unseen. The visualization results of CFashion-75sp and CKuzushiji-75sp are shown in App. D.

Figure 6: Visualization of subgraphs extracted by DisC. The width of edge is edge weight or .
(a) labeled by digit.
(b) labeled by color.
(c) labeled by digit.
(d) labeled by color.
Figure 7: Visualization of and with colors labeled by the digit and bias (color) labels. We observe that and are well clustered according to the groundtruth labels and bias labels, respectively.

Projection of disentangled representation. Fig. 7 shows the projection of latent vectors and extracted from the causal GNN and bias GNN of DisC, respectively, using t-SNE t-sne on CMNIST-75sp. Fig. 7 (a-b) are the projections of labeled by the target labels (digit) and bias labels (color), respectively. Fig. 7 (c-d) are the projections of labeled by the target labels and bias labels, respectively. We observe that are clustered according to the target labels while are clustered with the bias labels. And are mixed with bias labels and are mixed with target labels. The results indicate that DisC successfully learns the disentangled causal and bias representations.

Figure 8: Performance of GIN and GCNII on the weighted pruned graphs found by DisC.

Transferability of the learned mask. As our model could extract GNN-independent subgraphs, the learning edge weights can be used to purify original biased graphs. These sparse subgraphs represent significant semantic information and can be universally transferred to any GNNs. To validate this point, we learn the edge mask by DisC and prune the edges with least weights while keeping the remaining edge weights. Then we train vanilla GIN and GCNII on these weighted pruned datasets. Fig. 8 is the comparison of the results, where the dashed lines represent the results of base model on original biased graphs and the solid lines represent the performance of GNNs on weighted pruned datasets. The results show that the GNNs trained on the pruned datasets achieve better performances, indicating our learned edge mask has considerable transferability.

6 Conclusion

In this paper, we are first to study the generalization problem of GNNs on severe bias datasets, which is crucial to study the transparently knowledge learning mechanism of GNNs. We analyze the problem in a causal view that the generalization of GNNs will be hindered by entangled representations as well as the correlation between causal and bias variables. To remove the impact from these two aspects, we propose a general disentangling framework, DisC, which extracts causal substructure and bias substructure by two different functional GNNs, respectively. After the representations are well-disentangled, we proliferate the counterfactual unbiased samples by randomly swapping the disentangled vectors. With the new constructed benchmarks, we clearly validate the effectiveness, robustness, interpretability, and transferability of our method.

This work is supported in part by the National Natural Science Foundation of China (No. U20B2045, 62192784, 62172052, 62002029, 62172052, U1936014). This work is also partially supported by the Natural Sciences and Engineering Research Council (NSERC) Discovery Grant, the Canada CIFAR AI Chair Program, collaboration grants between Microsoft Research and Mila, Samsung Electronics Co., Ltd., Amazon Faculty Research Award, Tencent AI Lab Rhino-Bird Gift Fund and a NRC Collaborative R&D Project (AI4D-CORE-06). This project was also partially funded by IVADO Fundamental Research Project grant PRF-2019- 3583139727. The work of Shaohua Fan is supported by the China Scholarship Council (No.202006470078). The computation resource of this project is supported by Compute Canada333



  1. For all authors…

    1. Do the main claims made in the abstract and introduction accurately reflect the paper’s contributions and scope?

    2. Did you describe the limitations of your work? See App. E.

    3. Did you discuss any potential negative societal impacts of your work? We could not foresee any potential negative societal impacts of our work.

    4. Have you read the ethics review guidelines and ensured that your paper conforms to them?

  2. If you are including theoretical results…

    1. Did you state the full set of assumptions of all theoretical results? See Sec. 3.2.

    2. Did you include complete proofs of all theoretical results?

  3. If you ran experiments…

    1. Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)? We provide the URL of code and data for reproducing the main results.

    2. Did you specify all the training details (e.g., data splits, hyperparameters, how they were chosen)? See Sec. 5 and App. C.

    3. Did you report error bars (e.g., with respect to the random seed after running experiments multiple times)? See Sec. 5.

    4. Did you include the total amount of compute and the type of resources used (e.g., type of GPUs, internal cluster, or cloud provider)? See Sec. 5.

  4. If you are using existing assets (e.g., code, data, models) or curating/releasing new assets…

    1. If your work uses existing assets, did you cite the creators? We construct new assets based on existing assets. We have cited them in Sec. 5.

    2. Did you mention the license of the assets? See App. C.

    3. Did you include any new assets either in the supplemental material or as a URL? We release new assets through a URL in Sec. 5.

    4. Did you discuss whether and how consent was obtained from people whose data you’re using/curating? The source data for generating our data is publicly available.

    5. Did you discuss whether the data you are using/curating contains personally identifiable information or offensive content? We do not have personally identifiable information or offensive content.

  5. If you used crowdsourcing or conducted research with human subjects…

    1. Did you include the full text of instructions given to participants and screenshots, if applicable?

    2. Did you describe any potential participant risks, with links to Institutional Review Board (IRB) approvals, if applicable?

    3. Did you include the estimated hourly wage paid to participants and the total amount spent on participant compensation?

Appendix A Preliminaries of Causal Inference

a.1 Structural Causal Models

In order to rigorously formalize our causal assumption behind the dataset, we resort to the Structural Causal Models, or SCM. SCM is a way of describing the relevant features (variables) of a particular problem and how they interact with each other. In particular, an SCM describes how the system assigns values to variables of interest.

Formally, an SCM consists of a set of exogenous variables and a set of endogenous variables , and a set of functions that determines the values of variables in based on the other variables in the model. Casually, a variable is a direct cause of a variables if exists in the function that determines the value of . If is a direct cause of or of any cause of , is a cause of . Exogenous variables roughly means that they are external to the model, hence, in most scenarios, we choose not to explain how they are caused. Every endogenous variable is a descendant of at least one exogenous variable. Exogenous variables can only be the root variables. If we know the value of every exogenous variable, with the functions in , we can perfectly determine the value of every endogenous variable. In many cases, we usually assume that all exogenous variables are unobserved variables like noise and are independently distributed with an expected value zero, so we only interest with the interaction with endogenous variables. Every SCM is associated with a graphical causal model or simply referred to “casual graph”. Causal graph consists of nodes representing the variables in and , and the direct edges between the nodes representing the functions in . Note the in our SCM in Section 3.2, we only show the endogenous variables we are interested in.

a.2 -separation/connection

Given an SCM, we are particularly interested in (conditional) dependence information that is embedded in the model. There are three basic relationships of variables in an SCM, chains, forks and colliders, as shown in Fig. 9. For chains and forks, and would be dependent if is not in the conditional set, the path is unblocked, and vice versa. And for colliders, and would be independent if is not in the conditional set, the path is blocked. Built upon these rules, -separation is a criterion that can be applied in causal graphs of any complexity in order to predict dependencies that are shared by all datasets generated by the graph [glymour2016causal]. Two nodes and are -separated if every path between them is blocked. If even one path between and is unblocked, and are -connected. Formally, we have following definition of -separation:

Definition 1 (-separation [glymour2016causal])

A path is blocked by a set of nodes if and only if

1. p contains a chain of nodes or a fork such that the middle node is in ( is conditioned on), or

2. contains a collider such that the collision node is not in , and no descendant of is in .

With this principle, we could find that the paths (1) and (2) in Section 3.2 are unblocked paths, which would induce unexpected correlation between bias variable and prediction .

Figure 9: Three basic relations in causal graph.

More knowledge of causal inference, please refer to [glymour2016causal, pearl2009causality].

Appendix B Algorithm

Input : graph , label , max iteration , generation iteration
Output : Learned edge mask generator MLP, two GNN networks , and two classifiers
Initialization : iteration ; Initialize MLP, and
1 while not converged or  do
2       Extract subgraphs and based on edge mask generator;
3       Encode subgraphs and into and via and ;
4       Concatenate ;
5       Update and via in Eq. (6);
6       if :
7        Randomly swap into ;
8        Update and via in Eq. (8);
9       ;
10 end while
Algorithm 1 Disentangled Casual Substructure Learning

Appendix C Experimental Details

c.1 Datasets details

Dataset Causal subgraph type Bias subgraph type #Graphs(train/val/test) #Classes #Avg. Nodes #Avg. Edges Node feat (dim.) Bias degree Difficulties
CMNIST-75sp Digit subgraph Color background subgraph 10K/5K/10K 10 61.09 488.78 Pixel+Coord (5) 0.8/0.9/0.95 Easy
CFashion-75sp Fashion product subgraph Color background subgraph 10K/5K/10K 10 61.03 488.26 Pixel+Coord (5) 0.8/0.9/0.95 Medium
CKuzushiji-75sp Hiragana subgraph Color background subgraph 10K/5K/10K 10 52.87 423.0 Pixel+Coord (5) 0.8/0.9/0.95 Hard
Table 3: Statistics of Biased Graph Classification Datasets.

We summarize statistics of datasets constructed in this paper in Table 3. Note that the bias degree of validation set is 0.5, we use it to adjust the learning rate during training process. Without loss of any generality, here we subsample original 60K training samples into 10K training samples to make the training process more efficient. One could easily construct full dataset with our method. Each graph of CFashion-75sp is labeled by the category of fashion product it belongs to and each graph of CKuzushiji-75sp is labeled by one of 10 Hiragana characters. Moreover, we would like to list the map between label and predefined correlated color for all datasets in Table 4. The links for source image datasets are as follows:

  1. Kuzushiji-MNIST: CC BY-SA 4.0 License.

Label Color (RGB) Label Color (RGB)
0 (255, 0, 0) 5 (0, 255, 255)
1 (0, 255, 0) 6 (255, 128, 0)
2 (0, 0, 255) 7 (255, 0, 128)
3 (225, 225, 0) 8 (128, 0, 255)
4 (225, 0, 225) 9 (128, 128, 128)
Table 4: Mapping between label and color.

For unbiased testing dataset with unseen bias used in Table 2, the RGB value of predefined color set is {(199, 21, 133), (255, 140, 105), (255, 127, 36), (139, 71, 38), (107, 142, 35), (173, 255, 47), (60, 179, 113), (0, 255, 255), (64, 224, 208), (0, 191, 255)}.

c.2 Experimental setup

For GCN and GIN, we use the same model architectures as [hendrycks2019benchmarking]444, which have 4 layers, and 146 hidden dimension for GCN and 110 hidden dimension for GIN. And GIN utilizes its GIN0 variant. For GCNII, it has 4 layers and 146 hidden dimension. DIR555 utilizes the default parameters in original paper for MNIST-75sp dataset. For causal GNN or bias GNN in our model, it has the same architecture with base model. We optimize all models with the Adam [kingma2014adam] optimizer and 0.01 learning rate with for all experiments. The batch-size for all the methods is 256. We train all the models with 200 epochs and set the generation iteration of our method as 100. For our model, we set of GCE loss as and

as 10 for all experiments. Our substructure generator is a two-layer MLP, whose activation function is sigmoid function. For StableGNN, we use their GraphSAGE variant. For other baselines, we use their default hyperparameters. LDD

666 has same hyparameters with our model. To better reflect the performance of unbiased sample generation, we take the performances of last step as final results. All the experiments are conducted on the single NVIDIA V100 GPU.

Appendix D Visualization of CFashion-75sp and CKuzushiji-75sp

Figure 10 and Figure 11 are visualization results of CFashion-75sp and CKuzushiji-75sp dataset. As we can see, our model could also discover reasonable causal subgraphs for these challenging datasets.

Figure 10: Visualization of subgraphs extracted by the mask generator from CFashion-75sp.
Figure 11: Visualization of subgraphs extracted by the mask generator from CKuzushiji-75sp.

Appendix E Limitations and societal impacts

Our method assumes that the graph consists of causal subgraph and bias subgraph. In reality, it may also exist non-informative subgraphs, which is neither causal nor biased for label. We would like to consider more fine-grained splitting of graphs in the future. When deploying GNNs to real-world applications, especially safety-critical fields, whether the results of GNNs are stable is an important factor. The demands for a stable model are universal and extensive such as in the field of disease prediction [sun2020disease], traffic states prediction [cui2019traffic], and financial applications [wang2021review], where utilizing human-understandable causal knowledge for prediction is necessary.