Self-Supervised Learning of Contextual Embeddings for Link Prediction in Heterogeneous Networks

07/22/2020 ∙ by Ping Wang, et al. ∙ Virginia Polytechnic Institute and State University PNNL 0

Representation learning methods for heterogeneous networks produce a low-dimensional vector embedding for each node that is typically fixed for all tasks involving the node. Many of the existing methods focus on obtaining a static vector representation for a node in a way that is agnostic to the downstream application where it is being used. In practice, however, downstream tasks require specific contextual information that can be extracted from the subgraphs related to the nodes provided as input to the task. To tackle this challenge, we develop SLiCE, a framework bridging static representation learning methods using global information from the entire graph with localized attention driven mechanisms to learn contextual node representations. We first pre-train our model in a self-supervised manner by introducing higher-order semantic associations and masking nodes, and then fine-tune our model for a specific link prediction task. Instead of training node representations by aggregating information from all semantic neighbors connected via metapaths, we automatically learn the composition of different metapaths that characterize the context for a specific task without the need for any pre-defined metapaths. SLiCE significantly outperforms both static and contextual embedding learning methods on several publicly available benchmark network datasets. We also interpret the semantic association matrix and provide its utility and relevance in making successful link predictions between heterogeneous nodes in the network.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The topic of representation learning for heterogeneous networks has gained a lot of attention in recent years Dong et al. (2017); Cen et al. (2019); Yun et al. (2019); Vashishth et al. (2020); Wang et al. (2019); Abu-El-Haija et al. (2018) where a low-dimensional vector representation of each node in the graph is used for downstream tasks (e.g., link prediction Zhang and Chen (2018); Cen et al. (2019); Abu-El-Haija et al. (2018) or multi-hop reasoning Hamilton et al. (2018); Das et al. (2017); Zhang et al. (2018)). Many of the existing methods focus on obtaining a static

vector representation per node that is agnostic to the specific task. This representation is typically obtained by learning the importance of all of the node’s immediate and multi-hop neighbors in the graph. In practice, most downstream tasks require specific contextual information that can be extracted from the subgraphs related to the nodes provided as input to the task. Inspired by the recent success of contextual learning in the natural language processing community

Peters et al. (2018); Devlin et al. (2019), we develop SLiCE, a framework bridging static representation learning methods using global information from the entire graph with localized attention driven mechanisms to learn contextual node representations in heterogeneous networks.

Figure 1: Contextual learning of node representations heterogeneous networks. (a) Illustration of nodes participating in diverse contexts (b) State-of-the-art methods aggregate global semantics (c) Our approach : subgraph driven contextual node embeddings.

Figure 1 provides an illustration of our idea. Given a set of authors who publishes in diverse communities, we posit that downstream tasks such as link-prediction between would perform better if ’s representation is reflective of the common context. State-of-the art methods learn a single embedding that reflects information aggregation from diverse contexts. SLiCE will contextualize global embeddings (shown in green) as a function of their local connected subgraph, and shift them closer in vector space (shown in purple) leading to improved downstream task performance. Most of state-of-the-art methods are designed to answer the following question: “what is the best representation for a node ". Instead, our objective is to answer “what are the best collective node representations for a given subgraph " and “how such representations can be potentially useful in a downstream application?"

More formally, we tackle the following problem: Given a heterogeneous graph , a subgraph and a task , compute a function that maps each node in the set of vertices in , denoted as , to a real-valued embedding vector in a low-dimensional space such that . We also require that , a function serving as a proxy for a downstream task satisfies the following: when is a subgraph of the input graph and when . Embedding of nodes that participate in diverse contexts are known to be influenced by a global averaging effect Yang et al. (2018); Liu et al. (2019). However, for modeling the context of a subgraph , contextualization of embeddings requires us to reverse the averaging effect and update each node embedding such that (1) the resulting set is more consistent with the structure of Vashishth et al. (2020); Rossi et al. (2020), and (2) it increases the discriminative ability of ) Cen et al. (2019); Zhang et al. (2020), where is the number of nodes in .

Contextual Translation: Building on the concept of translation-based embedding models Bordes et al. (2013), given a node , it’s embedding computed using a global representation method, we formulate graph-based learning of contextual embeddings as applying a vector-space translation (informally referred to as shifting process) such that , where is the contextualized representation of . The key idea behind SLiCE is to learn the translation where .

We achieve this contextualization as follows: We first learn the higher-order semantic association (HSA) between nodes in a context subgraph. We do not assume any prior knowledge about important metapaths, which are paths connected via heterogeneous relations. We also do not enforce any constraint on the structure of the context subgraph such as limiting it to star-shaped subgraphs or paths. Specifically, 1) we pursue a self-supervised learning approach that pre-trains a model to learn a HSA matrix on a context-by-context basis. 2) We then fine-tune the model in a task-specific manner, where given a context subgraph

as input, we encode the subgraph with global features and then transform that initial representation via a HSA-based non-linear transformation to produce contextual embeddings (see Figure

2).

Our Contributions:1) We propose contextual embedding learning for graphs from single relation context to arbitrary subgraphs 2) We introduce a novel self-supervised learning approach to learn higher-order semantic associations between nodes by simultaneously capturing the global and local factors that characterize a context subgraph. 3) We show that SLiCE significantly outperforms existing static and contextual embedding learning methods by 11.95% and 26.9% on average (in F1-score) respectively on the link prediction task. We also demonstrate SLiCE’s ability to learn higher-order semantic association by correctly generating top- metapaths.

2 Related Work

Node representation learning

The basic representation learning algorithms for networks can be broadly categorized into three groups based on their usage of matrix factorization, random walks or graph neural network methods. Given a graph

