Graph Traversal with Tensor Functionals: A Meta-Algorithm for Scalable Learning

02/08/2021 ∙ by Elan Markowitz, et al. ∙ 11

Graph Representation Learning (GRL) methods have impacted fields from chemistry to social science. However, their algorithmic implementations are specialized to specific use-cases e.g.message passing methods are run differently from node embedding ones. Despite their apparent differences, all these methods utilize the graph structure, and therefore, their learning can be approximated with stochastic graph traversals. We propose Graph Traversal via Tensor Functionals(GTTF), a unifying meta-algorithm framework for easing the implementation of diverse graph algorithms and enabling transparent and efficient scaling to large graphs. GTTF is founded upon a data structure (stored as a sparse tensor) and a stochastic graph traversal algorithm (described using tensor operations). The algorithm is a functional that accept two functions, and can be specialized to obtain a variety of GRL models and objectives, simply by changing those two functions. We show for a wide class of methods, our algorithm learns in an unbiased fashion and, in expectation, approximates the learning as if the specialized implementations were run directly. With these capabilities, we scale otherwise non-scalable methods to set state-of-the-art on large graph datasets while being more efficient than existing GRL libraries - with only a handful of lines of code for each method specialization. GTTF and its various GRL implementations are on:



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Graph representation learning (GRL) has become an invaluable approach for a variety of tasks, such as node classification (e.g., in biological and citation networks; gat; kipf; sage; xu2018jknets), edge classification (e.g., link prediction for social and protein networks; DeepWalk; node2vec), entire graph classification (e.g., for chemistry and drug discovery gilmer; chen2018rise), etc.

In this work, we propose an algorithmic unification of various GRL methods that allows us to re-implement existing GRL methods and introduce new ones, in merely a handful of code lines per method. Our algorithm (abbreviated GTTF, Section 3.2), receives graphs as input, traverses them using efficient tensor111To disambiguate: by tensors

, we refer to multi-dimensional arrays, as used in Deep Learning literature; and by

operations, we refer to routines such as matrix multiplication, advanced indexing, etc operations, and invokes specializable functions during the traversal. We show function specializations for recovering popular GRL methods (Section 3.3

). Moreover, since GTTF is stochastic, these specializations automatically scale to arbitrarily large graphs, without careful derivation per method. Importantly, such specializations, in expectation, recover unbiased gradient estimates of the objective w.r.t. model parameters.

GTTF uses a data structure (Compact Adjacency, Section 3.1): a sparse encoding of the adjacency matrix. Node contains its neighbors in row , notably, in the first columns of . This encoding allows stochastic graph traversals using standard tensor operations. GTTF is a functional, as it accepts functions AccumulateFn and BiasFn, respectively, to be provided by each GRL specialization to accumulate necessary information for computing the objective, and optionally to parametrize sampling procedure . The traversal internally constructs a walk forest as part of the computation graph. Figure 1 depicts the data structure and the computation. From a generalization perspective, GTTF shares similarities with Dropout (dropout).

Our contributions are: (i) A stochastic graph traversal algorithm (GTTF) based on tensor operations that inherits the benefits of vectorized computation and libraries such as PyTorch and Tensorflow. (ii) We list specialization functions, allowing GTTF to approximately recover the learning of a broad class of popular GRL methods. (iii) We prove that this learning is unbiased, with controllable variance. Wor this class of methods, (iv) we show that GTTF can scale previously-unscalable GRL algorithms, setting the state-of-the-art on a range of datasets. Finally, (v) we open-source GTTF along with new stochastic traversal versions of several algorithms, to aid practitioners from various fields in applying and designing state-of-the-art GRL methods for large graphs.





(a) Example graph

(b) Adjacency matrix for graph


(c) CompactAdj for with sparse and dense . We store IDs of adjacent nodes in














(d) Walk Forest. GTTF invokes AccumulateFn once per (green) instance.
Figure 1: (c)&(d) Depict our data structure & traversal algorithm on a toy graph in (a)&(b).

2 Related Work

We take a broad standpoint in summarizing related work to motivate our contribution.





GCN, GAT MP exact
node2vec NE approx
WYS NE exact
Stochastic Sampling Methods
SAGE MP approx
FastGCN MP approx
LADIES MP approx
GraphSAINT MP approx
CluterGCN MP heuristic
Software Frameworks
PyG Both inherits / re-
DGL Both implements
Algorithmic Abstraction (ours)
GTTF Both approx

