Learning Discrete Structures for Graph Neural Networks

03/28/2019 ∙ by Luca Franceschi, et al. ∙ Istituto Italiano di Tecnologia NEC Corp. 0

Graph neural networks (GNNs) are a popular class of machine learning models whose major advantage is their ability to incorporate a sparse and discrete dependency structure between data points. Unfortunately, GNNs can only be used when such a graph-structure is available. In practice, however, real-world graphs are often noisy and incomplete or might not be available at all. With this work, we propose to jointly learn the graph structure and the parameters of graph convolutional networks (GCNs) by approximately solving a bilevel program that learns a discrete probability distribution on the edges of the graph. This allows one to apply GCNs not only in scenarios where the given graph is incomplete or corrupted but also in those where a graph is not available. We conduct a series of experiments that analyze the behavior of the proposed method and demonstrate that it outperforms related methods by a significant margin.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Relational learning is concerned with methods that cannot only leverage the attributes of data points but also their relationships. Diagnosing a patient, for example, not only depends on the patient’s vitals and demographic information but also on the same information about their relatives, the information about the hospitals they have visited, and so on. Relational learning, therefore, does not make the assumption of independence between data points but models their dependency explicitly. Graphs are a natural way to represent relational information and there is a large number of machine learning algorithms leveraging graph structure. Graph neural networks (GNNs) (Scarselli et al., 2009) are one such class of algorithms that are able to incorporate sparse and discrete dependency structures between data points.

While a graph structure is available in some domains, in others it has to be inferred or constructed. A possible approach is to first create a -nearest neighbor (NN) graph based on some measure of similarity between data points. This is a common strategy used by several learning methods such as LLE (Roweis & Saul, 2000) and Isomap (Tenenbaum et al., 2000). A major shortcoming of this approach, however, is that the efficacy of the resulting models hinges on the choice of

and, more importantly, on the choice of a suitable similarity measure over the input features. In any case, the graph creation and parameter learning steps are independent and require heuristics and trial and error. Alternatively, one could simply use a kernel matrix to model the similarity of examples implicitly at the cost, however, of introducing a dense dependency structure which may be problematic from a computational point of view.

With this paper, we follow a different route with the aim of learning discrete and sparse dependencies between data points while simultaneously training the parameters of graph convolutional networks (GCN), a class of GNNs. Intuitively, GCNs learn node representations by passing and aggregating messages between neighboring nodes (Kipf & Welling, 2017; Monti et al., 2017; Gilmer et al., 2017; Hamilton et al., 2017; Duran & Niepert, 2017; Velickovic et al., 2018)

. We propose to learn a generative probabilistic model for graphs, samples from which are used both during training and at prediction time. Edges are modelled with random variables whose parameters are treated as hyperparameters in a bilevel learning framework

(Franceschi et al., 2018). We iteratively sample the structure while minimizing an inner objective (a training error) and optimize the edge distribution parameters by minimizing an outer objective (a validation error).

To the best of our knowledge, this is the first method that simultaneously learns the graph and the parameters of a GNN for semi-supervised classification. Moreover, and this might be of independent interest, we adapt gradient-based hyperparameter optimization to work for a class of discrete hyperparameters (edges, in this work). The proposed approach makes GNNs applicable to problems where the graph is incomplete or entirely missing. We conduct a series of experiments and show that the proposed method is competitive with and often outperforms existing approaches. We also verify that the resulting graph generative models have meaningful edge probabilities.

2 Background

We first provide some background on graph theory, graph neural networks, and bilevel programming.

2.1 Graph Theory Basics

A graph is a pair with the set of vertices and the set of edges. Let be the number of vertices and the number of edges. Each graph can be represented by an adjacency matrix of size , where if there is an edge from vertex to vertex , and otherwise. The graph Laplacian is defined by where and if . We denote the set of all adjacency matrices by .

2.2 Graph Neural Networks

Graph neural networks are a popular class of machine learning models for graph-structured data. We focus specifically on graph convolutional networks (GCNs) and their application to semi-supervised learning. All GNNs have the same two inputs. First, a feature matrix

where is the number of different node features, second, a graph with adjacency matrix . Given a set of class labels and a labeling function that maps (a subset of) the nodes to their true class label, the objective is, given a set of training nodes , to learn a function

by minimizing some regularized empirical loss


where are the parameters of , is the output of for node ,

is a point-wise loss function, and

is a regularizer. An example of the function proposed by Kipf & Welling (2017) is the following two hidden layer GCN that computes the class probabilities as


where are the parameters of the GCN and is the normalized adjacency matrix, given by with .

2.3 Bilevel Programming in Machine Learning

Bilevel programs are optimization problems where a set of variables occurring in the objective function are constrained to be an optimal solution of another optimization problem (see Colson et al., 2007, for an overwiew). Formally given two objective functions and , the outer and inner objectives, and two sets of variables, and , the outer and inner variables, a bilevel program is given by


Bilevel programs arise in numerous situations such as hyperparmeter optimization, adversarial, multi-task, and meta-learning (Bennett et al., 2006; Flamary et al., 2014; Muñoz-González et al., 2017; Franceschi et al., 2018). For instance, in hyperparamter optimization, the inner objective can be a regularized training error while the outer objective can be the corresponding unregularized validation error. would then be the coefficient (hyperparameter) of the regularizer and the parameters of the model.

Solving Problem (3) is challenging since the solution sets of the inner problem are usually not available in closed-form. A standard approach involves replacing the minimization of with the repeated application of an iterative optimization dynamics

such as (stochastic) gradient descent 

(Domke, 2012; Maclaurin et al., 2015; Franceschi et al., 2017). Let denote the inner variables after iterations of the dynamics , that is, , and so on. Now, if and are real-valued and the objectives and dynamics smooth, we can compute the gradient of the function w.r.t. , denoted throughout as the hypergradient , as