, matrix factorization based methods such as GraRep Cao et al. (2015) and Ahmed et al. (2013); Ou et al. (2016) seek to learn a representation

that minimizes a loss function of the form

, where is a matrix containing pairwise proximity measures for . Random walk based methods such as DeepWalk Perozzi et al. (2014) and node2vec Grover and Leskovec (2016) try to learn representations that roughly minimize a cross-entry loss function of the form , where

is the probability of visiting a node

on a random walk of length starting from node . Node2vec based approach has been further extended to incorporate multi-relational properties of networks by constraining random walks Dong et al. (2017); Huang and Mamoulis (2017); Chen et al. (2018). There are some recent efforts Qiu et al. (2018) to unify the first two categories by demonstrating the equivalence of Perozzi et al. (2014) and Grover and Leskovec (2016)-like methods to matrix factorization approaches. The third category represents graph neural networks and their variants Scarselli et al. (2008); Kipf and Welling (2016). Attention mechanisms, or techniques that learn a distribution for aggregating information from node neighbors is investigated in Veličković et al. (2017). Extending the graph neural networks to heterogeneous networks Yun et al. (2019) and supporting attention over semantic neighbors, or nodes that are connected via multi-hop metapaths Wang et al. (2019) also represent some of the key research directions in recent times.

Contextual Representations The authors of Liu et al. (2019) study the “polysemy” effect and motivates the need to account for various facets that nodes in a heterogeneous graph participate in. However, their methodology still produces a fixed vector for each node in the graph. Similarly, the work in Abu-El-Haija et al. (2018) computes a node’s representation by learning the attention distribution over a graph walk context where it occurs. The work presented in Cen et al. (2019) is a metapath-constrained random walk based method that contextualizes node representations per relation. It combines a base embedding derived from global structure (similar to above methods) with a relation-specific component learnt from the metapaths. More recently, a GCN-based method Vashishth et al. (2020) was proposed to jointly learn the embedding of nodes and relations in a heterogeneous graph. It uses multiple node-relational embedding composition functions to adapt a node’s embedding based on associated relational context. Significantly departing from these existing works, our node representation learning problem is formulated on the basis of a subgraph. Such contextualization objective distinguishes SLiCE from Cen et al. (2019) and Vashishth et al. (2020). While subgraph-based representation learning objectives have been superficially investigated in the literature Zhang et al. (2020), they do not focus on generating contextual embeddings.

3 The Proposed Framework

3.1 Preliminaries

Unlike node-oriented computation graphs, our work aims at learning node representations under a given context. However, heterogeneous graphs are typically defined by the set of relations of the form (subject, predicate, object) and there is not a standard way to quantify the context. Hence, one of the challenges in learning contextual representation for nodes in a graph is the lack of a clear definition for context for a specific node or node-pair in the graph. We define context of a node based on a subgraph. Before presenting our overall framework, we first briefly provide the formal definitions and notations that are required to comprehend the proposed approach.

Definition: (Heterogeneous Graph). We represent a heterogeneous graph as a 6-tuple where, (alternately referred to as ) is the set of nodes and (or ) denotes the set of relationships between nodes. and are functions mapping the node (or edge) to its node (or edge) type and , respectively.

Definition: (Context Subgraph). Given a heterogeneous graph , the context of a node or node-pair in can be represented as the subgraph that includes a set of nodes selected with certain criteria (e.g., top- most similar nodes or -hop neighbors) along with their related edges. The context of the node or node-pair can be represented as and .

Figure 2: SLiCE architecture consists of a global feature generator step, followed by a series of semantic attention + feed forward network layers. The global feature generator learns embeddings based on complete graph , capturing global context for each node. Each layer in SLiCE shifts the embedding of all nodes in to emphasize on local dependencies in the contextual subgraph. The final embeddings for nodes in context subgraphs are determined as a function of output from last layers to combine global with local contextual semantics for each node.

3.2 SLiCE Framework

Our proposed SLiCE framework consists of following steps: Step 1) Contextual Subgraph Generation and Representation: generating a collection of context subgraphs which are encoded using the vector representations by considering various graph attributes about node, relation and graph structure. Step 2) Model Pre-training: learning higher order relations with the self-supervised contextualized node prediction task. Step 3) Model Fine-tuning: the model is then tuned by the supervised link prediction task with more fine-grained contexts for node pairs. Figure 2 shows the framework of the proposed SLiCE model. We will now provide more details of these steps.

Context Subgraph Generation and Representation: In this work, we use the following strategies for sampling a collection of context subgraphs from the input graph : (1) Shortest Path strategy considers the shortest path between two nodes as the context. (2) Random strategy, on the other hand, generates contexts following one of the random walks between two nodes, limited to a pre-defined maximum number of hops. Note that the context generation strategies are generic and can be applied for generating contexts in many downstream tasks such as link prediction Zhang and Chen (2018), knowledge base completion Socher et al. (2013) or multi-hop reasoning Das et al. (2017); Hamilton et al. (2018).

Each generated context subgraph is encoded as a set of nodes denoted by , where represents the number of nodes in . denotes the one-hot representations of nodes in . Different from the sequential orders enforced on graph sampled using pre-defined metapaths, the order of nodes in this sequence is not important and hence this allows us to handle context subgraph with arbitrary structures. We first represent each node in the context subgraph as a low-dimensional vector representation by , where is the learnable embedding matrix. We represent the input node embeddings in as . It is flexible to incorporate the node and relation attributes (if available) for attributed networks Cen et al. (2019) in the low-dimensional representations or initialize them with the output embeddings learnt from other global feature generation approaches that capture the multi-relational graph structure Grover and Leskovec (2016); Dong et al. (2017); Wang et al. (2019); Yun et al. (2019); Vashishth et al. (2020).

