MR-GNN: Multi-Resolution and Dual Graph Neural Network for Predicting Structured Entity Interactions

05/23/2019 ∙ by Nuo Xu, et al. ∙ Xi'an Jiaotong University 0

Predicting interactions between structured entities lies at the core of numerous tasks such as drug regimen and new material design. In recent years, graph neural networks have become attractive. They represent structured entities as graphs and then extract features from each individual graph using graph convolution operations. However, these methods have some limitations: i) their networks only extract features from a fix-sized subgraph structure (i.e., a fix-sized receptive field) of each node, and ignore features in substructures of different sizes, and ii) features are extracted by considering each entity independently, which may not effectively reflect the interaction between two entities. To resolve these problems, we present MR-GNN, an end-to-end graph neural network with the following features: i) it uses a multi-resolution based architecture to extract node features from different neighborhoods of each node, and, ii) it uses dual graph-state long short-term memory networks (L-STMs) to summarize local features of each graph and extracts the interaction features between pairwise graphs. Experiments conducted on real-world datasets show that MR-GNN improves the prediction of state-of-the-art methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

A large variety of applications require understanding the interactions between structured entities. For example, when one medicine is taken together with another, each medicine’s intended efficacy may be altered substantially (see Fig. 1). Understanding their interactions is important to minimize the side effects and maximize the synergistic benefits [Ryu et al.2018]. In chemistry, understanding what chemical reactions will occur between two chemicals is helpful in designing new materials with desired properties [Kwon and Yoon2017]. Despite its importance, examining all interactions by performing clinical or laboratory experiments is impractical due to the potential harms to patients and also highly time and monetary costs.

Recently, machine learning methods have been proposed to address this problem, and they are demonstrated to be effective in many tasks 

[Duvenaud et al.2015, Li et al.2017, Tian et al.2016, Ryu et al.2018]

. These methods use features extracted from entities to train a classifier to predict entity interactions. However, features have to be carefully provided by domain experts 

[Ryu et al.2018, Tian et al.2016]

, and it is labor-intensive. To automate feature extraction, graph convolution neural networks (GCNs) have been proposed 

[Alex et al.2017, Kwon and Yoon2017, Zitnik et al.2018]. GCNs represent structured entities as graphs, and use graph convolution operators to extract features. One of the state-of-the-art GCN models, proposed by Alex et al. protein_interface, extracts features from the 3-hop neighborhood of each node. We thus say that their model uses a fix-sized receptive field (RF). However, using a fix-sized RF to extract features may have limitations, which can be illustrated by the following example.

Figure 1: Overview of graph-based framework. We transform two drugs Allopurinol and Amoxicillin into graphs, where nodes represent atoms and edges refer to chemical bonds between atoms, and predict interactions between them. When there exists an adverse reaction between them, they cannot be taken together.
Example 1.

Figure 2 shows two weak acids, i.e., Hydroquinone and Acetic acid. They are weak acids due to the existence of substructures phenolic hydroxyl (ArOH) and carboxyl (COOH), respectively. Representing these two chemical compounds as graphs, we need a three-hop neighborhood to accurately extract ArOH from Hydroquinone, and a two-hop neighborhood to accurately extract COOH from Acetic acid. While using a fix-sized neighborhood will result in that either incomplete substructures being extracted (i.e., RF is too small), or useless substructures being included (i.e., RF is too large).

Figure 2: The structure of two weak acids: Hydroquinone and Acetic acid. The blue box shows the acidic substructures: ArOH and COOH. The red dashed circle shows the receptive field of the corresponding red node in different convolution layers.

Another limitation of existing GCNs is that, they learn each graph’s representation independently, and model the interactions only in the final prediction process. However, for different entities, the interaction also occurs by substructures of different size. Take Fig. 2 for example again, when these two weak acids are neutralized with the same strong base, the interaction can be accurately modeled by features of the second convolution layer for Acetic acid because the key substructure ArOH can be accurately extracted. But for Hydroquinone, the best choice is to model the interaction by features of the third convolution layer. Thus, modeling the interactions only in the final process may make a lot of noise to the prediction.

To address these limitations, this work presents a novel GCN model named Multi-Resolution RF based Graph Neural Network (MR-GNN), which leverages different-sized local features and models interaction during the procedure of feature extraction to predict structured entity interactions.

1.0.1 Overview of our approach.