Models for GRL have been proposed, including message passing (MP) algorithms, such as Graph Convolutional Network (GCN) (kipf), Graph Attention (GAT) (gat); as well as node embedding (NE) algorithms, including node2vec (node2vec), WYS (abu-wys); among many others (xu2018jknets; simplegcn; DeepWalk). The full-batch GCN of kipf, which drew recent attention and has motivated many MP algorithms, was not initially scalable to large graphs, as it processes all graph nodes at every training step. To scale MP methods to large graphs, researchers proposed Stochastic Sampling Methods that, at each training step, assemble a batch constituting subgraph(s) of the (large) input graph. Some of these sampling methods yield unbiased gradient estimates (with some variance) including SAGE (sage), FastGCN (fastgcn), LADIES (ladies2019), and GraphSAINT (graphsaint-iclr2020). On the other hand, ClusterGCN (clustergcn) is a heuristic in the sense that, despite its good performance, it provides no guarantee of unbiased gradient estimates of the full-batch learning. gilmer and chami2021taxonomy generalized many GRL models into Message Passing and Auto-Encoder frameworks. These frameworks prompt bundling of GRL methods under Software Libraries, like PyG (pyg) and DGL (dgl), offering consistent interfaces on data formats.

We now position our contribution relative to the above. Unlike generalized message passing (gilmer), rather than abstracting the model computation, we abstract the learning algorithm. As a result, GTTF can be specialized to recover the learning of MP as well as NE methods. Morever, unlike Software Frameworks, which are re-implementations of many algorithms and therefore inherit the scale and learning of the copied algorithms, we re-write the algorithms themselves, giving them new properties (memory and computation complexity), while maintaining (in expectation) the original algorithm outcomes. Further, while the listed Stochastic Sampling Methods target MP algorithms (such as GCN, GAT, alike), as their initial construction could not scale to large graphs, our learning algorithm applies to a wider class of GRL methods, additionally encapsulating NE methods. Finally, while some NE methods such as node2vec (node2vec) and DeepWalk (DeepWalk) are scalable in their original form, their scalability stems from their multi-step process: sample many (short) random walks, save them to desk, and then learn node embeddings using positional embedding methods (e.g., word2vec, word2vec) – they are sub-optimal in the sense that their first step (walk sampling) takes considerable time (before training even starts) and also places an artificial limit on the number of training samples (number of simulated walks), whereas our algorithm conducts walks on-the-fly whilst training.

3 Graph Traversal via Tensor Functionals (GTTF)

At its core, GTTF is a stochastic algorithm that recursively conducts graph traversals to build representations of the graph. We describe the data structure and traversal algorithm below, using the following notation. is an unweighted graph with nodes and edges, described as a sparse adjacency matrix . Without loss of generality, let the nodes be zero-based numbered i.e. . We denote the out-degree vector – it can be calculated by summing over rows of as . We assume for all : pre-processing can add self-connections to orphan nodes. denotes a batch of nodes.

3.1 Data Structure

Internally, GTTF relies on a reformulation of the adjacency matrix, which we term CompactAdj (for "Compact Adjacency", Figure 0(c)). It consists of two tensors:

  1. [noitemsep]

  2. , a dense out-degree vector (figure 0(c), right)

  3. , a sparse edge-list matrix in which the row contains left-aligned non-zero values. The consecutive entries contain IDs of nodes receiving an edge from node . The remaining are left unset, therefore, only occupies memory when stored as a sparse matrix (Figure 0(c), left).

CompactAdj allows us to concisely describe stochastic traversals using standard tensor operations. To uniformly sample a neighbor to node , one can draw , then get the neighbor ID with . In vectorized form, given node batch and access to continuous , we sample neighbors for each node in as: , where , then is a -sized vector, with containing a neighbor of , floor operation is applied element-wise, and is Hadamard product.

input: (current node); (path leading to , starts empty); (list of fanouts); AccumulateFn (function: with side-effects and no return. It is model-specific and
f                        records information for computing model and/or objective, see text)
BiasFn (function mapping to distribution on ’s neighbors, defaults to uniform)
1 def Traverse(, , , AccumulateFn, BiasFn):
      8765432 if then return  # Base case. Traversed up-to requested depth  # fanout duplication factor (i.e. breadth) at this depth. sample_bias BiasFn(, ) if sample_bias.sum() = 0 then return  # Special case. No sampling from zero mass sample_bias sample_bias / sample_bias.sum()  # valid distribution  # Sample nodes from ’s neighbors for  to  do
            11109 AccumulateFn Traverse  # Recursion