Self-supervised Contextualized Node Prediction: Our model pre-training is performed by training the self-supervised contextualized node prediction task. More specifically, for each node in , we generate the node context with diameter (defined as the largest shortest pair between any pair of nodes) using the aforementioned context generation methods and randomly mask a node for prediction based on the context subgraph. The graph structure is left unperturbed by the masking procedure. Therefore, the pre-training is learnt by maximizing the probability of observing this masked node based on the context subgraph in the following form.

(1)

where represents the set of model parameters. In a departure from traditional skip-gram methods that predicts a node from the path prefix that precedes it in a random walk, our random masking strategy forces the model to learn higher-order relationships between nodes that are arbitrarily connected by variable length paths with diverse relational patterns.

Fine-tuning with Supervised Link Prediction: The SLiCE model is further fine-tuned on the contextualized link prediction task by generating multiple fine-grained contexts for each specific node-pair that is under consideration for link prediction. Based on the predicted scores, this stage is trained by maximizing the probability of observing a positive edge () given context (), while also learning to assign low probability to negatively sampled edges () and their associated contexts (). The overall objective is obtained by summing over the training data subsets with positive edges () and negative edges ().

(2)

3.3 SLiCE Model for Contextual Translation

Given a set of nodes in a context subgraph and their global input embeddings , the primary goal of contextual learning is to translate (or shift) the global embeddings in the vector space towards their new positions that indicate the most representative roles of nodes in the structure of . Before introducing the details of this translation mechanism, we first provide the definition of the semantic association matrix, which serves as the primary indicator about the translation of embeddings according to specific contexts.

Definition: (Semantic Association Matrix). A semantic association matrix, denoted as , is an asymmetric weighted matrix that indicates the high-order relational dependencies between nodes in the context subgraph .

Note that the semantic association matrix will be asymmetric since the influences of two nodes on one another in a context subgraph tend to be different. The adjacency matrix of the context subgraph, denoted by , can be considered as a trivial candidate for , which includes the local relational information of context subgraph . However, the goal of contextual embedding learning is to translate the global embeddings using the connectivity structure of the specific context while keeping the nodes’ connectivity through the global graph. Hence, instead of setting it to , we contextually learn the semantic associations, or more specifically the weights of the matrix in each learning step by incorporating the connectivity between nodes through both local context subgraph and global graph .

Implementation of Contextual Translation: In the learning step , the semantic association matrix is updated by the transformation operation defined in Eq. (3). It is accomplished by performing message passing across all nodes in context subgraph and updating the node embedding to be .

(3)

where is a non-linear function and the transformation matrix

is the learnable parameter. The residual connection

He et al. (2016) is applied to preserve the contextual embeddings in the previous step. This allows us to still maintain the global relations by passing the original global embeddings through the layers while learning the contextual embeddings. Given two nodes and in the context subgraph , the corresponding entry in semantic association matrix can be computed using the multi-head (with heads) attention mechanism Vaswani et al. (2017) in order to capture relational dependencies under different subspaces. For each head, we calculate as follows:

(4)

where the transformation matrix and are learnable parameters. Note that different from the aggregation procedure performed across all nodes in the general graph G, the proposed translation operation is only performed within the local context subgraph . The updated embeddings after applying the translation operation according to context indicate the most representative roles of each node in the specific local context neighborhood. In order to capture the higher-order association relations within the context, we apply multiple layers of the transformation operation defined in Eq. (3) by stacking layers as shown in Figure 2, where

is the largest diameter of the subgraphs sampled in the context generation process. We concatenate the embeddings learnt from different layers and feed into the classifier. In pre-training step, a linear projection function is applied to predict the probability of masked nodes. For fine-tuning step, we apply a single layer feed-forward network with softmax activation function for binary link prediction.

4 Experiments

We address following questions through experimental analysis: Q1. Does subgraph-based contextual learning improve the performance of downstream tasks? Q2. How do we quantify the embedding shift from static to subgraph-based contextual embeddings during link prediction Q3. How do we interpret the semantic associations learnt by SLiCE?

Dataset Amazon DBLP Freebase YouTube Twitter
# Nodes 10,099 37,791 14,541 2,000 9,990
# Edges 129,811 170,794 248,611 835,330 294,330
# Relations 2 3 237 5 4
# Training edges (positive) 126,535 119,554 272,115 1,114,025 282,115
# Development edges 14,756 51,242 35,070 131,024 32,926
# Testing edges 29,492 51,238 40,932 262,014 65,838
Table 1: The basic statistics of the benchmark datasets used in the paper.

