GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs

03/20/2018 ∙ by Jiani Zhang, et al. ∙ Amazon Microsoft 0

We propose a new network architecture, Gated Attention Networks (GaAN), for learning on graphs. Unlike the traditional multi-head attention mechanism, which equally consumes all attention heads, GaAN uses a convolutional sub-network to control each attention head's importance. We demonstrate the effectiveness of GaAN on the inductive node classification problem. Moreover, with GaAN as a building block, we construct the Graph Gated Recurrent Unit (GGRU) to address the traffic speed forecasting problem. Extensive experiments on three real-world datasets show that our GaAN framework achieves state-of-the-art results on both tasks.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Many crucial machine learning tasks involve graph structured datasets, such as classifying posts in a social network 

(Hamilton et al., 2017a), predicting interfaces between proteins (Fout et al., 2017) and forecasting the future traffic speed in a road network (Li et al., 2018). The main difficulty in solving these tasks is how to find the right way to express and exploit the graph’s underlying structural information. Traditionally, this is achieved by calculating various graph statistics like degree and centrality, using graph kernels, or extracting human engineered features (Hamilton et al., 2017b).

Recent research, however, has pivoted to solving these problems by graph convolution (Duvenaud et al., 2015, Atwood and Towsley, 2016, Kipf and Welling, 2017, Fout et al., 2017, Hamilton et al., 2017a, Veličković et al., 2018, Li et al., 2018), which generalizes the standard definition of convolution over a regular grid topology (Gehring et al., 2017, Krizhevsky et al., 2012) to ‘convolution’ over graph structures. The basic idea behind ‘graph convolution’ is to develop a localized parameter-sharing operator on a set of neighboring nodes to aggregate a local set of lower-level features. We refer to such an operator as a graph aggregator (Hamilton et al., 2017a) and the set of local nodes as the receptive field

of the aggregator. Then, by stacking multiple graph aggregators, we build a deep neural network 

(LeCun et al., 2015) model which can be trained end-to-end to extract the local and global features across the graph. Note that we use the spatial definition instead of the spectral definition (Hammond et al., 2011, Bruna et al., 2014) of graph convolution because the full spectral treatment requires eigendecomposition of the Laplacian matrix, which is computationally intractable on large graphs, while the localized versions (Defferrard et al., 2016, Kipf and Welling, 2017) can be interpreted as graph aggregators (Hamilton et al., 2017a).

Graph aggregators are the basic building blocks of graph convolutional neural networks. A model’s ability to capture the structural information of graphs is largely determined by the design of its aggregators. Most existing graph aggregators are based on either pooling over neighborhoods 

(Kipf and Welling, 2017, Hamilton et al., 2017a) or computing a weighted sum of the neighboring features (Monti et al., 2017). In essence, functions that are permutation invariant and can be dynamically resizing are eligible graph aggregators. One class of such functions is the neural attention network (Bahdanau et al., 2015)

, which uses a subnetwork to compute the correlation weight of the elements in a set. Among the family of attention models, the multi-head attention model has been shown to be effective for machine translation tasks 

(Lin et al., 2017, Vaswani et al., 2017). It has later been adopted as a graph aggregator to solve the node classification problem (Veličković et al., 2018)

. A single attention head sums the elements that are similar to the query vector in one representation subspace. Using multiple attention heads allows exploring features in different representation subspaces, which can provide more modeling power in nature. However, treating each attention head equally loses the opportunity to benefit from some attention heads which are inherently more important than others.

To this end, we propose the Gated Attention Networks (GaAN) for learning on graphs. GaAN uses a small convolutional subnetwork to compute a soft gate at each attention head to control its importance. Unlike the traditional multi-head attention that admits all attended contents, the gated attention can modulate the amount of attended content via the introduced gates. Moreover, since only a simple and light-weighted subnetwork is introduced in constructing the gates, the computational overhead is negligible and the model is easy to train. We demonstrate the effectiveness of our new aggregator by applying it to the inductive node classification problem. We also improve the sampling strategy introduced in (Hamilton et al., 2017a) to reduce the memory cost and increase the run-time efficiency, in order to train our model and other graph aggregators on relatively large graphs. Furthermore, since our proposed aggregator is very general, we extend it to construct a Graph Gated Recurrent Unit (GGRU), which is directly applicable for spatiotemporal forecasting problem. Extensive experiments on two node classification datasets, PPI and Reddit (Hamilton et al., 2017a), and one traffic speed forecasting dataset, METR-LA (Li et al., 2018), show that GaAN consistently outperforms the baseline models and achieves the state-of-the-art performance.

In summary, our main contributions include: (a) a new multi-head attention-based aggregator with additional gates on the attention heads; (b) a unified framework for transforming graph aggregators to graph recurrent neural networks; and (c) the state-of-the-art prediction performance on three real-world datasets.