1312 def Sample(, , ):
      17161514  # Cumulative sum. Last entry must = 1. indices tf.searchsorted(, coin_flips) return
Algorithm 1 Stochastic Traverse Functional, parametrized by AccumulateFn and BiasFn.

3.2 Stochastic Traversal Functional Algorithm

Our traversal algorithm starts from a batch of nodes. It expands from each into a tree, resulting in a walk forest rooted at the nodes in the batch, as depicted in Figure 0(d). In particular, given a node batch , the algorithm instantiates seed walkers, placing one at every node in . Iteratively, each walker first replicates itself a fanout () number of times. Each replica then samples and transitions to a neighbor. This process repeats a depth () number of times. Therefore, each seed walker becomes the ancestor of a -ary tree with height . Setting recovers traditional random walk. In practice, we provide flexibility by allowing a custom fanout value per depth.

Functional Traverse is listed in Algorithm 1. It accepts: a batch of nodes222Our pseudo-code displays the traversal starting from one node rather than a batch only for clarity, as our actual implementation is vectorized e.g. would be a vector of nodes, would be a 2D matrix with each row containing transition path preceeding the corresponding entry in , … etc. Refer to Appendix and code.; a list of fanout values (e.g. to samples 3 neighbors per , then 5 neighbors for each of those); and more notably, two functions: AccumulateFn and BiasFn. These functions will be called by the functional on every node visited along the traversal, and will be passed relevant information (e.g. the path taken from root seed node). Custom settings of these functions allow recovering wide classes of graph learning methods. At a high-level, our functional can be used in the following manner:

  1. [leftmargin=5mm]

  2. Construct model & initialize parameters (e.g. to random). Define AccumulateFn and BiasFn.

  3. Repeat (many rounds):

    1. [leftmargin=2mm, label=.]

    2. Reset accumulation information (from previous round) and then sample batch .

    3. Invoke Traverse on (, AccumulateFn, BiasFn), which invokes the Fn’s, allowing the first to accumulate information sufficient for running the model and estimating an objective.

    4. Use accumulated information to: run model, estimate objective, apply learning rule (e.g. SGD).

AccumulateFn is a function that is used to track necessary information for computing the model and the objective function. For instance, an implementation of DeepWalk (DeepWalk) on top of GTTF, specializes AccumulateFn to measure an estimate of the sampled softmax likelihood of nodes’ positional distribution, modeled as a dot-prodct of node embeddings. On the other hand, GCN (kipf) on top of GTTF uses it to accumulate a sampled adjacency matrix, which it passes to the underlying model (e.g. 2-layer GCN) as if this were the full adjacency matrix.


is a function that customizes the sampling procedure for the stochastic transitions. If provided, it must yield a probability distribution over nodes, given the current node and the path that lead to it. If not provided, it defaults to

, transitioning to any neighbor with equal probability. It can be defined to read edge weights, if they denote importance, or more intricately, used to parameterize a second order Markov Chain

(node2vec), or use neighborhood attention to guide sampling (gat), as discussed in the Appendix.

3.3 Some Specializations of AccumulateFn & BiasFn

3.3.1 Message Passing: Graph Convolutional variants

These methods, including (kipf; sage; simplegcn; mixhop; xu2018jknets) can be approximated by by initializing to an empty sparse matrix, then invoking Traverse (Algorithm 1) with ; to list of fanouts with size ; Thus AccumulateFn and BiasFn become:


where is an -dimensional all-ones vector, and negative indexing is the last entry of . If a node has been visited through the stochastic traversal, then it already has fanout number of neighbors and NoRevisitBias ensures it does not get revisited for efficiency, per line 1 of Algorithm 1. Afterwards, the accumulated stochastic will be fed333 Before feeding the batch to model, in practice, we find nodes not reached by traversal and remove their corresponding rows (and also columns) from (and ). into the underlying model e.g. for a 2-layer GCN of kipf:


Lastly, should be set to the receptive field required by the model for obtaining output -dimensional features at the labeled node batch. In particular, to the number of GC layers multiplied by the number of hops each layers access. E.g. hops= for GCN but customizable for MixHop and SimpleGCN.

3.3.2 Node Embeddings

Given a batch of nodes , DeepWalk444We present more methods in the Appendix. can be implemented in GTTF by first initializing loss to the contrastive term estimating the partition function of log-softmax:


where is dot-product notation, is the trainable embedding matrix with is -dimensional embedding for node . In our experiments, we estimate the expectation by taking 5 samples and we set the negative distribution , following word2vec.

The functional is invoked with no BiasFn and


where hyperparameter

indicates maximum window size (inherited from word2vec, word2vec), in the summation on does not access invalid entries of as , the scalar fraction is inherited from context sampling of word2vec (Section 3.1 in levy2015improving), and rederived for graph context by abu-wys, and stores a scalar per node on the traversal Walk Forest, which defaults to 1 for non-initialized entries, and is used as a correction term. DeepWalk conducts random walks (visualized as a straight line) whereas our walk tree has a branching factor of . Setting fanout recovers DeepWalk’s simulation, though we found outperforms within fewer iterations e.g.

, within 1 epoch, outperforms DeepWalk’s published implementation. Learning can be performed using the accumulated

as: ;

4 Theoretical Analysis

Due to space limitations, we include the full proofs of all propositions in the appendix.

4.1 Estimating power of transition matrix

We show that it is possible with GTTF to accumulate an estimate of transition matrix to power . Let denote the walk forest generated by GTTF, as the node in the vector of nodes at depth of the walk tree rooted at , and

as the indicator random variable

. Let the estimate of the power of the transition matrix be denoted . Entry

should be an unbiased estimate of

for , with controllable variance. We define:


The fraction in Equation 6 counts the number of times the walker starting at visits in , divided by the total number of nodes visited at the step from .

Proposition 1.

(UnbiasedTk) as defined in Equation 6, is an unbiased estimator of

Proposition 2.

(VarianceTk) Variance of our estimate is upper-bounded:

Naive computation of powers of the transition matrix can be efficiently computed via repeated sparse matrix-vector multiplication. Specifically, each column of can be computed in , where is the number of edges in the graph. Thus, computing in its entirety can be accomplished in . However, this can still become prohibitively expensive if the graph grows beyond a certain size. GTTF on the other hand can estimate in time complexity independent of the size of the graph, (Prop. 8), with low variance. Transition matrix powers are useful for many GRL methods. (qiu2018network)

4.2 Unbiased Learning

As a consequence of Propositions 1 and 2, GTTF enables unbiased learning with variance control for classes of node embedding methods, and provides a convergence guarantee for graph convolution models under certain simplifying assumptions.

We start by analyzing node embedding methods. Specifically, we cover two general types. The first is based on matrix factorization of the power-series of transition matrix. and the second is based on cross-entropy objectives, e.g., like DeepWalk (DeepWalk), node2vec (node2vec), These two are shown in Proposations 3 and 4

Proposition 3.

(UnbiasedTFactorization) Suppose , i.e. factorization objective that can be optimized by gradient descent by calculating , where ’s are scalar coefficients. Let its estimate , where is obtained by GTTF according to Equation 6. Then .

Proposition 4.

(UnbiasedLearnNE) Learning node embeddings with objective function , decomposable as , where is linear over , then using yields an unbiased estimate of .

Generally, (and ) score the similarity between disconnected (and connected) nodes and . The above form of covers a family of contrastive learning objectives that use cross-entropy loss and assume a logistic or (sampled-)softmax distributions. We provide, in the Appendix, the decompositions for the objectives of DeepWalk (DeepWalk), node2vec (node2vec) and WYS (abu-wys).

Proposition 5.

(UnbiasedMP) Given input activations, , graph conv layer can use rooted adjacency accumulated by RootedAdjAcc (1), to provide unbiased pre-activation output, i.e. , with and defined in (3).

Proposition 6.

(UnbiasedLearnMP) If objective to a graph convolution model is convex and Lipschitz continous, with minimizer , then utilizing GTTF for graph convolution converges to .

4.3 Complexity Analysis

Proposition 7.

Storage complexity of GTTF is .

Proposition 8.

Time complexity of GTTF is for batch size , fanout , and depth .

Proposition 8 implies the speed of computation is irrespective of graph size. Methods implemented in GTTF inherit this advantage. For instance, the node embedding algorithm WYS (abu-wys) is , however, we apply its GTTF implementation on large graphs.

5 Experiments

