Pointer Graph Networks

06/11/2020 ∙ by Petar Veličković, et al. ∙ Google 33

Graph neural networks (GNNs) are typically applied to static graphs that are assumed to be known upfront. This static input structure is often informed purely by insight of the machine learning practitioner, and might not be optimal for the actual task the GNN is solving. In absence of reliable domain expertise, one might resort to inferring the latent graph structure, which is often difficult due to the vast search space of possible graphs. Here we introduce Pointer Graph Networks (PGNs) which augment sets or graphs with additional inferred edges for improved model expressivity. PGNs allow each node to dynamically point to another node, followed by message passing over these pointers. The sparsity of this adaptable graph structure makes learning tractable while still being sufficiently expressive to simulate complex algorithms. Critically, the pointing mechanism is directly supervised to model long-term sequences of operations on classical data structures, incorporating useful structural inductive biases from theoretical computer science. Qualitatively, we demonstrate that PGNs can learn parallelisable variants of pointer-based data structures, namely disjoint set unions and link/cut trees. PGNs generalise out-of-distribution to 5x larger test inputs on dynamic graph connectivity tasks, outperforming unrestricted GNNs and Deep Sets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 18

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Graph neural networks (GNNs) have seen recent successes in applications such as quantum chemistry gilmer2017neural , social networks qiu2018deepinf and physics simulations battaglia2016interaction ; kipf2018neural ; sanchez2020learning . For problems where a graph structure is known (or can be approximated), GNNs thrive. This places a burden upon the practitioner: which graph structure should be used? In many applications, particularly with few nodes, fully connected graphs work well, but on larger problems, sparsely connected graphs are required. As the complexity of the task imposed on the GNN increases and, separately, the number of nodes increases, not allowing the choice of graph structure to be data-driven limits the applicability of GNNs.

Classical algorithms cormen2009introduction span computations that can be substantially more expressive than typical machine learning subroutines (e.g. matrix multiplications), making them hard to learn, and a good benchmark for GNNs chen2020can ; dwivedi2020benchmarking . Prior work has explored learning primitive algorithms (e.g. arithmetic) by RNNs zaremba2014learning ; kaiser2015neural ; trask2018neural , neural approximations to NP-hard problems vinyals2015pointer ; kool2018attention , making GNNs learn (and transfer between) graph algorithms velivckovic2019neural ; georgiev2020neural , recently recovering a single neural core yan2020neural capable of sorting, path-finding and binary addition. Here, we propose Pointer Graph Networks (PGNs), a framework that further expands the space of general-purpose algorithms that can be neurally executed.

Idealistically, one might imagine the graph structure underlying GNNs should be fully learnt from data, but the number of possible graphs grows super-exponentially in the number of nodes stanley2009acyclic , making searching over this space a challenging (and interesting) problem. In addition, applying GNNs further necessitates the learning of the messages passed on top of the learnt structure, making it hard to disambiguate errors from having either the wrong structure or the wrong message. Several approaches have been proposed for searching over the graph structure li2018learning ; grover2018graphite ; kipf2018neural ; kazi2020differentiable ; wang2019dynamic ; franceschi2019learning ; serviansky2020set2graph ; jiang2019semi ; liu2020nonlocal . Our PGNs take a hybrid approach, assuming that the practitioner may specify part of the graph structure, and then adaptively learn a linear number of pointer edges between nodes (as in vinyals2015pointer for RNNs). The pointers are optimised through direct supervision on classical data structures cormen2009introduction . We empirically demonstrate that PGNs further increase GNN expressivity beyond those with static graph structures garg2020generalization , without sacrificing computational cost or sparsity for this added flexibility in graph structure.

Unlike prior work on algorithm learning with GNNs xu2019can ; velivckovic2019neural ; yan2020neural , we consider algorithms that do not directly align to dynamic programming (making them inherently non-local xu2018powerful ) and, crucially, the optimal known algorithms rely upon pointer-based data structures. The pointer connectivity of these structures dynamically changes as the algorithm executes. We learn algorithms that operate on two distinct data structures—disjoint set unions galler1964improved and link/cut trees sleator1983data . We show that baseline GNNs are unable to learn the complicated, data-driven manipulations that they perform and, through PGNs, show that extending GNNs with learnable dynamic pointer links enables such modelling.

Finally, the hallmark of having learnt an algorithm well, and the purpose of an algorithm in general, is that it may be applied to a wide range of input sizes. Thus by learning these algorithms we are able to demonstrate generalisation far beyond the size of training instance included in the training set.

Our PGN work presents three main contributions: we expand neural algorithm execution xu2019can ; velivckovic2019neural ; yan2020neural to handle algorithms relying on complicated data structures; we provide a novel supervised method for sparse and efficient latent graph inference; and we demonstrate that our PGN model can deviate from the structure it is imitating, to produce useful and parallelisable data structures.

2 Problem setup and PGN architecture