MR-GNN uses a multi-resolution RF, which consists of multiple graph convolution layers with different RFs, to extract local structure features effectively (see Fig. 2). When aggregating these multi-resolution local features, MR-GNN uses two key dual graph-state LSTMs. One is Summary-LSTM (S-LSTM), which aggregates multi-resolution local features for each graph. Compared with the straightforward method that simply sums all multi-resolution features up, S-LSTM learns additional effective features by modeling the diffusion process of node information in graphs which can greatly enrich the graph representation. The other is Interaction-LSTM (I-LSTM), which extracts interaction features between pairwise graphs during the procedure of feature extraction.

Our contributions are as follows:

  • In MR-GNN, we design a multi-resolution based architecture that mines features from multi-scale substructures to predict graph interactions. It is more effective than considering only fix-sized RFs.

  • We develop two dual graph-state LSTMs: One summarizes subgraph features of multi-sized RFs while modeling the diffusion process of node information, and the other extracts interaction features for pairwise graphs during feature extraction.

  • Experimental results on two benchmark datasets show that MR-GNN outperforms the state-of-the-art methods.

2 Problem Definition

Notations. We denote a structured entity by a graph , where is the node set and is the edge set. Each specific node is associated with a

-dimension feature vector

. The feature vectors can also be low-dimensional latent representations/embeddings for nodes or explicit features which intuitively reflects node attributes. Meanwhile, let denote ’s neighbors, and denote ’s degree.

Entity Interaction Prediction. Let denote a set of

interaction labels between two entities. The entity interaction prediction task is formulated as a supervised learning problem: Given training dataset

where is an input entity pair, and is the corresponding interaction label; let denote the size of , we want to accurately predict the interaction label of an unseen entity pair .

3 Method

In this section, we propose a graph neural network, i.e., MR-GNN, to address the entity interaction prediction problem.

3.1 Overview

Figure 3 depicts the architecture of MR-GNN, which mainly consists of three parts: 1) multiple weighted graph convolution layers, which extract structure features from receptive fields of different sizes, 2) dual graph-state LSTMs, which summarize multi-resolution structure features and extract interaction features, and 3) fully connected layers, which predict the entity interaction labels.

Figure 3: A three-layer framework of MR-GNN. For each input graph, it uses several graph convolution layers (GCLs) to learn multi-resolution structure features. Then, for each GCL, a graph-gather layer sums the node vectors of the same resolution to get a graph-state. We feed the graph-states of different GCLs, which have different receptive fields, into our S-LSTM and I-LSTM to learn the final representation comprehensively. Finally, the final S-LSTM hidden vectors and , the final I-LSTM hidden vectors , and the graph pooling (GP) vectors of entire graph and are concatenated and passed to the following fully connected layers for learning a predictive model.

3.2 Weighted graph convolution layers

Before introducing the motivation and design of our weighted graph convolution operators in detail, we elaborate the standard graph convolution operator.

Standard Graph Convolution Operator. Inspired by the convolution operator on images, for a specific node in a graph, the general spatial graph convolution [Duvenaud et al.2015] aggregates features of a node as well as its one-hop neighbors’ as the node’s new features. Based on the above definition, take the node as an example, the formula is:

(1)

where denotes the feature vector of in the graph convolution layer, is the weight matrix associated with the center node and

is the tanh activation function. Note that

.

Because the output graph of each graph convolution layer is exactly same as the input graph, MR-GNN can conveniently learn the structural characteristics of different resolutions through different iterations of the graph convolution layer. Take the node A in Fig. 3 as an example, after three iterations of graph convolution layer, the receptive field in the third graph convolution layer is a three-hop neighborhood centered on it.

However, since graphs are not regular grids compared with images, it is difficult for the existing graph convolution operator to distinguish the weight by spatial orientation position like the convolution operator on grid-like data, e.g., in the image processing, the right neighbor and the left neighbor of a pixel can be treated with different weight for each convolution kernel. Inspired by the fact that the degree of nodes can well reflect the importance of nodes in a network for many applications. We modify the graph convolution operator by adding weights according to the node degree . (Other metrics such as betweenness centrality can also work well. In this paper we choose the degree of nodes because of the simplicity of calculation.) Furthermore, Sukhbaatar et al. sukhbaatar2016learning treats different agents with different weights in order to distinguish the feature of the original node and the features of neighboring nodes. We treat each node and its neighbors with different weight matrixes, and . Our improved weighted graph convolution is as follows:

(2)

where denote the weight of node with degree , denotes the dimension of the feature vector in the graph convolution layer, and is a bias. We let .

After each convolution operation, similar to the classical CNN, we use a graph pooling operation to summarize the information within neighborhoods (i.e., a center node and its neighbors). For a specific node, the Graph Pooling [Altae-Tran et al.2017] returns a new feature vector of which each element is the maximum activation of the corresponding element of one-hop neighborhood at this node. We denote this operation by the following formula and get the feature vectors of the next layer:

(3)

3.3 Graph-gather layers

Graph interaction prediction is a graph-level problem rather than a node-level problem. To learn the graph-level features of different-sized receptive fields, we aggregate the node representations of each convolution layer’s graph to a graph-state by a graph-gather layer. Graph-gather layers compute a weighted sum of all node vectors in the connected graph convolution layers. The formula is:

(4)

where is the graph-gather weight of nodes with degree in the graph convolution layer, is the graph-state vector of the convolution layer, denotes the dimension of graph-states, is the nodes’ number in the graph and is a bias. Specially, the first graph-state only includes all individual nodes’ information.

3.4 Dual graph-state lstms

To solve graph-level tasks, the existing graph convolution networks (GCNs) methods [Altae-Tran et al.2017] generally choose the graph-state of the last convolution layer, which has the largest receptive fields, as input for subsequent prediction. But such state may loss many important features.

Referring to the CNN on images, there are multiple convolution kernels for extracting different features in each convolution layer, which ensure the hidden representation of the final convolution layer can fully learn features of input images. However, GCN is equivalent to CNN that only has one kernel in each layer. It is difficult for the output of the final graph convolution layer to fully learn all features in the large receptive fields, especially for structure features of small receptive field. The straightforward way is to design multiple graph convolution kernels and aggregate the output of them. However it is computational expensive.

To solve the above problem, we propose a multi-resolution based architecture in our model, in which the graph-state of each graph convolution layer is leveraged to learn the final representation. We propose a Summary-LSMT (S-LSTM) to aggregate the graph-states of different-sized receptive fields for learning the final features comprehensively. Instead of the straightforward method that directly sums all graph-states up, S-LSTM models the node information diffusion process of graphs by sequentially receiving the graph-state with receptive field from small to large as inputs. It is inspired by the idea a representation that encapsulates graph diffusion can provide a better basis for prediction than the graph itself. The formula of S-LSTM is:

(5)

where is the hidden vector of S-LSTM. To further enhance the global information of graphs, we concatenate the final hidden output of S-LSTM and the output of global graph pooling layer as the final graph-state of the input graph:

(6)

where is the result of global graph pooling on the final graph convolution layer.

In addition, to extract the interaction features of pairwise graphs, we propose an Interaction-LSTM (I-LSTM) which takes the concatenation of dual graph-states as input:

(7)

where is the hidden vector of I-LSTM .We initialize and as an all-zero vector and the S-LSTM is shared to both input graphs.

3.5 Fully connected layers

For the interaction prediction, we simply concatenate the final graph representations and interaction features of input graphs (i.e., , and ) and use fully connected layers for prediction. Formally, we have:

(8)
(9)

where are linear operations, and are trainable weight matrices, is the dimension of the hidden vector, and is the number of interaction labels. The activation function

is a rectified linear unit (ReLU), i.e.,

. is the output of softmax function , the element of is computed as

. At last, we choose the cross entropy function as loss function, that is:

(10)

where is the ground-truth vector.

4 Experiment

In this section, we conduct experiments to validate our method222Code available at https://github.com/prometheusXN/MR-GNN. We consider two prediction tasks: 1) predicting whether there is an interaction between two chemicals (i.e., binary classification), and 2) predicting the interaction label between two drugs (i.e., multi-class classification).

4.1 Dataset

CCI Dataset. For the binary classification task, we use the CCI dataset333http://stitch.embl.de/download/chemical_chemical.links.detailed.v5.0.tsv.gz. This dataset uses a score ranging from to

to describe the interaction level between two compounds. The higher the score is, the larger probability the interaction will occur with. According to threshold scores

, and , we got positive samples of three datasets: CCI, CCI, and CCI. As for negative samples, we choose the chemical pairs of which the score is . For each pair of chemicals, we assign a label “1” or “0” to indicate whether an interaction occurs between them. We use a public available API, DeepChem444https://deepchem.io/, to convert compounds to graphs, that each node has a 75-dimension feature vector.

DDI Dataset. For the multi-class classification task, we use the DDI dataset555http://www.pnas.org/content/suppl/2018/04/14/1803294115.DCSupplemental. This dataset contains interaction labels, and each drug is represented by SMILES string [Weininger1988]. In our preprocessing, we remove the data items that cannot be converted into graphs from SMILES strings.

