Multitask Learning on Graph Neural Networks - Learning Multiple Graph Centrality Measures with a Unified Network

by   Pedro H. C. Avelar, et al.

The application of deep learning to symbolic domains remains an active research endeavour. Graph neural networks (GNN), consisting of trained neural modules which can be arranged in different topologies at run time, are sound alternatives to tackle relational problems which lend themselves to graph representations. In this paper, we show that GNNs are capable of multitask learning, which can be naturally enforced by training the model to refine a single set of multidimensional embeddings ∈R^d and decode them into multiple outputs by connecting MLPs at the end of the pipeline. We demonstrate the multitask learning capability of the model in the relevant relational problem of estimating network centrality measures, i.e. is vertex v_1 more central than vertex v_2 given centrality c?. We then show that a GNN can be trained to develop a lingua franca of vertex embeddings from which all relevant information about any of the trained centrality measures can be decoded. The proposed model achieves 89% accuracy on a test dataset of random instances with up to 128 vertices and is shown to generalise to larger problem sizes. The model is also shown to obtain reasonable accuracy on a dataset of real world instances with up to 4k vertices, vastly surpassing the sizes of the largest instances with which the model was trained (n=128). Finally, we believe that our contributions attest to the potential of GNNs in symbolic domains in general and in relational learning in particular.


page 1

page 2

page 3

page 4


Graph Colouring Meets Deep Learning: Effective Graph Neural Network Models for Combinatorial Problems

Deep learning has consistently defied state-of-the-art techniques in man...

Graph Neural Networks Meet Neural-Symbolic Computing: A Survey and Perspective

Neural-symbolic computing has now become the subject of interest of both...

Learning to Solve NP-Complete Problems - A Graph Neural Network for the Decision TSP

Graph Neural Networks (GNN) are a promising technique for bridging diffe...

Learning to Solve NP-Complete Problems - A Graph Neural Network for Decision TSP

Graph Neural Networks (GNN) are a promising technique for bridging diffe...

Symbolic Relational Deep Reinforcement Learning based on Graph Neural Networks

We present a novel deep reinforcement learning framework for solving rel...

SSR-GNNs: Stroke-based Sketch Representation with Graph Neural Networks

This paper follows cognitive studies to investigate a graph representati...

Relational State-Space Model for Stochastic Multi-Object Systems

Real-world dynamical systems often consist of multiple stochastic subsys...


Deep learning is rapidly pushing the state of the art in artificial intelligence, from the huge successes of convolutional neural networks in image recognition

[Krizhevsky, Sutskever, and Hinton, Simonyan and Zisserman2014, Li et al.2015]

to the myriad of applications of recurrent neural networks for natural language processing

[Cho et al.2014b, Cho et al.2014a, Bahdanau, Cho, and Bengio2014]

. Deep learning has also played a fundamental role in unveiling the capabilities of machine learning in mastering a number of involved tasks such as classic Atari games and the Chinese board game Go by the means of deep reinforcement learning

[Mnih et al.2015, Silver et al.2017]. Nevertheless, limited attention has been given to the application of deep learning models in the symbolic domain. It is our belief that such inquiries are of the utmost importance, as they strive towards an unification of two departed branches of AI. Furthermore, the accumulating body of evidence in other fields is a strong invitation to evaluate whether symbolic problems, which are numerous and of central importance to computer science, can benefit from deep learning.

Graph neural networks have recently become a promising model in deep learning applications, see e.g. [Battaglia et al.2018]. In this sense, we will show that GNNs can be very naturally coupled with multitask learning applied to centrality measures. A promising technique for building neural networks on symbolic domains is to enforce permutation invariance by connecting adjacent elements of the domain of discourse through neural modules with shared weights which are themselves subject to training. By assembling these modules in different configurations one can reproduce each graph’s structure, in effect training neural components to compute the appropriate messages to send between elements. The resulting architecture can be seen as a message-passing algorithm where the messages and state updates are computed by trained neural networks. This model and its variants are the basis for several architectures such as message-passing neural networks [Gilmer et al.2017], recurrent relational networks [Palm, Paquet, and Winther2017], graph networks [Battaglia et al.2018] and graph neural networks [Scarselli et al.2009] whose terminology we adopt.

Graph Neural Networks (GNN) have been successfully employed on combinatorial domains, with [Palm, Paquet, and Winther2017] showing how they can tackle Sudoku puzzles and most importantly with [Selsam et al.2018] developing a GNN which is able to predict the satisfiability of CNF boolean formulas (corresponding to the -Complete problem SAT) with high accuracy and showing how constructive solutions in the format of boolean assignments can be extracted from the inner workings of the network. Both approaches have shown that these networks can generalise their computation over a larger number of time steps than they were trained on, showing that GNNs can not only learn from examples, but reason about what they learned in an iterative fashion.