Datasets: We use five publicly available benchmark datasets covering multiple applications: e-commerce (Amazon111https://github.com/THUDM/GATNE/tree/master/data ), academic graph (DBLP222https://github.com/Jhy1993/HAN/tree/master/data

), knowledge graphs (Freebase

333https://github.com/malllabiisc/CompGCN/tree/master/data_compressed) and social networks (Youtube1, Twitter1). We use the same data split for training, development and testing as described in previous works Cen et al. (2019); Abu-El-Haija et al. (2018); Vashishth et al. (2020). Table 1 provides the basic statistics of each benchmark datasets. More details about the dataset are provided in Appendix.

Baseline Methods: We compare SLiCE against state of the art, static and contextual embedding learning methods. (1) Static embedding: TransE Bordes et al. (2013) treats the relations between nodes as the translation operations in a low-dimensional embedding space. RefE Chami et al. (2020) incorporates hyperbolic space and attention-based geometric transformations to learn the hierarchical and logical patterns of networks. node2vec Grover and Leskovec (2016) is a random-walk based method that was developed for homogeneous networks and remains to be a popular choice for the link prediction task. metapath2vec Dong et al. (2017) is an extension of node2vec that constrains random walks to specified metapaths in the heterogeneous network. (2) Contextual embedding: GAN Abu-El-Haija et al. (2018) learns node embeddings by analyzing the attention distribution over the graph walk context. GATNE-T Cen et al. (2019) is a metapath-constrained random-walk based method that learns relation-specific embeddings by combining a base embedding that factors in global structure with a relation-specific component learnt from the metapaths. The recent GCN-based method CompGCN Vashishth et al. (2020) jointly learns the embedding of nodes and relations for heterogeneous graph and updates a node representation with multiple composition functions.

Evaluation setup: SLiCE

is implemented using PyTorch 1.3 and all evaluations were performed using NVIDIA Tesla P100 GPUs. The implementation of

SLiCE is made publicly available at444https://github.com/wangpinggl/slicelink. The dimension of contextual node embeddings is set to 128. We used a skip-gram based random walk approach to encode context subgraphs with global node features. Both pre-training and fine-tuning steps in SLiCE

are trained for 10 epochs at most using the cross-entropy loss function. The model parameters are trained with ADAM optimizer 

Kingma and Ba (2014) with a learning rate of 0.0001 and 0.001 for pre-training and fine-tuning steps respectively. The best model parameters were selected based on the development set. The best performance reported here is obtained by setting both the number of contextual translation layers and number of self-attention heads to 4. We generate context subgraphs by performing random walks between node pairs with maximum context subgraph size set to 6. We performed ablation studies for determining the optimal values of these parameters. The implementation details of the baseline methods and ablation studies are introduced in Appendix.

Methods Amazon DBLP Freebase YouTube Twitter
TransE (Bordes et al’ 2013) 50.28 49.60 47.78 50.32 50.60
RefE (Chami et al’ 2020) 51.86 49.60 50.25 50.20 48.55
node2vec (Grover et al’ 2016) 88.06 86.71 83.69 65.13 72.72
metapath2vec (Dong et al’ 2017) 88.86 44.58 77.18 62.41 66.73
Watch your step (GAN) (Abu-El-Haija et al’ 2018) 85.47 - - 68.70 85.01
GATNE-T (Cen et al 2019) 89.06 57.04 - 76.21 68.16
CompGCN (Vashishth et al’ 2019) 83.42 40.10 65.39 58.40 40.75
SLiCE (Proposed Method) 96.00 90.70 90.26 76.39 89.30
Table 2: Performance comparison of different models on link prediction task using micro-F1 score. The symbol “-” indicates that it is computationally prohibitive to obtain the results.

4.1 Performance Evaluation on Link Prediction

We evaluate the impact of contextual embeddings (addressing Q1) using the binary link prediction task, which has been widely used to study the structure-preserving properties of node embeddings Zhang and Chen (2018); Chen et al. (2018). To predict the link between two given nodes and , we compute the similarity by  Abu-El-Haija et al. (2018), where and are embeddings of and , respectively.

Table 2 provides the link prediction results of different methods on five datasets using micro-F1 score. The prediction scores for SLiCE are reported by from the context subgraph that produces the largest similarity score for validation set from multiple randomly generated contexts. Compared to the state-of-the-art methods, we observe that SLiCE significantly outperforms both static and contextual embedding learning methods by 11.95% and 26.9% in F1-score, respectively. Static methods perform better than relation based contextual learning methods. We attribute this to the ability of static learning methods to capture the connectivity patterns in the global network, however the relation based contextual learning (GATNE-T and CompGCN) limit their contextualization by overly emphasizing the impact of relations on nodes. These results indicate that the generated contexts are able to provide sufficient contextual information for link prediction between node pairs and further lead to the translation of the global embeddings to the localized contextual embeddings.

(a) node2vec: Amazon
(b) SLiCE: Amazon
(c) node2vec: Twitter
(d) SLiCE: Twitter
Figure 3: Comparisons between distributions of similarity scores of both positive and negative node-pairs obtained by node2vec and SLiCE on Amazon and Twitter.

Effect of Contextual Translation on Link Prediction (addressing Q2): Figure 3 provides the distribution of similarity scores for both positive and negative edges obtained by SLiCE on Amazon and Twitter datasets. We compare our embeddings against the embeddings produced by node2vec Grover and Leskovec (2016) which is one of the best performing static embedding methods in Table 2. We observe that for static embeddings produced by node2vec, the distribution of similarity scores across positive and negative edges overlaps significantly on both datasets. On the contrary, SLiCE increases the margin between the distributions of positive and negative edges significantly. Effectively, it brings the embeddings of nodes in positive edges closer and shifts the nodes in negative edges further away in the low-dimensional space. This indicates that the generated subgraphs provide informative contexts during link prediction and enhance embeddings such that it improves the discriminative capability of both positive and negative node-pairs.

4.2 SLiCE Model Interpretation

Interpretation of Semantic Association Matrix (addressing Q3): We provide the visualization of the semantic association matrix as defined in Eq. (3) to investigate how the node dependencies evolve through different layers in SLiCE. Given a node pair (, ) in the context subgraph , a high value of , indicates a strong global dependency of node on . While a high value of (the association after applying more translation layers) indicates a prominent high-order relation in the subgraph context.

Figure 4 shows weights of semantic association matrix for the context graph generated for node pair (N0: Summarizing itemset patterns: a profile-based approach (Paper), N1: Jiawei Han (Author)). Nodes in the context subgraph consist of N2: Patterns (Topic), N3: CloSpan: Mining Closed Sequential Patterns in Large Databases (Paper), N4: SDM (Conference) and N5: SpaRClus: Spatial Relationship Pattern-based Hierachial Clutering (Paper). We observe that at layer 1 (Figure 3(a)), the association between source node N0 and target node N1 is relatively low. Instead, they both assign high weights on nodes N2 and N4. However, the dependencies between nodes are dynamically updated when applying more learning layers, consequently enabling us to identify higher-order relations. For example, the dependency of N1 on N0 becomes higher from layer 3 (Figure 3(c)) and N0 primarily depends on itself without highly influenced by other nodes in layer 4 (Figure 3(d)). This visualization of semantic association matrix provides us an intuitive overview about how the global node embedding is translated into the localized embedding for contextual learning.

Symbolic Interpretation of Semantic Associations via Metapaths: Metapaths provide a symbolic interpretation of the higher-order relations in a heterogeneous graph (addressing Q3). We analyze the ablility of SLiCE

to learn relevant metapaths that characterize positive semantic associations in the graph by comparing with graph transformer networks (GTN) 

Yun et al. (2019). To our knowledge, GTN is the only reported method with such capability.

Learning Methods Paper-Author Paper-Conference Paper-Topic
Predefined Yun et al. (2019) APCPA, APA - -
GTN Yun et al. (2019) APCPA, APAPA, APA CPC -
SLiCE + Shortest Path TPA, APA, CPA TPC, APC, TPTPC TPT, CPT, APT
SLiCE + Random APA, APAPA TPTPC, TPAPC TPTPT, APTPT
Table 3: Comparisons of metapaths learned by SLiCE with both predefined and model learned on DBLP dataset for each relation type. Here, P, A, C and T represent Paper, Author, Conference and Topic, respectively.

We observe from Table 3 that SLiCE is able to match existing metapaths and also identify new metapath patterns for prediction of each relation type. For example, to predict the paper-author relationship, SLiCE learns three shortest metapaths, including “TPA" (authors who publish with the same topic), “APA" (co-authors who publish together) and “CPA"(authors who published in the same conference). The longer metapaths such as “APAPA" (chain of co-authorhsip), are also identified to be highly indicative for predicting paper-author relationships. Interestingly, our learning suggests that longer metapath “APCPA", which is commonly used to sample academic graphs for co-author relationship, is not as highly predictive of a positive relationship. This indicates that “all authors who publish in the same conference do not necessarily publish together". These analysis demonstrates SLiCE’s ability to discover higher order semantic associations for heterogeneous networks.

(a) Layer 1
(b) Layer 2
(c) Layer 3
(d) Layer 4
Figure 4: Visualization of the semantic association matrix (after normalization) learnt from different layers on a DBLP subgraph for link prediction between paper N0 and author N1. An intense color indicates a higher association. Initially (layer 1), nodes N0 and N1 have low association, but more association to topic N2 and conference N4. In layer 4, SLiCE learns higher semantic association from N1 to N0.

5 Conclusions

We introduce SLiCE, a framework for learning contextual subgraph representations. Our model brings together knowledge of structural information from the entire graph and then learns deep representations of higher-order relations in arbitrary context subgraphs. SLiCE learns the composition of different metapaths that characterize the context for a specific task in a drastically different manner compared to existing methods which primarily aggregate information from either direct neighbors or semantic neighbors connected via certain pre-defined metapaths. SLiCE significantly outperforms several competitive baseline methods on various benchmark datasets for the link prediction task. We also interpret the semantic association matrix and provide its utility and relevance in making successful predictions.

Broader Impact

Making the leap: towards contextual learning for graphs We were inspired by the impact contextual learning methods, such as BERT Devlin et al. (2019) and ELMo Peters et al. (2018)

, that have kicked off a new generation of research in the NLP community. As the graph-based research community has found out through trial and error, naively applying NLP methods to graphs was not sufficient to power geometric deep learning. To address this, fundamental advancements have been made and continue to be made to support contextual learning for graph-based research. Our work is among the first to address contextual learning in graphs. We anticipate that there will be a surge in such efforts in the near future. However, the path to impact will be clear only when the research community arrives at a convergence on important questions.

Which applications will benefit? The first and foremost of such questions involve defining the context and where contextual learning will have maximum payoff. Context is readily available in natural language text, where each sentence provides a natural definition of context. Even though it is easier to motivate contextual learning for graphs, arriving at a precise formulation of contextual learning is a hard problem. In that spirit, we point to a number of applications and methods well-studied in the NeurIPS and related communities. We suggest how these applications can be mapped into the SLiCE framework. Though we demonstrate the applicability of the proposed model using publicly available dataset, one can easily envision their utility in other complex proprietary datasets since heterogeneous networks are ubiquitous. Potentially being able to identify future links between nodes in a complex network can have several interesting real-world application such as recommender system, search engines, retrieval tasks, and matching applications. We hope our proposed framework provides initial guidelines for advancing the state-of-the-art for these methods and target applications in both academia and industry.

Figure 5: AI-driven drug discovery is a critical societal need for responding to pandemics. Figure 5a shows a heterogeneous network comprising genes and drugs Cheng et al. (2019). Figure 5b illustrates recent work from NeurIPS Hamilton et al. (2018) community aimed at accelerated AI-driven drug discovery. Certain drugs and diseases can be associated with many contexts. Using frameworks such as SLiCE will provide contextual node representation for subgraphs such as one shown in Figure 5b. We predict that such adaption will lead to enhanced performance for critical efforts such as drug discovery.

References

  • (1)
  • Abu-El-Haija et al. (2018) Sami Abu-El-Haija, Bryan Perozzi, Rami Al-Rfou, and Alexander A Alemi. 2018. Watch your step: Learning node embeddings via graph attention. In Advances in Neural Information Processing Systems. 9180–9190.
  • Ahmed et al. (2013) Amr Ahmed, Nino Shervashidze, Shravan Narayanamurthy, Vanja Josifovski, and Alexander J Smola. 2013. Distributed large-scale natural graph factorization. In Proceedings of the 22nd international conference on World Wide Web. ACM, 37–48.
  • Bordes et al. (2013) Antoine Bordes, Nicolas Usunier, Alberto Garcia-Duran, Jason Weston, and Oksana Yakhnenko. 2013. Translating embeddings for modeling multi-relational data. In Advances in neural information processing systems. 2787–2795.
  • Cao et al. (2015) Shaosheng Cao, Wei Lu, and Qiongkai Xu. 2015. Grarep: Learning graph representations with global structural information. In Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 891–900.
  • Cen et al. (2019) Yukuo Cen, Xu Zou, Jianwei Zhang, Hongxia Yang, Jingren Zhou, and Jie Tang. 2019. Representation Learning for Attributed Multiplex Heterogeneous Network. (2019).
  • Chami et al. (2020) Ines Chami, Adva Wolf, Da-Cheng Juan, Frederic Sala, Sujith Ravi, and Christopher Ré. 2020. Low-Dimensional Hyperbolic Knowledge Graph Embeddings. In Annual Meeting of the Association for Computational Linguistics.
  • Chen et al. (2018) Hongxu Chen, Hongzhi Yin, Weiqing Wang, Hao Wang, Quoc Viet Hung Nguyen, and Xue Li. 2018. PME: projected metric embedding on heterogeneous networks for link prediction. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 1177–1186.
  • Cheng et al. (2019) Feixiong Cheng, István A Kovács, and Albert-László Barabási. 2019. Network-based prediction of drug combinations. Nature communications 10, 1 (2019), 1–11.
  • Das et al. (2017) Rajarshi Das, Arvind Neelakantan, David Belanger, and Andrew McCallum. 2017.

    Chains of Reasoning over Entities, Relations, and Text using Recurrent Neural Networks. In

    Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 1, Long Papers. 132–141.
  • Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers). 4171–4186.
  • Dong et al. (2017) Yuxiao Dong, Nitesh V Chawla, and Ananthram Swami. 2017. metapath2vec: Scalable representation learning for heterogeneous networks. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. ACM, 135–144.
  • Grover and Leskovec (2016) Aditya Grover and Jure Leskovec. 2016. node2vec: Scalable feature learning for networks. In Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 855–864.
  • Hamilton et al. (2018) Will Hamilton, Payal Bajaj, Marinka Zitnik, Dan Jurafsky, and Jure Leskovec. 2018. Embedding logical queries on knowledge graphs. In Advances in Neural Information Processing Systems. 2026–2037.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2016. Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    . 770–778.
  • Huang and Mamoulis (2017) Zhipeng Huang and Nikos Mamoulis. 2017. Heterogeneous information network embedding for meta path based proximity. arXiv preprint arXiv:1701.05291 (2017).
  • Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. 2014. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014).
  • Kipf and Welling (2016) Thomas N Kipf and Max Welling. 2016. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016).
  • Liu et al. (2019) Ninghao Liu, Qiaoyu Tan, Yuening Li, Hongxia Yang, Jingren Zhou, and Xia Hu. 2019. Is a single vector enough? exploring node polysemy for network embedding. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 932–940.
  • Ou et al. (2016) Mingdong Ou, Peng Cui, Jian Pei, Ziwei Zhang, and Wenwu Zhu. 2016. Asymmetric transitivity preserving graph embedding. In Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining. 1105–1114.
  • Perozzi et al. (2014) Bryan Perozzi, Rami Al-Rfou, and Steven Skiena. 2014. Deepwalk: Online learning of social representations. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining. 701–710.
  • Peters et al. (2018) Matthew E Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and Luke Zettlemoyer. 2018. Deep contextualized word representations. In Proceedings of NAACL-HLT. 2227–2237.
  • Qiu et al. (2018) Jiezhong Qiu, Yuxiao Dong, Hao Ma, Jian Li, Kuansan Wang, and Jie Tang. 2018. Network embedding as matrix factorization: Unifying deepwalk, line, pte, and node2vec. In Proceedings of the Eleventh ACM International Conference on Web Search and Data Mining.
  • Rossi et al. (2020) Ryan A Rossi, Nesreen K Ahmed, Eunyee Koh, Sungchul Kim, Anup Rao, and Yasin Abbasi-Yadkori. 2020. A structural graph representation learning framework. In Proceedings of the 13th International Conference on Web Search and Data Mining. 483–491.
  • Scarselli et al. (2008) Franco Scarselli, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. 2008. The graph neural network model. IEEE Transactions on Neural Networks 20, 1 (2008), 61–80.
  • Socher et al. (2013) Richard Socher, Danqi Chen, Christopher D Manning, and Andrew Ng. 2013.

    Reasoning with neural tensor networks for knowledge base completion. In

    Advances in neural information processing systems. 926–934.
  • Vashishth et al. (2020) Shikhar Vashishth, Soumya Sanyal, Vikram Nitin, and Partha Talukdar. 2020. Composition-based multi-relational graph convolutional networks. In International Conference on Learning Representations.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in neural information processing systems. 5998–6008.
  • Veličković et al. (2017) Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. 2017. Graph attention networks. International Conference on Learning Representations.
  • Wang et al. (2019) Xiao Wang, Houye Ji, Chuan Shi, Bai Wang, Yanfang Ye, Peng Cui, and Philip S Yu. 2019. Heterogeneous Graph Attention Network. In The World Wide Web Conference. ACM, 2022–2032.
  • Yang et al. (2018) Liang Yang, Yuanfang Guo, and Xiaochun Cao. 2018. Multi-facet network embedding: Beyond the general solution of detection and representation. In

    Thirty-Second AAAI Conference on Artificial Intelligence

    .
  • Yun et al. (2019) Seongjun Yun, Minbyul Jeong, Raehyun Kim, Jaewoo Kang, and Hyunwoo J Kim. 2019. Graph Transformer Networks. In Advances in Neural Information Processing Systems. 11960–11970.
  • Zhang and Chen (2018) Muhan Zhang and Yixin Chen. 2018. Link prediction based on graph neural networks. In Advances in Neural Information Processing Systems. 5165–5175.
  • Zhang et al. (2020) Ruochi Zhang, Yuesong Zou, and Jian Ma. 2020. Hyper-SAGNN: a self-attention based graph neural network for hypergraphs. In International Conference on Learning Representations.
  • Zhang et al. (2018) Yuyu Zhang, Hanjun Dai, Zornitsa Kozareva, Alexander J Smola, and Le Song. 2018. Variational reasoning for question answering with knowledge graph. In Thirty-Second AAAI Conference on Artificial Intelligence.

