Neural Relational Inference with Efficient Message Passing Mechanisms

01/23/2021 ∙ by Siyuan Chen, et al. ∙ SUN YAT-SEN UNIVERSITY 33

Many complex processes can be viewed as dynamical systems of interacting agents. In many cases, only the state sequences of individual agents are observed, while the interacting relations and the dynamical rules are unknown. The neural relational inference (NRI) model adopts graph neural networks that pass messages over a latent graph to jointly learn the relations and the dynamics based on the observed data. However, NRI infers the relations independently and suffers from error accumulation in multi-step prediction at dynamics learning procedure. Besides, relation reconstruction without prior knowledge becomes more difficult in more complex systems. This paper introduces efficient message passing mechanisms to the graph neural networks with structural prior knowledge to address these problems. A relation interaction mechanism is proposed to capture the coexistence of all relations, and a spatio-temporal message passing mechanism is proposed to use historical information to alleviate error accumulation. Additionally, the structural prior knowledge, symmetry as a special case, is introduced for better relation prediction in more complex systems. The experimental results on simulated physics systems show that the proposed method outperforms existing state-of-the-art methods.



There are no comments yet.


page 1

page 2

page 3

page 4

page 5

page 6

page 8

page 9

Code Repositories


Code for Neural Relational Inference with Efficient Message Passing Mechanisms (AAAI 2021).

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.


Many complex processes in natural and social areas including multi-agent systems Yang et al. (2018); Li et al. (2020), swarm systems Oliveira et al. (2020), physical systems Ha and Jeong (2020); Bapst et al. (2020) and social systems Almaatouq et al. (2020); Zhang et al. (2020) can be viewed as dynamical systems of interacting agents. Revealing the underlying interactions and dynamics can help us understand, predict, and control the behavior of these systems. However, in many cases, only the state sequences of individual agents are observed, while the interacting relations and the dynamical rules are unknown.

A lot of works Hoshen (2017); van Steenkiste et al. (2018); Watters et al. (2017) use implicit interaction models to learn the dynamics. These models can be regarded as graph neural networks (GNNs) over a fully-connected graph, and the implicit interactions are modeled via message passing operations Watters et al. (2017) or attention mechanisms Hoshen (2017). Compared with modeling implicit interactions, modeling the explicit interactions offers a more interpretable way to understand the dynamical systems. A motivating example is shown in Fig. 1 Kipf et al. (2018), where a dynamical system consists of 5 particles linked by invisible springs. It is of interest to infer the relations among the particles and predict their future states. Kipf et al. (2018) propose the neural relational inference (NRI) model, a variational auto-encoder (VAE) Kingma and Welling (2014), to jointly learn explicit interactions and dynamical rules in an unsupervised manner.

Figure 1: A dynamical system consisting of 5 particles that are linked by invisible springs. The interacting relations and future states are to be predicted based on the observed state sequences.

Currently, there are three main limitations of NRI. First, the interacting relations are inferred independently, while the coexistence of these relations is not considered. Alet et al. (2019) tackle this problem by taking all relations as a whole and iteratively improving the prediction through modular meta-learning. However, this method is computationally costly and it is limited to small-scale systems. Second, NRI predicts multiple steps into the future to emphasize the effect of the interactions, which leads to error accumulation and prevents the model from reconstructing the dynamics accurately. Third, as the scale and the complexity of systems increase, it is difficult to infer the interactions solely based on the observed data, while incorporating some structural prior knowledge can help better explore the structural space and promote the precision of relation recovery.

To address the problems above, this paper introduces efficient message passing mechanisms with structural prior knowledge for neural relational inference. For relation reconstruction, a relation interaction mechanism is introduced to capture the dependencies among different relations. For dynamics reconstruction, a spatio-temporal message passing mechanism is introduced to utilize historical information to alleviate the error accumulation of multi-step prediction. Additionally, the prior knowledge about the relations is incorporated as a regularization term in the loss function to impose soft constraints, and as a special case, the symmetry of relations is taken into consideration.

The contributions of this work are summarized as follows.

  • Efficient message passing mechanisms are introduced to make a joint prediction of relations and alleviate the error accumulation of multi-step state prediction.

  • The prior knowledge about relations, symmetry as a special case, is introduced to better reconstruct relations on more complex dynamical systems.

  • Extensive experiments on physics simulation datasets are conducted. The results show the superiority of our method by comparing it with the state-of-the-art methods.


Neural Relational Inference (NRI)