We conduct experiments on 10 different graph datasets, listed in in Table 5.2. We experimentally demonstrate the following. (1) Re-implementing baseline method using GTTF maintains performance. (2) Previously-unscalable methods, can be made scalable when implemented in GTTF. (3) GTTF achieves good empirical performance when compared to other sampling-based approaches hand-designed for Message Passing. (4) GTTF consumes less memory and trains faster than other popular Software Frameworks for GRL. To replicate our experimental results, for each cell of the table in our code repository, we provide one shell script to produce the metric, except when we indicate that the metric is copied from another paper. Unless otherwise stated, we used fanout factor of 3 for GTTF implementations. Learning rates and model hyperparameters are included in the Appendix.

5.1 Node Embeddings for Link Prediction

In link prediction tasks, a graph is partially obstructed by hiding a portion of its edges. The task is to recover the hidden edges. We follow a popular approach to tackle this task: first learn node embedding from the observed graph, then predict the link between nodes and with score . We use two ranking metrics for evaluations: ROC-AUC, which is a ranking objective: how well do methods rank the hidden edges above randomly-sampled negative edges and Mean Rank.

We re-implement Node Embedding methods, DeepWalk (DeepWalk) and WYS (abu-wys), into GTTF (abbreviated ). Table 2 summarizes link prediction test performance.

LiveJournal and Reddit are large datasets, where original implementation of WYS is unable to scale to. However, scalable (WYS) sets new state-of-the-art on these datasets. For PPI and HepTh datasets, we copy accuracy numbers for DeepWalk and WYS from (abu-wys). For LiveJournal, we copy accuracy numbers for DeepWalk and PBG from (pytorch-biggraph) – note that a well-engineered approach (PBG, (pytorch-biggraph)), using a mapreduce-like framework, is under-performing compared to (WYS), which is a few lines specialization of GTTF.

5.2 Message Passing for Node Classification

We implement in GTTF the message passing models: GCN (kipf), GraphSAGE (sage), MixHop (mixhop), SimpleGCN (simplegcn), as their computation is straight-forward. For GAT (gat) and GCNII (gcnii), as they are more intricate, we download the authors’ codes, and wrap them as-is with our functional.

We show that we are able to run these models in Table 3 (left and middle), and that GTTF implementations matches the baselines performance. For the left table, we copy numbers from the published papers. However, we update GAT to work with TensorFlow 2.0 and we use our updated code (GAT*).

Dataset Split # Nodes # Edges # Classes Nodes Edges Tasks
PPI (a) 3,852 20,881 N/A proteins interaction LP
ca-HepTh (a) 80,638 24,827 N/A researchers co-authorship LP
ca-AstroPh (a) 17,903 197,031 N/A researchers co-authorship LP
LiveJournal (b) 4.85M 68.99M N/A users friendship LP
Reddit (c) 233,965 11.60M 41 posts user co-comment LP/FSC
Amazon (b) 2.6M 48.33M 31 products co-purchased FSC
Cora (d) 2,708 5,429 7 articles citation SSC
CiteSeer (d) 3,327 4,732 6 articles citation SSC
PubMed (d) 19,717 44,338 3 articles citation SSC
Products (e) 2.45M 61.86M 47 products co-purchased SSC
Table 1: Dataset summary. Tasks are LP, SSC, FSC, for link prediction, semi- and fully-supervised classification. Split indicates the train/validate/test paritioning, with (a) = (abu-wys), (b) = to be released, (c) = (sage), (d) = (planetoid); (e) = (ogb).
PPI HepTh Reddit DeepWalk 70.6 91.8 93.5 (DeepWalk) 87.9 89.9 95.5 WYS 89.8 93.6 OOM (WYS) 90.5 93.5 98.6 LiveJournal DeepWalk 234.6 PBG 245.9 WYS OOM* (WYS) 185.6
Table 2: Results of node embeddings on Link Prediction. Left: Test ROC-AUC scores. Right: Mean Rank on the right for consistency with pytorch-biggraph. *OOM = Out of Memory.
Cora Citeseer Pubmeb GCN 81.5 70.3 79.0 (GCN) 81.9 69.8 79.4 MixHop 81.9 71.4 80.8 (MixHop) 83.1 71.8 80.9 GAT* 83.2 72.4 77.7 (GAT) 83.3 72.5 77.8 GCNII 85.5 73.4 80.3 (GCNII) 85.3 74.4 80.2 Reddit Amazon SAGE 95.0 88.3 (SAGE) 95.9 88.5 SimpGCN 94.9 83.4 (SimpGCN) 94.8 83.8 Products node2vec 72.1 ClusterGCN 75.2 GraphSAINT 77.3 (SAGE) 77.0
Table 3: Node classification tasks. Left: test accuracy scores on semi-supervised classification (SSC) of citation networks. Middle: test micro-F1 scores for large fully-supervised classification. Right: test accuracy on an SSC task, showing only scalable baselines. We bold the highest value per column.
Speed (s) Memory (GB) Reddit Products Reddit Cora Citeseer Pubmed DGL 17.3 13.4 OOM 1.1 1.1 1.1 PyG 5.8 9.2 OOM 1.2 1.3 1.6 GTTF 1.8 1.4 2.44 0.32 0.40 0.43
Table 4: Performance of GTTF against frameworks DGL and PyG. Left: Speed is the per epoch time in seconds when training GraphSAGE. Memory is the memory in GB used when training GCN. All experiments conducted using an AMD Ryzen 3 1200 Quad-Core CPU and an Nvidia GTX 1080Ti GPU. Right: Training curve for GTTF and PyG implementations of Node2Vec.