where the symbol denotes the partial derivative (the Jacobian) and either the gradient (for scalar functions) or the total derivative. The first term can be computed efficiently in time with reverse-mode algorithmic differentiation (Griewank & Walther, 2008) by unrolling the optimization dynamics, repeatedly substituting

and applying the chain rule. This technique allows to optimize a number of hyperparameters several orders of magnitude greater than classic methods for hyperparameter optimization

(Feurer & Hutter, 2018).

Figure 1: Schematic representation of our approach for learning discrete graph structures for GNNs.

3 Learning Discrete Graph Structures

With this paper we address the challenging scenarios where a graph structure is either completely missing, incomplete, or noisy. To this end, we learn a discrete and sparse dependency structure between data points while simultaneously training the parameters of a GCN. We frame this as a bilevel programming problem whose outer variables are the parameters of a generative probabilistic model for graphs. The proposed approach, therefore, optimizes both the parameters of a GCN and the parameters of a graph generator so as to minimize the classification error on a given dataset. We developed a practical algorithm based on truncated reverse-mode algorithmic differentiation (Williams & Peng, 1990)

and hypergradient estimation to approximately solve the bilevel problem. A schematic illustration of the resulting method is presented in Figure


3.1 Jointly Learning the Structure and Parameters

Let us suppose that information about the true adjacency matrix is missing or incomplete. Since, ultimately, we are interested in finding a model that minimizes the generalization error, we assume the existence of a second subset of instances with known target, (the validation set), from which we can estimate the generalization error. Hence, we propose to find that minimizes the function


where is the minimizer, assumed unique, of (see Eq. (1) and Sec. 2.3) for a fixed adjacency matrix . We can then consider Equations (1) and (5) as the inner and outer objective of a mixed-integer bilevel programming problem where the outer objective aims to find an optimal discrete graph structure and the inner objective the optimal parameters of a GCN given a graph.

The resulting bilevel problem is intractable to solve exactly even for small graphs. Moreover, it contains both continuous and discrete-valued variables, which prevents us from directly applying Eq. (4). A possible solution is to construct a continuous relaxation (see e.g. Frecon et al., 2018), another is to work with parameters of a probability distribution over graphs. The latter is the approach we follow in this paper. We maintain a generative model for the graph structure and reformulate the bilevel program in terms of the (continuous) parameters of the resulting distribution over discrete graphs. Specifically, we propose to model each edge with a Bernoulli random variable. Let be the convex hull of the set of all adjacency matrices for nodes. By modeling all possible edges as a set of mutually independent Bernoulli random variables with parameter matrix we can sample graphs as . Eqs. (1) and (5) can then be replaced, by using the expectation over graph structures. The resulting bilevel problem can be written as


By taking the expectation, both the inner and the outer objectives become continuous (and possibly smooth) functions of the Bernoulli parameters. The bilevel problem given by Eqs. (6)-(7) is still challenging to solve efficiently. This is because the solution of the inner problem is not available in closed form for GCNs (the objective is non-convex); and the expectations are intractable to compute exactly222

This is different than e.g. (model free) reinforcement learning, where the objective function is usually unknown, depending in an unknown way from the action and the environment.

. An efficient algorithm, therefore, will only be able to find approximate stochastic solutions, that is, .

Before describing a method to solve the optimization problem given by Eqs. (6)-(7) approximately with hypergradient descent, we first turn to the question of obtaining a final GCN model that we can use for prediction. For a given distribution over graphs with nodes and with parameters , the expected output of a GCN is


Unfortunately, computing this expectation is intractable even for small graphs; we can, however, compute an empirical estimate of the output as


where is the number of samples we wish to draw. Note that

is an unbiased estimator of

. Hence, to use a GCN learned with the bilevel formulation for prediction, we sample graphs from the distribution and compute the prediction as the empirical mean of the values of .

Given the parametrization of the graph generator with Bernoulli variables (), one can sample a new graph in . Sampling from a large number of Bernoulli variables, however, is highly efficient, trivially parallelizable, and possible at a rate of millions per second. Other sampling strategies such as MCMC sampling are possible in constant time. Given a set of sampled graphs, it is more efficient to evaluate a sparse GCN times than to use the Bernoulli parameters as weights of the GCN’s adjacency matrix333Note also that , as the model is, in general, nonlinear. . Indeed, for GCN models, computing has a cost of , rather than for a fully connected graph, where is the expected number of edges for a directed graph, and is the dimension of the weights. Another advantage of using a graph-generative model is that we can interpret it probabilistically which is not the case when learning a dense adjacency matrix.

3.2 Structure Learning via Hypergradient Descent

The bilevel programming formalism is a natural fit for the problem of learning both a graph generative model and the parameters of a GNN for a specific downstream task. Here, the outer variables are the parameters of the graph generative model and the inner variables are the parameters of the GCN.

We now discuss a practical algorithm to approach the bilevel problem defined by Eqs. (6) and (7). Regarding the inner problem, we note that the expectation


is composed of a sum of terms, which is intractable even for relatively small graphs. We can, however, choose a tractable approximate learning dynamics such as stochastic gradient descent (SGD),


where is a learning rate and is drawn at each iteration. Under appropriate assumptions and for

, SGD converges to a weight vector

that depends on the edges’ probability distribution (Bottou, 2010).

Let be an approximate minimizer of (where may depend on ). We now need to compute an estimator for the hypergradient . Recalling Eq. (4), we have


where we can swap the gradient and expectation operators since the expectation is over a finite random variable, assuming that the loss function bounded. We use the so-called straight-through estimator (Bengio et al., 2013) and set (which would be a.e. otherwise); appears both explicitly in (12) and in the computation of , through (see Sec. 2.3 and Franceschi et al., 2017, for details). Finally, we take the single sample Monte Carlo estimator of (12) to update the parameters with projected gradient descent on the unit hypercube.