NRI Kipf et al. (2018) is an unsupervised model that infers interacting structure from observational data and learns system dynamics. NRI takes the form of VAE, where the encoder infers the interacting relations, and the decoder predicts the future states of individual agents.

Specifically, the encoder adopts GNNs with multiple rounds of message passing, and infers the distribution of potential interactions based on the input state,


where is the observed trajectories of objects in the system in time steps, is a GNN acting on the fully-connected graph, and is the factorized distribution of edge type .

Since directly sampling edges from , a discrete distribution, is a non-differentiable process, the back-propagation cannot be used. To solve this problem, NRI uses the Gumbel-Softmax trick Maddison et al. (2017), which simulates a differentiable sampling process from a discrete distribution using a continuous function, i.e.,


where represents the edge type between the nodes and ,

is a vector of i.i.d. samples from a

distribution and the temperature is a parameter controlling the “smoothness” of the samples.

According to the inferred relations and past state sequences, the decoder uses another GNN that models the effect of interactions to predict the future state,


where is the conditional likelihood of .

As a variational auto-encoder model, NRI is trained to maximize the evidence lower bound,


where the prior

is a factorised uniform distribution over edge types. In the right-hand side of Eq. (

5), the first term is the expected reconstruction error, while the second term encourages the approximate posterior to approach the prior .

Message Passing Mechanisms in GNNs

GNNs are a widely used class of neural networks that operates on graph structured data by message passing mechanisms Gilmer et al. (2017). For a graph with vertices and edges , the node-to-edge () and edge-to-node () message passing operations are defined as follows,


where is the embedding of node in the -th layer of the GNNs, is the feature of edge (e.g. edge type), and is the embedding of edge in the -th layer of the GNNs. denotes the set of indices of neighboring nodes with an incoming edge connected to vertex , and denotes the concatenation of vectors. and

are node-specific and edge-specific neural networks, such as multilayer perceptrons (MLP), respectively. Node embeddings are mapped to edge embeddings in Eq. (

6) and vice versa in Eq. (7).


The structure of our method is shown in Fig. 2. Our method follows the framework of VAE in NRI. In the encoder, a relation interaction mechanism is used to capture the dependencies among the latent edges for a joint relation prediction. In the decoder, a spatio-temporal message passing mechanism is used to incorporate historical information for alleviating the error accumulation in multi-step prediction. As both mechanisms mentioned above can be regarded as Message Passing Mechanisms in GNNs for Neural Relational Inference, our method is named as NRI-MPM. Additionally, the structural prior knowledge, symmetry as a special case, is incorporated as a regularization term in the loss function to improve relation prediction in more complex systems.

Figure 2: Overview of NRI-MPM. The encoder uses the state sequences to generate relation embeddings, and applies a relation interaction mechanism to jointly predict the relations. The structural prior knowledge of symmetry is imposed as a soft constraint for relation prediction. The decoder takes the predicted relations and the historical state sequences to predict the change in state .

Relation Interaction Mechanism

The encoder is aimed at inferring the edge types based on the observed state sequences . From the perspective of message passing, the encoder defines a node-to-edge message passing operation at a high level. As shown in Eqs. (8)-(11), NRI first maps the observed state to a latent vector , and then applies two rounds of node-to-edge and one round of edge-to-node message passing alternately to obtain the edge embeddings that integrate both local and global information. Then, are used to predict the pairwise interacting types independently.


However, the relations are generally dependent on each other since they jointly affect the future states of individual agents. Within the original formulations of GNNs, Eqs. (8)-(11) cannot effectively modeled the dependencies among edges. Typical GNNs are designed to learn node embeddings, while the edge embeddings are treated as a transient part of the computation. To capture the coexistence of all relations, this paper introduces an edge-to-edge () message passing operation that directly passes messages among edges, named the relation interaction mechanism,


Ideally, this operation includes modeling the pairwise dependencies among all edges, which is computationally costly as its time complexity is . Alternatively, as shown in Fig. 3, our method decomposes this operation into two sub-operations, intra-edge interaction and inter-edge interaction operations, for modeling the interactions among incoming edges of the same node and those among the incoming edges of different nodes, respectively.

The intra-edge interaction operation is defined as follows,

Figure 3: Relation interaction mechanism. Given the edge embeddings , the intra-edge and inter-edge interaction operations are used to model the interactions among incoming edges of the same node and those among the incoming edges of different nodes, respectively. The resulting embeddings and are concatenated to obtain the final edge representations .