5.3 Experiments comparing against Sampling methods for Message Passing

We now compare models trained with GTTF (where samples are walk forests) against sampling methods that are especially designed for Message Passing algorithms (GraphSAINT and ClusterGCN), especially since their sampling strategies do not match ours.

Table 3 (right) shows test performance on node classification accuracy on a large dataset: Products. We calculate the accuracy for (SAGE), but copy from (ogb) the accuracy for the baselines: GraphSAINT (graphsaint-iclr2020) and ClusterGCN (clustergcn) (both are message passing methods); and also node2vec (node2vec) (node embedding method).

5.4 Runtime and Memory comparison against optimized Software Frameworks

In addition to the accuracy metrics discussed above, we also care about computational performance. We compare against software frameworks DGL (dgl) and PyG (pyg). These software frameworks offer implementations of many methods. Table 4 summarizes the following. First (left), we show time-per-epoch on large graphs of their implementation of GraphSAGE, compared with GTTF’s, where we make all hyper parameters to be the same (of model architecture, and number of neighbors at message passing layers). Second (middle), we run their GCN implementation on small datasets (Cora, Citeseer, Pubmed) to show peak memory usage. The run times between GTTF, PyG and DGL are similar for these datasets. The comparison can be found in the Appendix. While the aforementioned two comparisons are on popular message passing methods, the third (right) chart shows a popular node embedding method: node2vec’s link prediction test ROC-AUC in relation to its training runtime.

6 Conclusion

We present a new algorithm, Graph Traversal via Tensor Functionals (GTTF) that can be specialized to re-implement the algorithms of various Graph Representation Learning methods. The specialization takes little effort per method, making it straight-forward to port existing methods or introduce new ones. Methods implemented in GTTF run efficiently as GTTF uses tensor operations to traverse graphs. In addition, the traversal is stochastic and therefore automatically makes the implementations scalable to large graphs. We theoretically show that the learning outcome due to the stochastic traversal is in expectation equivalent to the baseline when the graph is observed at-once, for popular GRL methods we analyze. Our thorough experimental evaluation confirms that methods implemented in GTTF maintain their empirical performance, and can be trained faster and using less memory even compared to software frameworks that have been thoroughly optimized.

7 Acknowledgements

We acknowledge support from the Defense Advanced Research Projects Agency (DARPA) under award FA8750-17-C-0106.



Appendix A Hyperparameters

For the general link prediction tasks we used a , , , 10 negative samples per edge, Adam optimizer with a learning rate of 0.5, multiplied by a factor of 0.2, every 50 steps, for 200 total iterations. The differences are listed below.

The Reddit dataset was trained using a starting learning rate of 2.0, decaying 50% every 10 iterations.

The LiveJournal task was trained using a fixed learning rate of 0.001, , , and 50 negative samples per edge.

For the node classifications tasks:

For on Amazon, we use , a batch size of 1024, and a learning rate of 0.02, decaying by a factor 0f 0.2 after 2 and 6 epochs for a total of 25 epochs. On Reddit, it is the same except

For on Amazon we use , a two layer model, a batch size of 256, and fixed learning rates of 0.001 and 0.002 respectively. On reddit we use , a fixed learning rate of 0.001, hidden dimension of 256 and a batch size of 1024. On the Products dataset, we used , a fixed learning learning rate of 0.001 and a batch size of 1024, a hidden dimension of 256 and a fixed learning rate of 0.003.