The remainder of the paper is structured as follows. First, we present the basic concepts of centrality measures used in this paper. We then introduce a GNN-based model for approximating and learning the relations between centralities in graphs, describe our experimental evaluation, and verify the model’s generalisation and interpretability. Finally, we conclude and point out direction for further research.

On Centrality Measures

Recent studies have suggested that advancing combinatorial generalisation is a key step forward in modern AI [Battaglia et al.2018]. The results presented in this paper can be seen as a natural step towards this goal presenting, to the best of our knowledge, the first application of GNNs to network centrality, a combinatorial problem with very relevant applications in our highly connected world, including the detection of power grid vulnerabilities [Wang, Scaglione, and Thomas2010, Liu et al.2018], influence inside interorganisational and collaboration networks [Chen et al.2017, Dong, McCarthy, and Schoenmakers2017], social network analysis [Morelli et al.2017, Kim and Hastak2018]

, pattern recognition on biological networks

[Tang et al.2015, Estrada and Ross2018] among others.

In general, node-level centralities summarise a node’s contribution to the network cohesion. Several types of centralities have been proposed and many models and interpretations of these centralities have been suggested, namely: autonomy, control, risk, exposure, influence, etc. [Borgatti and Everett2006]. Despite their myriad of applications and interpretations, in order to calculate some of these centralities one may face both high time and space complexity, thus making it costly to compute them on large networks. Although some studies pointed out a high degree of correlation between some of the most common centralities [Lee2006, Batool and Niazi2014], it is also stated that these correlations are attached to the underlying network structure and thus may vary across different network distributions [Schoch et al.2017]. Therefore, techniques to allow faster centrality computation are topics of active research. We select four well-known node centralities to investigate in our study:

  1. Degree - First proposed by [Shaw1954], it simply calculates to how many neighbours a node is connected. This algorithm has time complexity O().

  2. Betweenness - It calculates the number of shortest paths which cross by the given node. High betweenness nodes are more important to the graph’s cohesion, i.e., their removal may disconnect the graph. A fast algorithm version introduced by [Brandes2001] implies in a time complexity O().

  3. Closeness - As defined by [Beauchamp1965], it is also a distance-based centrality with time complexity O() (same as betweenness) which measures the average geodesic distance between a given node and all other reachable nodes.

  4. Eigenvector

    - This centrality uses the largest eigenvalue of the adjacency matrix to compute its eigenvector

    [Bonacich1987] and assigns to each node a score based upon the score of the nodes to whom it is connected (assumption: a powerful node is connected to nodes that are themselves powerful [Wąs and Skibski2018]). It is computed via a power iteration method with no convergence guaranteed, which stops after a given number of iterations or when a minimum delta between two iterations is not reached.

A GNN Model for Learning Relations Between Centrality Measures

On a conceptual level, our model assigns multidimensional embeddings to each vertex in the input graph. These embeddings are refined through iterations of message-passing. At each iteration, each vertex adds up all the messages received along its edges and adds up all the messages received along its outcoming edges, obtaining two tensors. These two tensors are concatenated to obtain a

tensor, which is fed to a Recurrent Neural Network (RNN) which updates the embedding of the vertex in question. Note that a “message” sent by a vertex embedding in this sense is the output of a Multilayer Perceptron (MLP) which is fed with the embedding in question.

In summary, our model can be seen as a message-passing algorithm in which the update () and message-computing (, ) modules are trained neural networks. In addition, we train a MLP for each centrality

, which is assigned with computing the probability that

given their embeddings ( here denotes the total ordering imposed by the centrality measure , that is, the node on the left has a strictly higher -centrality than the one on the right). A complete description of our algorithm is presented in Algorithm 1.

1:procedure GNN-Centrality()
2:     Compute adj. matrix
3:      Initialise all vertex embeddings with the initial embedding (this initial embedding is a parameter learned by the model)
4:      Run message-passing iterations
5:     for  do Refine each vertex embedding with messages received from incoming edges either as a source or a target vertex
8:     for  do Compute a fuzzy comparison matrix
9:           Compute a strict comparison matrix
Algorithm 1 Graph Neural Network Centrality Predictor

For each pair of vertices and for each centrality , our network guesses the probability that

. A straightforward way to train such a network is to perform Stochastic Gradient Descent (SGD), more specifically TensorFlow’s Adam