Computing the hypergradient by fully unrolling the dynamics may be too expensive both in time and memory444Moreover, since we rely on biased estimations of the gradients, we do not expect to gain too much from a full computation.. We propose to truncate the computation and estimate the hypergradient every iterations, where is a parameter of the algorithm. This is essentially an adaptation of truncated back-propagation through time (Werbos, 1990; Williams & Peng, 1990) and can be seen as a short-horizon optimization procedure with warm restart on . A sketch of the method is presented in Algorithm 1, while a more complete version that includes details on the hypergradient computation can be found in the appendix. Inputs and operations in squared brackets are optional.

1:  Input data: , , [, ]
2:  Input parameters: , [, ]
3:   {Init. to NN graph if }
4:   {Initialize as a deterministic distribution}
5:  while Stopping condition is not met do
7:     while Inner objective decreases do
8:         {Sample structure}
9:         {Optimize inner objective}
11:        if   then
12:            computeHG(, , , )
13:            {Optimize outer objective}
14:        end if
15:     end while
16:  end while
17:  return , {Best found weights and prob. distribution}
Algorithm 1 LDS

The algorithm contains stopping conditions at the outer and at the inner level. While it is natural to implement the latter with a decrease condition on the inner objective555We continue optimizing until , for ( in the experiments). Since is non-convex, we also use a patience window of steps., we find it useful to implement the first with a simple early stopping criterion. A fraction of the examples in the validation set is held-out to compute, in each outer iteration, the accuracy using the predictions of the empirically expected model (9). The optimization procedure terminates if there is no improvement for some consecutive outer loops. This helps avoiding overfitting the outer objective (6), which may be a concrete risk in this context given the quantity of (hyper)parameters being optimized and the relative small size of the validation sets.

The hypergradients estimated with Algorithm 1 at each outer iteration are biased. The bias stems from both the straight-trough estimator and from the truncation procedure introduced in lines 11-13 (Tallec & Ollivier, 2017). Nevertheless, we find empirically that the algorithm is able to make reasonable progress, finding configurations in the distribution space that are beneficial for the tasks at hand.

4 Experiments

We conducted a series of experiments with three main objectives. First, we evaluated LDS on node classification problems where a graph structure is available but where a certain fraction of edges is missing. Here, we compared LDS with graph-based learning algorithms including vanilla GCNs. Second, we wanted to validate our hypothesis that LDS can achieve competitive results on semi-supervised classification problems for which a graph is not available. To this end, we compared LDS with a number of existing semi-supervised classification approaches. We also compared LDS with algorithms that first create -NN affinity graphs on the data set. Third, we analyzed the learned graph generative model to understand to what extent LDS is able to learn meaningful edge probability distributions even when a large fraction of edges is missing.

4.1 Datasets

Cora and Citeseer are two benchmark datasets that are commonly used to evaluate relational learners in general and GCNs in particular (Sen et al., 2008). The input features are bag of words and the task is node classification. We use the same dataset split and experimental setup of previous work (Yang et al., 2016; Kipf & Welling, 2017). To evaluate the robustness of LDS on incomplete graphs, we construct graphs with missing edges by randomly sampling , , and of the edges. In addition to Cora and Citeseer where we removed all edges, we evaluate LDS on benchmark datasets that are available in scikit-learn (Pedregosa et al., 2011) such as Wine, Breast Cancer (Cancer), Digits, and 20 Newsgroup (20news). We take classes from 20 Newsgroup and use words (TFIDF) with a frequency of more than as features. We also use FMA, a dataset where audio features are extracted from 7,994 music tracks and where the problem is genre classification (Defferrard et al., 2017). The statistics of the datasets are listed in Table 1.

Figure 2: Mean accuracy standard deviation on validation (early stopping; dashed lines) and test (solid lines) sets for edge deletion scenarios on Cora (left) and Citeseer (center). (Right) Validation of the number of steps used to compute the hypergradient (Citeseer); corresponds to alternating minimization. All results are obtained from five runs with different random seeds.

4.2 Setup and Baselines

For the experiments on graphs with missing edges, we compare LDS to vanilla GCNs. In addition, we also conceived a method (GCN-RND) where we add randomly sampled edges at each optimization step of a vanilla GCN. With this method we intend to show that simply adding random edges to the standard training procedure of a GCN model (perhaps acting as a regularization technique) is not enough to improve the generalization.

When a graph is completely missing, GCNs boil down to feed-forward neural networks. Therefore, we evaluate different strategies to induce a graph on both labeled and unlabeled samples by creating (1) a sparse Erdős-Rényi random graph

(Erdos & Rényi, 1960) (Sparse-GCN); (2) a dense graph with equal edge probabilities (Dense-GCN); (3) a dense RBF kernel on the input features (RBF-GCN); and (4) a sparse -nearest neighbor graph on the input features (NN-GCN). For LDS we initialize the edge probabilities using the -NN graph (NN-LDS). We further include a dense version of LDS where we learn a dense similarity matrix (NN-LDS (dense)).

In this setting, we compare LDS to popular semi-supervised learning methods such as label propagation (LP) (Zhu et al., 2003), manifold regularization (ManiReg) (Belkin et al., 2006), and semi-supervised embedding (SemiEmb) (Weston et al., 2012). ManiReg and SemiEmb are given a -NN graph as input for the Laplacian regularization. We also compare LDS

to baselines that do not leverage a graph-structure such as logistic regression (LogReg), support vector machines (Linear and RBF SVM), random forests (RF), and feed-forward neural networks (FFNN). For comparison methods that need a

NN graph, and the metric (Euclidean or Cosine) are tuned using validation accuracy. For NN-LDS, is tuned from or .