For GAT (baseline), we follow the authors code and hyperparameters: for Cora and Citeseer, we use Adam with learning rate of 0.005, L2 regularization of 0.0005, 8 attention heads on the first layer and 1 attention head on the output layer. For Pubmed, we use Adam with learning rate of 0.01, L2 regularization of 0.01, 8 attention heads on the first layer and 8 attention heads on the output layer. For , we use the same aforementioned hyperparameters, a fanout of 3 and traversal depth of 2 (to cover two layers) i.e. . For , we use the authors’ recommended hyperparameters. Learning rate of 0.005, 0.001 L2 regularization, and , for all datasets. For both methods, we apply “patience” and stop the training if validation loss does not improve for 100 consecutive epochs, reporting the test accuracy at the best validation loss. For , we wrap the authors’ script and use their hyperparameters. For , we use , as their models are deep (64 layers for Cora). Otherwise, we inherit their network hyperparameters (latent dimensions, number of layers, dropout factor, and their introduced coefficients), as they have tuned them per dataset, but we change the learning rate to

(half of what they use) and we extend the patience from 100 to 1000, and extend the maximum number of epochs from 1500 to 5000 – this is because we are presenting a subgraph at each epoch, and therefore we intuitively want to slow down the learning per epoch, which is similar to the practice when someone applies Dropout to a neural networks. We re-run their shell scripts, with their code modified to use the Rooted Adjacency rather than the real adjacency, which is sampled at every epoch.

MLP was trained with 1 layer and a learning rate of 0.01.

Appendix B Proofs

b.1 Proof of Proposition 1


b.2 Proof of Proposition 2


Since , then is maximized with . Hence

b.3 Proof of Proposition 3


Consider a -dimensional factorization of , where ’s are scalar coefficients:


parametrized by . The gradients of w.r.t. parameters are:


Given estimate objective (replacing with using GTTF-estimated ):


It follows that:

Scaling property of expectation
Linearity of expectation
Proof of Proposition 1

The above steps can similarly be used to show

b.4 Proof of Proposition 4


We want to show that . Since the terms of are unaffected by , they are excluded w.l.g. from in the proof.

The following table gives the decomposition for DeepWalk, node2vec, and Watch Your Step. Node2vec also introduces a biased sampling procedure based on hyperparameters (they name and ) instead of uniform transition probabilities. We can equivalently bias the transitions in GTTF to match node2vec’s. This would then show up as a change in in the objective. This effect can also be included in the objective by multiplying by the probability of such a transition in . In this format, the and variables appear in the objective and can be included in the optimization. For WYS, are also trainable parameters.

Watch Your Step
Table 5: Decomposition of graph embedding methods to demonstrate unbiased learning. For WYS, .

For methods in which the transition distribution is not uniform, such as node2vec, there are two options for incorporating this distribution in the loss. The obvious choice is to sample from a biased transition matrix, , where is the transition weights. Alternatively, the transition bias can be used as a weight on the objective itself. This approach is still unbiased as

b.5 Proof of Proposition 5


Let be the neighborhood patch returned by GTTF, and let indicate a measurement based on the sampled graph, , such as the degree vector, , or diagonal degree matrix, . For the remainder of this proof, let all notation for adjacency matrices, or , and diagonal degree matrices, or , and degree vector, , refer to the corresponding measure on the graph with self loops e.g. . We now show that the expectation of the layer output is unbiased.

implies that is unbiased if .

Let be the set of all walks , and let indicate that the path exists in the graph given by . Let be the transition probability from to in steps, and let be the probability of a random walker traversing the graph along path .

Thus, and

For writing, we assumed nodes have degree, , though the proof still holds if that is not the case as the probability of an outgoing edge being present from becomes and the transition probability becomes i.e. the same as no estimate at all.

b.6 Proof of Proposition 6

GTTF can be seen as a way of applying dropout (dropout), and our proof is contingent on the convergence of dropout, which is shown in dropout2014baldi. Our dropout is on the adjacency, rather than the features. Denote the output of a graph convolution network555The following definition averages the node features (uses non-symmetric normalization) and appears in multiple GCN’s including sage. with :

We restrict our analysis to GCNs with linear activations. We are interested in quantifying the change of as changes, and therefore the fixed (always visible) features is placed on the subscript. Let denote adjacency accumulated by GTTF’s RootedAdjAcc (Eq. 1).