Appendix A Complexity Analysis of SLiCE

a.1 Complexity Analysis

We assume that denotes the number of context subgraphs generated for each node, represents the maximum number of nodes in any context subgraph, and represents the number of nodes in the input graph . Then, the total number of context subgraphs considered in pre-training stage can be calculated as and the cost of iterating over all these subgraphs through multiple epochs will be . Since the generated context subgraphs should provide us a good approximation of the total number of edges in the entire graph, we approximate the total cost as , where is the number of edges in the training dataset. It can also be represented as , where is the total number of edges in the input graph and represents the ratio of training split. The cost for each contextual translation layer in SLiCE model is since the dot product for calculating nodes similarity is the dominant computation and is quadratic to the number of nodes in the context subgraph. In this case, the total training complexity will be . The maximum number of nodes in context subgraphs is relatively small and it can be considered as a constant that does not depend on the size of the input graph. Therefore, the training complexity of SLiCE is approximately linear to the number of edges in the input graph.

a.2 Sampling Analysis

In the complexity analysis, we approximate the total number of training edges in the entire graph as . This also provides us guidelines for determine the number of context subgraphs for each node . By incorporating

into the approximation, we can estimate the number of context subgraphs per node as

. Table 4 shows the estimated numbers (with ) for the five datasets used in this work. These estimation provides us an approximate range for the value of during the context generation step. In the parameter sensitivity analysis, we generally consider 1, 5 and 10 for the value of on all the five datasets to keep total run time (pretraining and finetuning) of a dataset to a maximum of 2 days. This also explains that the lower performance of SLiCE on YouTube compared to other datasets may be caused by the smaller value of we considered since the above quantitative analysis would require around 250 contexts per node for YouTube dataset. However, the large amount of context subgraphs for each node would substantially increase the training time and make it prohibitive for us to run many experiments.