From the above definition, is required to be permutation equivariant Lee et al. (2019) to preserve the correspondences between the input and output edge embeddings. Formally, let be the ascending sequence of and be the set of all permutations on . For any permutation , . is said to be permutation equivariant if , where denotes the composition of functions. Although permutation invariant operators such as sum and max aggregators are widely used in GNNs Xu et al. (2019), the design for permutation equivariant operators is less explored. Inspired by Hamilton et al. (2017), this paper treats a set as an unordered sequence and defines as a sequence model over it, appending by an inverse permutation to restore the order, i.e.,



is a sequence model such as recurrent neural networks (RNNs) and convolutional neural networks (CNNs). When

is implemented by self-attention, is permutation equivariant, which may not hold when is an RNN. In this case, is expected to be approximately permutation equivariant with a well trained .

Similarly, one can define an inter-edge interaction operation. The difference is that this operation treats all incoming edges of a node as a whole. For simplicity, this paper applies mean-pooling to the edge embeddings to get an overall representation. Formally, the inter-edge interaction operation is defined to be the composition of the following steps,


where is a node-to-node () operation that passes messages among nodes and takes a similar form of . To analyze the time complexity of relation interaction, this paper assumes that all sequence models involved are RNNs. Since and Mean can be applied to each node in parallel and the time complexity of Eqs. (14)-(16) are all , the overall time complexity of relation interaction is , which is much more effective than the pairwise interaction.

Finally, an MLP is used to unify the results of the two operations, and the predicted edge distribution is defined as


As in Eq. (3), is sampled via the Gumbel-Softmax trick to allow back-propagation.

Spatio-Temporal Message Passing Mechanism

The decoder is aimed at predicting the future state using the inferred relations and the historical states . Since the interactions can have a small effect on short-term dynamics, NRI predicts multiple steps into the future to avoid a degenerate decoder that ignores the effect of interactions. In multi-step prediction, the predicted value is replaced by the ground truth state , i.e.,


where denotes the decoder, and is the number of time steps to predict. This means that errors in the prediction process accumulate over steps. As the decoder of NRI only uses the current state to predict future states, it is difficult to avoid error accumulation. Apparently, the current state of the interacting system is related to the previous states, and thus, incorporating historical information can help learn the dynamical rules and alleviate the error accumulation in multi-step prediction. To this end, sequence models are used to capture the non-linear correlation between previous states and the current state. GNNs combined with sequence models can be used to capture spatio-temporal information, which has been widely used in traffic flow forecasting Li et al. (2018); Guo et al. (2019). In this paper, the sequence model can be composed of one or more of CNNs, RNNs, attention mechanisms, etc.

Inspired by interaction networks Battaglia et al. (2016) and graph networks Sanchez-Gonzalez et al. (2018), sequence models are added to the original node-to-edge and edge-to-node operations to obtain their spatio-temporal versions. In this way, the decoder can integrate the spatio-temporal interaction information at different fine-grained levels, i.e., the node level and the edge level, which can help better learn the dynamical rules.

As shown in Fig. 4, the decoder contains a node-to-edge and an edge-to-node spatio-temporal message passing operations. The node-to-edge spatio-temporal message passing operation is defined as


where is the -th element of the vector , representing the edge type , whose effect is modeled by an MLP . For each edge , the effect of all potential interactions are aggregated into as a weighted sum. Its concatenation with the previous hidden states is fed to to generate the future hidden state of interactions that captures the temporal dependencies at edge level.

Figure 4: The structure of the decoder. The decoder takes the interacting relations , the current state , the historical hidden states and as inputs to predict the change in state . Sequence models are added to the message passing operations to jointly capture the spatio-temporal dependencies. Elements that are currently updated are highlighted in blue, e.g., is an edge embedding updated in the node-to-edge () message passing operation.

Similarly, the edge-to-node spatio-temporal message passing operation is defined as


For each node , the spatial dependencies are aggregated into . together with the current state and the previous hidden states are fed to to generate the future hidden state of nodes that captures the temporal dependencies at node level.

Finally, the predicted future state is defined as


where is an MLP predicting the change in state , and

is a fixed variance. Note that since both

and rely on the historical states , the decoder is implicitly a function of .

Structural Prior Knowledge

As the scale and complexity of systems increase, it becomes more difficult to infer relations solely based on the state sequences of individual agents. Therefore, it is desirable to incorporate possible structural prior knowledge, such as sparsity Kipf et al. (2018), symmetry and node degree distribution Li et al. (2019). Since symmetric relations widely exist in physical dynamical systems, the symmetry of relations is studied as a special case. Li et al. (2019) impose a hard constraint on symmetry, i.e., setting . However, the hard constraint may limit the exploration of the model during the training procedure, and sometimes it may lead to a decrease of the prediction precision. By contrast, this paper imposes a soft constraint by adding a regularization term to the original loss function.