[Kingma and Ba2014] implementation, on the binary cross entropy loss between the probabilities computed by the network and the binary “labels” obtained from the total ordering provided by . This process can be made simple by organising the network outputs for each centrality, as well as the corresponding labels, into matrices, as Figure 1 exemplifies.

Figure 1: Example of a predicted fuzzy comparison matrix at the left and the training label given by an upper triangular matrix at the right, for a graph with three vertices sorted in ascending centrality order as given by the centrality measure . The binary cross entropy is computed as


We instantiate our model with size 64 vertex embeddings and three-layered (64,64,64) MLPs , and  111 here denotes the set of centrality measures

with Rectified Linear Units (ReLU) for all hidden non-linearities and a linear output layer.

We generate a training dataset by producing graphs between and vertices for each of the four following random graph distributions (total ): 1) Erdős-Rényi [Batagelj and Brandes2005], 2) Random power law tree222This refers to a tree with a power law degree distribution specified by the parameter , 3) Connected Watts-Strogatz small-world model [Watts and Strogatz1998], 4) Holme-Kim model [Holme and Kim2002]. Further details are reported in Table 1. All graphs were generated with the Python NetworkX package [Hagberg, Swart, and S Chult2008]. Examples sampled from each distribution are shown in Figure 2.

Graph Distribution Parameters
Random power law tree
Table 1: Training instances generation parameters
Figure 2: Examples of training instances with vertices for each graph distribution, clockwise from the top left: Erdős-Rényi in red, Random power law tree in green, Holme-Kim in blue and Watts-Strogatz in yellow.

After 32 training epochs, the model was able to compute centrality comparisons (i.e. is vertex

more central than vertex ?) with accuracy (averaged over all centralities) for the problems it was trained on (32-128 vertices), accuracy for a test dataset of the same size, accuracy for a test dataset of the same size composed of unforeseen distributions and accuracy on a test dataset of far larger test problems with up to four times more vertices than the largest training instances (128-512 vertices). The training was halted thereupon to prevent overfitting.

Experimental Analysis

In this section, we report the experiments we carried out to validate our model. The loss and accuracy of the training process for each centrality metric is reported in Figure 3, which also compares these values with those obtained by a model trained without multitasking (that is, trained to predict only the centrality metric in question).

Figure 3: Evolution of training loss and training accuracy per batch for all four centrality metrics throughout the training process. The loss is plotted in orange and blue and the accuracy in red and green for training with and without multitasking, respectively.

Performance metrics were computed for a test dataset similar to the training one only with respect to instances size and quantity, i.e., a dataset composed of instances the model had never seen before (distributed evenly among all four graph distributions and generated as described in Table 1). Also, in order to verify the feasibility of multitasking in the centrality computation context, we compared the test performance from both types of trained models (one model with multitask learning versus four basic models, each trained to predict only one centrality). After training, our model can predict centrality comparisons with high performance, as reported in Table 2, obtaining its worst result in the closeness recall for the models with and without multitask learning ( and respectively). The average accuracy, computed among all centralities, is for both models.

Although the multitasking model is outperformed by the basic model in many cases, the overall accuracy is not changed (see Table 2) and the model has roughly half the number of parameters when compared with having a separate model for each centrality. In this context, recall that the multitask learning model is required to develop a “lingua franca” of vertex embeddings from which information about any of the studied centralities can be easily extracted, so in a sense it is solving a harder problem. We also computed performance metrics for a test dataset with far larger instances, each with between one and four times the number of vertices of the largest training instances, for which we obtain overall accuracy. This result shows that the model is able to generalise to larger problem sizes than those it was trained on, with only a slight decrease accuracy.

Figure 4: Evolution of the 1D projection of vertex embeddings plotted against the corresponding eigenvector centralities through time for a graph sampled from the Watts-Strogatz small world distribution. Vertices are coloured according to their eigenvector centrality rank (the most central elements are blue and the least central ones are red)
Centrality P (%) R (%) TN (%) Acc (%)
Table 2: Performance metrics (Precision, Recall, True Negative rate, Accuracy) computed for the trained models on the test dataset (with/without multitasking). Dataset contains instances with between and vertices.

Generalising to Other Distributions

Having obtained good performance for the distributions the network was trained on, we wanted to assess the possibility of accurately predicting centrality comparisons for graph distributions it has never seen. That was done by computing performance metrics for two new random graph distributions, the Barabási-Albert model [Albert and Barabási2002] and shell graphs333The shell graphs used here were generated with the number of points on each shell proportional to the “radius” of that shell. I.E., with being the number of nodes in the i-th shell. [Sethuraman and Dhavamani2000], for which the results are reported in Table 3. Although its accuracy is reduced in comparison ( vs overall), the model can still predict centrality comparisons with high performance, obtaining its worst result at recall for the degree centrality. Again, the model without multitasking outperforms the multitasking one only by a narrow margin (2% at the overall accuracy).