Figure 1: High-level overview of the pointer graph network (PGN) dataflow. Using descriptions of entity operations (

), the PGN re-estimates latent features

, masks , and (asymmetric) pointers . The symmetrised pointers, , are then used as edges for a GNN that computes next-step latents, , continuing the process. The latents may be used to provide answers, , to queries about the underlying data. We highlight masked out nodes in red, and modified pointers and latents in blue. See Appendix A for a higher-level visualisation, along with PGN’s gradient flow.

Problem setup

We consider the following sequential supervised learning setting: Assume an underlying set of

entities. Given are sequences of inputs where for

is defined by a list of feature vectors

for every entity . We will suggestively refer to as an operation on entity at time

. The task consists now in predicting target outputs

based on operation sequence up to .

A canonical example for this setting is characterising graphs with dynamic connectivity, where inputs indicate edges being added/removed at time , and target outputs are binary indicators of whether pairs of vertices are connected. We describe this problem in-depth in Section 3.

Pointer Graph Networks

As the above sequential prediction task is defined on the underlying, un-ordered set of entities, any generalising prediction model is required to be invariant under permutations of the entity set vinyals2015order ; murphy2018janossy ; zaheer2017deep . Furthermore, successfully predicting target in general requires the prediction model to maintain a robust data structure to represent the history of operations for all entities throughout their lifetime. In the following we present our proposed prediction model, the Pointer Graph Network (PGN), that combines these desiderata in an efficient way.

At each step , our PGN model computes latent features for each entity . Initially, . Further, the PGN model determines dynamic pointers—one per entity and time step111Chosen to match semantics of C/C++ pointers; a pointer of a particular type may have exactly one endpoint.—which may be summarised in a pointer adjacency matrix . Pointers correspond to undirected edges between two entities: indicating that one of them points to the other. is a binary symmetric matrix, indicating locations of pointers as 1-entries. Initially, we assume each node points to itself: . A summary of the coming discussion may be found in Figure 1.

The Pointer Graph Network follows closely the encode-process-decode hamrick2018relational paradigm. The current operation is encoded together with the latents in each entity using an encoder network :

(1)

after which the derived entity representations are given to a processor network, , which takes into account the current pointer adjacency matrix as relational information:

(2)

yielding next-step latent features, ; we discuss choices of below. These latents can be used to answer set-level queries using a decoder network :

(3)

where is any permutation-invariant readout aggregator, such as summation or maximisation.

Many efficient data structures only modify a small222Typically on the order of elements. subset of the entities at once cormen2009introduction . We can incorporate this inductive bias into PGNs by masking their pointer modifications through a sparse mask for each node that is generated by a masking network :

(4)

where the output activation function for

is the logistic sigmoid function, enforcing the probabilistic interpretation. In practice, we threshold the output of

as follows:

(5)

The PGN now re-estimates the pointer adjacency matrix using . To do this, we leverage self-attention vaswani2017attention , computing all-pairs dot products between queries and keys :

(6)

where and

are learnable linear transformations, and

is the dot product operator. indicates the relevance of entity to entity , and we derive the pointer for by choosing entity with the maximal . To simplify the dataflow, we found it beneficial to symmetrise this matrix:

(7)

where is the indicator function, denotes the pointers before symmetrisation, denotes logical disjunction between the two operands, and corresponds to negating the mask. Nodes and will be linked together in (i.e., ) if is the most relevant to , or vice-versa.

Unlike prior work which relied on the Gumbel trick kazi2020differentiable ; kipf2018neural , we will provide direct supervision with respect to ground-truth pointers, , of a target data structure. Applying effectively masks out parts of the computation graph for Equation 6, yielding a graph attention network-style update velickovic2017gat . Further, our data-driven conditional masking is reminiscent of neural execution engines yan2020neural . While Equation 6 involves computation of coefficients, this constraint exists only at training time; at test time, computing entries of reduces to 1-NN queries in key/query space, which can be implemented storage-efficiently pritzel2017neural . The attention mechanism may also be sparsified, as in kitaev2020reformer .

PGN components and optimisation

In our implementation, the encoder, decoder, masking and key/query networks are all linear layers of appropriate dimensionality—placing most of the computational burden on the processor, , which explicitly leverages the computed pointer information.

In practice, is realised as a graph neural network (GNN), operating over the edges specified by . If an additional input graph between entities is provided upfront, then its edges may be also included, or even serve as a completely separate “head” of GNN computation.

Echoing the results of prior work on algorithmic modelling with GNNs velivckovic2019neural , we recovered strongest performance when using message passing neural networks (MPNNs) gilmer2017neural for , with elementwise maximisation aggregator. Hence, the computation of Equation 2 is realised as follows:

(8)

where and are linear layers producing vector messages. Accordingly, we found that elementwise-max was the best readout operation for in Equation 3

; while other aggregators (e.g. sum) performed comparably, maximisation had the least variance in the final result. This is in line with its