Specifically, an auxiliary distribution is introduced as the “transpose” of the predicted relation distribution , namely,


Then, the Kullback-Leibler divergence between

and is used as a regularization term for symmetry, i.e.,


where is a penalty factor for the symmetric prior.

Notations used in this paper, details of the computation of and the pseudo code of NRI-MPM are shown in Appendix A, Appendix B and Appendix C, respectively.


All methods are tested on three types of simulated physical systems: particles connected by springs, charged particles and phase-coupled oscillators, named Springs, Charged and Kuramoto, respectively. For the Springs and Kuramoto datasets, objects do or do not interact with equal probability. For the Charged datasets, objects attract or repel with equal probability. For each type of system, a 5-object dataset and a 10-object dataset are simulated. All datasets contain 50k training samples, and 10k validation and test samples. Further details on data generation can be found in

Kipf et al. (2018). All methods are evaluated w.r.t. two metrics, the accuracy of relation reconstruction and the mean squared error (MSE) of future state prediction.

For our method, the sequence models used in both encoders and decoders are composed of gated recurrent units (GRUs)

Cho et al. (2014), except that attention mechanisms are added to the decoder in the Kuramoto datasets (see Appendix D). The penalty factor is set to for all datasets except that it is set to 1 and for the 5-object and 10-object Kuramoto datasets, respectively (see Appendix E).


To evaluate the performance of our method, we compare it with several competitive methods as follows.

  • Correlation Kipf et al. (2018): a baseline that predicts the relations between two particles based on the correlation of their state sequences.

  • LSTM Kipf et al. (2018): an LSTM that takes the concatenation of all state vectors and predict all future states simultaneously.

  • NRI Kipf et al. (2018): the neural relational inference model that jointly learns the relations and the dynamics with a VAE.

  • SUGAR Li et al. (2019): a method that introduces structural prior knowledge such as hard symmetric constraint and node degree distribution for relational inference.

  • ModularMeta Alet et al. (2019): a method that solves the relational inference problem via modular meta-learning.

To compare the performance of our method with “gold standard” methods, i.e., those trained given the ground truth relations, this paper introduces two variants as follows.

  • Supervised: a variant of our method that only trains an encoder with the ground truth relations.

  • NRI-MPM (true graph): a variant of our method that only trains a decoder given the ground truth relations.

The codes of NRI222 and ModularMeta333 are public and thus directly used in our experiments. SUAGR is coded by ourselves according to the original paper. Note that Correlation and Supervised are only designed for relation reconstruction, while LSTM and NRI-MPM (true graph) are only designed for state prediction.

To verify the effectiveness of the proposed message passing mechanisms and the structural prior knowledge, this paper introduces some variants of our method as follows.

  • NRI-MPM w/o RI: a variant of our method without the relation interaction mechanism.

  • NRI-MPM w/o intra-RI, NRI-MPM w/o inter-RI: variants of our method without the intra- and inter-edge interaction operations, respectively.

  • NRI-MPM w/o ST: a variant of our method without the spatio-temporal message passing mechanism.

  • NRI-MPM w/o Sym, NRI-MPM w/ hard Sym: variants of our method without the symmetric prior and that with hard symmetric constraint, respectively.

Comparisons with Baselines

Model Springs Charged Kuramoto
5 objects
10 objects
  • The results in these datasets are unavailable in the original paper, and they are obtained by running the codes provided by the authors.

Table 1: Accuracy (%) of relation reconstruction.
Datasets Springs Charged Kuramoto
Predictions steps 1 10 20 1 10 20 1 10 20
LSTM 4.13e-8 2.19e-5 7.02e-4 1.68e-3 6.45e-3 1.49e-2 3.44e-4 1.29e-2 4.74e-2
NRI 3.12e-8 3.29e-6 2.13e-5 1.05e-3 3.21e-3 7.06e-3 1.40e-2 2.01e-2 3.26e-2
SUGAR 3.71e-8 3.86e-6 1.53e-5 1.18e-3 3.43e-3 7.38e-3 2.12e-2 9.45e-2 1.83e-1
ModularMeta 3.13e-8 3.25e-6 - 1.03e-3 3.11e-3 - 2.35e-2 1.10e-1 1.96e-1
NRI-MPM 8.89e-9 5.99e-7 2.52e-6 7.29e-4 2.57e-3 5.41e-3 1.57e-2 2.73e-2 5.36e-2
NRI-MPM (true graph) 1.60e-9 9.06e-9 1.50e-7 8.06e-4 2.51e-3 5.66e-3 1.73e-2 2.49e-2 4.09e-2
  • Results in these datasets are unavailable in the original paper, and they are obtained by running the codes provided by the authors.

