clusternet
Code release for NeurIPS 2019 paper "End to End Learning and Optimization on Graphs"
view repo
Real-world applications often combine learning and optimization problems on graphs. For instance, our objective may be to cluster the graph in order to detect meaningful communities (or solve other common graph optimization problems such as facility location, maxcut, and so on). However, graphs or related attributes are often only partially observed, introducing learning problems such as link prediction which must be solved prior to optimization. We propose an approach to integrate a differentiable proxy for common graph optimization problems into training of machine learning models for tasks such as link prediction. This allows the model to focus specifically on the downstream task that its predictions will be used for. Experimental results show that our end-to-end system obtains better performance on example optimization tasks than can be obtained by combining state of the art link prediction methods with expert-designed graph optimization algorithms.
READ FULL TEXT VIEW PDFCode release for NeurIPS 2019 paper "End to End Learning and Optimization on Graphs"
This is a modified clone repository for the code of the paper "End to end learning and optimization on graphs"
While deep learning has proven enormously successful at a range of tasks, an expanding area of interest concerns systems that can flexibly combine learning with optimization. Examples include recent attempts to solve combinatorial optimization problems using neural architectures
[40, 25, 6, 27], as well as work which incorporates explicit optimization algorithms into larger differentiable systems [2, 16, 42]. The ability to combine learning and optimization promises improved performance for real-world problems which require decisions to be made on the basis of machine learning predictions by enabling end-to-end training which focuses the learned model on the decision problem at hand.We focus on graph optimization problems, an expansive subclass of combinatorial optimization. While graph optimization is ubiquitous across domains, complete applications must also solve machine learning challenges. For instance, the input graph is usually incomplete; some edges may be unobserved or nodes may have attributes that are only partially known. Recent work has introduced sophisticated methods for tasks such as link prediction and semi-supervised classification [34, 26, 35, 22, 47], but these methods are developed in isolation of downstream optimization tasks. Most current solutions use a two-stage approach which first trains a model using a standard loss and then plugs the model’s predictions into an optimization algorithm ([45, 8, 4, 7, 38]
). However, predictions which minimize a standard loss function (e.g., cross-entropy) may be suboptimal for specific optimization tasks, especially in difficult settings where even the best model is imperfect.
A preferable approach is to directly incorporate the downstream optimization problem into the training of the machine learning model [16, 42]
, which requires a differentiable layer that produces a solution to the optimization problem. To date, there are two main approaches to differentiable optimization. First, training a generic neural network to directly output a solution to the optimization problem. This approach often requires a large amount of data and results in suboptimal optimization performance because the network needs to discover algorithmic structure entirely from scratch. Second, hand-crafting a differentiable solver for the particular optimization problem (using, e.g., the LP relaxation of an integral problem
[42]) and including this solver as a layer in the network. This approach requires manual effort to develop a differentiable solver for each particular problem and often results in cumbersome systems that must, e.g, call a LP solver in every forward pass.We propose a new approach that gets the best of both worlds by incorporating a more generic form of algorithmic structure into the network, which can then be automatically fine-tuned to solve a range of optimization tasks. In particular, we use a differentiable version of -means clustering. Clustering is motivated by the observation that graph neural networks embed nodes into a continuous space, allowing us to approximate optimization over the discrete graph with optimization in continuous embedding space. We then interpret the cluster assignments as a solution to the discrete problem. We instantiate this approach for two classes of optimization problems: those that require partitioning the graph (e.g., community detection or maxcut), and those that require selecting a subset of nodes (facility location, influence maximization, immunization, etc). We don’t claim that clustering is the right algorithmic structure for all tasks, but it is sufficient for many problems as shown in this paper.
In short, we make three contributions. First, we introduce a general framework for integrating graph learning and optimization, with optimization in continuous space as a proxy for the discrete problem. Second, we show how to differentiate through the clustering layer, allowing it to be used in deep learning systems. Third, we show experimental improvements over both two-stage baselines as well as alternate end-to-end approaches on a range of example datasets and optimization problems.
We build on a recent work on decision-focused learning [16, 42, 13], which includes a solver for an optimization problem into training in order to improve performance on a downstream decision problem. Some work in structured prediction also integrates differentiable solvers for discrete problems (e.g., image segmentation [14] or time series alignment [31]). Our work differs in two ways. First, we tackle more difficult optimization problems. Previous work mostly focuses on convex problems [16] or discrete problems with near-lossless convex relations [42, 14]. We focus on highly combinatorial problems where the methods of choice are hand-designed discrete algorithms. Second, in response to this difficulty, we differ methodologically in that we do not attempt to include a solver for the exact optimization problem at hand (or a close relaxation of it). Instead, we include a more generic algorithmic skeleton that is automatically finetuned to the optimization problem at hand.
There is also recent interest in training neural networks to solve combinatorial optimization problems [40, 25, 6, 27]. While we focus mostly on combining graph learning with optimization, our model can also be trained just to solve an optimization problem given complete information about the input. The main methodological difference is that we include more structure via a differentiable
-means layer instead of using more generic tools (e.g., feed-forward or attention layers). Another difference is that prior work mostly trains via reinforcement learning. By contrast, we use a differentiable approximation to the objective which removes the need for a policy gradient estimator. This is a benefit of our architecture, in which the final decision is fully differentiable in terms of the model parameters instead of requiring non-differentiable selection steps (as in
[25, 6, 27]). We give our end-to-end baseline (“GCN-e2e") the same advantage by training it with the same differentiable decision loss as our own model instead of forcing it to use noisier policy gradient estimates.Finally, some work uses deep architectures as a part of a clustering algorithm [39, 28, 21, 37], or includes a clustering step as a component of a deep network [18, 19]. While some techniques are similar, the overall task we address and framework we propose are entirely distinct. Our aim is not to cluster a Euclidean dataset (as in [39, 28, 21, 37]), or to solve perceptual grouping problems (as in [18, 19]). Rather, we propose an approach for graph optimization problems. Perhaps the closest of this work is Neural EM [19], which uses an unrolled EM algorithm to learn representations of visual objects. Rather than using EM to infer representations for objects, we use -means in graph embedding space to solve an optimization problem. There is also some work which uses deep networks for graph clustering [44, 46]. However, none of this work includes an explicit clustering algorithm in the network, and none consider our goal of integrating graph learning and optimization.
We consider settings that combine learning and optimization. The input is a graph , which is in some way partially observed. We will formalize our problem in terms of link prediction as an example, but our framework applies to other common graph learning problems (e.g., semi-supervised classification). In link prediction, the graph is not entirely known; instead, we observe only training edges . Let denote the adjacency matrix of the graph and denote the adjacency matrix with only the training edges. The learning task is to predict from . In domains we consider, the motivation for performing link prediction, is to solve a decision problem for which the objective depends on the full graph. Specifically, we have a decision variable , objective function , and a feasible set . We aim to solve the optimization problem
(1) |
However, is unobserved. We can also consider an inductive setting in which we observe graphs as training examples and then seek to predict edges for a partially observed graph from the same distribution. The most common approach to either setting is to train a model to reconstruct from using a standard loss function (e.g., cross-entropy), producing an estimate . The two-stage approach plugs into an optimization algorithm for Problem 1, maximizing .
We propose end-to-end models which map from directly to a feasible decision . The model will be trained to maximize , i.e., the quality of its decision evaluated on the training data (instead of a loss that measures purely predictive accuracy). One approach is to “learn away" the problem by training a standard model (e.g., a GCN) to map directly from to . However, this forces the model to entirely rediscover algorithmic concepts, while two-stage methods are able to exploit highly sophisticated optimization methods. We propose an alternative that embeds algorithmic structure into the learned model, getting the best of both worlds.
Our proposed ClusterNet system (Figure 1) merges two differentiable components into a system that is trained end-to-end. First, a graph embedding layer which uses and any node features to embed the nodes of the graph into . In our experiments, we use GCNs [26]. Second, a layer that performs differentiable optimization. This layer takes the continuous-space embeddings as input and uses them to produce a solution to the graph optimization problem. Specifically, we propose to use a layer that implements a differentiable version of -means clustering. This layer produces a soft assignment of the nodes to clusters, along with the cluster centers in embedding space.
The intuition is that cluster assignments can be interpreted as the solution to many common graph optimization problems. For instance, in community detection we can interpret the cluster assignments as assigning the nodes to communities. Or, in maxcut, we can use two clusters to assign nodes to either side of the cut. Another example is maximum coverage and related problems, where we attempt to select a set of nodes which cover (are neighbors to) as many other nodes as possible. This problem can be approximated by clustering the nodes into components and choosing nodes whose embedding is close to the center of each cluster. We do not claim that any of these problems is exactly reducible to -means. Rather, the idea is that including -means as a layer in the network provides a useful inductive bias. This algorithmic structure can be fine-tuned to specific problems by training the first component, which produces the embeddings, so that the learned representations induce clusterings with high objective value for the underlying downstream optimization task. We now explain the optimization layer of our system in greater detail. We start by detailing the forward and the backward pass for the clustering procedure, and then explain how the cluster assignments can be interpreted as solutions to the graph optimization problem.
Let denote the embedding of node and denote the center of cluster . denotes the degree to which node is assigned to cluster . In traditional -means, this is a binary quantity, but we will relax it to a fractional value such that for all . Specifically, we take , which is a soft-min assignment of each point to the cluster centers based on distance. While our architecture can be used with any norm
, we use the negative cosine similarity due to its strong empirical performance.
is an inverse-temperature hyperparameter; taking
recovers the standard -means assignment. We can optimize the cluster centers via an iterative process analogous to the typical -means updates by alternately setting(2) |
These iterates converge to a fixed point where remains the same between successive updates [30].
We will use the implicit function theorem to analytically differentiate through the fixed point that the forward pass -means iterates converge to, obtaining expressions for and . Previous work [16, 42]
has used the implicit function theorem to differentiate through the KKT conditions of optimization problems; here we take a more direct approach that characterizes the update process itself. Doing so allows us to backpropagate gradients from the decision loss to the component that produced the embeddings
. Define a function as(3) |
Now, are a fixed point of the iterates if . Applying the implicit function theorem yields that , from which
can be easily obtained via the chain rule.
Exact backward pass: We now examine the process of calculating . Both and can be easily calculated in closed form (see appendix). Computing the former requires time . Computing the latter requires time, after which it must be inverted (or else iterative methods must be used to compute the product with its inverse). This requires time since it is a matrix of size . While the exact backward pass may be feasible for some problems, it quickly becomes burdensome for large instances. We now propose a fast approximation.
Approximate backward pass: We start from the observation that
will often be dominated by its diagonal terms (the identity matrix). The off-diagonal entries capture the extent to which updates to one entry of
indirectly impact other entries via changes to the cluster assignments . However, when the cluster assignments are relatively firm, will not be highly sensitive to small changes to the cluster centers. We find to be typical empirically, especially since the optimal choice of the parameter (which controls the hardness of the cluster assignments) is typically fairly high. Under these conditions, we can approximate by its diagonal, . This in turn gives .We can formally justify this approximation when the clusters are relatively balanced and well-separated. More precisely, define to be the closest cluster to point . Proposition 1 (proved in the appendix) shows that the quality of the diagonal approximation improves exponentially quickly in the product of two terms: , the hardness of the cluster assignments, and , which measures how well separated the clusters are. (defined below) measures the balance of the cluster sizes. We assume for convenience that the input is scaled so .
Suppose that for all points , for all and that for all clusters , . Moreover, suppose that . Then, where is the operator 1-norm.
We now show that the approximate gradient obtained by taking can be calculated by unrolling a single iteration of the forward-pass updates from Equation 2 at convergence. Examining Equation 3, we see that the first term () is constant with respect to , since here is a fixed value. Hence,
which is just the update equation for . Since the forward-pass updates are written entirely in terms of differentiable functions, we can automatically compute the approximate backward pass with respect to (i.e., compute products with our approximations to and ) by applying standard autodifferentiation tools to the final update of the forward pass. Compared to computing the exact analytical gradients, this avoids the need to explicitly reason about or invert . The final iteration (the one which is differentiated through) requires time , linear in the size of the data.
Compared to differentiating by unrolling the entire sequence of updates in the computational graph (as has been suggested for other problems [15, 3, 48]), our approach has two key advantages. First, it avoids storing the entire history of updates and backpropagating through all of them. The runtime for our approximation is independent of the number of updates needed to reach convergence. Second, we can in fact use entirely non-differentiable operations to arrive at the fixed point
, e.g., heuristics for the
-means problem, stochastic methods which only examine subsets of the data, etc. This allows the forward pass to scale to larger datasets since we can use the best algorithmic tools available, not just those that can be explicitly encoded in the autodifferentiation tool’s computational graph.Having obtained the cluster assignments , along with the centers , in a differentiable manner, we need a way to (1) differentiably interpret the clustering as a soft solution to the optimization problem, (2) differentiate a relaxation of the objective value of the graph optimization problem in terms of that solution, and then (3) round to a discrete solution at test time. We give a generic means of accomplishing these three steps for two broad classes of problems: those that involve partitioning the graph into disjoint components, and those that that involve selecting a subset of nodes.
Partitioning: (1) We can naturally interpret the cluster assignments as a soft partitioning of the graph. (2) One generic continuous objective function (defined on soft partitions) follows from the random process of assigning each node
to a partition with probabilities given by
, repeating this process independently across all nodes. This gives the expected training decision loss , where denotes this random assignment. is now differentiable in terms of , and can be computed in closed form via standard autodifferentiation tools for many problems of interest (see Section 5). (3) At test time, we simply apply a hard maximum to to obtain each node’s assignment.Subset selection: (1) Here, it is less obvious how to obtain a subset of
nodes from the cluster assignments. Our continuous solution will be a vector
, , where . Intuitively, is the probability of including in the solution. Our approach obtains by placing greater probability mass on nodes that are near the cluster centers. Specifically, each center is endowed with one unit of probability mass, which it allocates to the points by as . The total probability allocated to node is . Since we may have , we passthrough a sigmoid function to cap the entries at 1; specifically, we take
where is a tunable parameter. If the resulting exceeds the budget constraint (), we output .(2) We interpret this solution in terms of the objective similarly as above. Specifically, we consider the result of drawing a discrete solution where every node is included (i.e., set to 1) independently with probability from the end of step (1). The training objective is then . For many problems, this can again be computed and differentiated through in closed form (see Section 5).
(3) At test time, we need a feasible discrete vector ; note that independently rounding the individual entries may produce a vector with more than ones. Here, we apply a fairly generic approach based on pipage rounding [1], a randomized rounding scheme which has been applied to many problems (particularly those with submodular objectives). Pipage rounding can be implemented to produce a random feasible solution in time [23]; in practice we round several times and take the solution with the best decision loss on the observed edges. While pipage rounding has theoretical guarantees only for specific classes of functions, we find it to work well even in other domains (e.g., facility location). However, more domain-specific rounding methods can be applied if available.
Learning + optimization | Optimization | |||||||||
---|---|---|---|---|---|---|---|---|---|---|
cora | cite. | prot. | adol | fb | cora | cite. | prot. | adol | fb | |
ClusterNet | 0.54 | 0.55 | 0.29 | 0.49 | 0.30 | 0.72 | 0.73 | 0.52 | 0.58 | 0.76 |
GCN-e2e | 0.16 | 0.02 | 0.13 | 0.12 | 0.13 | 0.19 | 0.03 | 0.16 | 0.20 | 0.23 |
Train-CNM | 0.20 | 0.42 | 0.09 | 0.01 | 0.14 | 0.08 | 0.34 | 0.05 | 0.57 | 0.77 |
Train-Newman | 0.09 | 0.15 | 0.15 | 0.15 | 0.08 | 0.20 | 0.23 | 0.29 | 0.30 | 0.55 |
Train-SC | 0.03 | 0.02 | 0.03 | 0.23 | 0.19 | 0.09 | 0.05 | 0.06 | 0.49 | 0.61 |
GCN-2stage-CNM | 0.17 | 0.21 | 0.18 | 0.28 | 0.13 | - | - | - | - | - |
GCN-2stage-Newman | 0.00 | 0.00 | 0.00 | 0.14 | 0.02 | - | - | - | - | - |
GCN-2stage-SC | 0.14 | 0.16 | 0.04 | 0.31 | 0.25 | - | - | - | - | - |
Learning + optimization | Optimization | |||||||||
cora | cite. | prot. | adol | fb | cora | cite. | prot. | adol | fb | |
ClusterNet | 10 | 14 | 6 | 6 | 4 | 9 | 14 | 6 | 5 | 3 |
GCN-e2e | 12 | 15 | 8 | 6 | 5 | 11 | 14 | 7 | 6 | 5 |
Train-greedy | 14 | 16 | 8 | 8 | 6 | 9 | 14 | 7 | 6 | 5 |
Train-gonzalez | 12 | 17 | 8 | 6 | 6 | 10 | 15 | 7 | 7 | 3 |
GCN-2Stage-greedy | 14 | 17 | 8 | 7 | 6 | - | - | - | - | - |
GCN-2Stage-gonzalez | 13 | 17 | 8 | 6 | 6 | - | - | - | - | - |
We now show experiments on domains that combine link prediction with optimization.
Learning problem: In link prediction, we observe a partial graph and aim to infer which unobserved edges are present. In each of the experiments, we hold out of the edges in the graph, with observed during training. We used a graph dataset which is not included in our results to set our method’s hyperparameters, which were kept constant across datasets (see appendix for details). The learning task is to use the training edges to predict whether the remaining edges are present, after which we will solve an optimization problem on the predicted graph. The objective is to find a solution with high objective value measured on the entire graph, not just the training edges.
Optimization problems: We consider two optimization tasks, one from each of the broad classes introduced above. First, community detection aims to partition the nodes of the graph into distinct subgroups which are dense internally, but with few edges across groups. Formally, the objective is to find a partition maximizing the modularity [33], defined as . Here, is the degree of node , and is 1 if node is assigned to community and zero otherwise. This measures the number of edges within communities compared to the expected number if edges were placed randomly. Our clustering module has one cluster for each of the communities. Defining to be the modularity matrix with entries , our training objective (the expected value of a partition sampled according to ) is .
Second, facility location, where the decision problem is to select a subset of nodes from the graph, minimizing the maximum distance from any node to a facility (selected node). Letting be the shortest path length from a vertex to a set of vertices , the objective is . To obtain the training loss, we take two steps. First, we replace by , where denotes drawing a set from the product distribution with marginals . This can easily be calculated in closed form [23]. Second, we replace the with a softmin.
Baseline learning methods: We instantiate ClusterNet using a 2-layer GCN for node embeddings, followed by a clustering layer. We compare to three families of baselines. First, GCN-2stage, the two stage approach which first trains a model for link prediction, and then inputs the predicted graph into an optimization algorithm. For link prediction, we use the GCN-based system of [35] (we also adopt their training procedure, including negative sampling and edge dropout). For the optimization algorithms, we use standard approaches for each domain, outlined below. Second, “train", which runs each optimization algorithm only on the observed training subgraph (without attempting any link prediction). Third, GCN-e2e, an end-to-end approach which does not include explicit algorithm structure. We train a GCN-based network to directly predict the final decision variable ( or ) using the same training objectives as our own model. Empirically, we observed best performance with a 2-layer GCN. This baseline allows us to isolate the benefits of including algorithmic structure.
Baseline optimization approaches: In each domain, we compare to expert-designed optimization algorithms found in the literature. In community detection, we compare to “CNM" [9], an agglomerative approach, “Newman", an approach that recursively partitions the graph [32]
, and “SC", which performs spectral clustering
[41] on the modularity matrix. In facility location, we compare to “greedy", the common heuristic of iteratively selecting the point with greatest marginal improvement in objective value, and “gonzalez" [17], an algorithm which iteratively selects the node furthest from the current set (which attains the optimal 2-approximation).Datasets: We use several standard graph datasets: cora [36] (a citation network with 2,708 nodes), citeseer [36] (a citation network with 3,327 nodes), protein [12] (a protein interaction network with 3,133 nodes), adol [10] (an adolescent social network with 2,539 vertices), and fb [11, 29] (an online social network with 2,888 nodes). For facility location, we use the largest connected component of the graph (since otherwise distances may be infinite). Cora and citeseer have node features (based on a bag-of-words representation of the document), which were given to all GCN-based methods. For the other datasets, we generated unsupervised node2vec features [20] using the training edges.
We start out with results for the combined link prediction and optimization problem. Table 1 shows the objective value obtained by each approach on the full graph for community detection, with Table 2 showing facility location. We focus first on the “Learning + optimization" column which shows the combined link prediction/optimization task. We use ; is very similar and may be found in the appendix. ClusterNet outperforms the baselines in nearly all cases, often substantially. GCN-e2e learns to produce nontrivial solutions, often rivaling the other baseline methods. However, the explicit structure used by our approach ClusterNet results in much higher performance.
Interestingly, the two stage approach sometimes performs worse than the train-only baseline which optimizes just based on the training edges (without attempting to learn). This indicates that approaches which attempt to accurately reconstruct the graph can sometimes miss qualities which are important for optimization, and in the worst case may simply add noise that overwhelms the signal in the training edges. In order to confirm that the two-stage method learned to make meaningful predictions, in the appendix we give AUC values for each dataset. The average AUC value is 0.7584, indicating that the two stage model does learn to make nontrivial predictions. However, the small amount of training data (only 40% of edges are observed) prevents it from perfectly reconstructing the true graph. This drives home the point that decision-focused learning methods can offer substantial benefits when highly accurate predictions are out of reach even for sophisticated methods.
We next examine an optimization-only task where the entire graph is available as input (the “Optimization" column of Tables 1 and Table 2). This tests ClusterNet’s ability to learn to solve combinatorial optimization problems compared to expert-designed algorithms, even when there is no partial information or learning problem in play. We find that ClusterNet is highly competitive, meeting and frequently exceeding the baselines. It is particularly effective for community detection, where we observe large (> 3x) improvements compared to the best baseline on some datasets. At facility location, our method always at least ties the baselines, and frequently improves on them. These experiments provide evidence that our approach, which is automatically specialized during training to optimize on a given graph, can rival and exceed hand-designed algorithms from the literature. The alternate learning approach, GCN-e2e, at best ties the baselines and typically underperforms. This underscores the benefit of including algorithmic structure as a part of the end-to-end architecture.
Community detection | Facility location | ||||||||
synthetic | pubmed | synthetic | pubmed | ||||||
No finetune | Avg. | % | Avg. | % | No finetune | Avg. | % | Avg. | % |
ClusterNet | 0.57 | 26/30 | 0.30 | 7/8 | ClusterNet | 7.90 | 25/30 | 7.88 | 3/8 |
GCN-e2e | 0.26 | 0/30 | 0.01 | 0/8 | GCN-e2e | 8.63 | 11/30 | 8.62 | 1/8 |
Train-CNM | 0.14 | 0/30 | 0.16 | 1/8 | Train-greedy | 14.00 | 0/30 | 9.50 | 1/8 |
Train-Newman | 0.24 | 0/30 | 0.17 | 0/8 | Train-gonzalez | 10.30 | 2/30 | 9.38 | 1/8 |
Train-SC | 0.16 | 0/30 | 0.04 | 0/8 | 2Stage-greedy | 9.60 | 3/30 | 10.00 | 0/8 |
2Stage-CNM | 0.51 | 0/30 | 0.24 | 0/8 | 2Stage-gonz. | 10.00 | 2/30 | 6.88 | 5/8 |
2Stage-Newman | 0.01 | 0/30 | 0.01 | 0/8 | ClstrNet-1train | 7.93 | 12/30 | 7.88 | 2/8 |
2Stage-SC | 0.52 | 4/30 | 0.15 | 0/8 | |||||
ClstrNet-1train | 0.55 | 0/30 | 0.25 | 0/8 | |||||
Finetune | Finetune | ||||||||
ClstrNet-ft | 0.60 | 20/30 | 0.40 | 2/8 | ClstrNet-ft | 8.08 | 12/30 | 8.01 | 3/8 |
ClstrNet-ft-only | 0.60 | 10/30 | 0.42 | 6/8 | ClstrNet-ft-only | 7.84 | 16/30 | 7.76 | 4/8 |
Next, we investigate whether our method can learn generalizable strategies for optimization: can we train the model on one set of graphs drawn from some distribution and then apply it to unseen graphs? We consider two graph distributions. First, a synthetic generator introduced by [43], which is based on the spatial preferential attachment model [5] (details in the appendix). We use 20 training graphs, 10 validation, and 30 test. Second, a dataset obtained by splitting the pubmed graph into 20 components using metis [24]. We fix 10 training graphs, 2 validation, and 8 test. At test time, only 40% of the edges in each graph are revealed, matching the “Learning + optimization" setup above.
Table 3 shows the results. To start out, we do not conduct any fine-tuning to the test graphs, evaluating entirely the generalizability of the learned representations. ClusterNet outperforms all baseline methods on all tasks, except for facility location on pubmed where it places second. We conclude that the learned model successfully generalizes to completely unseen graphs. We next investigate (in the “finetune" section of Table 3) whether ClusterNet’s performance can be further improved by fine-tuning to the 40% of observed edges for each test graph (treating each test graph as an instance of the link prediction problem from Section 5.1, but with the model initialized to the parameters of the model learned over the training graphs). We see that ClusterNet’s performance typically improves, indicating that fine-tuning can allow us to extract additional gains if extra training time is available.
Interestingly, only fine-tuning (not using the training graphs at all) yields similar performance (the row “ClstrNet-ft-only"). While our earlier results show that ClusterNet can learn generalizable strategies, doing so may not be necessary when there is the opportunity to fine-tune. This allows a user to trade off between quality and runtime: without fine-tuning, applying our method at test time requires just a single forward pass through the network, which is extremely efficient. If additional computational cost at test time is acceptable, fine-tuning can be used to improve performance. Complete runtimes for all methods are shown in the appendix. ClusterNet’s forward pass (i.e., no fine-tuning) is extremely efficient, requiring at most 0.23 seconds on the largest network, and is always faster than the baselines (on identical hardware). Fine-tuning requires longer, on par with the slowest baseline.
We lastly investigate the reason why pretraining provides little to no improvement over only fine-tuning. Essentially, we find that ClusterNet is extremely sample-efficient: using only a single training graph results in nearly as good performance as the full training set (and still better than all of the baselines), as seen in the “ClstrNet-1train" row of Table 3. That is, ClusterNet is capable of learning optimization strategies that generalize with strong performance to completely unseen graphs after observing only a single training example. This underscores the benefits of including algorithmic structure as a part of the architecture, which guides the model towards learning meaningful strategies.
Neural expectation maximization.
In NeurIPS, 2017.Finding community structure in networks using the eigenvectors of matrices.
Physical review E, 74(3):036104, 2006.Unsupervised deep embedding for clustering analysis.
In International conference on machine learning, pages 478–487, 2016.Conditional random fields as recurrent neural networks.
InProceedings of the IEEE international conference on computer vision
, pages 1529–1537, 2015.