alignment to the algorithmic task, as previously studied richter2020normalized ; xu2019can

. We apply ReLU to the outputs of

and .

Besides the downstream query loss in (Equation 3), PGNs optimise two additional losses, using ground-truth information from the data structure they are imitating: cross-entropy of the attentional coefficients (Equation 6) against the ground-truth pointers, , and binary cross-entropy of the masking network (Equation 4) against the ground-truth entities being modified at time . This provides a mechanism for introducing domain knowledge in the learning process. At training time we readily apply teacher forcing, feeding ground-truth pointers and masks as input whenever appropriate. We allow gradients to flow from these auxiliary losses back in time through the latent states and (see Appendix A for a diagram of the backward pass of our model).

PGNs share similarities with and build on prior work on using latent -NN graphs franceschi2019learning ; kazi2020differentiable ; wang2019dynamic , primarily through addition of pointer-based losses against a ground-truth data structure and explicit entity masking—which will prove critical to generalising out-of-distribution in our experiments.

3 Task: Dynamic graph connectivity

We focus on instances of the dynamic graph connectivity setup to illustrate the benefits of PGNs. Even the simplest of structural detection tasks are known to be very challenging for GNNs chen2020can , motivating dynamic connectivity as one example of a task where GNNs are unlikely to perform optimally.

Dynamic connectivity querying is an important subroutine in computer science, e.g. when computing minimum spanning trees—deciding if an edge can be included in the solution without inducing cycles kruskal1956shortest , or maximum flow—detecting existence of paths from source to sink with available capacity dinic1970algorithm .

Formally, we consider undirected and unweighted graphs of nodes, with evolving edge sets through time; we denote the graph at time by . Initially, we assume the graphs to be completely disconnected: . At each step, an edge may be added to or removed from , yielding , where is the symmetric difference operator.

A connectivity query is then defined as follows: for a given pair of vertices , does there exist a path between them in ? This yields binary ground-truth query answers which we can supervise towards. Several classical data structures exist for answering variants of connectivity queries on dynamic graphs, on-line, in sub-linear time galler1964improved ; shiloach1981line ; sleator1983data ; tarjan1984finding of which we will discuss two. All the derived inputs and outputs to be used for training PGNs are summarised in Appendix B.

Incremental graph connectivity with disjoint-set unions

We initially consider incremental graph connectivity: edges can only be added to the graph. Knowing that edges can never get removed, it is sufficient to combine connected components through set union. Therefore, this problem only requires maintaining disjoint sets, supporting an efficient union(u, v) operation that performs a union of the sets containing and . Querying connectivity then simply requires checking whether and are in the same set, requiring an efficient find(u) operation which will retrieve the set containing .

To put emphasis on the data structure modelling, we consider a combined operation on : first, query whether and are connected, then perform a union on them if they are not. Pseudocode for this query-union(u, v) operation is given in Figure 2 (Right). query-union is a key component of important algorithms, such as Kruskal’s algorithm kruskal1956shortest for minimum spanning trees.—which was not modellable by prior work velivckovic2019neural .

Figure 2: Pseudocode of DSU operations; initialisation and find(u) (Left), union(u, v) (Middle) and query-union(u, v), giving ground-truth values of (Right). All manipulations of ground-truth pointers ( for node ) are in blue; the path compressionheuristic is highlighted in green.

The tree-based disjoint-set union (DSU) data structure galler1964improved is known to yield optimal complexity fredman1989cell for this task. DSU represents sets as rooted trees—each node, , has a parent pointer, —and the set identifier will be its root node, , which by convention points to itself (). Hence, find(u) reduces to recursively calling find(pi[u]) until the root is reached—see Figure 2 (Left). Further, path compression tarjan1984worst is applied: upon calling find(u), all nodes on the path from to will point to . This self-organisation substantially reduces future querying time along the path.

Calling union(u, v) reduces to finding the roots of and ’s sets, then making one of these roots point to the other. To avoid pointer ambiguity, we assign a random priority, , to every node at initialisation time, then always preferring the node with higher priority as the new root—see Figure 2 (Middle). This achieves worst-case time complexity of for nodes and calls to find(u) tarjan1984worst , which, while not optimal333Making the priorities size-dependent recovers the optimal amortised time complexity of per operation tarjan1975efficiency , where is the inverse Ackermann function; essentially a constant for all sensible values of ., is still highly efficient.

Casting to the framework of Section 2: at each step , we call query-union(u, v), specified by operation descriptions , containing the nodes’ priorities, along with a binary feature indicating which nodes are and . The corresponding output indicates the return value of the query-union(u, v). Finally, we provide supervision for the PGN’s (asymmetric) pointers, , by making them match the ground-truth DSU’s pointers, (i.e., iff and otherwise.). Ground-truth mask values, , are set to for only the paths from and to their respective roots—no other node’s state is changed (i.e., for the remaining nodes).