Table 2: Mean squared error in predicting future states for simulations with 5 interacting objects.
(a) LSTM
(b) Ground Truth
Figure 5: Visualization of the ground truth states (b) together with the predicted states for LSTM (a) and NRI-MPM (c) in the 5-object Kuramoto dataset.

The results of relation reconstruction and future state prediction in the 5-object datasets (for MSEs in the 10-object datasets, see Appendix E) are shown is Table 1 and Table 2, respectively. Our method significantly outperforms all baselines in nearly all datasets in terms of both accuracy and MSE. In the Springs datasets, the accuracies of NRI, SUGAR, ModularMeta and our method are all comparable with the supervised baseline, while our method achieves significantly lower MSEs in the 5-object systems. This indicates that our method can better learn the relations and dynamics simultaneously. In the Charged datasets, our method outperforms the baselines by 4.9%-9.6% in terms of accuracy. The reason may be that the charged systems are densely connected since a charged particle interacts with all other particles, and our method can better handle this situation. SUGAR achieves higher accuracies than NRI, suggesting the effectiveness of structural prior knowledge. Still, our method performs better with the extra help of the proposed message passing mechanisms.

In the Kuramoto datasets, the performance of our method is non-dominant. Our method achieves higher accuracies than NRI, while the MSEs are larger. ModularMeta achieves higher accuracy than our method in the 10-object system, while the MSEs are much poorer than all other methods. Maybe the interactions among objects in the Kuramoto dataset are relatively weak Kuramoto (1975), making it more difficult to infer the relations based on the observed states, and the situation worsens as the scale of systems increases. ModularMeta infers all relations as a whole, and modifies its predictions with the help of simulated annealing, which may help better search the structural space to achieve better relation predictions. However, this does not translate to better state prediction. According to Kuramoto (1975), Kuramoto is a synchronization system where the object states converge to a certain set of values in the long run. Maybe the interaction is less helpful for state prediction in this case.

Figure 6: Running times of different methods.

Qualitative Analysis of Future State Prediction

One observes that LSTM achieves lower MSEs for short-term prediction in the 5-object Kuramoto dataset, but its performance declines for long-term prediction. To further understand the predictive behavior of LSTM and our method, we conduct a qualitative analysis by visualizing the predicted states together with the ground truth states in 49 steps, and the results are shown in Fig. 5. From Fig. 5(a), LSTM can capture the shape of the sinusoidal waveform but fails to make accurate prediction for time steps larger than 40 (e.g., curves in green and purple). By contrast, as shown in Fig. 5(c), the predicted states of our method closely match the ground truth states except for those in the last few time steps for the third particle, whose curve is colored in green. Maybe our method can better capture the interactions among particles that affect the long-term dynamics. Note that this result is consistent with that reported in Appendix A.1 in the original paper of NRI.

Figure 7: Predicted relations of different methods.

Running Time of Different Methods

The running times of different methods in a single epoch are reported in Fig. 

6. It can be seen that our method requires some more time than NRI, which is a natural consequence of more complex models. The running time of SUGAR is comparable with our method. It is worth noting that ModularMeta requires much more running time than the others and the situation becomes more severe in larger systems. Maybe meta-learning the proposal function is computationally costly in ModularMeta. These results show that our method can achieve better performance with a smaller additional cost of computation.

Ablation Study

Figure 8: MSEs of multi-step prediction.

The ablation studies are conducted in the 10-object Charged dataset, and the results are shown in Table 3.

Metrics Accuracy MSEs
Prediction steps 1 10 20
NRI-MPM w/o RI 8.07e-4 4.13e-3 1.13e-2
NRI-MPM w/o intra-RI 7.83e-4 3.90e-3 1.26e-2
NRI-MPM w/o inter-RI 7.71e-4 4.11e-3 1.13e-2
NRI-MPM w/o ST 1.26e-3 4.92e-3 1.38e-2
NRI-MPM w/o Sym 7.60e-4 3.77e-3 1.06e-2
NRI-MPM w/ hard Sym 7.96e-4 3.83e-3 1.04e-2
NRI-MPM 7.52e-4 3.80e-3 1.03e-2
Table 3: Ablation study in the 10-object Charged dataset.