Name Samples Features Train/Valid/Test
Wine 178 13 3 10 / 20 / 158
Cancer 569 30 2 10 / 20 / 539
Digits 1,797 64 10 50 / 100 / 1,647
Citeseer 3,327 3,703 6 120 / 500 / 1,000
Cora 2,708 1,433 7 140 / 500 / 1,000
20news 9,607 236 10 100 / 200 / 9,307
FMA 7,994 140 8 160 / 320 / 7,514
Table 1: Summary statistics of the datasets.
Wine Cancer Digits Citeseer Cora 20news FMA
LogReg 92.1 (1.3) 93.3 (0.5) 85.5 (1.5) 62.2 (0.0) 60.8 (0.0) 42.7 (1.7) 37.3 (0.7)
Linear SVM 93.9 (1.6) 90.6 (4.5) 87.1 (1.8) 58.3 (0.0) 58.9 (0.0) 40.3 (1.4) 35.7 (1.5)
RBF SVM 94.1 (2.9) 91.7 (3.1) 86.9 (3.2) 60.2 (0.0) 59.7 (0.0) 41.0 (1.1) 38.3 (1.0)
RF 93.7 (1.6) 92.1 (1.7) 83.1 (2.6) 60.7 (0.7) 58.7 (0.4) 40.0 (1.1) 37.9 (0.6)
FFNN 89.7 (1.9) 92.9 (1.2) 36.3 (10.3) 56.7 (1.7) 56.1 (1.6) 38.6 (1.4) 33.2 (1.3)
LP 89.8 (3.7) 76.6 (0.5) 91.9 (3.1) 23.2 (6.7) 37.8 (0.2) 35.3 (0.9) 14.1 (2.1)
ManiReg 90.5 (0.1) 81.8 (0.1) 83.9 (0.1) 67.7 (1.6) 62.3 (0.9) 46.6 (1.5) 34.2 (1.1)
SemiEmb 91.9 (0.1) 89.7 (0.1) 90.9 (0.1) 68.1 (0.1) 63.1 (0.1) 46.9 (0.1) 34.1 (1.9)
Sparse-GCN 63.5 (6.6) 72.5 (2.9) 13.4 (1.5) 33.1 (0.9) 30.6 (2.1) 24.7 (1.2) 23.4 (1.4)
Dense-GCN 90.6 (2.8) 90.5 (2.7) 35.6 (21.8) 58.4 (1.1) 59.1 (0.6) 40.1 (1.5) 34.5 (0.9)
RBF-GCN 90.6 (2.3) 92.6 (2.2) 70.8 (5.5) 58.1 (1.2) 57.1 (1.9) 39.3 (1.4) 33.7 (1.4)
NN-GCN 93.2 (3.1) 93.8 (1.4) 91.3 (0.5) 68.3 (1.3) 66.5 (0.4) 41.3 (0.6) 37.8 (0.9)
NN-LDS (d) 97.5 (1.2) 94.9 (0.5) 92.1 (0.7) 70.9 (1.3) 70.9 (1.1) 45.6 (2.2) 38.6 (0.6)
NN-LDS 97.3 (0.4) 94.4 (1.9) 92.5 (0.7) 71.5 (1.1) 71.5 (0.8) 46.4 (1.6) 39.7 (1.4)
Table 2: Test accuracy (

standard deviation) in percentage on various classification datasets. The best results and the statistical competitive ones (by paired t-test with

) are in bold. All experiments have been repeated with 5 different random seeds. We compare NN-LDS to several supervised baselines and semi-supervised learning methods. No graph is provided as input. NN-LDS achieves high accuracy results on most of the datasets and yields the highest gains on datasets with underlying graphs (Citeseer, Cora). NN-LDS (d) is the dense version of LDS.

We use the two layers GCN given by Eq. (2) with

hidden neurons and

activation. Given a set of labelled training instances (nodes or examples) we use the regularized cross-entropy loss



is the one-hot encoded target vector for the

-th instance, denotes the element-wise multiplication and is a non-negative coefficient. As additional regularization technique we apply dropout (Srivastava et al., 2014) with as in previous work. We use Adam (Kingma & Ba, 2015) for optimizing , tuning the learning rate from {, , }. The same number of hidden neurons and the same activation is used for SemiEmb and FFNN.

For LDS, we set the initial edge parameters to except for the known edges (or those found by NN) which we set to . We then let all the parameters (including those initially set to ) to be optimized by the algorithm. We further split the validation set evenly to form the validation (A) and early stopping (B) sets. As outer objective we use the un-regularized cross-entropy loss on (A) and optimize it with stochastic gradient descent with exponentially decreasing learning rate. Initial experiments showed that accelerated optimization methods such as Adam or SGD with momentum underperform in this setting. We tune the step size of the outer optimization loop and the number of updates used to compute the truncated hypergradient. Finally, we draw samples to compute the output predictions (see Eq. (9)). For LDS and GCN, we apply early stopping with a window size of steps. An important factor for the successful optimization of the outer objective is to use vanilla SGD with an exponentially decreasing step size.

LDS was implemented666Source code will be released soon.

in TensorFlow

(Abadi et al., 2015). The implementations of the supervised baselines and LP are those from the scikit-learn python package (Pedregosa et al., 2011). GCN, ManiReg, and SemiEmb are implemented in Tensorflow. The hyperparameters for all the methods are selected through the validation accuracy.

4.3 Results