Before moving on, consider how having access to these pointers helps the PGN answer queries, compared to a baseline without them: checking connectivity of and boils down to following their pointer links and checking if they meet, which drastically relieves learning pressure on its latent state.

Fully dynamic tree connectivity with link/cut trees

We move on to fully dynamic connectivity—edges may now be removed, and hence set unions are insufficient to model all possible connected component configurations. The problem of fully dynamic tree connectivity—with the restriction that is acyclic—is solvable in amortised time by link/cut trees (LCTs) sleator1983data , elegant data structures that maintain forests of rooted trees, requiring only one pointer per node.

The operations supported by LCTs are: find-root(u) retrieves the root of node ; link(u, v) links nodes and , with the precondition that is the root of its own tree; cut(v) removes the edge from to its parent; evert(u) re-roots ’s tree, such that becomes the new root.

LCTs also support efficient path-aggregate queries on the (unique) path from to , which is very useful for reasoning on dynamic trees. Originally, this speeded up bottleneck computations in network flow algorithms dinic1970algorithm . Nowadays, the LCT has found usage across online versions of many classical graph algorithms, such as minimum spanning forests and shortest paths tarjan2010dynamic . Here, however, we focus only on checking connectivity of and ; hence find-root(u) will be sufficient for our queries.

Similarly to our DSU analysis, here we will compress updates and queries into one operation, query-toggle(u, v), which our models will attempt to support. This operation first calls evert(u), then checks if and are connected: if they are not, adding the edge between them wouldn’t introduce cycles (and is now the root of its tree), so link(u, v) is performed. Otherwise, cut(v) is performed—it is guaranteed to succeed, as is not going to be the root node. Pseudocode of query-toggle(u, v), along with visualising the effects of running it, is provided in Figure 3.

We encode each query-toggle(u, v) as . Random priorities, , are again used; this time to determine whether or will be the node to call evert on, breaking ambiguity. As for DSU, we supervise the asymmetric pointers, , using the ground-truth LCT’s pointers, and ground-truth mask values, , are set to only if is modified in the operation at time . Link/cut trees require elaborate bookkeeping; for brevity, we delegate further descriptions of their operations to Appendix C, and provide our C++ implementation of the LCT in the supplementary material.

0 1  
Figure 3: Left: Pseudocode of the query-toggle(u, v) operation, which will be handled by our models; Right: Effect of calling query-toggle(h, d) on a specific forest (Top), followed by calling query-toggle(g, e) (Bottom). Edges affected by evert (blue), link (green), and cut (red) are highlighted. N.B. this figure represents changes to the forest being modelled, and not the underlying LCT pointers; see Appendix C for more information on pointer manipulation.

4 Evaluation and results

Experimental setup

As in velivckovic2019neural ; yan2020neural , we evaluate out-of-distribution generalisation—training on operation sequences for small input sets ( entities with operations), then testing on up to larger inputs ( and ). In line with velivckovic2019neural , we generate 70 sequences for training, and 35 sequences across each test size category for testing. We generate operations by sampling input node pairs uniformly at random at each step ; query-union(u, v) or query-toggle(u, v) is then called to generate ground-truths , and . This is known to be a good test-bed for spanning many possible DSU/LCT configurations and benchmarking various data structures (see e.g. Section 3.5 in tarjan2010dynamic ).

All models compute latent features in each layer, and are trained for epochs using Adam kingma2014adam with learning rate of . We perform early stopping, retrieving the model which achieved the best query F score on a validation set of 35 small sequences (). We attempted running the processor (Equation 8) for more than one layer between steps, and using a separate GNN for computing pointers—neither yielding significant improvements.

We evaluate the PGN model of Section 2 against three baseline variants, seeking to illustrate the benefits of its various graph inference components. We describe the baselines in turn.

Deep Sets zaheer2017deep independently process individual entities, followed by an aggregation layer for resolving queries. This yields an only-self-pointer mechanism, for all , within our framework. Deep Sets are popular for set-based summary statistic tasks. Only the query loss is used.

(Unrestricted) GNNs gilmer2017neural ; santoro2017simple ; xu2019can make no prior assumptions on node connectivity, yielding an all-ones adjacency matrix: for all . Such models are a popular choice when relational structure is assumed but not known. Only the query loss is used.

PGN without masks (PGN-NM) remove the masking mechanism of Equations 47. This repeatedly overwrites all pointers, i.e. for all . PGN-NM is related to a directly-supervised variant of the prior art in learnable -NN graphs franceschi2019learning ; kazi2020differentiable ; wang2019dynamic . PGN-NM has no masking loss in its training.

These models are universal approximators on permutation-invariant inputs xu2019can ; zaheer2017deep , meaning they are all able to model the DSU and LCT setup perfectly. However, unrestricted GNNs suffer from oversmoothing as graphs grow zhao2019pairnorm ; chen2019measuring ; wang2019improving ; luan2019break , making it harder to perform robust credit assignment of relevant neighbours. Conversely, Deep Sets must process the entire operation history within their latent state, in a manner that is decomposable after the readout—which is known to be hard xu2019can .