Let denote the (countable) set of all adjacency matrices realizable by GTTF. For the analysis, assume the graph is -regular: the assumption eases the notation though it is not needed. Therefore, degree for all . Our analysis depends666If not -regular, it would be on . i.e. the average realizable matrix by GTTF is proportional (entry-wise) to the full adjacency. This is can be shown when considering one-row at a time: given node with outgoing neighbors, each of its neighbors has the same appearance probability . Summing over all combinations , makes each edge appear the same frequency , noting that evenly divides for all .

We define a dropout module:


where acts as Multinoulli selector over the elements of , with one of its entries set to 1 and all others to zero. With this definitions, GCNs can be seen in the droupout framework as: . Nonetheless, in order to inherit the analysis of (dropout2014baldi, see their equations 140 & 141), we need to satisfy two conditions which their analysis is founded upon:

  1. [label=()]

  2. : in the usual (feature-wise) dropout, such condition is easily verified.

  3. Backpropagated error signal does not vary too much around around the mean, across all realizations of .

Condition (i) is satisfied due to proof of Proposition 5

. To analyze the error signal, i.e. the gradient of the error w.r.t. the network, assume loss function

, outputs scalar loss, is -Lipschitz continuous. The Liptchitz continuity allows us to bound the difference in error signal between and :


where (a) is by Lipschitz continuity, (b) is by Cauchy–Schwarz inequality, “w.p.” means with probability and uses Chebyshev’s inequality, with the following equality because the variance of is shown element-wise in proof for Prop. 2. Finally, we get the last line by dividing both sides over the common term. This shows that one can make the error signal for the different realizations arbitrarily small, for example, by choosing a larger fanout value or putting (convex) norm constraints on and e.g. through batchnorm and/or weightnorm. Since we can have with high probability, then the analysis of dropout2014baldi applies. Effectively, it can be thought of as an online learning algorithm where the elements of are the stochastic training examples and analyzed per (onlinealgorithms; stochasticlearning), as explained by dropout2014baldi .

b.7 Proof of Proposition 7

The storage complexity of CompactAdj is .

Moreover, for extemely large graphs, the adjacncy can be row-wise partitioned across multiple machines and therefore admitting linear scaling. However, we acknolwedge that choosing which rows to partition to which machines can drastically affect the performance. Balanced partitioning is ideal. It is an NP-hard problem, but many approximations have been proposed. Nonetheless, reducing inter-communication, when distributing the data structure across machines, is outside our scope.

b.8 Proof of Proposition 8

For each step of GTTF, the computational complexity is . This follows trivially from the GTTF functional: each nodes in batch ( of them) builds a tree with depth and fanout i.e. with tree nodes. This calculation assumes random number generation, AccumulateFn and BiasFn take constant time. The searchsorted function is linear, as it is called on a sorted list: cumulative sum of probabilities.

Appendix C Additional GTTF Implementations

c.1 Message Passing Implementations

c.1.1 Graph Attention Networks (Gat, gat)

One can implement GAT by following the previous subsection, utilizing AccumulateFn and BiasFn defined in (1) and (2), but just replacing the model (3) by GAT’s:


where is hadamard product and is an matrix placing a positive scalar (an attention value) on each edge, parametrized by multi-headed attention described in (gat). However, for some high-degree nodes that put most of the attention weight on a small subset of their neighbors, sampling uniformly (with BiasFn=NoRevisitBias) might mostly sample neighbors with entries in with value , and could require more epochs for convergence. However, our flexible functional allows us to propose a sample-efficient alternative, that is in expectation, equivalent to the above:


c.1.2 Deep Graph Infomax (Dgi, dgi)

DGI implementation on GTTF can use AccumulateFn=RootedAdjAcc, defined in (1). To create the positive graph: it can sample some nodes . It would pass to GTTF’s Traverse , and utilize the accumulated adjacency for running: and , where the second run randomly permutes the order of nodes in

. Finally, the output of those GCNs can then be fed into a readout function which outputs to a descriminator trying to classify if the readout latent vector correspond to the real, or the permuted features.

c.2 Node Embedding Implementations

c.2.1 Node2Vec (node2vec)

A simple implementation follows from above: N2vAcc DeepWalkAcc; but override BiasFn =


where denotes indicator function, are hyperparameters of node2vec assigning (unnormalized) probabilities for transitioning back to the previous node or to node connected to it. counts mutual neighbors between considered node and previous .

An alternative implementation is to not override BiasFn but rather fold it into AccumulateFn, as: