Improving the Long-Range Performance of Gated Graph Neural Networks

07/19/2020 ∙ by Denis Lukovnikov, et al. ∙ University of Bonn Ruhr University Bochum 0

Many popular variants of graph neural networks (GNNs) that are capable of handling multi-relational graphs may suffer from vanishing gradients. In this work, we propose a novel GNN architecture based on the Gated Graph Neural Network with an improved ability to handle long-range dependencies in multi-relational graphs. An experimental analysis on different synthetic tasks demonstrates that the proposed architecture outperforms several popular GNN models.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Graph Neural Networks (GNN) form a class of neural network architectures specifically designed to work with graph-structured data. In our work, we focus on multi-relational graphs, where edges are labeled with different edge types. While different GNN variants have been proposed in recent literature, to the best of our knowledge, their ability to capture long-term dependencies in graph data has not been thoroughly investigated. Due to their local aggregation nature, many layers of a GNN must be used to capture long-range patterns (i.e., at least GNN layers are needed to incorporate any information from a node that is hops away in the graph). However, GNNs suffer from decreasing performance when the number of layers is increased. Zhao et al. [pairnorm] point out that this may be due to (1) over-fitting, (2) oversmoothing, and (3) vanishing gradients. Oversmoothing reffers to the phenomenon that node representations become less distinguishable from each other when more and more layers are used in a GNN. Even though oversmoothing (and over-squashing [alon2020bottleneck]) might compound learning difficulties with deep GNNs in general, problems with learning long-range dependencies caused by vanishing gradients could already be an issue in very simple graphs which should suffer less from oversmoothing. Nevertheless, while several works have investigated over-fitting (as for example addressed by CompGCN [compgcn]) and oversmoothing [li2018deeper, chen2019measuring, pairnorm, dropedge, yang2020revisiting], possible vanishing gradient [hochreiter1997long, pascanu2013difficulty] problems in GNNs have so far received less attention.

The vanishing and exploding gradient problems have been studied extensively for recurrent neural networks (RNN), resulting in the development of gated update functions such as those employed in the Long Short Term Memory (LSTM 

[hochreiter1997long]

) and the Gated Recurrent Unit (GRU 

[gru]

), as well as other methods, such as gradient clipping 

[pascanu2013difficulty]. Popular relational GNN architectures (such as the RGCN [rgcn], the CompGCN [compgcn], and the Gated GNN (GGNN) [ggnn]

) may suffer from vanishing gradients when trying to learn long-range patterns in multi-relational graphs. Using gated update functions, such as the GRU in the popular GGNN, can improve learning deep networks as it avoids vanishing gradients in the depth (i.e., vertically or between layers). However, based on the way the gated functions are currently employed in the GGNN architecture, it may still suffer from vanishing gradients with respect to distant nodes (i.e., horizontally) because all backpropagation paths between distant nodes consists of a series of matrix multiplications and

tanh non-linearities.

In this work, we focus on improving the learning of long-range dependencies in multi-relational graphs by tackling the aforementioned vanishing gradient problem and developing a novel GGNN architecture which we show to be advantageous in experiments on different synthetic tasks.

2 Message Passing Networks

Given a graph , a GNN maps each node

onto representation vectors

by repeatedly aggregating the the representation vectors of the immediate neighbourhoods of the nodes and updating node vectors in every step of the encoding process, each associated with one of layers of the GNN. Relational GNNs must also take into account the edge types between nodes in the graph . Following a message passing framework for GNNs (similar to [gilmer2017neural]), a single layer/step of many GNNs can be decomposed into a three-step process:

(1)

where is a function that computes a “message” along a graph edge, aggregates the incoming messages into a single vector, and computes a new representation for node . In the equation, are the immediate neighbours of node in the given graph and denote features associated with the edge from to . After subsequently applying Eq. (1) times111Optionally with different parameters for , and/or in every step. on each node of the graph, the final node representations can be used for different tasks, such as graph or node classification.

A standard message function is a simple matrix multiplication, which is used by RGCN [rgcn] and GGNN [ggnn]:

(2)

where is a matrix of parameters associated with edge type of an edge from to .

GGNNs implement the update function based on Gated Recurrent Units (GRUs) [gru]:

(3)
(4)
(5)
(6)

where is the vector representing the aggregated neighbourhood of node . A simple choice for computing would be

