NRI-MPM
Code for Neural Relational Inference with Efficient Message Passing Mechanisms (AAAI 2021).
view repo
Many complex processes can be viewed as dynamical systems of interacting agents. In many cases, only the state sequences of individual agents are observed, while the interacting relations and the dynamical rules are unknown. The neural relational inference (NRI) model adopts graph neural networks that pass messages over a latent graph to jointly learn the relations and the dynamics based on the observed data. However, NRI infers the relations independently and suffers from error accumulation in multi-step prediction at dynamics learning procedure. Besides, relation reconstruction without prior knowledge becomes more difficult in more complex systems. This paper introduces efficient message passing mechanisms to the graph neural networks with structural prior knowledge to address these problems. A relation interaction mechanism is proposed to capture the coexistence of all relations, and a spatio-temporal message passing mechanism is proposed to use historical information to alleviate error accumulation. Additionally, the structural prior knowledge, symmetry as a special case, is introduced for better relation prediction in more complex systems. The experimental results on simulated physics systems show that the proposed method outperforms existing state-of-the-art methods.
READ FULL TEXT VIEW PDF
Interacting systems are prevalent in nature, from dynamical systems in
p...
read it
Graph neural networks (GNNs) emerged recently as a standard toolkit for
...
read it
We propose Dynamically Pruned Message Passing Networks (DPMPN) for
large...
read it
This paper describes an end-to-end solution for the relationship predict...
read it
Most online message threads inherently will be cluttered and any new use...
read it
Neural Processes (NPs) are powerful and flexible models able to incorpor...
read it
Visual dialog is a challenging task that requires the comprehension of t...
read it
Code for Neural Relational Inference with Efficient Message Passing Mechanisms (AAAI 2021).
Many complex processes in natural and social areas including multi-agent systems Yang et al. (2018); Li et al. (2020), swarm systems Oliveira et al. (2020), physical systems Ha and Jeong (2020); Bapst et al. (2020) and social systems Almaatouq et al. (2020); Zhang et al. (2020) can be viewed as dynamical systems of interacting agents. Revealing the underlying interactions and dynamics can help us understand, predict, and control the behavior of these systems. However, in many cases, only the state sequences of individual agents are observed, while the interacting relations and the dynamical rules are unknown.
A lot of works Hoshen (2017); van Steenkiste et al. (2018); Watters et al. (2017) use implicit interaction models to learn the dynamics. These models can be regarded as graph neural networks (GNNs) over a fully-connected graph, and the implicit interactions are modeled via message passing operations Watters et al. (2017) or attention mechanisms Hoshen (2017). Compared with modeling implicit interactions, modeling the explicit interactions offers a more interpretable way to understand the dynamical systems. A motivating example is shown in Fig. 1 Kipf et al. (2018), where a dynamical system consists of 5 particles linked by invisible springs. It is of interest to infer the relations among the particles and predict their future states. Kipf et al. (2018) propose the neural relational inference (NRI) model, a variational auto-encoder (VAE) Kingma and Welling (2014), to jointly learn explicit interactions and dynamical rules in an unsupervised manner.
Currently, there are three main limitations of NRI. First, the interacting relations are inferred independently, while the coexistence of these relations is not considered. Alet et al. (2019) tackle this problem by taking all relations as a whole and iteratively improving the prediction through modular meta-learning. However, this method is computationally costly and it is limited to small-scale systems. Second, NRI predicts multiple steps into the future to emphasize the effect of the interactions, which leads to error accumulation and prevents the model from reconstructing the dynamics accurately. Third, as the scale and the complexity of systems increase, it is difficult to infer the interactions solely based on the observed data, while incorporating some structural prior knowledge can help better explore the structural space and promote the precision of relation recovery.
To address the problems above, this paper introduces efficient message passing mechanisms with structural prior knowledge for neural relational inference. For relation reconstruction, a relation interaction mechanism is introduced to capture the dependencies among different relations. For dynamics reconstruction, a spatio-temporal message passing mechanism is introduced to utilize historical information to alleviate the error accumulation of multi-step prediction. Additionally, the prior knowledge about the relations is incorporated as a regularization term in the loss function to impose soft constraints, and as a special case, the symmetry of relations is taken into consideration.
The contributions of this work are summarized as follows.
Efficient message passing mechanisms are introduced to make a joint prediction of relations and alleviate the error accumulation of multi-step state prediction.
The prior knowledge about relations, symmetry as a special case, is introduced to better reconstruct relations on more complex dynamical systems.
Extensive experiments on physics simulation datasets are conducted. The results show the superiority of our method by comparing it with the state-of-the-art methods.
NRI Kipf et al. (2018) is an unsupervised model that infers interacting structure from observational data and learns system dynamics. NRI takes the form of VAE, where the encoder infers the interacting relations, and the decoder predicts the future states of individual agents.
Specifically, the encoder adopts GNNs with multiple rounds of message passing, and infers the distribution of potential interactions based on the input state,
(1) | ||||
(2) |
where is the observed trajectories of objects in the system in time steps, is a GNN acting on the fully-connected graph, and is the factorized distribution of edge type .
Since directly sampling edges from , a discrete distribution, is a non-differentiable process, the back-propagation cannot be used. To solve this problem, NRI uses the Gumbel-Softmax trick Maddison et al. (2017), which simulates a differentiable sampling process from a discrete distribution using a continuous function, i.e.,
(3) |
where represents the edge type between the nodes and ,
is a vector of i.i.d. samples from a
distribution and the temperature is a parameter controlling the “smoothness” of the samples.According to the inferred relations and past state sequences, the decoder uses another GNN that models the effect of interactions to predict the future state,
(4) |
where is the conditional likelihood of .
As a variational auto-encoder model, NRI is trained to maximize the evidence lower bound,
(5) |
where the prior
is a factorised uniform distribution over edge types. In the right-hand side of Eq. (
5), the first term is the expected reconstruction error, while the second term encourages the approximate posterior to approach the prior .GNNs are a widely used class of neural networks that operates on graph structured data by message passing mechanisms Gilmer et al. (2017). For a graph with vertices and edges , the node-to-edge () and edge-to-node () message passing operations are defined as follows,
(6) | ||||||
(7) |
where is the embedding of node in the -th layer of the GNNs, is the feature of edge (e.g. edge type), and is the embedding of edge in the -th layer of the GNNs. denotes the set of indices of neighboring nodes with an incoming edge connected to vertex , and denotes the concatenation of vectors. and
are node-specific and edge-specific neural networks, such as multilayer perceptrons (MLP), respectively. Node embeddings are mapped to edge embeddings in Eq. (
6) and vice versa in Eq. (7).The structure of our method is shown in Fig. 2. Our method follows the framework of VAE in NRI. In the encoder, a relation interaction mechanism is used to capture the dependencies among the latent edges for a joint relation prediction. In the decoder, a spatio-temporal message passing mechanism is used to incorporate historical information for alleviating the error accumulation in multi-step prediction. As both mechanisms mentioned above can be regarded as Message Passing Mechanisms in GNNs for Neural Relational Inference, our method is named as NRI-MPM. Additionally, the structural prior knowledge, symmetry as a special case, is incorporated as a regularization term in the loss function to improve relation prediction in more complex systems.
The encoder is aimed at inferring the edge types based on the observed state sequences . From the perspective of message passing, the encoder defines a node-to-edge message passing operation at a high level. As shown in Eqs. (8)-(11), NRI first maps the observed state to a latent vector , and then applies two rounds of node-to-edge and one round of edge-to-node message passing alternately to obtain the edge embeddings that integrate both local and global information. Then, are used to predict the pairwise interacting types independently.
(8) | |||||
(9) | |||||
(10) | |||||
(11) |
However, the relations are generally dependent on each other since they jointly affect the future states of individual agents. Within the original formulations of GNNs, Eqs. (8)-(11) cannot effectively modeled the dependencies among edges. Typical GNNs are designed to learn node embeddings, while the edge embeddings are treated as a transient part of the computation. To capture the coexistence of all relations, this paper introduces an edge-to-edge () message passing operation that directly passes messages among edges, named the relation interaction mechanism,
(12) |
Ideally, this operation includes modeling the pairwise dependencies among all edges, which is computationally costly as its time complexity is . Alternatively, as shown in Fig. 3, our method decomposes this operation into two sub-operations, intra-edge interaction and inter-edge interaction operations, for modeling the interactions among incoming edges of the same node and those among the incoming edges of different nodes, respectively.
The intra-edge interaction operation is defined as follows,
(13) |
From the above definition, is required to be permutation equivariant Lee et al. (2019) to preserve the correspondences between the input and output edge embeddings. Formally, let be the ascending sequence of and be the set of all permutations on . For any permutation , . is said to be permutation equivariant if , where denotes the composition of functions. Although permutation invariant operators such as sum and max aggregators are widely used in GNNs Xu et al. (2019), the design for permutation equivariant operators is less explored. Inspired by Hamilton et al. (2017), this paper treats a set as an unordered sequence and defines as a sequence model over it, appending by an inverse permutation to restore the order, i.e.,
(14) |
where
is a sequence model such as recurrent neural networks (RNNs) and convolutional neural networks (CNNs). When
is implemented by self-attention, is permutation equivariant, which may not hold when is an RNN. In this case, is expected to be approximately permutation equivariant with a well trained .Similarly, one can define an inter-edge interaction operation. The difference is that this operation treats all incoming edges of a node as a whole. For simplicity, this paper applies mean-pooling to the edge embeddings to get an overall representation. Formally, the inter-edge interaction operation is defined to be the composition of the following steps,
(15) | |||||
(16) |
where is a node-to-node () operation that passes messages among nodes and takes a similar form of . To analyze the time complexity of relation interaction, this paper assumes that all sequence models involved are RNNs. Since and Mean can be applied to each node in parallel and the time complexity of Eqs. (14)-(16) are all , the overall time complexity of relation interaction is , which is much more effective than the pairwise interaction.
Finally, an MLP is used to unify the results of the two operations, and the predicted edge distribution is defined as
(17) | ||||
(18) |
As in Eq. (3), is sampled via the Gumbel-Softmax trick to allow back-propagation.
The decoder is aimed at predicting the future state using the inferred relations and the historical states . Since the interactions can have a small effect on short-term dynamics, NRI predicts multiple steps into the future to avoid a degenerate decoder that ignores the effect of interactions. In multi-step prediction, the predicted value is replaced by the ground truth state , i.e.,
(19) | ||||
(20) |
where denotes the decoder, and is the number of time steps to predict. This means that errors in the prediction process accumulate over steps. As the decoder of NRI only uses the current state to predict future states, it is difficult to avoid error accumulation. Apparently, the current state of the interacting system is related to the previous states, and thus, incorporating historical information can help learn the dynamical rules and alleviate the error accumulation in multi-step prediction. To this end, sequence models are used to capture the non-linear correlation between previous states and the current state. GNNs combined with sequence models can be used to capture spatio-temporal information, which has been widely used in traffic flow forecasting Li et al. (2018); Guo et al. (2019). In this paper, the sequence model can be composed of one or more of CNNs, RNNs, attention mechanisms, etc.
Inspired by interaction networks Battaglia et al. (2016) and graph networks Sanchez-Gonzalez et al. (2018), sequence models are added to the original node-to-edge and edge-to-node operations to obtain their spatio-temporal versions. In this way, the decoder can integrate the spatio-temporal interaction information at different fine-grained levels, i.e., the node level and the edge level, which can help better learn the dynamical rules.
As shown in Fig. 4, the decoder contains a node-to-edge and an edge-to-node spatio-temporal message passing operations. The node-to-edge spatio-temporal message passing operation is defined as
(21) | ||||
(22) |
where is the -th element of the vector , representing the edge type , whose effect is modeled by an MLP . For each edge , the effect of all potential interactions are aggregated into as a weighted sum. Its concatenation with the previous hidden states is fed to to generate the future hidden state of interactions that captures the temporal dependencies at edge level.
Similarly, the edge-to-node spatio-temporal message passing operation is defined as
(23) | ||||
(24) |
For each node , the spatial dependencies are aggregated into . together with the current state and the previous hidden states are fed to to generate the future hidden state of nodes that captures the temporal dependencies at node level.
Finally, the predicted future state is defined as
(25) | ||||
(26) | ||||
(27) |
where is an MLP predicting the change in state , and
is a fixed variance. Note that since both
and rely on the historical states , the decoder is implicitly a function of .As the scale and complexity of systems increase, it becomes more difficult to infer relations solely based on the state sequences of individual agents. Therefore, it is desirable to incorporate possible structural prior knowledge, such as sparsity Kipf et al. (2018), symmetry and node degree distribution Li et al. (2019). Since symmetric relations widely exist in physical dynamical systems, the symmetry of relations is studied as a special case. Li et al. (2019) impose a hard constraint on symmetry, i.e., setting . However, the hard constraint may limit the exploration of the model during the training procedure, and sometimes it may lead to a decrease of the prediction precision. By contrast, this paper imposes a soft constraint by adding a regularization term to the original loss function.
Specifically, an auxiliary distribution is introduced as the “transpose” of the predicted relation distribution , namely,
(28) |
Then, the Kullback-Leibler divergence between
and is used as a regularization term for symmetry, i.e.,(29) |
where is a penalty factor for the symmetric prior.
Notations used in this paper, details of the computation of and the pseudo code of NRI-MPM are shown in Appendix A, Appendix B and Appendix C, respectively.
All methods are tested on three types of simulated physical systems: particles connected by springs, charged particles and phase-coupled oscillators, named Springs, Charged and Kuramoto, respectively. For the Springs and Kuramoto datasets, objects do or do not interact with equal probability. For the Charged datasets, objects attract or repel with equal probability. For each type of system, a 5-object dataset and a 10-object dataset are simulated. All datasets contain 50k training samples, and 10k validation and test samples. Further details on data generation can be found in
Kipf et al. (2018). All methods are evaluated w.r.t. two metrics, the accuracy of relation reconstruction and the mean squared error (MSE) of future state prediction.For our method, the sequence models used in both encoders and decoders are composed of gated recurrent units (GRUs)
Cho et al. (2014), except that attention mechanisms are added to the decoder in the Kuramoto datasets (see Appendix D). The penalty factor is set to for all datasets except that it is set to 1 and for the 5-object and 10-object Kuramoto datasets, respectively (see Appendix E).To evaluate the performance of our method, we compare it with several competitive methods as follows.
Correlation Kipf et al. (2018): a baseline that predicts the relations between two particles based on the correlation of their state sequences.
LSTM Kipf et al. (2018): an LSTM that takes the concatenation of all state vectors and predict all future states simultaneously.
NRI Kipf et al. (2018): the neural relational inference model that jointly learns the relations and the dynamics with a VAE.
SUGAR Li et al. (2019): a method that introduces structural prior knowledge such as hard symmetric constraint and node degree distribution for relational inference.
ModularMeta Alet et al. (2019): a method that solves the relational inference problem via modular meta-learning.
To compare the performance of our method with “gold standard” methods, i.e., those trained given the ground truth relations, this paper introduces two variants as follows.
Supervised: a variant of our method that only trains an encoder with the ground truth relations.
NRI-MPM (true graph): a variant of our method that only trains a decoder given the ground truth relations.
The codes of NRI^{2}^{2}2https://github.com/ethanfetaya/nri and ModularMeta^{3}^{3}3https://github.com/FerranAlet/modular-metalearning are public and thus directly used in our experiments. SUAGR is coded by ourselves according to the original paper. Note that Correlation and Supervised are only designed for relation reconstruction, while LSTM and NRI-MPM (true graph) are only designed for state prediction.
To verify the effectiveness of the proposed message passing mechanisms and the structural prior knowledge, this paper introduces some variants of our method as follows.
NRI-MPM w/o RI: a variant of our method without the relation interaction mechanism.
NRI-MPM w/o intra-RI, NRI-MPM w/o inter-RI: variants of our method without the intra- and inter-edge interaction operations, respectively.
NRI-MPM w/o ST: a variant of our method without the spatio-temporal message passing mechanism.
NRI-MPM w/o Sym, NRI-MPM w/ hard Sym: variants of our method without the symmetric prior and that with hard symmetric constraint, respectively.
Model | Springs | Charged | Kuramoto |
---|---|---|---|
5 objects | |||
Correlation | |||
NRI | |||
SUGAR | |||
ModularMeta | |||
NRI-MPM | |||
Supervised | |||
10 objects | |||
Correlation | |||
NRI | |||
SUGAR | |||
ModularMeta | |||
NRI-MPM | |||
Supervised |
The results in these datasets are unavailable in the original paper, and they are obtained by running the codes provided by the authors.
Datasets | Springs | Charged | Kuramoto | ||||||
---|---|---|---|---|---|---|---|---|---|
Predictions steps | 1 | 10 | 20 | 1 | 10 | 20 | 1 | 10 | 20 |
LSTM | 4.13e-8 | 2.19e-5 | 7.02e-4 | 1.68e-3 | 6.45e-3 | 1.49e-2 | 3.44e-4 | 1.29e-2 | 4.74e-2 |
NRI | 3.12e-8 | 3.29e-6 | 2.13e-5 | 1.05e-3 | 3.21e-3 | 7.06e-3 | 1.40e-2 | 2.01e-2 | 3.26e-2 |
SUGAR | 3.71e-8 | 3.86e-6 | 1.53e-5 | 1.18e-3 | 3.43e-3 | 7.38e-3 | 2.12e-2 | 9.45e-2 | 1.83e-1 |
ModularMeta | 3.13e-8 | 3.25e-6 | - | 1.03e-3 | 3.11e-3 | - | 2.35e-2 | 1.10e-1 | 1.96e-1 |
NRI-MPM | 8.89e-9 | 5.99e-7 | 2.52e-6 | 7.29e-4 | 2.57e-3 | 5.41e-3 | 1.57e-2 | 2.73e-2 | 5.36e-2 |
NRI-MPM (true graph) | 1.60e-9 | 9.06e-9 | 1.50e-7 | 8.06e-4 | 2.51e-3 | 5.66e-3 | 1.73e-2 | 2.49e-2 | 4.09e-2 |
Results in these datasets are unavailable in the original paper, and they are obtained by running the codes provided by the authors.
The results of relation reconstruction and future state prediction in the 5-object datasets (for MSEs in the 10-object datasets, see Appendix E) are shown is Table 1 and Table 2, respectively. Our method significantly outperforms all baselines in nearly all datasets in terms of both accuracy and MSE. In the Springs datasets, the accuracies of NRI, SUGAR, ModularMeta and our method are all comparable with the supervised baseline, while our method achieves significantly lower MSEs in the 5-object systems. This indicates that our method can better learn the relations and dynamics simultaneously. In the Charged datasets, our method outperforms the baselines by 4.9%-9.6% in terms of accuracy. The reason may be that the charged systems are densely connected since a charged particle interacts with all other particles, and our method can better handle this situation. SUGAR achieves higher accuracies than NRI, suggesting the effectiveness of structural prior knowledge. Still, our method performs better with the extra help of the proposed message passing mechanisms.
In the Kuramoto datasets, the performance of our method is non-dominant. Our method achieves higher accuracies than NRI, while the MSEs are larger. ModularMeta achieves higher accuracy than our method in the 10-object system, while the MSEs are much poorer than all other methods. Maybe the interactions among objects in the Kuramoto dataset are relatively weak Kuramoto (1975), making it more difficult to infer the relations based on the observed states, and the situation worsens as the scale of systems increases. ModularMeta infers all relations as a whole, and modifies its predictions with the help of simulated annealing, which may help better search the structural space to achieve better relation predictions. However, this does not translate to better state prediction. According to Kuramoto (1975), Kuramoto is a synchronization system where the object states converge to a certain set of values in the long run. Maybe the interaction is less helpful for state prediction in this case.
One observes that LSTM achieves lower MSEs for short-term prediction in the 5-object Kuramoto dataset, but its performance declines for long-term prediction. To further understand the predictive behavior of LSTM and our method, we conduct a qualitative analysis by visualizing the predicted states together with the ground truth states in 49 steps, and the results are shown in Fig. 5. From Fig. 5(a), LSTM can capture the shape of the sinusoidal waveform but fails to make accurate prediction for time steps larger than 40 (e.g., curves in green and purple). By contrast, as shown in Fig. 5(c), the predicted states of our method closely match the ground truth states except for those in the last few time steps for the third particle, whose curve is colored in green. Maybe our method can better capture the interactions among particles that affect the long-term dynamics. Note that this result is consistent with that reported in Appendix A.1 in the original paper of NRI.
The running times of different methods in a single epoch are reported in Fig.
6. It can be seen that our method requires some more time than NRI, which is a natural consequence of more complex models. The running time of SUGAR is comparable with our method. It is worth noting that ModularMeta requires much more running time than the others and the situation becomes more severe in larger systems. Maybe meta-learning the proposal function is computationally costly in ModularMeta. These results show that our method can achieve better performance with a smaller additional cost of computation.The ablation studies are conducted in the 10-object Charged dataset, and the results are shown in Table 3.
Metrics | Accuracy | MSEs | ||
---|---|---|---|---|
Prediction steps | 1 | 10 | 20 | |
NRI-MPM w/o RI | 8.07e-4 | 4.13e-3 | 1.13e-2 | |
NRI-MPM w/o intra-RI | 7.83e-4 | 3.90e-3 | 1.26e-2 | |
NRI-MPM w/o inter-RI | 7.71e-4 | 4.11e-3 | 1.13e-2 | |
NRI-MPM w/o ST | 1.26e-3 | 4.92e-3 | 1.38e-2 | |
NRI-MPM w/o Sym | 7.60e-4 | 3.77e-3 | 1.06e-2 | |
NRI-MPM w/ hard Sym | 7.96e-4 | 3.83e-3 | 1.04e-2 | |
NRI-MPM | 7.52e-4 | 3.80e-3 | 1.03e-2 |
As shown in Table 3, the accuracy drops significantly by removing the relation interacting mechanism, verifying its effectiveness in relation prediction. Besides, the MSEs decrease, indicating that more accurate relation prediction can help with future state prediction. Besides, the contribution of the inter-edge interaction is higher than that of the intra-edge interaction. Intuitively, the intra- and inter-edge operations capture local and global interactions, respectively, and maybe in this dataset, global interactions are more informative than local interactions.
To gain an intuitive understanding of the effect of the relation interaction mechanism, we conduct a case study on the 10-object charged particle systems. The distributions of predicted relations of NRI-MPM, NRI-MPM w/o RI and NRI together with the ground truth relations are visualized in Fig. 7. The two types of relations are highlighted in red and green, respectively, while the diagonal elements are all in white as there no self-loops. It is known that two particles with the same charge repel each other while those with the opposite charge attract each other. Consequently, for any particles , and , the relations and are correlated. As shown in Fig. 7, our method can model the dependencies among all relations much better than NRI, while removing the relation interaction mechanism results in less consistent prediction.
As shown in Table 3, the MSEs increase by removing the spatio-temporal message passing mechanism, and the differences are more significant for and , verifying that the spatio-temporal message passing mechanism can alleviate error accumulation in multi-step prediction. Furthermore, the MSEs of different methods that predict 40 steps into the future are shown in Fig. 8. Note that the differences among different methods narrow for large . Maybe it is still challenging for all methods to handle error accumulation in the long run.
Interestingly, the decrease of NRI-MPM w/o ST in terms of accuracy is more significant than that of NRI-MPM w/o RI. Maybe the spatio-temporal message passing mechanism imposes strong dependencies of historical interactions, which indirectly helps with relation reconstruction.
As shown in Table 3, without the symmetric prior, the accuracy of NRI-MPM decreases by 2.3%, while the MSEs are on par with the original model, indicating that the symmetric prior can help with relation reconstruction without hurting the precision of future state prediction. Compared with NRI-MPM w/ hard Sym, our method achieves higher accuracy with lower MSEs. Maybe the hard constraint of symmetry limits the exploration of the model in the training procedure, while a soft constraint provides more flexibility.
The effect of the penalty factor is shown in Fig. 9. The rate of asymmetry decreases significantly as increases, while the accuracy increases steadily and peaks around , verifying that adjusting can control the effect of the symmetric prior and reasonable values will benefit relation reconstruction.
This paper is part of an emerging direction of research attempting to model the explicit interactions in dynamical systems using neural networks as in NRI Kipf et al. (2018).
Most closely related are the papers of Alet et al. (2019), Li et al. (2019), Webb et al. (2019) and Zhang et al. (2019). Alet et al. (2019) frame this problem as a modular meta-learning problem to jointly infer the relations and use the data more effectively. To deal with more complex systems, Li et al. (2019) incorporate various structural prior knowledge as a complement to the observed states of agents. Webb et al. (2019) extend NRI to multiplex interaction graphs. Zhang et al. (2019) explore relational inference over a wider class of dynamical systems, such as discrete systems and chaotic systems, assuming a shared interaction graph for all state sequences, which is different from the experimental settings of Kipf et al. (2018) and Alet et al. (2019). Compared with this line of work, our method focuses on introducing efficient message passing mechanisms to enrich the representative power of NRI.
Many recent works seek to extend the message passing mechanisms of GNNs. Zhu et al. (2020) define a bilinear aggregator to incorporate the possible interactions among all neighbors of a given node. Brockschmidt (2020) defines node aware transformations over messages to impose feature-wise modulation. Nevertheless, theses variants treat the messages as a transient part of the computation of the node embeddings, our relation interaction mechanism is aimed at learning edge embeddings. Herzig et al. (2019) use non-local operations to capture the interactions among all relations, which requires quadratic time complexity.
Besides, many works extend GNNs to handle structured time series. Graph convolutional networks with CNNs Yu et al. (2018), GRUs Li et al. (2018), or attention mechanisms Guo et al. (2019) are introduced to deal with spatial and temporal dependencies separately for traffic flow forecasting. Furthermore, spatio-temporal graphs Song et al. (2020) and spatio-temporal attention mechanisms Zheng et al. (2020) are proposed to capture complex spatio-temporal correlations. Our methods borrow these ideas from traffic flow forecasting to define spatio-temporal message passing operations for neural relational inference.
This paper introduces efficient message passing mechanisms with structural prior knowledge for neural relational inference. The relation interaction mechanism can effectively capture the coexistence of all relations and help make a joint prediction. By incorporating historical information, the spatio-temporal message passing mechanism can effectively alleviate error accumulation in multi-step state prediction. Additionally, the structural prior knowledge, symmetry as a special case, can promote the accuracy of relation reconstruction in more complex systems. The results of extensive experiments on simulated physics systems validate the effectiveness of our method.
Currently, only simple yet effective implementations using GRUs and attention mechanisms are adopted for the sequence models. Future work includes introducing more advanced models like the Transformers to further improve the performance. Besides, current experiments are conducted on simulated systems over static and homogeneous graphs. Future work includes extending our method to systems over dynamic and heterogeneous graphs. Furthermore, the proposed method will be applied to study the mechanism of the emergence of intelligence in different complex systems, including multi-agent systems, swarm systems, physical systems and social systems.
This work is supported by the National Key R&D Program of China (2018AAA0101203), and the National Natural Science Foundation of China (62072483, 61673403, U1611262). This work is also supported by MindSpore.
Towards automated statistical physics: data-driven modeling of complex systems with deep learning
. arXiv preprint arXiv:2001.02539. Cited by: Introduction.The concrete distribution: a continuous relaxation of discrete random variables
. In ICLR, Cited by: Neural Relational Inference (NRI).Relational neural expectation maximization: unsupervised discovery of objects and their interactions
. In ICLR, Cited by: Introduction.A study of ai population dynamics with million-agent reinforcement learning
. In AAMAS, pp. 2133–2135. Cited by: Introduction.
Comments
There are no comments yet.