To assess the utility of the data structure learnt by the PGN mechanism, as well as its performance limits, we perform two tests with fixed pointers, supervised only on the query:

PGN-Ptrs: first, a PGN model is learnt on the training data. Then it is applied to derive and fix the pointers at all steps for all training/validation/test inputs. In the second phase, a new GNN over these inferred pointers is learnt and evaluated, solely on query answering.

Oracle-Ptrs: learn a query-answering GNN over the ground-truth pointers . Note that this setup is, especially for link/cut trees, made substantially easier than PGN: the model no longer needs to imitate the complex sequences of pointer rotations of LCTs.

Results and discussion

Our results, summarised in Table 1, clearly indicate outperformance and generalisation of our PGN model, especially on the larger-scale test sets. Competitive performance of PGN-Ptrs implies that the PGN models a robust data structure that GNNs can readily reuse. While the PGN-NM model is potent in-distribution, its performance rapidly decays once it is tasked to model larger sets of pointers at test time. Further, on the LCT task, baseline models often failed to make very meaningful advances at all—PGNs are capable of surpassing this limitation, with a result that even approaches ground-truth pointers with increasing input sizes.

We corroborate some of these results by evaluating pointer accuracy (w.r.t. ground truth) with the analysis in Table 2. Without masking, the PGNs fail to meaningfully model useful pointers on larger test sets, whereas the masked PGN consistently models the ground-truth to at least accuracy. Mask accuracies remain consistently high, implying that the inductive bias is well-respected.

Model Disjoint-set union Link/cut tree
GNN
Deep Sets
PGN-NM
PGN
PGN-Ptrs
Oracle-Ptrs
Table 1: F scores on the dynamic graph connectivity tasks for all models considered, on five random seeds. All models are trained on and tested on larger test sets.
Accuracy of Disjoint-set union Link/cut tree
Pointers (NM)
Pointers
Masks
Table 2: Pointer and mask accuracies of the PGN model w.r.t. ground-truth pointers.

Using the max readout in Equation 3 provides an opportunity for a qualitative analysis of the PGN’s credit assignment. DSU and LCT focus on paths from the two nodes operated upon to their roots in the data structure, implying they are highly relevant to queries. As each global embedding dimension is pooled from exactly one node, in Appendix D we visualise how often these relevant nodes appear in the final embedding—revealing that the PGN’s inductive bias amplifies their credit substantially.

Rollout analysis of PGN pointers

Figure 4: Visualisation of a PGN rollout on the DSU setup, for a pathological ground-truth case of repeated union(i, i+1) (Left). The first few pointers in are visualised (Middle) as well as the final state (Right)—the PGN produced a valid DSU at all times, but shallower than ground-truth.

Tables 12 indicate a substantial deviation of PGNs from the ground-truth pointers, , while maintaining strong query performance. These learnt pointers are still meaningful: given our 1-NN-like inductive bias, even minor discrepancies that result in modelling invalid data structures can have negative effects on the performance, if done uninformedly.

We observe the learnt PGN pointers on a pathological DSU example (Figure 4). Repeatedly calling query-union(i, i+1) with nodes ordered by priority yields a linearised DSU444Note that this is not problematic for the ground-truth algorithm; it is constructed with a single-threaded CPU execution model, and any subsequent find(i) calls would result in path compression, amortising the cost.. Such graphs (of large diameter) are difficult for message propagation with GNNs. During rollout, the PGN models a correct DSU at all times, but halving its depth—easing GNN usage and GPU parallelisability.

Effectively, the PGN learns to use the query supervision from to “nudge” its pointers in a direction more amenable to GNNs, discovering parallelisable data structures which may substantially deviate from the ground-truth . Note that this also explains the reduced performance gap of PGNs to Oracle-Ptrs on LCT; as LCTs cannot apply path-compression-like tricks, the ground-truth LCT pointer graphs are expected to be of substantially larger diameters as test set size increases.

5 Conclusions

We presented pointer graph networks (PGNs), a method for simultaneously learning a latent pointer-based graph and using it to answer challenging algorithmic queries. Introducing step-wise structural supervision from classical data structures, we incorporated useful inductive biases from theoretical computer science, enabling outperformance of standard set-/graph-based models on two dynamic graph connectivity tasks, known to be challenging for GNNs. Out-of-distribution generalisation, as well as interpretable and parallelisable data structures, have been recovered by PGNs.

Broader Impact

It is our opinion that this work does not have a specific immediate and predictable real-world application and hence no specific ethical risks associated. However, PGN offers a natural way to introduce domain knowledge (borrowed from data structures) into the learning of graph neural networks, which has the potential of improving their performance, particularly when dealing with large graphs. Graph neural networks have seen a lot of successes in modelling diverse real world problems, such as social networks, quantum chemistry, computational biomedicine, physics simulations and fake news detection. Therefore, indirectly, through improving GNNs, our work could impact these domains and carry over any ethical risks present within those works.