Dataset Graph Meaning #Graphs #Pairs
CCI900 Chemical Compounds 11990 19624
CCI800 Chemical Compounds 73602 151796
CCI700 Chemical Compounds 114734 343277
DDI Drug Molecule Graphs 1704 191400
Table 1: Statistics of datasets.

4.2 Baselines

CCI900 CCI800 CCI700
AUC accuracy recall F1 AUC accuracy recall F1 AUC accuracy recall F1
PIP
SNR
DGCNN
DeepDDI
DeepCCI
MR-GNN
Table 2: Experimental results of the binary classification task.

We compare our method with the following state-of-the-art models:

  • DeepCCI  [Kwon and Yoon2017] is one of the state-of-the-art methods on the CCI datasets. It represents SMILES strings of chemicals as one-hot vector matrices and use classical CNN to predict interaction labels.

  • DeepDDI [Ryu et al.2018]

    is one of the state-of-the-art methods on the DDI dataset. DeepDDI designs a feature called structural similarity profile (SSP) combined with multilayer perceptron (MLP) for prediction.

  • PIP  [Alex et al.2017] is proposed to predict the protein interface. It extracts features from the fixed three-hop neighborhood for each node to learn a node representation. In this paper, when building this model, we use our graph-gather layer to aggregate node representations to get the graph representation.

  • DGCNN [Zhang et al.2018a] uses the standard graph convolution operator as described in Section 3. It concatenates the node vectors of each graph convolution layer and applies CNN with a node ordering scheme to generate a graph representation.

  • SNR [Li et al.2017] uses the similar graph convolution layer as our method. The difference is that this work introduces an additional node that sums all nodes features up to a graph representation.

4.3 Binary classification

Settings. We divide each CCI dataset into a training dataset and a testing dataset with ratio , and randomly choose of the training dataset as a validation dataset. We set the three graph convolution layers with , , output units, respectively. We set output units of graph-gather layers as the same as the LSTM layer. The fully connected layer has

hidden units followed by a softmax layer as the output layer. We set the learning rate to

. To evaluate the experimental results, we choose four metrics: area under ROC curve (AUC), accuracy, recall, and F1.

Results. Table 2

shows the performance of different methods. MR-GNN performs the best in terms of all of the evaluation metrics. Compared with the state-of-the-art method DeepCCI, our MR-GNN improves accuracy by

-, F1 by -, recall by -, and AUC by -. As for little improvement of AUC, we think it is ascribed to the fact that the basic value is too large to provide enough space for improvement. When translated into the remaining space, the AUC is increased by -. The performance improvement proves that features extraction of MR-GNN, which represents structured entities as graphs for features extraction, is more effective than DeepCCI, which treats SMILES string as character sequence without considering topological information of structured entities. Compared with PIP, the performance of MR-GNN demonstrates that the multi-resolution based architecture is more effective than the fix-sized RF based framework. In addition, compared with SNR which directly sums all node features to get the graph representation, experimental results prove that our S-LSTM summarizes the local features more effectively and more comprehensively. We attribute this improvement to the diffusion process and the interaction that our graph-state LSTM modeled during the procedure of feature extraction, which is effective for the prediction.

4.4 Multi-class classification

Settings. To make an intuitional comparison, similar to DeepDDI, we use , , of dataset for the training, validation and testing, respectively. All hyper-parameter selections are the same as the binary classification task. To evaluate the experimental results, we choose five metrics on the multi-classification problem: AUPRC, Micro average, Macro recall, Macro precision, and Macro F1. (In particular, we choose the AUPRC metric due to the imbalance of the DDI dataset.) We show the results on DDI dataset in Table 3.

Mi_avg Ma_recall Ma_pre Ma_F1 AUPRC
PIP
SNR
DGCNN
DeepCCI
DeepDDI
MR-GNN
-no I-LSTM
-no S-LSTM
-no w-GCL
-no LSTMs
Table 3: Results on the DDI dataset.

Results. We observe that MR-GNN performs the best in terms of all five evaluation metrics. MR-GNN improves these five metrics by , , , and , respectively. Compared with the state-of-the-art method DeepDDI, the performance improvement of MR-GNN is attributed to the higher quality representations learned by end-to-end training instead of the human-designed representation called SSP. In addition, we also conduct experiments on CCI and DDI datasets, and we observe that MR-GNN indeed improves performance.