2 Notations

We denote vectors with bold lowercase letters, matrices with bold uppercase letters and sets with calligraphy letters. We denote a single fully-connected layer with a non-linear activation as , where are the parameters. Also,

with different subscripts mean different transformation parameters. For activation functions, we denote

to be the LeakyReLU activation (Xu et al., 2015a) with negative slope equals to 0.1 and to be the sigmoid activation.

means applying no activation function after the linear transform. We denote

as the concatenation operation and as sequentially concatenating through . We denote the Hadamard product as ‘’ and the dot product between two vectors as .

3 Related Work

In this section, we will review relevant research on learning on graphs. Our model is also related to many graph aggregators proposed by previous work. We will discuss these aggregators in Section 4.3.

Neural attention mechanism

 Neural attention mechanism is widely adopted in deep learning literature and many variants have been proposed 

(Chorowski et al., 2014, Xu et al., 2015b, Seo et al., 2017, Vaswani et al., 2017). Among them, our model takes inspiration from the multi-head attention architecture proposed in (Vaswani et al., 2017). Given a query vector and a set of key-value pairs , a single attention head computes a weighted combination of the value vectors . The weights are generated by applying softmax to the inner product between the query and keys, i.e., . In the multi-head case, the outputs of different heads are concatenated to form an output vector with fixed dimensionality. The difference between the proposed model, GaAN, and the multi-head attention mechanism is that we compute additional gates to control the importance of each head’s output.

Graph convolutional networks on large graph Applying graph convolution on large graphs is challenging because the memory complexity is proportional to the total number of nodes, which could be hundreds of thousands of nodes in large graphs (Hamilton et al., 2017a). To reduce memory usage and computational cost, (Hamilton et al., 2017a) proposed the GraphSAGE framework that uses a sampling algorithm to select a small subset of the nodes and edges. On each iteration, GraphSAGE first uniformly samples a mini-batch of nodes. Then, for each node, only a fixed number of neighborhoods are selected for aggregation. More recently, Chen et al. (Chen et al., 2018) proposed a new sampling method that randomly samples two sets of nodes according to a proposed distribution. However, this method is only applicable to one aggregator, i.e., the Graph Convolutional Network (GCN) (Kipf and Welling, 2017).

Graph convolution networks for spatiotemporal forecasting Recently, researchers have applied graph convolution, which is commonly used for learning on static graphs, to spatiotemporal forecasting. (Seo et al., 2016) proposed Graph Convolutional Recurrent Neural Network (GCRNN), which replaced the fully-connected layers in LSTM (Hochreiter and Schmidhuber, 1997) with the ChebNet operator (Defferrard et al., 2016), and applied it to a synthetic video prediction task. Li et al. (Li et al., 2018) proposed Diffusion Convolutional Recurrent Neural Network (DCRNN) to address the traffic forecasting problem, where the goal is to predict future traffic speeds in a sensor network given historic traffic speeds and the underlying road graph. DCRNN replaces the fully-connected layers in GRU (Chung et al., 2014) with the diffusion convolution operator (Atwood and Towsley, 2016). Furthermore, DCRNN takes the direction of graph edges into account. The difference between our GGRU with GCRNN and DCRNN is that we have proposed a unified method for constructing a recurrent neural network based on an arbitrary graph aggregator rather than proposing a single model.

4 Gated Attention Networks

In this section, we first give a generic formulation of graph aggregators followed by the multi-head attention mechanism. Then, we introduce the proposed gated attention aggregator. Finally, we review the other kinds of graph aggregators proposed by previous work and explain their relationships with ours.

Generic formulation of graph aggregators  Given a node and its neighboring nodes , a graph aggregator is a function in the form of , where and are the input and output vectors of the center node . is the set of the reference vectors in the neighboring nodes and is the learnable parameters of the aggregator. In this paper, we do not consider aggregators that use edge features. However, it is straightforward to incorporate edges in our definition by defining to contain the edge feature vectors .

4.1 Multi-Head Attention Aggregator

We linearly project the center node feature to get the query vector and project the neighboring node features to get the key and value vectors. We then apply the multi-head attention mechanism (Vaswani et al., 2017) to get the final aggregation function. The detailed formulation of the multi-head attention aggregator is as follows:


Here, is the number of attention heads. is the th attentional weights between the center node and the neighboring node , which is generated by applying a softmax to the dot product values. , and are the parameters of the th head for computing the query, key and value vectors, which have dimensions of , and respectively. The attention outputs are concatenated with the input vector and pass to an output fully-connected layer parameterized by to get the final output , which has dimension . The difference between our aggregator and that in GAT (Veličković et al., 2018) is that we have adopted the key-value attention mechanism and the dot product attention while GAT does not compute additional value vectors and uses a fully-connected layer to compute .