We would like to thank the developers of JAX jax2018github and Haiku haiku2020github . Further, we are very thankful to Danny Sleator for his invaluable advice on the theoretical aspects and practical applications of link/cut trees, and Abe Friesen, Daan Wierstra, Heiko Strathmann, Jess Hamrick and Kim Stachenfeld for reviewing the paper prior to submission.

References

Appendix A Pointer graph networks gradient computation

Figure 5: Detailed view of the dataflow within the PGN model, highlighting inputs (outlined), objects optimised against ground-truths (query answers , masks and pointers ) (shaded) and all intermediate latent states ( and ). Solid lines indicate differentiable computation with gradient flow in red, while dotted lines indicate non-differentiable opeations (teacher-forced at training time). N.B. This computation graph should also include edges from into the query answers, masks and pointers (as it gets concatenated with )—we omit these edges for clarity.

To provide a more high-level overview of the PGN model’s dataflow across all relevant variables (and for realising its computational graph and differentiability), we provide the visualisation in Figure 5.

Most operations of the PGN are realised as standard neural network layers and are hence differentiable; the two exceptions are the thresholding operations that decide the final masks and pointers , based on the soft coefficients computed by the masking network and the self-attention, respectively. This makes no difference to the training algorithm, as the masks and pointers are teacher-forced, and the soft coefficients are directly optimised against ground-truth values of and .

Further, note that our setup allows a clear path to end-to-end backpropagation (through the latent vectors) at all steps, allowing the representation of

to be optimised with respect to all predictions made for steps in the future.

Appendix B Summary of operation descriptions and supervision signals

Data structure Operation descriptions, Supervision signals
and operation
Disjoint-set union [10] query-union(u, v) : randomly sampled priority of node , : Is node being operated on? : are and in the same set?,
: is node visited by find(u)
   or find(v)?,
: is after executing?
   (asymmetric pointer)
Link/cut tree [36] query-toggle(u, v) : randomly sampled priority of node , : Is node being operated on? : are and connected?,
: is node visited during
   query-toggle(u, v)?,
: is after executing?
   (asymmetric pointer)
Table 3: Summary of operation descriptions and supervision signals on the data structures considered.

To aid clarity, within Table LABEL:tab:summary, we provide an overview of all the operation descriptions and outputs (supervision signals) for the data structures considered here (disjoint-set unions and link/cut trees).

Note that the manipulation of ground-truth pointers () is not discussed for LCTs in the main text for purposes of brevity; for more details, consult Appendix C.

Appendix C Link/cut tree operations

In this section, we provide a detailed overview of the link/cut tree (LCT) data structure [36], as well as the various operations it supports. This appendix is designed to be as self-contained as possible, and we provide the C++ implementation used to generate our dataset within the supplementary material.

Before covering the specifics of LCT operations, it is important to understand how it represents the forest it models; namely, in order to support efficient operations and path queries, the pointers used by the LCT can differ significantly from the edges in the forest being modelled.

Preferred path decomposition

Many design choices in LCTs follow the principle of “most-recent access”: if a node was recently accessed, it is likely to get accessed again soon—hence we should keep it in a location that makes it easily accessible.

The first such design is preferred path decomposition: the modelled forest is partitioned into preferred paths, such that each node may have at most one preferred child: the child most-recently accessed during a node-to-root operation. As we will see soon, any LCT operation on a node will involve looking up the path to its respective root —hence every LCT operation will be be composed of several node-to-root operations.

One example of a preferred path decomposition is demonstrated in Figure 6 (Left). Note how each node may have at most one preferred child. When a node is not a preferred child, its parent edge is used to jump between paths, and is hence often called a path-parent.

LCT pointers

Each preferred path is represented by LCTs in a way that enables fast access—in a binary search tree (BST) keyed by depth. This implies that the nodes along the path will be stored in a binary tree (each node will potentially have a left and/or right child) which respects the following recursive invariant: for each node, all nodes in its left subtree will be closer to the root, and all nodes in its right subtree will be further from the root.

For now, it is sufficient to recall the invariant above—the specific implementation of binary search trees used in LCTs will be discussed towards the end of this section. It should be apparent that these trees should be balanced: for each node, its left and right subtree should be of (roughly) comparable sizes, recovering an optimal lookup complexity of , for a BST of nodes.

Each of the preferred-path BSTs will specify its own set of pointers. Additionally, we still need to include the path-parents, to allow recombining information across different preferred paths. While we could keep these links unchanged, it is in fact canonical to place the path-parent in the root node of the path’s BST (N.B. this node may be different from the top-of-path node555The top-of-path node is always the minimum node of the BST, obtained by recursively following left-children, starting from the root node, while possible.!).