Ablation experiment. We also conducted ablation experiments on the DDI dataset to study the effects of three components in our model (namely S-LSTM, I-LSTM, and weighted GCL). We find that each of these three components can improve performance. Among them, weighted GCLs contributes most significantly, then comes S-LSTM and I-LSTM.

4.5 Efficiency and robustness

In the third experiment, we conduct experiments to analyze the efficiency and robustness of MR-GNN.

Effects of training dataset size. We carried out a comparative experiment with different size of training datasets from to on the CCI900 dataset. In each comparative experiments, we kept the same of the dataset as the test dataset to evaluate the performance of all six methods. Figure 4(a) shows that MR-GNN always performs the best under different training dataset size. In particular, as the training dataset proportions increases, the improvement of MR-GNN increases significantly, demonstrating that our MR-GNN has better robustness. This is due to the fact that MR-GNN is good at learning subgraph information of different-sized receptive fields, especially subgraphs of small receptive fields that often appear in various graphs.

Training efficiency. Figure 4(b) shows that the training time of MR-GNN is at a moderate level among all methods. Although the graph-state LSTMs takes the additional time, the training of MR-GNN is still fast and acceptable.

Effects of hyper-parameter variation. In this experiment, we consider the impact of hyper-parameters of MR-GNN: the output units number of GCLs () and LSTMs (), the hidden units number of the fully connected layer (), and . The results are shown in Fig. 5. We see that the impact of hyper-parameter variation is insignificant (the absolute difference is less than ). Fig. 5(a) shows that larger provides a better performance (with an salient point at ). Fig. 5(b) shows that similar result of while a salient point is at . The performance increases fast when and slightly declines when . As for and , the best point appears at and , respectively.

(a)
(b)
Figure 4:

Result on CCI900: a) Accuracy under different training set proportions; b) Training time per epoches.

5 Related Work

Node-level Applications. Many neural network based methods have been proposed to solve the node-level tasks such as node classification [Henaff et al.2015, Li et al.2015, Defferrard et al.2016, Kipf and Welling2016, Velic̈kovic et al.2018], link prediction [Zhang and Chen2018, Zhang et al.2018b], etc. They rely on node embedding techniques, including skip-gram based methods like DeepWalk [Perozzi et al.2014] and LINE [Tang et al.2015]

, autoencoder based methods like SDNE 

[Wang et al.2016], neighbor aggregation based methods like GCN [Defferrard et al.2016, Thomas and Welling2017] and GraphSAGE [Hamilton et al.2017a], etc.

(a)
(b)
(c)
(d)
Figure 5: Parameter sensitivities w.r.t. , , and .

Single Graph Based Applications. Attention also has been paid on the graph-level tasks. Most existing works focus on classifying graphs and predicting graphs’ properties [Duvenaud et al.2015, Atwood and Towsley2016, Li et al.2017, Zhang et al.2018a] and they compute one embedding per graph. To learn graph representations, the most straightforward way is to aggregate node embeddings, including average-based methods (simple average and weight average)  [Li et al.2017, Duvenaud et al.2015, Zhao et al.2018], sum-based methods [Hamilton et al.2017b] and some more sophisticated schemes, such as aggregating nodes via histograms [Kearnes et al.2016] or learning node ordering to make graphs suitable for CNN [Zhang et al.2018a].

Pairwise Graph Based Applications. Nowadays, very little neural network based works pay attention to the pairwise graph based tasks whose input is a pair of graphs. However, most existing works focus on learning “similarity” relation between graphs [Bai et al.2018, Yanardag and Vishwanathan2015] or links between nodes across graphs [Alex et al.2017]. In this work, we study the prediction of the universal graph interactions.

6 Conclusion

In this paper, we propose a novel graph neural network, i.e., MR-GNN, to predict the interactions between structured entities. MR-GNN can learn comprehensive and effective features by leveraging a multi-resolution architecture. We empirically analyze the performance of MR-GNN on different interaction prediction tasks, and the results demonstrate the effectiveness of our model. Moreover, MR-GNN can easily be extended to large graphs by assigning node weights to node groups that based on the distribution of node degrees. In the future, we will apply it to more other domains.

Acknowledgments

The research presented in this paper is supported in part by National Key R&D Program of China (2018YFC0830500), National Natural Science Foundation of China (UI736205, 61603290), Shenzhen Basic Research Grant (ICYJ20170816100819428), Natural Science Basic Research Plan in Shaanxi Province of China (2019JM-159), Natural Science Basic Research Plan in ZheJiang Province of China (LGG18F020016).

References