Figure 1: Illustration of a three-head gated attention aggregator with two center nodes in a mini-batch. and respectively. Different colors indicate different attention heads. Gates in darker color stands for larger values. (Best viewed in color)
(a) Attention Aggregator
(b) Gated Attention Aggregator
(c) Pooling Aggregator
(d) Pairwise Sum Aggregator
Figure 2: Comparison of different graph aggregators. The aggregators are drawn for only one aggregation step. The nodes in red are center nodes and the nodes in blue are neighboring nodes. The bold black lines between the center node and neighbor nodes indicate that a learned pairwise relationship is used for calculating the relative importance. The oval in dash line around the neighbors means the interaction among neighbors is utilized when determining the weights. (Best viewed in color)

4.2 Gated Attention Aggregator

While the multi-head attention aggregator has the ability to explore multiple representation subspaces between the center node and its neighborhoods, not all of these subspaces are equally important; some subspaces may not even exist for certain nodes. Feeding the output of an attention head that captures a useless representation can mislead the model’s final prediction.

Therefore, we compute an additional soft gate between 0 (low importance) and 1 (high importance) to assign different importance to each head. In combination with the multi-head attention aggregator, we get the formulation of the gated attention aggregator:


where is a scalar, the gate value of the th head at node . To make sure adding gates will not introduce too many additional parameters, we use a convolutional network that takes the center node and neighboring node features to generate the gate values. All the other parameters have the same meanings as in Eqn. (1).

There are multiple possible designs of the

network. In this paper, we combine average pooling and max pooling to construct the network. The detailed formula is given below:


Here, maps the neighbor features to a dimensional vector before taking the element-wise max and maps the concatenated features to the final gates. By setting a small , the subnetwork for computing the gate will have negligible computational overhead. A visual illustration of GaAN aggregator’s structure can be found in Figure 1. Also, we compare the general structures of the multi-head attention aggregator and the gated attention aggregator in Figure 1(a) and Figure 1(b).

4.3 Other Graph Aggregators

Most previous graph aggregators except attention-based aggregators can be summarized into two general categories: graph pooling aggregators and graph pairwise sum aggregators. In this section, we first describe these two types of aggregators and then explain their relationship with the attention-based aggregator. Finally, we give a list of the baseline aggregators other than the multi-head attention aggregator used in the paper.

Graph pooling aggregators The main characteristic of graph pooling aggregators is that they do not consider the correlation between neighboring nodes and the center node. Instead, neighboring nodes’ features are directly aggregated and the center node’s feature is simply concatenated or added to the aggregated vector and then passed through an output function :


Here, the projection function and the output function can be a single fully-connected layer and the pool() operator can be average pooling, max pooling or sum pooling..

The majority of existing graph aggregators are special cases of the graph pooling aggregators. Some models only integrate the node features of neighborhoods (Duvenaud et al., 2015, Kipf and Welling, 2017, Hamilton et al., 2017a), while others integrated edge features as well (Atwood and Towsley, 2016, Fout et al., 2017, Schütt et al., 2017). In Figure 1(c), we illustrate the architecture of the graph pooling aggregators.

Graph pairwise sum aggregators Like attention-based aggregators, graph pairwise sum aggregators also aggregate the neighborhood features by taking weighted sums. The difference is that the weight between node and its neighbor is not related to the other neighbors in . The formula of graph pairwise sum aggregator is given as follows:


Here, is only related to the pair and while in attention-based models is related to features of all neighbors . Models like the adaptive forget gate strategy in Graph LSTM (Liang et al., 2016) and MoNet (Monti et al., 2017) employed pairwise sum aggregators with a single head or multiple heads. In Figure 1(d), we illustrate the architecture of the graph pairwise sum aggregators.

Baseline aggregators  To fairly evaluate the effectiveness of GaAN against previous work, we choose two representative aggregators in each category as baselines:

  • Avg. pooling:

  • Max pooling:

  • Pairwise + sigmoid:

  • Pairwise + tanh: Replace the sigmoid activation in Pairwise + sigmoid to tanh.

5 Inductive Node Classification

5.1 Model

In the inductive node classification setting, every node is assigned one or multiple labels. During training, the validation and testing nodes are not observable and the goal is to predict the labels of the unseen testing nodes. Our approach follows that of (Hamilton et al., 2017a), where a mini-batch of nodes are sampled on each iteration during training and multiple layers of graph aggregators are stacked to compute the predictions.