As we will notice, this will enable more elegant operation of the LCT, and further ensures that each LCT node will have exactly one parent pointer (either in-BST parent or path-parent, allowing for jumping between different path BSTs), which aligns perfectly with our PGN model assumptions.

The ground-truth pointers of LCTs, , are then recovered as all the parent pointers contained within these binary search trees, along with all the path-parents. Similarly, ground-truth masks, , will be the subset of LCT nodes whose pointers may change during the operation at time . We illustrate how a preferred path decomposition can be represented with LCTs within Figure 6 (Right).

Figure 6: Left: Rooted tree modelled by LCTs, with its four preferred paths indicated by solid lines. The most-recently accessed path is . Right: One possible configuration of LCT pointers which models the tree. Each preferred path is stored in a binary search tree (BST) keyed by depth (colour-coded to match the LHS figure), and path-parents (dashed) emanate from the root node of each BST—hence their source node may changed (e.g. is represented as ).

LCT operations

Now we are ready to cover the specifics of how individual LCT operations (find-root(u), link(u, v), cut(u) and evert(u)) are implemented.

All of these operations rely on an efficient operation which exposes the path from a node to its root, making it preferred—and making the root of the entire LCT (i.e. the root node of the top-most BST). We will denote this operation as expose(u), and assume its implementation is provided to us for now. As we will see, all of the interesting LCT operations will necessarily start with calls to expose(u) for nodes we are targeting.

Before discussing each of the LCT operations, note one important invariant after running expose(u): node is now the root node of the top-most BST (containing the nodes on the path from to ), and it has no right children in this BST—as it is the deepest node in this path.

As in the main document, we will highlight in blue all changes to the ground-truth LCT pointers , which will be considered as the union of ground-truth BST parents and path-parents . Note that each node will either have or ; we will denote unused pointers with nil. By convention, root nodes, , of the entire LCT will point to themselves using a BST parent; i.e. , .

  • find-root(u) can be implemented as follows: first execute expose(u). This guarantees that is in the same BST as , the root of the entire tree. Now, since the BST is keyed by depth and is the shallowest node in the BST’s preferred path, we just follow left children while possible, starting from : is the node at which this is no longer possible. We conclude with calling expose on , to avoid pathological behaviour of repeatedly querying the root, accumulating excess complexity from following left children. While currently considered node has left child Follow left child links Re-expose to avoid pathological complexity

  • link(u, v) has the precondition that must be the root node of its respective tree (i.e. ), and and are not in the same tree. We start by running expose(u) and expose(v). Attaching the edge extends the preferred path from to its root, , to incorporate . Given that can have no left children in its BST (it is a root node, hence shallowest), this can be done simply by making a left child of (given is shallower than on the path ). cannot have left-children before this, as it is the root of its tree cannot have parents before this, as it has been exposed

  • cut(u), as above, will initially execute expose(u). As a result, will retain all the nodes that are deeper than it (through path-parents pointed to by ), and can just be cut off from all shallower nodes along the preferred path (contained in ’s left subtree, if it exists). Cut off from left child, making it a root node of its component

  • evert(u), as visualised in Figure 3, needs to isolate the path from to , and flip the direction of all edges along it. The first part is already handled by calling expose(u), while the second is implemented by recursively flipping left and right subtrees within the entire BST containing (this makes shallower nodes in the path become deeper, and vice-versa).

    This is implemented via lazy propagation: each node stores a flip bit, (initially set to 0). Calling evert(u) will toggle node ’s flip bit. Whenever we process node , we further issue a call to a special operation, release(u), which will perform any necessary flips of ’s left and right children, followed by propagating the flip bit onwards. Note that release(u) does not affect parent-pointers —but it may affect outcomes of future operations on them. and Perform the swap of ’s left and right subtrees Propagate flip bit to left subtree Propagate flip bit to right subtree Toggle ’s flip bit ( is binary exclusive OR) Perform lazy propagation of flip bit from

Implementing expose(u)

It only remains to provide an implementation for expose(u), in order to specify the LCT operations fully.

LCTs use splay trees as the particular binary search tree implementation to represent each preferred path. These trees are also designed with “most-recent access” in mind: nodes recently accessed in a splay tree are likely to get accessed again, therefore any accessed node is turned into the root node of the splay tree, using the splay(u) operation. The manner in which splay(u) realises its effect is, in turn, via a sequence of complex tree rotations; such that rotate(u) will perform a rotation that brings one level higher in the tree.

We describe these three operations in a bottom-up fashion: first, the lowest-level rotate(u), which merely requires carefully updating all the pointer information. Depending on whether is its parent’s left or right child, a zig or zag rotation is performed—they are entirely symmetrical. Refer to Figure 7 for an example of a zig rotation followed by a zag rotation (often called zig-zag for short).

Zig rotation Zag rotation Adjust grandparent

Figure 7: A schematic of a zig-zag rotation: first, node is rotated around node ; then, node is rotated around node , bringing it two levels higher in the BST without breaking invariants.