% Edges 25% 50% 75% 100%
Cora Initial 1357 2714 4071 5429
Cora Learned 3635.6 4513.9 5476.9 6276.4
Citeseer Initial 1183 2366 3549 4732
Citeseer Learned 3457.4 4474.2 7842.5 6745.2
Table 3: Initial number of edges and expected number of sampled edges of learned graph by LDS.
Figure 3: Mean edge probabilities to nodes aggregated wrt four groups during LDS optimization, in scale for three example nodes. For each example node, all other nodes are grouped by the following criteria: (a) adjacent in the ground truth graph; (b) same class membership; (c) different class membership; and (d) unknown class membership. Probabilities are computed with LDS () on Cora with retained edges. From left to right, the example nodes belong to the training, validation, and test set, respectively. The vertical gray lines indicate when the inner optimization dynamics restarts, that is, when the weights of the GCN are reinitialized.

The results on the incomplete graphs are shown in Figure 2 for Cora (left) and Citeseer (center). For each percentage of retained edges the accuracy on the validation (used for early stopping) and the test set are plotted. LDS achieves competitive results in all scenarios and accuracy gains of up to percentage points. Notably, LDS improves the generalization accuracy of GCN models also when the given graph is that of the respective dataset (100% of edges retained), by learning additional helpful edges. The accuracy of 84.08% and 75.04% for Cora and Citeseer, respectively, exceed all previous state-of-the-art results. Conversely, adding random edges does not help decreasing the generalization error. GCN and GCN-RND perform similarly which indicates that adding random edges to the graph is not helpful.

Figure 2 (right) depicts the impact of the number of iterations to compute the hypergradients. Taking multiple steps strongly outperforms alternating optimization (i.e. ) in all settings. For , one step of optimization of w.r.t. , fixing is interleaved with one step of minimization of w.r.t. , fixing . Even if computationally lighter, this approach disregards the nested structure of (6)-(7), not computing the first term of Eq. (4). Increasing further to the value of , however, does not yield significant benefits, while increasing the computational cost.

In Table 3 we computed the expected number of edges in a sampled graph for Cora and Citeseer, to analyze the properties of the graphs sampled from the learned graph generator. The expected number of edges for LDS is higher than the original number which is to be expected since LDS has better accuracy results than the vanilla GCN in Figure 2. Nevertheless, the learned graphs are still very sparse (e.g. for Cora, on average, less than edges are present). This facilitates efficient learning of the GCN in the inner learning loop of LDS.

Figure 4: Normalized histograms of edges’ probabilities for the same nodes of Figure 3.

Table 2 lists the results for semi-supervised classification problems. The supervised learning baselines work well on some datasets such as Wine and Cancer but fail to provide competitive results on others such as Digits, Citeseer, Cora, and 20News. The semi-supervised learning baselines LP, ManiReg and SemiEmb can only improve the supervised learning baselines on one, three and four datasets, respectively. The results for the GCN with different input graphs show that NN-GCN works well and provides competitive results compared to the supervised baselines on all datasets. NN-LDS significantly outperforms NN-GCN on four out of seven datasets. In addition, NN-LDS is among the most competitive methods on all datasets and yields the highest gains on datasets that have an underlying graph. NN-LDS performs slightly better than its dense counterpart, where we learn a dense adjacency matrix. The advantage of the sparse graph representation, however, lies in the potential to scale to larger datasets.

In Figure 3, we show the evolution of mean edge probabilities during optimization on three types of nodes (train, validation, test) on the Cora dataset. LDS is able to learn a graph generative model that is, on average, attributing to times more probability to edges between samples sharing the same class label. LDS often attributes a higher probability to edges that are present in the true held-out adjacency matrix (green lines in the plots). In Figure 4 we report the normalized histograms of the optimized edges probabilities for the same nodes of Figure 3, sorted into six bins in -scale. Edges are divided in two groups: edges between nodes of the same class (blue) and between nodes of unknown or different classes (orange). LDS is able to learn highly non-uniform edge probabilities that reflect the class membership of the nodes.

Figure 5: Histograms for three Citeseer test nodes, missclassified by

NN-GCN and rightly classified by


Figure 5 shows similar qualitative results as Figure 4, this time for three Citeseer test nodes, missclassified by NN-GCN and correctly classified by NN-LDS

. Again, the learned edge probabilities linking to nodes of the same classes is significantly different to those from different classes; but in this case the densities are more skewed toward the first bin. On the datasets we considered, what seems to matter is to capture a useful distribution (i.e. higher probability for links between same class) rather than pick exact links; of course for other datasets this may vary.

5 Related work

Semi-supervised Learning. Early works on graph-based semi-supervised learning use graph Laplacian regularization and include label propagation (LP) (Zhu et al., 2003), manifold regularization (ManiReg) (Belkin et al., 2006), and semi-supervised embedding (SemiEmb) (Weston et al., 2012). These methods assume a given graph structure whose edges represent some similarity between nodes. Later, (Yang et al., 2016) proposed a method that uses graphs not for regularization but rather for embedding learning by jointly classification and graph context prediction. Kipf & Welling (2017) presented the first GCN for semi-supervised learning. There are now numerous GCN variants all of which assume a given graph structure. Contrary to all existing graph-based semi-supervised learning approaches, LDS simultaneously learns the graph structure and the parameters of a GCN for node classification and is, therefore, able to work even when the graph is incomplete or entirely missing.

Graph synthesis and generation. LDS learns a probabilistic generative model for graphs. The earliest probabilistic generative model for graphs was the Erdős-Rényi random graph model (Erdos & Rényi, 1960), where edge probabilities are modelled as identically distributed and mutually independent Bernoullis. Several network models have been proposed to model well particular graph properties such as degree distribution (Leskovec et al., 2005) or network diameter (Watts & Strogatz, 1998). Leskovec et al. (2010)

proposed a generative model based on the Kronecker product that takes a real graph as input and generates graphs that have similar properties. Recently, deep learning based approaches have been proposed for graph generation