Centrality P (%) R (%) TN (%) Acc (%)
Table 3: Performance metrics (Precision, Recall, True Negative rate, Accuracy) computed for the trained model on a test dataset composed of Barabási-Albert networks and shell graphs between and vertices.

We also wanted to assess the model’s performance on real world instances. We ran it on power-eris1176, a power grid network, econ-mahindas, an economic network, socfb-haverford76 and ego-Facebook, Facebook networks, bio-SC-GT, a biological network and ca-GrQc, a scientific collaboration network. All networks were obtained from the Network Repository [Rossi and Ahmed], and the results are reported in Table 4. The trained model was able to obtain up to accuracy (on both betweenness and degree) and average accuracy on the best case (socfb-haverford76), and accuracy (closeness) and average accuracy on the worst case (ca-GrQc).

Note that these networks significantly surpass the size range that the network has been trained on, overestimating from to the size of the largest () networks it has seen during training, while also pertaining to entirely different graph distributions than those described in Table 1. In this context, we found it impressive that the model can predict betweenness centrality with accuracy (or without multitasking) on a large graph such as ca-GrQc, a network with four thousand vertices and fourteen hundred edges. It is also notable that one of the worst performances occur on the smallest real network (power-eris1176) – an overall accuracy below 70% for both models. This perhaps can be explained in [Hines and Blumsack2008, Hines et al.2010] who highlighted the significant topological differences between power grid networks and Erdős-Rényi and Watts-Strogatz small-world models (two of the models used to train the network).

In short, our multitask model accuracy presents an expected decay with increasing problem sizes (see Figure 5). However, this decay is not a free fall towards 50% for instances with almost twice the size of the ones used to train the model, in fact the overall accuracy remains around 77% when which implies that some level of generalisation (to larger problem sizes) is achievable.

Figure 5: The overall accuracy decays with increasing problem sizes, although it still does not approach the baseline of (equivalent to randomly guessing each vertex-to-vertex centrality comparison) for the largest instances tested here. The dotted lines delimit the range of problem sizes used to train the network ().
Figure 6: Evolution of the 2D projection of vertex embeddings through time for a graph sampled from the power law cluster distribution. Vertices are coloured according to their eigenvector centrality rank (the most central elements are blue and the least central ones are red)


Graph Centrality Accuracy (%)
power-eris1176 (n=1.2K, m=8.7K) Betweenness
econ-mahindas (n=1.3K, m=7.6K) Betweenness
socfb-haverford76 (n=1.4K, m=59.6K) Betweenness
bio-SC-GT (n=1.7K, m=34K) Betweenness
ca-GrQc (n=4K, m=14.4K) Betweenness
ego-Facebook (n=4K, m=88.2K) Betweenness
Table 4: Accuracy performance during test on real world graphs (Model with/without multitasking)

Machine learning has achieved impressive feats in the recent years, but the interpretability of the computation that takes place within trained models is still limited [Breiman and others2001, Lipton2016]. [Selsam et al.2018]

have shown that it is possible to extract useful information from the embeddings of CNF literals, which they manipulate to obtain satisfying assignments from the model (trained only as a classifier). This allowed them to deduce that neurosat works by guessing UNSAT as a default and changing its prediction to SAT only upon finding a satisfying assignment.

In our case, we can obtain insights about the learned algorithm by projecting the refined set of embeddings

onto one-dimensional space by the means of Principal Component Analysis (PCA)

[Jolliffe2011] and plotting the projections against the centrality values of the corresponding vertices. Figure 4 shows the evolution of this plot through embedding-refining iterations, from which we can infer some aspects of the learned algorithm. First of all, the zeroth step is suppressed for space limitations, but because all embeddings start out the same way, it corresponds to a single vertical line. At the second step, the network is able to sort the vertices into five distinct classes, placing low rank vertices at one extreme and high rank vertices at the opposite. This is not sufficient to yield satisfactory accuracy, though, as each vertical line corresponds to a wide range of vertices whose embeddings (being very similar) the network cannot compare. As the solution process progresses, the network progressively manipulates each individual embedding to produce a correlation between the centrality values and the vector, which can be visualised here as reordering data points along the horizontal axis. Further insight can be gained by projecting embeddings to bidimensional space and plotting them alongside with vertex connections, as shown in Figure 6. Here one can see that progress in accuracy is accompanied by a successful separation of high centrality vertices (at the bottom right) from low centrality vertices (top left).