Dataset Amazon DBLP Freebase YouTube Twitter
# Nodes () 10,099 37,791 14,541 2,000 9,990
# Edges () 129,811 170,794 248,611 835,330 294,330
# Contexts () 7.74 2.71 10.26 250.60 17.67
Table 4: Estimation of the number of context subgraphs for each node in the knowlege graph.

Appendix B Feature Generation

In general, there are mainly two types of methods for generating node features in knowledge graphs, including the encoder based and random walk based approaches. The encoder based approaches leverage the properties of adjacency matrix of knowledge graph and perform message passing to aggregate node information. Random walk based approaches feed the walk paths into a skip-gram model for feature generation by learning contexts of nodes. We include both types of approaches as baselines as shown in Table 2 and Table 5. Compared with other baseline methods, we observe that the node embeddings obtained from node2vec (random walk based) produces competitive performance for link prediction tasks. Therefore, in the proposed SLiCE model, we mainly consider the pre-trained node representation vectors from node2vec as the feature of nodes. More specifically, we first collect the set of subgraphs with shortest path or random strategy based on the context subgraph generation methods described in Section 3. The node paths extracted from the generated context subgraphs are fed into the skip-gram model for feature generation, which will produce similar embeddings for the nodes with similar contexts neighbors.

Appendix C Experiments