Effect of Relation Interaction Mechanism

As shown in Table 3, the accuracy drops significantly by removing the relation interacting mechanism, verifying its effectiveness in relation prediction. Besides, the MSEs decrease, indicating that more accurate relation prediction can help with future state prediction. Besides, the contribution of the inter-edge interaction is higher than that of the intra-edge interaction. Intuitively, the intra- and inter-edge operations capture local and global interactions, respectively, and maybe in this dataset, global interactions are more informative than local interactions.

To gain an intuitive understanding of the effect of the relation interaction mechanism, we conduct a case study on the 10-object charged particle systems. The distributions of predicted relations of NRI-MPM, NRI-MPM w/o RI and NRI together with the ground truth relations are visualized in Fig. 7. The two types of relations are highlighted in red and green, respectively, while the diagonal elements are all in white as there no self-loops. It is known that two particles with the same charge repel each other while those with the opposite charge attract each other. Consequently, for any particles , and , the relations and are correlated. As shown in Fig. 7, our method can model the dependencies among all relations much better than NRI, while removing the relation interaction mechanism results in less consistent prediction.

Effect of Spatio-temporal Message Passing Mechanism

As shown in Table 3, the MSEs increase by removing the spatio-temporal message passing mechanism, and the differences are more significant for and , verifying that the spatio-temporal message passing mechanism can alleviate error accumulation in multi-step prediction. Furthermore, the MSEs of different methods that predict 40 steps into the future are shown in Fig. 8. Note that the differences among different methods narrow for large . Maybe it is still challenging for all methods to handle error accumulation in the long run.

Interestingly, the decrease of NRI-MPM w/o ST in terms of accuracy is more significant than that of NRI-MPM w/o RI. Maybe the spatio-temporal message passing mechanism imposes strong dependencies of historical interactions, which indirectly helps with relation reconstruction.

Effect of Symmetric Prior

Figure 9: The accuracy and the rate of asymmetry w.r.t. .

As shown in Table 3, without the symmetric prior, the accuracy of NRI-MPM decreases by 2.3%, while the MSEs are on par with the original model, indicating that the symmetric prior can help with relation reconstruction without hurting the precision of future state prediction. Compared with NRI-MPM w/ hard Sym, our method achieves higher accuracy with lower MSEs. Maybe the hard constraint of symmetry limits the exploration of the model in the training procedure, while a soft constraint provides more flexibility.

The effect of the penalty factor is shown in Fig. 9. The rate of asymmetry decreases significantly as increases, while the accuracy increases steadily and peaks around , verifying that adjusting can control the effect of the symmetric prior and reasonable values will benefit relation reconstruction.

Related Work

This paper is part of an emerging direction of research attempting to model the explicit interactions in dynamical systems using neural networks as in NRI Kipf et al. (2018).

Most closely related are the papers of Alet et al. (2019), Li et al. (2019), Webb et al. (2019) and Zhang et al. (2019). Alet et al. (2019) frame this problem as a modular meta-learning problem to jointly infer the relations and use the data more effectively. To deal with more complex systems, Li et al. (2019) incorporate various structural prior knowledge as a complement to the observed states of agents. Webb et al. (2019) extend NRI to multiplex interaction graphs. Zhang et al. (2019) explore relational inference over a wider class of dynamical systems, such as discrete systems and chaotic systems, assuming a shared interaction graph for all state sequences, which is different from the experimental settings of Kipf et al. (2018) and Alet et al. (2019). Compared with this line of work, our method focuses on introducing efficient message passing mechanisms to enrich the representative power of NRI.

Many recent works seek to extend the message passing mechanisms of GNNs. Zhu et al. (2020) define a bilinear aggregator to incorporate the possible interactions among all neighbors of a given node. Brockschmidt (2020) defines node aware transformations over messages to impose feature-wise modulation. Nevertheless, theses variants treat the messages as a transient part of the computation of the node embeddings, our relation interaction mechanism is aimed at learning edge embeddings. Herzig et al. (2019) use non-local operations to capture the interactions among all relations, which requires quadratic time complexity.

Besides, many works extend GNNs to handle structured time series. Graph convolutional networks with CNNs Yu et al. (2018), GRUs Li et al. (2018), or attention mechanisms Guo et al. (2019) are introduced to deal with spatial and temporal dependencies separately for traffic flow forecasting. Furthermore, spatio-temporal graphs Song et al. (2020) and spatio-temporal attention mechanisms Zheng et al. (2020) are proposed to capture complex spatio-temporal correlations. Our methods borrow these ideas from traffic flow forecasting to define spatio-temporal message passing operations for neural relational inference.