Armed with the rotation primitive, we may define splay(u) as a repeated application of zig, zig-zig and zig-zag rotations, until node becomes the root of its BST666Note: this exact sequence of operations is required to achieve optimal amortised complexity.. We also repeatedly perform lazy propagation by calling release(u) on any encountered nodes.

Repeat while is not BST root Lazy propagation zig or zag rotation zig-zig or zag-zag rotation zig-zag or zag-zig rotation In case was root node already

Finally, we may define expose(u) as repeatedly interchanging calls to splay(u) (which will render the root of its preferred-path BST) and appropriately following path-parents, , to fuse with the BST above. This concludes the description of the LCT operations.

do Make root of its BST Any deeper nodes than along preferred path are no longer preferred They get cut off into their own BST This generates a new path-parent into is either LCT root or it has gained a path-parent by splaying Attach into ’s BST First, splay to simplify operation Any deeper nodes than are no longer preferred; detach them Convert ’s path-parent into a parent Repeat until is root of its LCT

It is worth reflecting on the overall complexity of individual LCT operations, taking into account the fact they’re propped up on expose(u), which itself requires reasoning about tree rotations, followed by appropriately leveraging preferred path decompositions. This makes the LCT modelling task substantially more challenging than DSU.

Remarks on computational complexity and applications

As can be seen throughout the analysis, the computational complexity of all LCT operations can be reduced to the computational complexity of calling expose(u)—adding only a constant overhead otherwise. splay(u) has a known amortised complexity of , for nodes in the BST; it seems that the ultimate complexity of exposing is this multiplied by the worst-case number of different preferred-paths encountered.

However, detailed complexity analysis can show that splay trees combined with preferred path decomposition yield an amortised time complexity of exactly for all link/cut tree operations. The storage complexity is highly efficient, requiring additional bookkeeping per node.

Finally, we remark on the utility of LCTs for performing path aggregate queries. When calling expose(u), all nodes from to the root become exposed in the same BST, simplifying computations of important path aggregates (such as bottlenecks, lowest-common ancestors, etc). This can be augmented into an arbitrary path(u, v) operation by first calling evert(u) followed by expose(v)—this will expose only the nodes along the unique path from to within the same BST.

Appendix D Credit assignment analysis

Figure 8: Credit assignment study results for the DSU setup, for the baseline GNN (Top) and the PGN (Bottom), arranged left-to-right by test graph size. PGNs learn to put larger emphasis on both the two nodes being operated on (blue) and the nodes on their respective paths-to-roots (green).
Figure 9: Credit assignment study results for the LCT setup, following same convention as Figure 8.

Firstly, recall how our decoder network, , is applied to the latent state (, ), in order to derive predicted query answers, (Equation 3). Knowing that the elementwise maximisation aggregator performed the best as aggregation function, we can rewrite Equation 3 as follows:

(9)

This form of max-pooling readout has a unique feature: each dimension of the input vectors to will be contributed to by exactly one node (the one which optimises the corresponding dimension in or ). This provides us with opportunity to perform a credit assignment study: we can verify how often every node has propagated its features into this vector—and hence, obtain a direct estimate of how “useful” this node is for the decision making by any of our considered models.

We know from the direct analysis of disjoint-set union (Section 3) and link/cut tree (Appendix C) operations that only a subset of the nodes are directly involved in decision-making for dynamic connectivity. These are exactly the nodes along the paths from and , the two nodes being operated on, to their respective roots in the data structure. Equivalently, these nodes directly correspond to the nodes tagged by ground-truth masks (nodes for which ).

With the above hindsight, we compare a trained baseline GNN model against a PGN model, in terms of how much credit is assigned to these “important” nodes, throughout the rollout. The results of this study are visualised in Figures 8 (for DSU) and 9 (for LCT), visualising separately the credit assigned to the two nodes being operated on (blue) and the remaining nodes along their paths-to-roots (green).

From these plots, we can make several direct observations:

  • In all settings, the PGN amplifies the overall credit assigned to these relevant nodes.

  • On the DSU setup, the baseline GNN is likely suffering from oversmoothing effects: at larger test set sizes, it seems to hardly distinguish the paths-to-root (which are often very short due to path compression) from the remainder of the neighbourhoods. The PGN explicitly encodes the inductive bias of the structure, and hence more explicitly models such paths.

  • As ground-truth LCT pointers are not amenable to path compression, paths-to-root may more significantly grow in lengthwith graph size increase. Hence at this point the oversmoothing effect is less pronounced for baselines; but in this case, LCT operations are highly centered on the node being operated on. The PGN learns to provide additional emphasis to the nodes operated on, and .

In all cases, it appears that through a careful and targeted constructed graph, the PGN is able to significantly overcome the oversmoothing issues with fully-connected GNNs, providing further encouragement for applying PGNs in problem settings where strong credit assignment is required, one example of which are search problems.