In this section, we first introduce more details about the knowledge graph dataset used in the work and the experimental setup for baseline methods. In addition, more experimental results are also provided, including link performance evaluation with AUROC and parameter sensitivity.

c.1 Dataset Description

We provided the details about node and relation types in each knowledge graph as follows.

Amazon555https://github.com/THUDM/GATNE/tree/master/data : This knowledge graph includes the co-viewing and co-purchasing links between products. The two edge types, also_bought and also_viewed, represent that two products are co-bought or co-viewed by the same user, respectively.

DBLP666https://github.com/Jhy1993/HAN/tree/master/data: This knowledge graph includes the relationships between papers, authors, venues and terms. The edge types include paper_has_term, published_at and has_author. Original dataset IDs were mapped to an integer range from 0 to 37790 and maps were saved to preserve interpretability. The same number of negative edges as positive edges were generated for the link prediction task.

Freebase777https://github.com/malllabiisc/CompGCN/tree/master/data_compressed: This knowledge graph is a pruned version of FB15K with inverse relations removed. It includes the links between people and their nationality, gender, profession, institution, placeofbirth, placeofdeath along with other demographic features. Original dataset IDs were mapped to an integer range from 0 to 14540 and maps were saved to preserve interpretability. The same number of negative edges as positive edges were generated for the link prediction task.

YouTube5: This knowledge graph includes various links between YouTube users, including contact, shared friends, shared subscription, shared subscriber and shared favorite videos.