With a stack of layers of graph aggregators, we will first sample a mini-batch of nodes and then recursively expand to be by sampling the neighboring nodes of . After sampling steps, we can get a hierarchy of node batches: . The node representations, which are initialized to be the node features, will be aggregated in reverse order from to . The representations of the last layer, i.e., the final representations of the nodes in , are projected to get the output. We use the sigmoid activation for multi-label classification and the softmax activation for multi-class classification. Also, we use the cross-entropy loss to train the model.

Strategy/Sample Step
Sample without merge 512 7.8K 124.4K 1.9M
Sample and merge 512 7.5K 70.7K 0.2M
Table 1: Effect of the merge operation. Both methods sample a maximum of 15 neighborhoods without replacement for three recursive steps on the Reddit dataset. We start from 512 seed nodes. The total number of nodes after the th sampling step is denoted as . The sampling process is repeated for ten times and the mean is reported.

A naive sampling algorithm is to always sample all neighbors. However, it is not practical on large graphs because the memory complexity is and the time complexity is , where and are the total number of nodes and edges. Instead, similar to GraphSAGE, we only sample a subset of the neighborhoods for each node. In our implementation, at the th sampling step, we sample neighbors without replacement for the node , where

is a hyperparameter that controls the maximum number of sampled neighbors at the

th step. Moreover, to improve over GraphSAGE and further reduce memory cost, we merge repeated nodes that are sampled from different seeds’ neighborhoods within each mini-batch. This greatly reduces the size of s as shown in Table 1.

Note that is not the same for all the nodes

. Rather than padding the sampled neighborhood set to the same size, we implemented new GPU kernels that directly operate on inputs with variable lengths to accelerate computations.

5.2 Experimental Setup

We performed a thorough comparison of GaAN with the state-of-the-art models, five aggregator-based models in our framework and a two-layer fully connected neural network on the PPI and Reddit datasets (Hamilton et al., 2017a). The five baseline aggregators include the multi-head attention aggregator, two pooling based aggregators and two pairwise sum based aggregators mentioned in Section 4.3. We also conducted comprehensive ablation analysis of these two datasets.

The PPI dataset was collected from the molecular signatures database (Subramanian et al., 2005). Each node represents a protein and edges represent the interaction between proteins. Labels represent the cellular functions of each protein from gene ontology. This dataset contains 24 sub-graphs, with 20 in the training set, two in the validation set, and two in the testing set. Reddit is an online discussion forum where users can post and discuss contents on different topics. Each node represents a post and two nodes are connected if they are commented by the same user. The labels indicate which community a post belongs to. Detailed statistics of the datasets are listed in Table 2.

Data #Nodes #Edges #Fea #Classes
PPI 56.9K 806.2K 50 121(multi)
Reddit 233.0K 114.6M 602 41(single)
Table 2: Datasets for inductive node classification. ‘multi’ stands for multilabel classification and ‘single’ otherwise.

5.3 Model Architectures and Implementation Detail

The GaAN and other five aggregator-based networks are stacked with two graph aggregators. Each aggregator is followed by the LeakyReLU activation with negative slope equals to 0.1 and a dropout layer with dropout rate set to be 0.1. The output dimension of all layers are fixed to be 128 except when we compare the relative performance with different output dimensions. To keep the number of parameters comparable for the multi-head models with a different number of heads, we fix the product of the dimension of the value vector and the number of heads, i.e., to be the same when evaluating the effect of varying the number of heads. Also, the hyperparameters of the first and the second layer are assumed to be the same if no special explanation is given.

In the PPI experiments, both pooling aggregators have , where means the dimensionality of the value vector projected by . For the pairwise sum aggregators, the dimension of the keys is set to be 24, and . For both GaAN and the multi-head attention based aggregator, is set to be 24 and the product is fixed to be 256. For GaAN, we set to be 64 in the gate-generation network. Also, we use the entire neighborhoods in the mini-batch training algorithm.

In the Reddit experiments, both pooling aggregators have . For the pairwise sum aggregators, , and . For the attention based aggregators, is set to be 32 and is fixed to be 512. We set the gate-generation network in GaAN to have . Also, the number of heads is fixed to 1 in the first layer for both attention-based models. The maximum number of sampled neighbors in the first and second sampling steps are denoted as and and are respectively set to be 25 and 10 in the main experiment. In the ablation analysis, we also look at the performance when setting them to be (50, 20), (100, 40) and (200, 80).

To illustrate the effectiveness of incorporating graph structures, we also evaluate a two-layer fully-connected neural network with the hidden dimension of 1024 and ReLU activation.

Models / Datasets PPI Reddit
GraphSAGE (Hamilton et al., 2017a) (61.2)111The performance reported in the paper is relatively low because the author has not trained their model into convergence. Also, it is not fair to compare it with the other scores because it uses the sampling strategy while the others have not. 95.4
GAT (Veličković et al., 2018) 97.3 0.2 -
Fast GCN (Chen et al., 2018) - 93.7
2-Layer FNN 54.070.06 73.580.09
Avg. pooling 96.850.19 95.780.07
Max pooling 98.390.05 95.620.03
Pairwise+sigmoid 98.390.05 95.860.08
Pairwise+tanh 98.320.18 95.800.03
Attention-only 98.460.09 96.190.07
GaAN 98.710.02 96.360.03
Table 3: Summary of different models’ test micro F1 scores in the inductive node classification task. In the first block, we include the best-reported results in the previous papers. In the second block, we report the results obtained by our models. For the PPI dataset, we do not use any sampling strategies. For the Reddit dataset, we use the maximum number sampling strategy with =25 and =10.
Models Reddit PPI
#Param #Param
25,10 50,20 100,40 200,80 all, all
2-Layer FNN 1.71M 73.580.09 73.580.09 73.580.09 73.580.09 1.23M 54.070.06
Avg. pooling 866K 95.780.07 96.110.07 96.280.05 96.350.02 274K 96.850.19
Max pooling 866K 95.620.03 96.060.09 96.180.11 96.330.04 274K 98.390.05
Pairwise+sigmoid 965K 95.860.08 96.190.04 96.330.05 96.380.08 349K 98.390.05
Pairwise+tanh 965K 95.800.03 96.110.05 96.260.03 96.360.04 349K 98.320.18
Attention-only-K1 562K 96.150.06 96.400.05 96.480.02 96.540.07 168K 96.310.08
Attention-only-K2 571K 96.190.07 96.400.04 96.520.02 96.570.02 178K 97.360.08
Attention-only-K4 587K 96.110.06 96.400.02 96.490.03 96.560.02 196K 98.090.07
Attention-only-K8 620K 96.100.03 96.380.01 96.500.04 96.530.02 233K 98.460.09
GaAN-K1 620K 96.290.05 96.500.08 96.670.04 96.730.05 201K 96.950.09
GaAN-K2 629K 96.330.02 96.590.02 96.710.05 96.820.05 211K 97.920.05
GaAN-K4 645K 96.360.03 96.600.03 96.730.04 96.830.03 230K 98.420.02
GaAN-K8 678K 96.310.13 96.600.02 96.750.03 96.790.08 267K 98.710.02
Table 4: Comparison of the test F1 score on the Reddit and PPI datasets with different sampling neighborhood sizes and attention head number . and are the maximum number of sampled neighborhoods in the 1st and 2nd sampling steps. ‘all’ means to sample all the neighborhoods.

We train all the aggregator-based models with Adam (Kingma and Ba, 2015) and early stopping on the validation set. Besides, we use the validation set to perform learning rate decay scheduler. For Reddit, before training we normalize all the features and project all the features to a hidden dimension of 256. The initial learning rate is 0.001 and gradually decreases to 0.0001 with the decay rate of

each time the validation F1 score does not decrease in a window of 4 epochs and early stopping occurs for 10 epochs. The gradient normalization value clips no larger than 1.0. For the PPI dataset, all the input features are projected to a 64-dimension hidden state before passing to the aggregators. The learning rate begins at 0.01 and decays to 0.001 with the decay rate of 0.5 if the validation F1 score does not increase for 15 epochs and stops training for 30 epochs.

The training batch size is fixed to be . Also, in all experiments, we use the validation set to select the optimal hyperparameters for training. The training, validation, and testing splits are the same as that in (Hamilton et al., 2017a)

. The micro-averaged F1 score is used to evaluate the prediction accuracy for both datasets. We repeat the training five times for Reddit and three times for PPI with different random seeds and report the average test F1 score along with the standard deviation.

5.4 Main Results

We compare our model with the previous state-of-the-art methods on inductive node classification. This includes GraphSAGE (Hamilton et al., 2017a), GAT (Veličković et al., 2018), and FastGCN (Chen et al., 2018). The GraphSAGE model used a 2-layer sample and aggregate model with a neighborhood size of and without dropout. The 3-layer GAT model consisted of 4, 4 and 6 heads in the first, second and third layer respectively. Each attention head had 256 dimensions. GAT did not use neighborhood sampling, L2 regularization, or dropout. The FastGCN model is a fast version of the 3-layer, 128-dimension GCN with sampled neighborhood size being 400, 100, and 400 for each layer and no sampling is done during testing. Table 3 summarizes all results of the state-of-the-art models as well as the models proposed in this paper. We denote the multi-head attention aggregator as ‘Attention-only’ in the tables and figures. We find that the proposed model, GaAN, achieves the best F1 score on both benchmarks and the other baseline aggregators can also show competitive results to the state-of-the-art. We note that aggregator-based models achieve much higher F1 score than the fully-connected model, which demonstrate the effectiveness of the graph aggregators. Our max pooling and avg. pooling baselines have higher scores on Reddit than that in the original GraphSAGE paper. This mainly contributes to our usage of dropout and the LeakyReLU activation.