, a sum of edge type-dependent linear transformations of neighbouring node vectors.

3 GNN modeling and long-term dependencies

We develop our method starting from the GGNN [ggnn] architecture, which has been used for different NLP tasks [ggnn, bogin2019representing, beck2018graph, marcheggiani2017encoding]. While the GRU used in GGNNs enables easy backpropagation over a large number of layers “vertically” (from top-level state of a node to its initial state ), it may suffer from vanishing gradients “horizontally” (w.r.t. distant nodes) with increasing depth of the network. In fact, considering the GGNN equations (Eq. (2) and Eqs. (3)-(6)), it becomes clear that every backpropagation path to a distant node can suffer from vanishing or exploding gradients, similarly to a simple (non-gated) RNN, because of repeated matrix multiplications (both in and in ), and nonlinearities.

To achieve well-behaving gradients over a large number of hops, we make extensive use of gated functions to implement additive updates to the node states for integrating information from previous states as well as from neighbours (Section 3.1) and for incorporating edge type information (Section 3.2). We also integrate attention-based aggregation into our model (Section 3.3) and use a vector-based parameterization of relations (Section 3.2) to avoid overfitting and improve memory consumption. These specific changes are elaborated in the following sections. We call the resulting architecture the Symmetrically Gated Graph Neural Network with Relational Vectors and Graph ATtention (SGGNN-RV-GAT).

3.1 Symmetrically Gated Graph Neural Network

For the update function , following the ideas behind the LSTM, we propose to change the GRU equations such that both inputs of the update functions are gated similarly. We call the proposed update function, which is described by the following formulas, a Symmetrically Gated Recurrent Unit (SGRU):

(7)
(8)
(9)
(10)
(11)
(12)
(13)
(14)

The SGRU equations differ from the GRU equations of the GGNN (shown in (3)-(6) in the Appendix) in (i) introducing an additional reset gate that is applied to the aggregated neighbour states and (ii) computing the output state as a three-way mixture between the previous node state , the aggregated neighbour states , and the candidate state instead of a two-way mixture between the previous node state and the candidate state.222 Note, that the mixing gates in our SGRU are similar to those in the Sentence State LSTM [zhang2018sentence]. However, the original formulation of the Sentence State LSTM is not applicable for graphs in general.

3.2 Incorporating edge type information

Using the simple (see Eq. 2

) function for adding edge type information to a node state can be problematic because it may over-parameterize the edge type update, leading to easier over-fitting and larger parameter and memory requirements. A possible solution to reduce the number of parameters is to use a limited number of basis matrices from which a relation matrix is composed or use other decomposition techniques to provide a more parameter-efficient edge type tensor decompositions 

[rgat, rgcn]. Alternatively, CompGCN [compgcn] makes use of a vector-based edge type parameterization. We share this idea, and propose the following gated message function:

(15)
(16)
(17)

where is a vector associated with the type of the edge from node to node , and , and are trainable weight matrices that are shared between different edge types. The function implements a gate that mixes (using ) between the original node state and the relation-aware update .

3.3 Attention-based Neighbourhood Aggregation

To implement the aggregation function , we adapt the scaled multi-head multiplicative attention mechanism [vaswani2017attention] for aggregation. Attention-based neighbourhood aggregation [gat, rgat] has shown to be a useful alternative to the (scaled) sum aggregation of the GGNN or RGCN. The per-head attention distributions are computed as described for Transformers [vaswani2017attention], with the change that the key vectors are computed using both the incoming messages as well as the edge vectors . Also, we do not transform the states to obtain value vectors.

The query and key vectors for each attention head (of in total heads), and , are computed using head-specific linear transformations parameterized by and :

(18)
(19)

where denotes the (vertical) concatenation of two (column) vectors and is the same edge type-specific parameter vector that is used in Eqs. (15)-(17). Note that Eq. (19) is slightly different from a standard attention mechanism, which would compute key vectors as .

The unnormalized attention weight for node and head is computed by a scaled dot product and normalized to attention score :

(20)
(21)

Unlike multi-head attention in transformers and previous work [gat, rgat], we do not perform linear transformations to obtain the value vectors333In our preliminary experiments, using (head-specific) linear transformations to obtain value vectors as in previous work, combined with the SGRU-based update, results in a significant decrease in performance.. Instead, for every head , the value vector is obtained by splitting the vector in equally sized parts (every ) and taking the -th part. The attention scores and value vectors are then used to compute the neighbourhood aggregation vector for head as a weighted sum:

(22)

Then, the full neighbourhood aggregation vector is computed as a concatenation over all heads :

(23)

4 Experimental analysis

In this section, we present our experiments on synthetic tasks that are aimed at assessing the ability of GNNs to model long-term dependencies in multi-relational graphs. Experimenting on synthetic tasks with generated data allows us to evaluate specific aspects of the models in setups where we can fully control the complexity of the problem and can freely generate training examples.

We experiment with a simple sequence classification task (Section 4.2) and a node classification task (Section 4.3). Each experiment is presented in a separate subsection that contains a description of the task and experimental results in comparison to other models and of an ablation study. We consider the following ablations of our full method (SGGNN-RV-GAT) in our experiments: (1) GGNN-RV-GAT replaces the SGRU with a normal GRU, (2) SGGNN-RV-mean replaces the attention based aggregation function with a simple mean (equivalent to uniform attention), and (3) SGGNN-RM-GAT replaces with .

We compare against RGCN, the GGNN and a version of the RGAT described in [rgat]. See Section 4.1 for more details. 444 When using the original formulations of these models in our experiments, we quickly ran out of memory for larger hidden dimensions of the GNN and larger numbers of layers. Gradient accumulation significantly slows down training. To maintain reasonable training speed for our experiments, we replace (see Eq. 2) in our RGCN and GGNN with , where is an edge type specific square matrix of lower dimensionality than from Eq. 2 and and are matrices projecting into and out of ’s dimensionality that are shared for all edge types.

4.1 Baselines

4.1.1 Rgcn

For our RGCN [rgcn] baseline, we rely on the implementation provided by the DGL framework555https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/relgraphconv.py. We experimented with weight sharing between RGCN layers and found it to perform better.

4.1.2 Ggnn

We write our own implementation of the GGNN, basing our implementation on the code provided by the DGL framework666https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatedgraphconv.py. We use variational dropout on the GRU.

4.1.3 Rgat

We adapt  [rgat] for our attention-based baseline (RGAT). Like in [rgat], we use relation-specific transformation matrices and relation and head-specific query and key matrices and .

First, we define relation-dependent representations for a node, which is computed based on its current state :

(24)

Subsequently, for every head , we define the relation-specific query, key and value projections as:

(25)
(26)
(27)

We compute attention between messages , where indexes over all edges in the input graph. First, we compute the attention score for a message as in [vaswani2017attention]:

(28)

where is the message sent along an edge labeled by the relation . Please note that a node may receive multiple messages from a node , and that those messages could contain the same relation.

The scores are normalized over all messages that node receives:

(29)

The normalized scores are used to compute a summary:

(30)

where is the relation of and is the source node id and is the target node id of message .

The updated representation for node

is then the concatenation over all heads, fed through the activation function

:

(31)

We experiment with a ReLU, as well as with a linear

.

4.2 Conditional Recall

In this experiment, we aim to evaluate the ability of the different GNNs to (i) remember node labels over a large number of hops and (ii) learn simple reasoning rules for sequences.

4.2.1 Task Setup:

We define a sequence classification task where given a sequence of characters , the model is asked to predict the correct class based on the computed representation of the last node (corresponding to input ). The input sequences consist of strings of letters and numbers of a given length (which was varied between different experiments). The class of the sequence is determined by the following rules: (i) if there is a digit in the sequence, the first digit corresponds to the class label; (ii) otherwise, if there is an upper case character, the first upper case character is the class label; (iii) otherwise, the class is given by the first character in the sequence. Some examples are: ”abcdefg“ ”a“, ”abcDefg“ ”D“, ”abcd3Fg“ ”3“, ”abCd3fg“ ”3“. We generated 20 examples per output class for a total of 1220 examples and split the data in 80/10/10 train/validation/test splits.

We transform the input sequences into a graph by (i) creating a node for every character of the sequence and (ii) adding edges next and previous between every adjacent element in the sequence. We also add self-edges from each node to itself. Given these edges, the GNN has to use at least layers/steps in order to propagate information from the first node to the last node. Readout for prediction is done by taking the representation of the last node.

Figure 1: Example of input graph for the Conditional Recall task. The top-level state of the blue node is used for prediction. The red node specifies the desired output.