(You et al., 2018; Li et al., 2018; Grover et al., 2018; De Cao & Kipf, 2018). The goal of these methods, however, is to learn a sophisticated generative model that reflects the properties of the training graphs. LDS, on the other hand, learns graph generative models as a means to perform well on classification problems and its input is not a collection of graphs. More recent work proposed an unsupervised model that learns to infer interactions between entities while simultaneously learning the dynamics of physical systems such as spring systems (Kipf et al., 2018). Contrary to LDS, the method is specific to dynamical interacting systems, is unsupervised, and uses a variational encoder-decoder. Finally, we note that Johnson (2017) proposed a fully differentiable neural model able to process and produce graph structures at both input, representation and output levels; training the model requires, however, supervision in terms of ground truth graphs.

Link prediction. Link prediction is a decades-old problem (Liben-Nowell & Kleinberg, 2007). Several survey papers cover the large body of work ranging from link prediction in social networks to knowledge base completion (Lü & Zhou, 2011; Nickel et al., 2016). While a majority of the methods are based on some similarity measure between node pairs, there has been a number of neural network based methods (Zhang & Chen, 2017, 2018). The problem we study in this paper is related to link prediction as we also want to learn or extend a graph. However, existing link prediction methods do not simultaneously learn a GNN node classifier. Moreover, LDS learns a generative model over edges and can not only add new edges but also remove incorrect ones. Statistical relational learning (SRL) (Getoor & Taskar, 2007) models often perform both link prediction and node classification through the existence of binary and unary predicates. However, SRL models are inherently intractable and the structure and parameter learning steps are independent. LDS, on the other hand, learns the graph structure and the GCN’s parameters simultaneously.

Gradient estimation for discrete random variables.

Due to the intractable nature of the two bilevel objectives, LDS needs to estimate the hypergradients through a stochastic computational graph (Schulman et al., 2015). Using the score function estimator, also known as REINFORCE (Williams, 1992), would treat the outer objective as a black-box function and would not exploit being differentiable w.r.t. the sampled adjacency matrices and inner optimization dynamics. Conversely, the path-wise estimator is not readily applicable, since the random variables are discrete. LDS borrows from a solution proposed before (Bengio et al., 2013), at the cost of having biased estimates. Recently, Jang et al. (2017); Maddison et al. (2017)

presented an approach based on continuous relaxations to reduce variance, which

Tucker et al. (2017) combined with REINFORCE to obtain an unbiased estimator. Grathwohl et al. (2018) further introduced surrogate models to construct control variates for black-box functions. Unfortunately, these latter methods require to compute the function in the interior of the hypercube, possibly in multiple points (Tucker et al., 2017). This would introduce additional computational overhead777Recall that can be computed only after (approximately) solving the inner optimization problem..

6 Conclusion

We propose LDS, a framework that simultaneously learns the graph structure and the parameters of a GNN. While we have used a specific GCN variant (Kipf & Welling, 2017) in the experiments, the method is more generally applicable to other GNNs. The strengths of LDS are its high accuracy gains on typical semi-supervised classification datasets at a reasonable computational cost. Moreover, due to the graph generative model LDS learns, the edge parameters have a probabilistic interpretation.