5.5 Ablation Analysis

We ran a quantity of ablation experiments to analyze the performance of different graph aggregators when different hyperparameters were used. We also visualized the gates of the GaAN model.

Effect of the number of attention heads and the sample size We compare the performance of the aggregators when a different number of attention heads and sampling strategies are used. Results are shown in Table 4. We find that attention-based models consistently outperform pooling and pairwise sum based models with the fewer number of parameters, which demonstrates the effectiveness of the attention mechanism in this task. Moreover, GaAN consistently beats the multi-head attention model with the same number of attention heads . This proves that adding additional gates to control the importance of the attention heads is beneficial to the final classification performance. From the last two row blocks of Table 4, we note that increasing the number of attention heads will not always produce better results on Reddit. In contrast, on PPI, the larger the , the better the prediction results.

Also, we can see steady improvement with larger sampling sizes, which is consistent with the observation in (Hamilton et al., 2017a).

(a) Performance of different models with a varying number of output dimensions on PPI.
(b) Visualization of 8 gate values of 5 example nodes on Reddit. Each row represents a learned gate vector for one node.
Figure 3: Ablation analysis on PPI and Reddit

Effect of output dimensions in the PPI dataset We changed the output dimension to be 64, 96 and 128 in the models for training in the PPI dataset. The test F1 score is shown in Figure 2(a). All multi-head models have =8. We find that the performance becomes better for larger output dimensions and the proposed GaAN consistently outperforms the other models.

Visualization of gate values In Figure 2(b), we visualized the gate values of five different nodes output by the GaAN-K8 model trained on the Reddit dataset. It illustrates the diversity of the learned gate combinations for different nodes. In most cases, the gates vary across attention heads, which shows that the gate-generation network can be learned to assign different importance to different heads.

6 Traffic Speed Forecasting

6.1 Graph Gru

Following (Lin et al., 2017), we formulate traffic speed forecasting as a spatiotemporal sequence forecasting problem where the input and the target are sequences defined on a fixed spatiotemporal graph, e.g., the road network. To simplify notations, we denote as applying the aggregator for all nodes in , i.e., . Based on a given graph aggregator , we can construct a GRU-like RNN structure using the following equations:


Here, are the input features and are the hidden states of the nodes at the th timestamp. is the total number of nodes, is the dimension of the input and is the dimension of the state. and are the update gate and reset gate that controls how is calculated. is the graph that defines the connection structure between different nodes.

Figure 4: Illustration of the encoder-decoder structure used in the paper. We use two layers of Graph GRUs to predict a length-3 output sequence based on a length-2 input sequence. ‘SS’ denotes the scheduled sampling step.
Models / T 15 min 30 min 60 min Average
FC-LSTM (Li et al., 2018) 3.44 6.30 9.6% 3.77 7.23 10.9% 4.37 8.69 13.2% 3.86 7.41 11.2%
GCRNN (Li et al., 2018) 2.80 5.51 7.5% 3.24 6.74 9.0% 3.81 8.16 10.9% 3.28 6.80 9.13%
 (Li et al., 2018) 2.77 5.38 7.3% 3.15 6.45 8.8% 3.60 7.60 10.5% 3.17 6.48 8.87%
Avg Pool 2.79 5.42 7.26% 3.20 6.52 8.84% 3.69 7.69 10.73% 3.22 6.54 8.94%
Max Pool 2.77 5.36 7.21% 3.18 6.45 8.78% 3.69 7.73 10.80% 3.21 6.51 8.93%
Pairwise + Sigmoid 2.76 5.36 7.14% 3.18 6.46 8.72% 3.70 7.73 10.77% 3.22 6.52 8.88%
Pairwise + Tanh 2.76 5.34 7.14% 3.18 6.46 8.73% 3.70 7.73 10.73% 3.21 6.51 8.87%
Attention-only 2.74 5.33 7.09% 3.16 6.45 8.69% 3.67 7.61 10.77% 3.19 6.49 8.85%
GaAN 2.71 5.24 6.99% 3.12 6.36 8.56% 3.64 7.65 10.62% 3.16 6.41 8.72%
Table 5: Performance comparison of different models for traffic speed forecasting on the METR-LA dataset. Models marked with ‘’ treat sensor map as a directed graph while other models convert it into an undirected graph. Scores under “min” are the scores at the th predicted frame. The last three columns contain the average scores of the 15 min, 30 min, and 60 min forecasting horizons.