An example of the conditional recall task and its graph representation are given in Figure  1. In this example, the input sequence is “a8cDe”. Each of the characters gets its own node and the nodes are connected using next and previous edges. According to the rules, the answer for this input is “8”. The node corresponding to the output class is highlighted in red in the figure. The readout happens at the node corresponding to the last element in the input sequence, in this case the node with label “e”, which is marked in blue in the figure.

While a simple GRU or LSTM would solve the task well, and different graph representation and readout 777For example, we could add an (artificial) “master” node that is connected to every node in the original graph, which is a common trick used in graph tasks to bypass the long-term reasoning limitations of baseline GNNs. Also, readout by pooling over all nodes is commonly used as well. methods would perform better, here we deliberately choose the above formulation to measure how well the different GNNs can learn over a larger numbers of hops in a simplest graph with large diameter.

4.2.2 Experimental details:

For our method, SGGNN-RVGAT, in our experiments, we vary the dimensionality of the node states, depending on the length of the input sequence of the task ( = 100, = 120,

= 200). For our baselines, we run hyperparameter search for different values of

, varying node state between 100 and 200, dropout rate between 0 and .5. We generally keep our batch size at 20, except for higher numbers of examples per class, where we set batch size to 50. Throughout most of experiments, unless otherwise indicated, we use the Adam [adam] optimizer with a learning rate of 0.001. We also use label smoothing with a factor 0.1.

The initial node states are initialized by embedding the node type using a low-dimensional embedding matrix (dimensionality 20; for a vocabulary of 62 characters) and projecting the low-dimensional embeddings to the node state dimension.

4.2.3 Results and Discussion:

3 5 7 10
RGCN 96.7 78.7 38.5 6.6
GGNN 96.7 74.6 17.2 2.5
RGAT 96.7 90.2 0.0 0.0
SGGNN-RV-GAT 98.4 96.7 91.8 96.7
GGNN-RV-GAT 100.0 18.9 13.1 4.1
SGGNN-RV-mean 95.9 94.3 33.6 11.5
SGGNN-RM-GAT 98.4 98.4 84.4 8.2
Table 1: Conditional recall results for shorter sequences.

We show performance of the different GNNs for different sequence lengths in Table 1. The proposed model retains a high accuracy for longer sequences, demonstrating the ability to (i) learn simple rules over a large number of hops in the graph and (ii) contain a large number of layers (e.g., we employed a GNN with depth 31 for sequences of length 30).

The results for RGCN, GGNN, and RGAT in the first half of Table 1 reveal that these previously proposed methods are able to solve the task to a satisfactory degree only up to length 5. For task lengths 7, their performance degrades drastically and falls under 10% for task length 10. Note that even though the GGNN uses a GRU, as explained before, the GRU is applied “vertically”, and only prevents vanishing gradients in the depth of the network (towards the initial states), not in the width (towards the neighbours).

The results for different ablations of our model in the bottom half of Table 1 show that the three components are necessary to achieve the best performance, with the change from SGRU to GRU resulting in the worst performance decrease.

4.3 Tree Max

In this experiment, we aim to evaluate the ability of a GNN to retrieve node labels over a large number of hops for many nodes simultaneously.

4.3.1 Task Setup:

We define a node classification task on tree-shaped graphs as described in the following. The input graphs are trees with nodes labeled with random integers between 1 and 100. The target labels for node classification are defined as the largest value of all the descendants of a node and the node itself. The graphs contain edges from a parent to its children, and from children to their parent, as well as self-edges. We use numbered child edges and numbered child-of edges, for example :CHILD-1-OF for the edge going from the first child of a node to its parent.

As an example of this task, consider the following input tree:

(1 (2 (3 ) (4 )) (5 (6 ) (7 (8 ) (9 ) (10 ))))

This tree has a depth of four. The nodes in the tree should be labeled with the value of their highest valued descendant. Thus, the tree with its output labels will become:

(10 (4 (3 ) (4 )) (10 (6 ) (10 (8 ) (9 ) (10 ))))

In this case, to correctly predict the output label of the root node (1), the GNN must handle 3 hops. In the case where the tree would have been labeled differently, and the maximum value wouldn’t be in the deepest leaves, the GNN would still have to be able to handle 3 hops to ensure that the deepest leaves are not larger. However, it would also not be noticeable if the GNN can’t reach the deepest leaves from the root since the prediction can be done with less than 3 hops. See Figure 2 for an illustration of the graph representation of this example.