The cases shown here, however, are not universally true, and vary somewhat depending on the distribution from which the graph was drawn. Graphs sampled from the power law tree distribution, for example, seem to be more exponential in nature when comparing the log-centrality value and the normalised 1-dimensional PCA value. But most of the distributions trained on had a similar behaviour of making a line between the logarithm of the centrality and the normalised PCA values. However, even in the cases where the centrality model did not achieve a high accuracy, we can still look at the PCA values and see whether they yield a somewhat sensible answer to the problem. Thus, the embeddings generated by the network can be seen as the GNN trying to create a centrality measure of its own with parts of, or the whole embedding being correlated with those centralities with which the network was trained.

Reproducibility and Implementation Notes

Reproducibility in the field of machine learning may be difficult to achieve due to the plethora of hyperparameters, random initialisation values and the number of variables. Thus, we aim at facilitating our paper reproducibility by offering implementation notes in which we shall make our best to report all non-intuitive parametric and architectural decisions needed to produce a functioning model.

The embedding size was chosen as , all message-passing MLPs are three-layered with layer sizes with ReLU non linearities as the activation of all layers except the last one, which has a linear activation. The kernel weights are initialised with TensorFlow’s Xavier initialisation method described in [Glorot and Bengio2010] and the biases are initialised with zeroes. The recurrent unit assigned with updating embeddings is a layer-norm LSTM [Ba, Kiros, and Hinton2016] with ReLU as its activation and both with kernel weights and biases initialised with TensorFlow’s Glorot Uniform Initialiser [Glorot and Bengio2010], with the addition that the forget gate bias had 1 added to it. The number of message-passing timesteps is set at . The comparison MLP had, in turn, as its layer sizes and was initialised as the message MLPs. We tried regressing the centrality measures directly, but found that producing comparisons yielded a better performance.

Each training epoch is composed by SGD operations on batches of size , randomly sampled from the training dataset (The sampling inhibited duplicates in an epoch, but duplicates are allowed across epochs). We produced graphs for every training distribution, with nodes per graph. If any error occurred during graph generation or centrality calculation, the graph was discarded and its generation was restarted. We also generated new instances for the same parameters and kept these as a validation set. The test sets for bigger sizes and different distributions had and graphs generated with and , respectively. Instances were batched together by performing a disjoint union on the graphs, producing a single graph with every graph being a disjoint subgraph of the batch-graph, in this way the messages from one graph won’t be passed to another, effectively separating them.


The application of deep learning to symbolic domains remains a challenging endeavour. In this paper, we demonstrated how to train a neural network to predict graph centrality measures, while feeding it with only the raw network structure. In order to do so, we enforced permutation invariance among graph elements by engineering a message-passing algorithm composed of neural modules with shared weights. These modules can be assembled in different configurations to reflect the network structure of each problem instances. We show that the proposed model can be trained to predict centrality comparisons (i.e. is vertex more central than vertex given the centrality measure ?) with high accuracy, and further that this performance generalises reasonably well to other problem distributions and larger problem sizes. We also show that the model shows promising performance for very large real world instances, which overestimate the largest instances known at training time from 9 to 31 (4,000 as opposed to 128 vertices).

We also show that although our model can be instantiated separately for each centrality measure, it can also be trained to predict all centralities simultaneously, with minimal effect to the overall accuracy. In a nutshell, this means that upon training, the model is able to encode all useful information about any trained centrality into the multidimensional vertex embeddings which are iteratively refined by the message-passing process. We then use a different MLP to decode them into predictions for each such centrality. To shed light on the behaviour of the algorithm learned by the network, we interpret the low-dimensional PCA projections of each vertex embedding, and argue that the model iteratively reorders them in multidimensional space to enforce a correlation with the corresponding centrality values.

In summary, this work presents, to the best of our knowledge, the first application of Graph Neural Networks to centrality measures. We yield an effective model and provide ways to have such a model work with various centralities at once, in a more memory-efficient way than having a different model for every centrality – with minimal loss in its performance. Finally, our work attests to the power of relational inductive bias in neural networks, allowing them to tackle graph-based problems, and also showing how the proposed model can be used to provide a network that condenses multiple information about a graph in a single embedding.


This study was financed in part by the Coordenação de Aperfeiçoamento de Pessoal de Nível Superior - Brasil (CAPES) - Finance Code 001 and the Conselho Nacional de Desenvolvimento Científico e Tecnológico (CNPq).