We refer to this RNN structure as Graph GRU (GGRU). GGRU can be used as the basic building block for RNN encoder-decoder structure (Lin et al., 2017) to predict the future K steps of traffic speeds in the sensor network based on the previous steps of observed traffic speeds . In the decoder, we use the scheduled sampling technique described in (Lin et al., 2017). Figure 4 illustrates the encoder-decoder structure in the paper. When attention-based aggregators are used, i.e., the multi-head attention aggregator or our GaAN aggregator, the connection structure in the recurrent step will also be learned based on the attention process. This can be viewed as an extension of Trajectory GRU (TrajGRU) (Shi et al., 2017) on irregular, graph-structured data.

Data #Nodes #Edges #Timestamps
METR-LA 207 1,515 34,272
Table 6: The Dataset used for traffic speed forecasting.

6.2 Experimental Setup

To evaluate the proposed GGRU model on traffic speed forecasting, we use the METR-LA dataset from (Li et al., 2018). The dataset contains traffic information of the highways of Los Angeles County. The nodes in the dataset represent sensors measuring traffic speed and edges denote proximity between sensor pairs measured by road network distance. The sensor speeds are recorded every five minutes. Complete dataset statistics are given in Table 6.

We follow (Li et al., 2018)

’s way to split the dataset. The first 70% of the sequences are used for training, the middle 10% are used for validation and the final 20% are used for testing. We also use the same evaluation metrics as in 

(Li et al., 2018) for evaluation, including Mean Absolute Error (MAE), Root Mean Squared Error (RMSE), and Mean Absolute Percentage Error (MAPE). A sequence of length 12 is used as the input to predict the future traffic speed in one hour (12 steps).

6.3 Main Results

We compare six variations of the proposed GGRU architecture with three baseline models, including fully-connected LSTM, GCRNN, and DCRNN (Li et al., 2018). We use the same set of six aggregators as in the inductive node classification experiment to construct the GGRU and we use two layers of GGRUs with the state dimension of 64 both in the encoder and the decoder. For attention based models, we set , and . For GaAN, we set and only use max pooling in the gate-generation network. For pooling based aggregators, we set . For pairwise sum aggregators, we set , , and .

Since the road map is directed and our model does not deal with edge information, we first convert the road map into an undirected graph and use it as the in Eqn. (6). All models are trained by minimizing MAE loss with Adam optimizer. The initial learning rate is set to 0.001 and the batch-size is 64. We use the same scheduled sampling strategy as in (Li et al., 2018). Table 1 shows the comparison of different approaches for 15 minutes, 30 minutes and 1 hour ahead forecasting on both datasets.

The scores for 15 minutes, 30 minutes, and 1 hour ahead forecasting as well as the average scores over three forecasting horizons are shown in Table 5. For the average score, we can see that the proposed GGRU models consistently give better results than GCRNN, which also models the traffic network as an undirected graph. Moreover, the GaAN based GGRU model, which does not use edge information, achieves higher accuracy than DCRNN, which uses edge information in the road network.

7 Conclusion and Future Work