The method has its limitations. While relatively efficient, it cannot currently scale to large datasets: this would require an implementation that works with mini-batches of nodes. We evaluated LDS only in the transductive setting, when all data points (nodes) are available during training. Adding additional nodes after training (the inductive setting) would currently require retraining the entire model from scratch. When sampling graphs, we do not currently enforce the graphs to be connected. This is something we anticipate to improve the results, but this would require a more sophisticated sampling strategy. All of these shortcomings motivate future work. In addition, we hope that suitable variants of LDS algorithm will also be applied to other problems such as neural architecture search or to tune other discrete hyperparameters.


  • Abadi et al. (2015) Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard, M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Levenberg, J., Mané, D., Monga, R., Moore, S., Murray, D., Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever, I., Talwar, K., Tucker, P., Vanhoucke, V., Vasudevan, V., Viégas, F., Vinyals, O., Warden, P., Wattenberg, M., Wicke, M., Yu, Y., and Zheng, X. TensorFlow: Large-scale machine learning on heterogeneous systems, 2015. URL http://tensorflow.org/. Software available from tensorflow.org.
  • Belkin et al. (2006) Belkin, M., Niyogi, P., and Sindhwani, V. Manifold regularization: A geometric framework for learning from labeled and unlabeled examples. Journal of Machine Learning Research, 7:2399–2434, 2006.
  • Bengio et al. (2013) Bengio, Y., Léonard, N., and Courville, A. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Bennett et al. (2006) Bennett, K. P., Hu, J., Ji, X., Kunapuli, G., and Pang, J.-S. Model selection via bilevel optimization. In Neural Networks, 2006. IJCNN’06. International Joint Conference on, pp. 1922–1929. IEEE, 2006.
  • Bottou (2010) Bottou, L. Large-scale machine learning with stochastic gradient descent. In Proceedings of COMPSTAT’2010, pp. 177–186. Springer, 2010.
  • Colson et al. (2007) Colson, B., Marcotte, P., and Savard, G. An overview of bilevel optimization. Annals of operations research, 153(1):235–256, 2007.
  • De Cao & Kipf (2018) De Cao, N. and Kipf, T. Molgan: An implicit generative model for small molecular graphs. arXiv preprint arXiv:1805.11973, 2018.
  • Defferrard et al. (2017) Defferrard, M., Benzi, K., Vandergheynst, P., and Bresson, X. Fma: A dataset for music analysis. In 18th International Society for Music Information Retrieval Conference, 2017. URL https://arxiv.org/abs/1612.01840.
  • Domke (2012) Domke, J. Generic methods for optimization-based modeling. In Artificial Intelligence and Statistics, pp. 318–326, 2012.
  • Duran & Niepert (2017) Duran, A. G. and Niepert, M. Learning graph representations with embedding propagation. In Advances in Neural Information Processing Systems, pp. 5119–5130, 2017.
  • Erdos & Rényi (1960) Erdos, P. and Rényi, A. On the evolution of random graphs. Publ. Math. Inst. Hung. Acad. Sci, 5(1):17–60, 1960.
  • Feurer & Hutter (2018) Feurer, M. and Hutter, F. Hyperparameter optimization. In Hutter, F., Kotthoff, L., and Vanschoren, J. (eds.), Automatic Machine Learning: Methods, Systems, Challenges, pp. 3–38. Springer, 2018. In press, available at http://automl.org/book.
  • Flamary et al. (2014) Flamary, R., Rakotomamonjy, A., and Gasso, G. Learning constrained task similarities in graph-regularized multi-task learning. Regularization, Optimization, Kernels, and Support Vector Machines, pp. 103, 2014.
  • Franceschi et al. (2017) Franceschi, L., Donini, M., Frasconi, P., and Pontil, M. Forward and reverse gradient-based hyperparameter optimization. ICML, 2017.
  • Franceschi et al. (2018) Franceschi, L., Frasconi, P., Salzo, S., Grazzi, R., and Pontil, M. Bilevel programming for hyperparameter optimization and meta-learning. ICML, 2018.
  • Frecon et al. (2018) Frecon, J., Salzo, S., and Pontil, M. Bilevel learning of the group lasso structure. In Advances in Neural Information Processing Systems 31, pp. 8311–8321, 2018.
  • Getoor & Taskar (2007) Getoor, L. and Taskar, B. Introduction to statistical relational learning, volume 1. MIT press Cambridge, 2007.
  • Gilmer et al. (2017) Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and Dahl, G. E. Neural message passing for quantum chemistry. arXiv preprint arXiv:1704.01212, 2017.
  • Grathwohl et al. (2018) Grathwohl, W., Choi, D., Wu, Y., Roeder, G., and Duvenaud, D. Backpropagation through the void: Optimizing control variates for black-box gradient estimation. ICLR, 2018.
  • Griewank & Walther (2008) Griewank, A. and Walther, A. Evaluating derivatives: principles and techniques of algorithmic differentiation, volume 105. Siam, 2008.
  • Grover et al. (2018) Grover, A., Zweig, A., and Ermon, S. Graphite: Iterative generative modeling of graphs. arXiv preprint arXiv:1803.10459, 2018.
  • Hamilton et al. (2017) Hamilton, W., Ying, Z., and Leskovec, J. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pp. 1024–1034, 2017.
  • Jang et al. (2017) Jang, E., Gu, S., and Poole, B. Categorical reparameterization with gumbel-softmax. ICLR, 2017.
  • Johnson (2017) Johnson, D. D. Learning graphical state transitions. ICLR, 2017.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. ICLR, 2015.
  • Kipf et al. (2018) Kipf, T., Fetaya, E., Wang, K.-C., Welling, M., and Zemel, R. Neural relational inference for interacting systems. arXiv preprint arXiv:1802.04687, 2018.
  • Kipf & Welling (2017) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks. ICLR, 2017.
  • Leskovec et al. (2005) Leskovec, J., Kleinberg, J., and Faloutsos, C. Graphs over time: densification laws, shrinking diameters and possible explanations. In Proceedings of the eleventh ACM SIGKDD international conference on Knowledge discovery in data mining, pp. 177–187. ACM, 2005.
  • Leskovec et al. (2010) Leskovec, J., Chakrabarti, D., Kleinberg, J., Faloutsos, C., and Ghahramani, Z. Kronecker graphs: An approach to modeling networks. Journal of Machine Learning Research, 11(Feb):985–1042, 2010.
  • Li et al. (2018) Li, Y., Vinyals, O., Dyer, C., Pascanu, R., and Battaglia, P. Learning deep generative models of graphs. arXiv preprint arXiv:1803.03324, 2018.
  • Liben-Nowell & Kleinberg (2007) Liben-Nowell, D. and Kleinberg, J. The link-prediction problem for social networks. Journal of the American society for information science and technology, 58(7):1019–1031, 2007.
  • Lü & Zhou (2011) Lü, L. and Zhou, T. Link prediction in complex networks: A survey. Physica A: statistical mechanics and its applications, 390(6):1150–1170, 2011.
  • Maaten & Hinton (2008) Maaten, L. v. d. and Hinton, G. Visualizing data using t-sne. Journal of machine learning research, 9(Nov):2579–2605, 2008.
  • Maclaurin et al. (2015) Maclaurin, D., Duvenaud, D., and Adams, R. Gradient-based hyperparameter optimization through reversible learning. In International Conference on Machine Learning, pp. 2113–2122, 2015.
  • Maddison et al. (2017) Maddison, C. J., Mnih, A., and Teh, Y. W. The concrete distribution: A continuous relaxation of discrete random variables. ICLR, 2017.
  • Monti et al. (2017) Monti, F., Boscaini, D., Masci, J., Rodola, E., Svoboda, J., and Bronstein, M. M. Geometric deep learning on graphs and manifolds using mixture model cnns. In Proc. CVPR, volume 1, pp.  3, 2017.
  • Muñoz-González et al. (2017) Muñoz-González, L., Biggio, B., Demontis, A., Paudice, A., Wongrassamee, V., Lupu, E. C., and Roli, F. Towards poisoning of deep learning algorithms with back-gradient optimization. In Proceedings of the 10th ACM Workshop on Artificial Intelligence and Security, pp. 27–38. ACM, 2017.
  • Nickel et al. (2016) Nickel, M., Murphy, K., Tresp, V., and Gabrilovich, E.

    A review of relational machine learning for knowledge graphs.

    Proceedings of the IEEE, 104(1):11–33, 2016.
  • Pedregosa et al. (2011) Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion, B., Grisel, O., Blondel, M., Prettenhofer, P., Weiss, R., Dubourg, V., Vanderplas, J., Passos, A., Cournapeau, D., Brucher, M., Perrot, M., and Duchesnay, E. Scikit-learn: Machine learning in Python. Journal of Machine Learning Research, 12:2825–2830, 2011.
  • Roweis & Saul (2000) Roweis, S. T. and Saul, L. K. Nonlinear dimensionality reduction by locally linear embedding. Science, 290(5500):2323–2326, 2000.
  • Scarselli et al. (2009) Scarselli, F., Gori, M., Tsoi, A. C., Hagenbuchner, M., and Monfardini, G. The graph neural network model. IEEE Transactions on Neural Networks, 20(1):61–80, 2009.
  • Schulman et al. (2015) Schulman, J., Heess, N., Weber, T., and Abbeel, P. Gradient estimation using stochastic computation graphs. In Advances in Neural Information Processing Systems, pp. 3528–3536, 2015.
  • Sen et al. (2008) Sen, P., Namata, G., Bilgic, M., Getoor, L., Galligher, B., and Eliassi-Rad, T. Collective classification in network data. AI magazine, 29(3):93, 2008.
  • Srivastava et al. (2014) Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. Dropout: a simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research, 15(1):1929–1958, 2014.
  • Tallec & Ollivier (2017) Tallec, C. and Ollivier, Y. Unbiasing truncated backpropagation through time. arXiv preprint arXiv:1705.08209, 2017.
  • Tenenbaum et al. (2000) Tenenbaum, J. B., De Silva, V., and Langford, J. C. A global geometric framework for nonlinear dimensionality reduction. Science, 290(5500):2319–2323, 2000.
  • Tucker et al. (2017) Tucker, G., Mnih, A., Maddison, C. J., Lawson, J., and Sohl-Dickstein, J. Rebar: Low-variance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Information Processing Systems, pp. 2627–2636, 2017.
  • Velickovic et al. (2018) Velickovic, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., and Bengio, Y. Graph attention networks. ICLR, 2018.
  • Watts & Strogatz (1998) Watts, D. J. and Strogatz, S. H. Collective dynamics of ‘small-world’networks. nature, 393(6684):440, 1998.
  • Werbos (1990) Werbos, P. J. Backpropagation through time: What it does and how to do it. Proceedings of the IEEE, 78(10):1550–1560, 1990.
  • Weston et al. (2012) Weston, J., Ratle, F., Mobahi, H., and Collobert, R. Deep learning via semi-supervised embedding. In Neural Networks: Tricks of the Trade, pp. 639–655. Springer, 2012.
  • Williams (1992) Williams, R. J. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.
  • Williams & Peng (1990) Williams, R. J. and Peng, J. An efficient gradient-based algorithm for on-line training of recurrent network trajectories. Neural computation, 2(4):490–501, 1990.
  • Yang et al. (2016) Yang, Z., Cohen, W. W., and Salakhutdinov, R. Revisiting semi-supervised learning with graph embeddings. In Proceedings of the 33nd International Conference on Machine Learning, ICML 2016, New York City, NY, USA, June 19-24, 2016, pp. 40–48, 2016.
  • You et al. (2018) You, J., Ying, R., Ren, X., Hamilton, W. L., and Leskovec, J. Graphrnn: A deep generative model for graphs. ICML, 2018.
  • Zhang & Chen (2017) Zhang, M. and Chen, Y. Weisfeiler-lehman neural machine for link prediction. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 575–583. ACM, 2017.
  • Zhang & Chen (2018) Zhang, M. and Chen, Y. Link prediction based on graph neural networks. arXiv preprint arXiv:1802.09691, 2018.
  • Zhu et al. (2003) Zhu, X., Ghahramani, Z., and Lafferty, J. D. Semi-supervised learning using gaussian fields and harmonic functions. In Proceedings of the 20th International conference on Machine learning (ICML-03), pp. 912–919, 2003.