Twitter5: This knowledge graph between tweets users is generated based on tweets related to the discovery of the Higgs boson between 1st and 7th, July 2012. The edge types included in the network are re-tweet, reply, mention and friendship/follower.

In addition, for each positive edge in the training set, we generated double the number of negative edges during training phase of link prediction task.

c.2 Experimental Setup

In this section, we present the implementation details of baseline methods. The parameters not specified here are using the default settings. The best model parameters are selected using the development data.

TransE and RefE888https://github.com/HazyResearch/KGEmb: The dimension of node embeddings is set to be 128. Adagrad optimizer is used to train the model parameters with learning rate 0.01. The number of negative samples are 50.

node2vec999https://github.com/aditya-grover/node2vec: We set the dimension of the node embedding to be 128 and sampled 10 random walks starting from each node with length 80. The parameters and used for the neighborhood sampling are both set to be 1. The size of the sliding window is set to be 10.

metapath2vec101010https://ericdongyx.github.io/metapath2vec/m2v.html: The performance of metapath2vec is evaluated based on the generated context subgraphs for the proposed SLiCE model. We generate 12 walks for each node. Both the size of negative samples and the sliding window are set to be 5. We set the dimension of the node embedding to be the default of 100. The experiments are ran with 32 threads.

Watch your step (GAN)111111https://github.com/google-research/google-research/tree/master/graph_embedding/watch_your_step: We experimented with a learning rate of [0.1, 0.01, 0.001] and reported the best performance after maximum number of steps set to 10 in the GAN network. A transition power of 5 was used and the embedding dimension was set to 128.

GATNE-T121212https://github.com/THUDM/GATNE: We set the dimension of the node embedding to be 200 and generated 20 walks with length 10 for each node by considering 10 neighbors. Both the size of negative samples and the sliding window are set to be 5.

CompGCN131313https://github.com/malllabiisc/CompGCN: We set the dimension of the node embedding to be 200. We apply 1 GCN layer and use the multiplication operation for the composition of node and relation embeddings.

c.3 Experimental Results

c.3.1 Performance Evaluation on Link Prediction

Besides the micro-F1 score provided in Table 2 for comparing the link prediction performance of different models, we provide the performance with AUROC in Table 5. Compared to both static and contextual embedding learning methods, SLiCE outperforms all other methods in AUROC across four of the five datasets, except Youtube network which requires a higher sampling(context per node) of the dataset owing to its very dense nature.

Figure 6 shows the distribution of similarity scores for both positive and negative node-pairs obtained on DBLP, Freebase and YouTube by SLiCE. Compared with the distribution patterns showed in Figure 3 on Amazon and Twitter, we observe similar patterns on these three datasets. The embeddings learnt from node2vec produce high similarity scores for both positive and negative node-pairs. While the contextual embeddings provided by SLiCE are able to efficiently differentiate the positive and negative node-pairs by producing smaller similarity scores for negative node-pairs and higher similarity scores for positive node-pairs. In this case, the distribution of negative node-pairs are pushed to the left end of the plots. These results indicates that the generated subgraphs are able to provide informative context information to improve the performance of link prediction task.

Methods Amazon DBLP Freebase YouTube Twitter
TransE (Bordes et al’ 2013) 50.53 49.05 48.18 50.03 50.26
RefE (Chami et al’ 2020) 51.74 48.50 50.41 50.13 49.28
node2vec (Grover et al’ 2016) 94.48 93.87 89.77 71.98 80.48
metapath2vec (Dong et al’ 2017) 95.42 38.41 84.33 67.64 72.16
Watch your step (GAN) (Abu-El-Haija et al’ 2018) 92.86 - - 75.24 92.39
GATNE-T (Cen et al 2019) 94.74 58.44 - 83.50 72.07
CompGCN (Vashishth et al’ 2019) 90.14 34.04 72.01 61.33 39.86
SLiCE (Proposed Method) 99.02 96.69 96.41 79.28 95.73
Table 5: Performance comparison of different models on link prediction task using AUROC. The symbol “-” indicates that it is computationally prohibitive to obtain the results.
(a) node2vec: DBLP
(b) node2vec: Freebase
(c) node2vec: YouTube
(d) SLiCE: DBLP
(e) SLiCE: Freebase
(f) SLiCE: YouTube
Figure 6: Comparisons between distributions of similarity scores of both positive and negative node-pairs obtained by node2vec and NetBERT on DBLP, Freebase and YouTube.

c.3.2 Parameter Sensitivity

In Figure 7, we provide the link prediction performance with micro-F1 score on five datasets by varying four parameters used in SLiCE model, including number of heads, number of layers, walk length and number of (context) walks per node in pre-training. The performance shown in these plots are the averaged performance by fixing one parameter and varying other three parameters. In Figure 8, we also show the range of the micro-F1 scores using boxplots when fixing of the four parameters. We observe that the averaged performance of SLiCE is stable when we varying the values of number of heads, walk length and number of walks per node. This indicates that involving more contexts or more nodes in the context subgraphs does not affect the model performance. However, for the parameter number of layers, we can observe that applying 4 layers of contextual translation provides the best performance on all the datasets and the performance dropped significantly when applying 16 layers. Based on these analysis, we set the default values for both number of heads and number of layers to be 4, and generate 1 walk for each node with length 6 in pre-training step.

(a)
(b)
(c)
(d)
Figure 7: Micro-F1 scores for link prediction with different parameters in SLiCE on five datasets.
(a)
(b)
(c)
(d)
Figure 8: Sensitivity analysis of SLiCE as a function of number of heads, layers, number of nodes per context and number of context subgraphs per node.