We introduced the GaAN model and applied it to two challenging tasks: inductive node classification and traffic speed forecasting. GaAN beats previous state-of-the-art algorithms in both cases. In the future, we plan to extend GaAN by integrating edge features and processing massive graphs with millions or even billions of nodes. Moreover, our model is not restricted to graph learning. A particularly exciting direction for future work is to apply GaAN to natural language processing tasks like machine translation.


  • Atwood and Towsley (2016) J. Atwood and D. Towsley. Diffusion-convolutional neural networks. In NIPS, pages 1993–2001, 2016.
  • Bahdanau et al. (2015) D. Bahdanau, K. Cho, and Y. Bengio. Neural machine translation by jointly learning to align and translate. In ICLR, 2015.
  • Bruna et al. (2014) J. Bruna, W. Zaremba, A. Szlam, and Y. Lecun. Spectral networks and locally connected networks on graphs. In ICLR, 2014.
  • Chen et al. (2018) J. Chen, T. Ma, and C. Xiao. FastGCN: Fast learning with graph convolutional networks via importance sampling. In ICLR, 2018.
  • Chorowski et al. (2014) J. Chorowski, D. Bahdanau, K. Cho, and Y. Bengio. End-to-end continuous speech recognition using attention-based recurrent NN: First results. In NIPS, Workshop on Deep Learning and Representation Learning, 2014.
  • Chung et al. (2014) J. Chung, C. Gulcehre, K. Cho, and Y. Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. NIPS, Workshop on Deep Learning and Representation Learning, 2014.
  • Defferrard et al. (2016) M. Defferrard, X. Bresson, and P. Vandergheynst. Convolutional neural networks on graphs with fast localized spectral filtering. In NIPS, pages 3844–3852, 2016.
  • Duvenaud et al. (2015) D. K. Duvenaud, D. Maclaurin, J. Iparraguirre, R. Bombarell, T. Hirzel, A. Aspuru-Guzik, and R. P. Adams. Convolutional networks on graphs for learning molecular fingerprints. In NIPS, pages 2224–2232, 2015.
  • Fout et al. (2017) A. Fout, J. Byrd, B. Shariat, and A. Ben-Hur. Protein interface prediction using graph convolutional networks. In NIPS, pages 6533–6542, 2017.
  • Gehring et al. (2017) J. Gehring, M. Auli, D. Grangier, D. Yarats, and Y. N. Dauphin. Convolutional sequence to sequence learning. In ICML, pages 1243–1252, 2017.
  • Hamilton et al. (2017a) W. Hamilton, Z. Ying, and J. Leskovec. Inductive representation learning on large graphs. In NIPS, pages 1025–1035, 2017a.
  • Hamilton et al. (2017b) W. L. Hamilton, R. Ying, and J. Leskovec. Representation learning on graphs: Methods and applications. arXiv preprint arXiv:1709.05584, 2017b.
  • Hammond et al. (2011) D. K. Hammond, P. Vandergheynst, and R. Gribonval. Wavelets on graphs via spectral graph theory. Applied and Computational Harmonic Analysis, 30(2):129–150, 2011.
  • Hochreiter and Schmidhuber (1997) S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
  • Kingma and Ba (2015) D. Kingma and J. Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
  • Kipf and Welling (2017) T. N. Kipf and M. Welling. Semi-supervised classification with graph convolutional networks. In ICLR, 2017.
  • Krizhevsky et al. (2012) A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet classification with deep convolutional neural networks. In NIPS, pages 1097–1105, 2012.
  • LeCun et al. (2015) Y. LeCun, Y. Bengio, and G. Hinton. Deep learning. Nature, 521(7553):436–444, 2015.
  • Li et al. (2018) Y. Li, R. Yu, C. Shahabi, and Y. Liu. Diffusion convolutional recurrent neural network: Data-driven traffic forecasting. In ICLR, 2018.
  • Liang et al. (2016) X. Liang, X. Shen, J. Feng, L. Lin, and S. Yan. Semantic object parsing with graph lstm. In ECCV, pages 125–143, 2016.
  • Lin et al. (2017) Z. Lin, M. Feng, C. N. d. Santos, M. Yu, B. Xiang, B. Zhou, and Y. Bengio. A structured self-attentive sentence embedding. In ICLR, 2017.
  • Monti et al. (2017) F. Monti, D. Boscaini, J. Masci, E. Rodola, J. Svoboda, and M. M. Bronstein. Geometric deep learning on graphs and manifolds using mixture model cnns. In CVPR, pages 5115–5124, 2017.
  • Schütt et al. (2017) K. T. Schütt, F. Arbabzadah, S. Chmiela, K. R. Müller, and A. Tkatchenko.

    Quantum-chemical insights from deep tensor neural networks.

    Nature communications, 8:13890, 2017.
  • Seo et al. (2017) M. Seo, A. Kembhavi, A. Farhadi, and H. Hajishirzi. Bidirectional attention flow for machine comprehension. In ICLR, 2017.
  • Seo et al. (2016) Y. Seo, M. Defferrard, P. Vandergheynst, and X. Bresson. Structured sequence modeling with graph convolutional recurrent networks. arXiv preprint arXiv:1612.07659, 2016.
  • Shi et al. (2017) X. Shi, Z. Gao, L. Lausen, H. Wang, D.-Y. Yeung, W.-k. Wong, and W.-c. Woo. Deep learning for precipitation nowcasting: A benchmark and a new model. In NIPS, pages 5622–5632, 2017.
  • Subramanian et al. (2005) A. Subramanian, P. Tamayo, V. K. Mootha, S. Mukherjee, B. L. Ebert, M. A. Gillette, A. Paulovich, S. L. Pomeroy, T. R. Golub, E. S. Lander, et al. Gene set enrichment analysis: a knowledge-based approach for interpreting genome-wide expression profiles. Proceedings of the National Academy of Sciences, 102(43):15545–15550, 2005.
  • Vaswani et al. (2017) A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin. Attention is all you need. In NIPS, pages 6000–6010, 2017.
  • Veličković et al. (2018) P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio. Graph attention networks. In ICLR, 2018.
  • Xu et al. (2015a) B. Xu, N. Wang, T. Chen, and M. Li. Empirical evaluation of rectified activations in convolutional network. arXiv preprint arXiv:1505.00853, 2015a.
  • Xu et al. (2015b) K. Xu, J. Ba, R. Kiros, K. Cho, A. Courville, R. Salakhudinov, R. Zemel, and Y. Bengio. Show, attend and tell: Neural image caption generation with visual attention. In ICML, pages 2048–2057, 2015b.