Figure 2: Example of input graph for the Tree Max task. The double-ended arrow between a node and a node represents two edges: one going from to and the other from to . Not every arrow is labeled for clearer presentation. Self-edges are omitted for clearer presentation as well. The labels on the arrows indicate edge labels: the forward arrow corresponds to the label of the edge going from the parent to the child and to the label of the edge going from the child to the parent. The input labels are the black numbers in the node circles and the output labels are the blue numbers in the node circles.

During data generation, we first randomly pick a tree depth between 5 and 15. Then, we generate a tree, choosing between 0, 2 or 3 children for each node until we reach the chosen tree depth. The largest trees in our generated data contained over 200 nodes. The generated dataset contained 800 examples; we used a 50/25/25 train/validation/test split.

For every model tested, we perform a hyperparameter search using a fixed random seed. For all models except our RGAT baseline, we used 17 layers. Then we run the best hyperparameter setting with 5 different random seeds and report the test results in Table 2. 888Note that the same 5 seeds are re-used for experiments for all models, and that every seed results in a different dataset being generated. We used early stopping and reloaded the best model. For more details, please see Appendix 4.3.2.

We evaluated the trained models based on node-wise and graph-wise accuracy999

The graph-level accuracy is 100% for an example only if all nodes in the graph have been classified correctly, and is 0% otherwise

.

4.3.2 Experimental details:

We generate a dataset of 800 examples and split it into training, validation and test sets using 50%, 25% and 25% of the data.

We use a 17-layer network for all examples unless otherwise indicated. Note that in our experiments we use early stopping based on node-level classification accuracy on the validation set with patience set to 10 and the minimum number of training epochs set to 20. We reload the best model weights after training finishes and evaluate on the test set.

In our experiments, we randomly explore different combinations of different hyperparameter settings: dropout is selected from , dimensionality of node vectors is selected from . Learning rate is chosen from . Since our SGGNN-RM-GAT ablation model showed unstable training behaviour, we used a learning rate of 0.00025.

For RGCN, we explore the following numbers of layers/steps: [4, 7, 12, 17]. For our RGAT baseline, we use [10, 14, 17]. Note that when the number of layers/steps is lower than 14, we can not achieve 100% accuracy on the task.

We take the best hyperparameter setting for every model and run each with five different manually pre-defined random seeds. The same seeds are re-used across different models.

4.3.3 Results and Discussion:

node entire graph
RGCN
GGNN
RGAT
SGGNN-RV-GAT
GGNN-RV-GAT
SGGNN-RV-mean
SGGNN-RM-GAT
Table 2: Node and graph-level accuracies over the test set of the Tree Max task.

The results reported in Table 2 show that RGCN and GGNN reach a high accuracy for node classification. For RGAT, we obtained the best results with 5 layers, which is insufficient to capture the longest necessary dependencies. We also quickly ran into memory problems, which required using gradient accumulation and made training much slower than the other models.

The best node-level and graph-level accuracies were obtained using the proposed SGGNN-RV-GAT model. The ablation study indicates all components are essential for achieving the best performance.

Note that for correct node classification, most nodes don’t require long-range information propagation. In fact, only  % of all nodes across all examples in the test portion of our dataset require handling 10 hops or more and only  % require handling 5 hops or more. So even though the graph-level accuracy may exaggerate the node-level errors, we believe both should be considered, in particular regarding long-range performance.

5 Discussion and Conclusion

In this work, we proposed a novel GNN architecture that incorporates three changes into the GGNN to improve the learning of long-term dependencies in multi-relational graphs. We perform experiments on two synthetic tasks and show that the proposed architecture outperforms several popular GNN variants on these tasks. The proposed model beats the baselines by a significant margin and is also more parameter and memory efficient, but can be more computationally expensive because of the use of attention and gated updates. The ablation study shows that all three changes of the proposed model w.r.t. the GNN are essential.

We plan to run further experiments on real datasets to ensure the transferabilty of our observations to the real world setting. Major improvements from using our proposed model instead of the baselines only start to appear with a larger number of layers. Even though this might not be significantly advantageous for all real tasks, we expect our approach to improve upon existing architectures for tasks that require to learn long-term dependencies, especially in more sparsely connected graphs. This may include tasks that use parse trees of natural language sentences, code syntax trees, or knowledge graphs.

References