Appendix A Extended algorithm

In this section we provide an extended version of the Algorithm 1 that includes the explicit computation of the hypergradient by truncated reverse mode algorithmic differentiation. Recall that the inner objective is replaced by an iterative dynamics such as stochastic gradient descent. Hence, starting from an initial point , the iterates are computed as

Let and denote the Jacobians of the dynamics:

and recall that we set so that . We report the pseudocode in Algorithm 2, where the letter is used to indicate the adjoint variables (Lagrangian multipliers). Note that for the algorithm does not enter in the loop at line 15. Finally, note also that at line 16, we re-sample the adjacency matrices instead of reusing those computed in the forward pass (lines 8-10).

Algorithm 2 was implemented in TensorFlow as an extension of the software package Far-HO, freely available at https://github.com/lucfra/FAR-HO.

1:  Input data: , , [, ]
2:  Input parameters: , [, ]
3:   {Init. to NN graph if }
4:   {Initialize as a deterministic distribution}
5:  while Stopping condition is not met do
7:     while Inner objective decreases do
8:         {Sample structure}
9:         {Optimize inner objective}
11:        if   then
15:           for  downto  do
19:           end for
20:            {Optimize outer objective}
21:        end if
22:     end while
23:  end while
24:  return , {Best found weights and prob. distribution}
Algorithm 2 LDS (extended)

Appendix B Visualization of Embeddings

We further visualize the embeddings learned by GCN and LDS using T-SNE (Maaten & Hinton, 2008). Figure 6 depicts the T-SNE visualizations of the embeddings learned on Citeseer with Dense-GCN (left), NN-GCN (center), and NN-LDS (right). As can be seen, the embeddings learned by NN-LDS provides the best separation among different classes.

Figure 6: T-SNE visualization of the output activations (before the classification layer) on the Citeseer dataset. Left: Dense-GCN, Center: NN-GCN, Right NN-LDS