This paper introduces efficient message passing mechanisms with structural prior knowledge for neural relational inference. The relation interaction mechanism can effectively capture the coexistence of all relations and help make a joint prediction. By incorporating historical information, the spatio-temporal message passing mechanism can effectively alleviate error accumulation in multi-step state prediction. Additionally, the structural prior knowledge, symmetry as a special case, can promote the accuracy of relation reconstruction in more complex systems. The results of extensive experiments on simulated physics systems validate the effectiveness of our method.

Currently, only simple yet effective implementations using GRUs and attention mechanisms are adopted for the sequence models. Future work includes introducing more advanced models like the Transformers to further improve the performance. Besides, current experiments are conducted on simulated systems over static and homogeneous graphs. Future work includes extending our method to systems over dynamic and heterogeneous graphs. Furthermore, the proposed method will be applied to study the mechanism of the emergence of intelligence in different complex systems, including multi-agent systems, swarm systems, physical systems and social systems.


This work is supported by the National Key R&D Program of China (2018AAA0101203), and the National Natural Science Foundation of China (62072483, 61673403, U1611262). This work is also supported by MindSpore.


  • F. Alet, E. Weng, T. Lozano-Pérez, and L. P. Kaelbling (2019) Neural relational inference with fast modular meta-learning. In NeurIPS, pp. 11827–11838. Cited by: Introduction, 5th item, Related Work.
  • A. Almaatouq, A. Noriega-Campero, A. Alotaibi, P. M. Krafft, M. Moussaid, and A. Pentland (2020) Adaptive social networks promote the wisdom of crowds. PNAS 117 (21), pp. 11379–11386. Cited by: Introduction.
  • V. Bapst, T. Keck, A. Grabska-Barwińska, C. Donner, E. D. Cubuk, S. S. Schoenholz, A. Obika, A. W. R. Nelson, T. Back, D. Hassabis, and P. Kohli (2020) Unveiling the predictive power of static structure in glassy systems. Nature Physics 16 (4), pp. 448–454. Cited by: Introduction.
  • P. Battaglia, R. Pascanu, M. Lai, D. J. Rezende, and K. kavukcuoglu (2016) Interaction networks for learning about objects, relations and physics. In NeurIPS, pp. 4509–4517. Cited by: Spatio-Temporal Message Passing Mechanism.
  • M. Brockschmidt (2020) GNN-film: graph neural networks with feature-wise linear modulation. In ICML, Cited by: Related Work.
  • K. Cho, B. Van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio (2014) Learning phrase representations using rnn encoder–decoder for statistical machine translation. In EMNLP, pp. 1724–1734. Cited by: Experiments.
  • J. Gilmer, S. S. Schoenholz, P. F. Riley, O. Vinyals, and G. E. Dahl (2017) Neural message passing for quantum chemistry. In ICML, pp. 1263–1272. Cited by: Message Passing Mechanisms in GNNs.
  • S. Guo, Y. Lin, N. Feng, C. Song, and H. Wan (2019) Attention based spatial-temporal graph convolutional networks for traffic flow forecasting. In AAAI, pp. 922–929. Cited by: Spatio-Temporal Message Passing Mechanism, Related Work.
  • S. Ha and H. Jeong (2020)

    Towards automated statistical physics: data-driven modeling of complex systems with deep learning

    arXiv preprint arXiv:2001.02539. Cited by: Introduction.
  • W. Hamilton, Z. Ying, and J. Leskovec (2017) Inductive representation learning on large graphs. In NeurIPS, pp. 1024–1034. Cited by: Relation Interaction Mechanism.
  • R. Herzig, E. Levi, H. Xu, H. Gao, E. Brosh, X. Wang, A. Globerson, and T. Darrell (2019) Spatio-temporal action graph networks. In ICCVW, Cited by: Related Work.
  • Y. Hoshen (2017) VAIN: attentional multi-agent predictive modeling. In NeurIPS, pp. 2701–2711. Cited by: Introduction.
  • D. P. Kingma and M. Welling (2014) Auto-encoding variational bayes. In ICLR, Cited by: Introduction.
  • T. Kipf, E. Fetaya, K. Wang, M. Welling, and R. Zemel (2018) Neural relational inference for interacting systems. In ICML, pp. 2688–2697. Cited by: Introduction, Neural Relational Inference (NRI), Structural Prior Knowledge, 1st item, 2nd item, 3rd item, Experiments, Related Work, Related Work.
  • Y. Kuramoto (1975) Self-entrainment of a population of coupled non-linear oscillators. In International Symposium on Mathematical Problems in Theoretical Physics, pp. 420–422. Cited by: Comparisons with Baselines.
  • J. Lee, Y. Lee, J. Kim, A. Kosiorek, S. Choi, and Y. W. Teh (2019) Set transformer: a framework for attention-based permutation-invariant neural networks. In ICML, pp. 3744–3753. Cited by: Relation Interaction Mechanism.
  • J. Li, F. Yang, M. Tomizuka, and C. Choi (2020) EvolveGraph: multi-agent trajectory prediction with dynamic relational reasoning. In NeurIPS, Cited by: Introduction.
  • Y. Li, C. Meng, C. Shahabi, and Y. Liu (2019) Structure-informed graph auto-encoder for relational inference and simulation. In ICML Workshop on Learning and Reasoning with Graph-Structured Representations, Cited by: Structural Prior Knowledge, 4th item, Related Work.
  • Y. Li, R. Yu, C. Shahabi, and Y. Liu (2018) Diffusion convolutional recurrent neural network: data-driven traffic forecasting. In ICLR, Cited by: Spatio-Temporal Message Passing Mechanism, Related Work.
  • C. J. Maddison, A. Mnih, and Y. W. Teh (2017)

    The concrete distribution: a continuous relaxation of discrete random variables

    In ICLR, Cited by: Neural Relational Inference (NRI).
  • M. Oliveira, D. Pinheiro, M. Macedo, C. Bastos-Filho, and R. Menezes (2020) Uncovering the social interaction network in swarm intelligence algorithms. Applied Network Science 5 (1), pp. 1–20. Cited by: Introduction.
  • A. Sanchez-Gonzalez, N. Heess, J. T. Springenberg, J. Merel, M. Riedmiller, R. Hadsell, and P. Battaglia (2018) Graph networks as learnable physics engines for inference and control. In ICML, pp. 4467–4476. Cited by: Spatio-Temporal Message Passing Mechanism.
  • C. Song, Y. Lin, S. Guo, and H. Wan (2020) Spatial-temporal synchronous graph convolutional networks: a new framework for spatial-temporal network data forecasting. AAAI, pp. 914–921. Cited by: Related Work.
  • S. van Steenkiste, M. Chang, K. Greff, and J. Schmidhuber (2018)

    Relational neural expectation maximization: unsupervised discovery of objects and their interactions

    In ICLR, Cited by: Introduction.
  • N. Watters, D. Zoran, T. Weber, P. Battaglia, R. Pascanu, and A. Tacchetti (2017) Visual interaction networks: learning a physics simulator from video. In NeurIPS, pp. 4539–4547. Cited by: Introduction.
  • E. Webb, B. Day, H. Andres-Terre, and P. Lió (2019) Factorised neural relational inference for multi-interaction systems. ICML Workshop on Learning and Reasoning with Graph-Structured Representations. Cited by: Related Work.
  • K. Xu, W. Hu, J. Leskovec, and S. Jegelka (2019) How powerful are graph neural networks. In ICLR, Cited by: Relation Interaction Mechanism.
  • Y. Yang, L. Yu, Y. Bai, Y. Wen, W. Zhang, and J. Wang (2018)

    A study of ai population dynamics with million-agent reinforcement learning

    In AAMAS, pp. 2133–2135. Cited by: Introduction.
  • B. Yu, H. Yin, and Z. Zhu (2018) Spatio-temporal graph convolutional networks: a deep learning framework for traffic forecasting. In IJCAI, pp. 3634–3640. Cited by: Related Work.
  • J. Zhang, W. Wang, F. Xia, Y. Lin, and H. Tong (2020) Data-driven computational social science: a survey. Big Data Research, pp. 100145. Cited by: Introduction.
  • Z. Zhang, Y. Zhao, J. Liu, S. Wang, R. Tao, R. Xin, and J. Zhang (2019) A general deep learning framework for network reconstruction and dynamics learning. Applied Network Science 4 (1), pp. 1–17. Cited by: Related Work.
  • C. Zheng, X. Fan, C. Wang, and J. Qi (2020) GMAN: a graph multi-attention network for traffic prediction. In AAAI, pp. 1234–1241. Cited by: Related Work.
  • H. Zhu, F. Feng, X. He, X. Wang, Y. Li, K. Zheng, and Y. Zhang (2020) Bilinear graph neural network with neighbor interactions. In IJCAI, pp. 1452–1458. Cited by: Related Work.