Log In Sign Up

Sparse Graph Learning for Spatiotemporal Time Series

Outstanding achievements of graph neural networks for spatiotemporal time series prediction show that relational constraints introduce a positive inductive bias into neural forecasting architectures. Often, however, the relational information characterizing the underlying data generating process is unavailable; the practitioner is then left with the problem of inferring from data which relational graph to use in the subsequent processing stages. We propose novel, principled – yet practical – probabilistic methods that learn the relational dependencies by modeling distributions over graphs while maximizing, at the same time, end-to-end the forecasting accuracy. Our novel graph learning approach, based on consolidated variance reduction techniques for Monte Carlo score-based gradient estimation, is theoretically grounded and effective. We show that tailoring the gradient estimators to the graph learning problem allows us also for achieving state-of-the-art forecasting performance while controlling, at the same time, both the sparsity of the learned graph and the computational burden. We empirically assess the effectiveness of the proposed method on synthetic and real-world benchmarks, showing that the proposed solution can be used as a stand-alone graph identification procedure as well as a learned component of an end-to-end forecasting architecture.


page 1

page 2

page 3

page 4


Graph Deep Factors for Forecasting

Deep probabilistic forecasting techniques have recently been proposed fo...

Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations

Modeling multivariate time series as temporal signals over a (possibly d...

Relational State-Space Model for Stochastic Multi-Object Systems

Real-world dynamical systems often consist of multiple stochastic subsys...

Spatially-Aware Graph Neural Networks for Relational Behavior Forecasting from Sensor Data

In this paper, we tackle the problem of relational behavior forecasting ...

Inductive Graph Neural Networks for Spatiotemporal Kriging

Time series forecasting and spatiotemporal kriging are the two most impo...

Modeling Oceanic Variables with Dynamic Graph Neural Networks

Researchers typically resort to numerical methods to understand and pred...

1 Introduction

Traditional statistical and signal processing methods to time series forecasting leverage on temporal dependencies to model data generating processes [harvey1990forecasting]. Graph signal processing methods extend these approaches to dependencies observed both in time and space, i.e., to the setting where temporal signals are observed over the nodes of a graph [ortega2018graph, stankovic2020graph2, di2018adaptive, isufi2019forecasting]

. The key ingredient here is the use of graph shift operators, constructed from the graph adjacency matrix, that localize learned filters on the graph structure. The same holds true for graph deep learning methods that have revolutionized the landscape of machine learning for graphs 

[bruna2014spectral, bronstein2017geometric, bacciu2020gentle]. However, it is often the case that no prior topological information about the reference graph is available, or that dependencies in the dynamics observed at different locations are not well modeled by the available spatial information (e.g., physical proximity of the sensors). Examples are social networks, smart grids, and brain networks, just to name a few relevant application domains.

In the deep learning setting, several methods train, end-to-end, a graph learning module with a neural forecasting architecture to maximize performance on the downstream task [shang2021discrete, wu2020connecting]. A typical deep learning approach consists in exploiting spatial attention mechanisms to discover the reciprocal salience of different spatial locations at each layer [cao2020spectral, satorras2022multivariate]. Despite their effectiveness, attention-based approaches impair two major benefits of graph-based representation learning: they 1) do not allow for the sparse computation enabled by the discrete nature of graphs and 2) do not take advantage of the structure, introduced by the graph topology, as a regularization of the learning system. Indeed, sparse computation allows graph neural networks (GNNs; [scarselli2008graph, bacciu2020gentle]) with message-passing architectures [gilmer2017neural] to scale in terms of network depth and the dimension of the graphs that are possible to process. At the same time, sparse graphs constrain learned representations to be localized in node space and mitigate over-fitting spurious correlations in the training data. Graph learning approaches that do attempt to learn relational structures for time series forecasting exist, but often rely on continuous relaxations of the binary adjacency matrix and, thus, dense computations to enable automatic reverse-mode differentiation through any subsequent processing [shang2021discrete, kipf2018neural]. Conversely, other solutions make the computation sparse [wu2020connecting, deng2021graph] at the expense of the quality of the gradient estimates as shown in [zugner2021study]. Ideally, we would like to have a methodology for providing accurate gradients while, at the same time, allowing for sparse computations in the downstream message-passing operations, typical of modern GNNs.

In this paper, we model the graph learning problem from a probabilistic perspective, which, besides naturally accounting for uncertainty and the embedding of priors, enables learning sparse graphs as realizations of a discrete probability distribution. In particular, we seek to learn a parametric distribution

s.t. graphs sampled from it maximize the performance on the downstream forecasting task. Previous works [shang2021discrete, kipf2018neural] learn by exploiting reparametrization tricks to obtain path-wise gradient estimates [glasserman1991gradient, kingma2014auto]. However, these approaches imply approximating the discrete distribution with a softmax continuous relaxation [paulus2020gradient] which makes all the downstream computations dense and quadratic in the number of nodes. Differently, here, we adopt the framework of score-function (SF) gradient estimators [rubinstein1969some, williams1992simple, mohamed2020monte] which allows us for preserving the sparsity of the sampled graphs and that of the subsequent processing steps (e.g., the forward pass of a message-passing network).

Our contributions are as follows.

  • [leftmargin=0em, itemindent=1.5em, itemsep=0.2em, topsep=0em]

  • We provide an end-to-end methodological framework for probabilistic graph learning in spatiotemporal data, based on score-function (SF) gradient estimators [Sec. 4]. In this regard, we provide a convenient rewriting of the gradient for the considered settings [Prop. 2].

  • We introduce two parametrizations of

    as a set of Bernoulli distributions and as the sampling

    without replacement of edges under a sparsity constraint [Sec. 4.2]. We show how to sample graphs from both distributions and we derive the associated differentiable likelihood functions. Both parametrized distributions allow dealing with an adaptive number on neighboring nodes.

  • We estimate the parameters of with the SF method, and design Monte Carlo (MC) estimators for stochastic message-passing architectures [Sec. 4.1 and 4.3]. We introduce a novel and effective, yet simple to implement, variance reduction method for the estimators based on the evaluation of the loss w.r.t. the Fréchet means of graph distributions, for which we provide closed-form solutions [Prop. 1]. Our method does not require the estimation of additional parameters and, unlike more general-purpose approaches (e.g., [mnih2014neural]), can be done at least as efficiently as taking a sample from the considered distributions and evaluating the corresponding cost.

We strongly believe that the sparse computation and efficient training enabled by our probabilistic graph learning methods will constitute key elements to design a new, even more effective, class of novel graph-based forecasting architectures. The paper is organized as follows. Sec. 2 discusses related works; then, Sec. 3 provides the formulation of the problem. We introduce our approach in Sec. 4 and proceed with its empirical evaluation in Sec. 5; finally, we draw our conclusions in Sec. 6.

2 Related works

Graph neural networks have become increasingly popular in spatiotemporal time series processing [seo2018structured, li2018diffusion, yu2018spatio, wu2019graph, deng2021graph, cini2022filling] and the graph learning problem is well known within this context. GraphWavenet [wu2019graph] learns a structure for time series forecasting by factorizing a weighted adjacency matrix with learned embeddings, and several other methods follow this direction [bai2020adaptive, oreshkin2020fcgaga]. MTGNN [wu2020connecting] and GDN [deng2021graph] sparsify the learned factorized adjacency by each node to the -most similar ones. satorras2022multivariate showed that hierarchical attention-based architectures are also viable to obtain accurate predictions; however, attention-based approaches do not actually learn a graph. Within models learning from non-temporal data, franceschi2019learning tackle the probabilistic graph learning by using a bi-level optimization routing and a straight-through gradient trick [bengio2013estimating], which nonetheless requires dense computations. NRI [kipf2018neural] learns a latent variable model predicting the interactions of physical objects by learning edge attributes of a fully connected graph. GTS [shang2021discrete] simplifies the NRI module considering only binary relationships and integrates graph inference in a spatiotemporal recurrent graph neural network [li2018diffusion]. Both NRI and GTS exploit path-wise gradient estimators based on the categorical Gumbel trick [maddison2017concrete, jang2017categorical] and, as such, suffer from the computational setbacks discussed in the introduction. Finally, the graph learning module proposed by kazi2020differentiable shares some similarities with our approach and uses the Gumbel-Top-K trick [kool2019stochastic] to sample a

-NN graph; however, node scores are learned by using a heuristic for increasing the likelihood of sampling edges that contribute to correct classifications. To the best of our knowledge, we are the first to propose a graph learning module for spatiotemporal signals that keeps the computation sparse, both in training and during inference.

3 Problem Formulation

Consider a set of sensors as nodes. We indicate with the -dimensional observation acquired by the -th sensor at time step , and denote by the aggregated observation matrix at the sensor network level. Similarly, we indicate with the optional -dimensional exogenous variables (e.g., related to current weather conditions) and with static node attributes, like, sensor specifics. We assume nodes (sensors) to be identified, hence maintaining a correspondence between nodes at different time steps and allowing us to talk about the time series associated with each node. We encode topological information among the nodes with a (binary) adjacency matrix .

Graph learning and time series forecasting

Given a window of past observations open on the time series, we consider the problem of forecasting the next measurements , where notation indicates the time interval

. We address the graph inference problem by considering a separate module to model a parametric probability distribution over graphs. Denoting the predictor as

and the graph distribution as , we consider the family of predictive models


where , indicate the model parameters learned jointly s.t.


where is the optimization objective at time step expressed as the expectation, over the graph distribution

, of a cost –loss– function

, typically based on a -norm with, e.g., or . Note that can either be over dynamic or static graphs, i.e., conditioned on the current observations or not. When not explicitly needed, we may omit the conditioning of for the sake of presentation clarity. We consider a family of models implemented by a spatiotemporal graph neural network (STGNN), based on the massage-passing (MP) framework [gilmer2017neural] and following either the time-then-space (TTS) or the time-and-space (T&S) paradigm to process information along the temporal dimension [gao2021equivalence]. We refer to the appendix for more details on STGNN architectures. Notably, can be designed as to exchange messages along the spatial dimension according to a single realization of as well as a set  (independent) samples , e.g., one for each MP layer in .

Core challenge

Minimizing the expectation is challenging, as it involves estimating a gradient w.r.t. the parameters of the discrete sampling distribution . Sampling graphs throughout the learning process results in a stochastic computational graph111It is important to distinguish graphs representing relationships among data from the computational graph of the architecture, i.e., the directed acyclic graph where nodes and edges characterize the sequence of operations that need to be performed to carry out a certain computation, such as the forward pass of a neural network. (CG). While automatic differentiation of CGs is a core component of modern deep learning libraries [paske2019pytorch, abadi2015tensorflow], dealing with stochastic nodes introduces additional challenges as gradients have to be estimated w.r.t. expectations over the sampling of the stochastic nodes [schulman2015gradient, weber2019credit, mohamed2020monte]. Although tools for automatic differentiation of stochastic CGs are being developed [foerster2018dice, bingham2019pyro, krieken2021storchastic, tfprobability], general approaches can be ineffective and prone to failure, especially in the case of discrete distributions (see also [mohamed2020monte]). Our case is markedly problematic: the MP paradigm constrains the flow of spatial information between nodes, making the structure of the CG dependent on the input message-passing graph (MPG). A stochastic input MPG introduces, then, stochastic nodes in the resulting CG (i.e., one for each stochastic edge in MPG), leading to a large number of paths data can potentially flow through. Considering an -layered architecture, the number of stochastic nodes increases to , making the design of reliable, low-variance –i.e., effective– MC gradient estimators inherently challenging. Furthermore, as already mentioned, computing gradients associated with each stochastic edge introduce additional challenges w.r.t. time and space complexity; further discussion and actionable solutions are given in the next section.

4 Learning Stochastic Diffusion Processes for Time Series Forecasting

Figure 1: Overview of the learning architecture. The graph module samples a graph for each layer of the STGNN predictor; predictions and samples are used to compute costs, log-likelihoods, and baselines. Gradient estimates are propagated back to the respective modules.

In this section, we present our approach to probabilistic graph learning. We start by discussing score-based gradient estimators [Sec. 4.1]. Then, we propose two models for the graph distribution [Sec. 4.2] and we focus on the problem of controlling the variance of the estimator with novel and principled variance reduction techniques tailored for graph-based architectures [Sec. 4.3]. Finally, we provide a convenient rewriting of the gradient for -layered MP architectures leading to a novel surrogate loss [Sec. 4.4]. Fig. 1 shows a schematic overview of the framework.

4.1 Estimating gradients for stochastic message-passing networks

SF estimators are based on the identity


which holds –under mild assumptions– for generic cost functions and distributions

. The SF trick allows for estimating gradients easily by MC sampling and backpropagating through the computation of the

score function  [mohamed2020monte]. SF estimators are black-box optimization methods, i.e., they only require to evaluate the cost function ( in our case) which does not necessary need to be differentiable w.r.t. the distributional parameters . In our problem setup, Eq. 3 becomes


which has the appealing property of allowing for computing gradients w.r.t. the graph generative process without requiring a full evaluation of all the stochastic nodes in the CG. Conversely, path-wise gradient estimators tackle this problem by exploiting continuous relaxations of the discrete , thus estimating the gradient by differentiating through all nodes of the stochastic CG. Defined to be the number of edges in a realization of , the cost of learning a graph with a path-wise estimator is that of making any subsequent MP operation scale with , instead of the complexity that would have been possible with a sparse computational graph. The outcome is even more dramatic if we consider T&S models where MP is used for propagating information for each time step, thus making the computational and memory costs scale up to : simply unsustainable for any practical application at scale.

4.2 Neighborhood sampling and sample likelihood

The parametrized distribution should be chosen so that we can (i) efficiently sample graphs and evaluate their likelihood and (ii) backpropagate the errors through the computation of the score in Eq. 3 to the parameters . We consider as parametrized by a matrix of weights (scores) , where each associated with the stochastic edge . In the case of a static graph, we can simply consider as parameters of ; however, more complex parametrizations are possible, e.g., allowing for modeling dynamic graphs by exploiting amortized inference to condition distribution on the observed values; further discussion is deferred to the supplemental material.

Binary edge sampler

A straightforward approach is to consider a Bernoulli random variable with parameter

associated with each edge s.t. , where indicates the entry of matrix . Graphs are then sampled by drawing independently from the distribution associated with each edge. We refer to this graph learning module as binary edge sampler (BES). Here, the sampling from can be done efficiently and is highly parallelizable, computing the log-likelihood of a sample is cheap and differentiable as it corresponds to evaluating the binary cross-entropy between the sample and the distributional parameters. BES graph generators are a common choice in the literature [franceschi2019learning, shang2021discrete] as the independence assumption makes the mathematics amenable and avoids the often combinatorial complexity of dealing with more structured distributions. Sparsity priors can be imposed by regularizing , e.g., by adding a Kullback-Leibler regularization term to the loss [shang2021discrete, kipf2018neural].

Subset neighborhood sampler

Encoding prior structural information about the sparseness of the graphs directly in is often desirable and might allow –depending on the problem– to remarkably reduce sample complexity. We can use the score matrix to parametrize a stochastic top-k sampler. For each -th node, we sample a subset of neighboring nodes by sampling without replacement from the Categorical distribution parametrized by the softmax of the scores . The probability of sampling neighborhood for each -th node is given by


where denotes an ordered sample without replacement, is the set of all the permutations of . While sampling can be done efficiently by exploiting the Gumbel-top-k trick [kool2019stochastic], Eq. 5 shows that directly computing requires marginalizing over all the possible orderings of . While exploiting the Gumbel-max trick can bring down computation to  [huijben2022review, kool2020estimating], exact computation remains untractable for any practical application. Luckily, numerical methods to efficiently approximate exist [kool2020estimating], and, as we show in Sec. 5, it is possible to perform backpropagation through the numerical solver to evaluate the score function. For more details on the numerical integration method and its computational complexity we defer to the supplemental material. The resulting subset neighborhood sampler (SNS), then, allows for embedding structural priors on the sparsity of the latent graph directly into the generative model. Fixing the number of neighbors might introduce an irreducible approximation error when identifying graphs with nodes with a variable number of neighbors; however, we solve this problem by adding up to

dummy nodes to the set of candidate neighbors and then discarding them after having sampled a graph. By doing so, then, the hyperparameter

can also be used to cap the maximum number of edges and set a minimum sparsity threshold, which also makes the computational complexity in MP layers scale with at most .

4.3 Variance reduction

Variance reduction is a critical step in the use of score-based estimators. As for any MC estimator, a direct method to reduce the variance consists in increasing the number of independent samples used to compute the estimator, which results in reducing the variance by a factor w.r.t. the one-sample estimator. In our setting, sampling adjacency matrices results in evaluations of the cost and the associated score: we are interested in finding more sample efficient alternatives. In particular, we elaborate our strategy starting from the control variates method.

Control variates and baseline method

The control variates method consists in introducing an auxiliary quantity  (possibly input dependant) for which we know how to efficiently compute the expectation under the sampling distribution [mohamed2020monte]. Referring to Eq. 3, this allows us for considering , with , but . Quantity is called control variate. A computationally cheap choice is to use the score function itself as control variate, i.e., in our case, , which yields . This narrows the problem to finding an appropriate (possibly observation dependant) value , often referred to as baseline. Unfortunately, finding the optimal can be as hard as estimating the gradient. Since , the optimal baseline is given by


Finding the exact is intractable (note the dependence on both model’s parameters and input observation ). In the supplemental material, we show that a good choice is to approximate with


with being the expected adjacency matrix; what we mean here by “expected” is given in following Proposition 1, where we also provide analytic solutions for finding w.r.t. BES and SNS. Note that in Eq. 7 does not depend on and can be estimated as a moving average of the squared score. In practice, by assuming the ratio in Eq. 7 to be approximately , computational complexity is brought further down by using , which requires only a single evaluation of the cost .

Proposition 1.

The expected adjacency matrix and the Fréchet mean adjacency matrix

w.r.t. the Hamming distance and the support of are the following:


where if is true and otherwise, and , with being the indices of the smallest elements s.t. .

Sketch of the proof.

The components of in BES are independent, so the derivation of is straightforward. For SNS, we analytically derive the probability

from the CDF and quantile function of the distribution Gumbel(

). Regarging the Fréchet mean, we use the fact that, in this case, , and derive the analytical expression. The detailed proof is given in the supplemental material. ∎

While the proof is quite complex, the conclusions are very intuitive. Notably, the above proposition allows us to compute for SNS samplers without explicitly evaluating , which results in a considerable reduction in computational complexity. Finally, we point out that, even though may differ from , the variance is reduced as long as . We indicate the the modified cost, i.e., the cost minus the baseline as ; as shown in Fig. 1, the modified cost is computed after each forward pass and used to update the parameters of .

4.4 Surrogate objective

As a final step, we can leverage on the structure of MP neural networks to rewrite the gradient . This formulation allows for obtaining a different estimator for the case where we sample a different at each MP layer.

Proposition 2.

Consider family of models with exactly massage-passing layers propagating massages w.r.t. different adjacency matrices , , sampled from (either BES or SNS). Then


where indicates the cost associated with the -th node and denotes the -th row of adjacency matrix , i.e., the row corresponding to the neighborhood of the -th node.

Sketch of the proof.

The proof, detailed in the supplemental material, is derived by noticing that, in the stochastic CG of , there is a direct path from stochastic edges in to the costs. ∎

Proposition 2 shows that each row impacts only the associated cost , while for affects the entire (global) cost function . Intuitively, the second sum in Eq. 8 can be interpreted directly rewarding connections that lead to accurate final predictions w.r.t. the local cost . Besides providing a more general MC estimator, this observation motivates us in considering a similar surrogate approximate loss for the case where we use a single for all layers, i.e., we consider


as gradient to learn . Eq. 9 is developed from Eq. 8 by considering a single sample and introducing the hyperparameter . Note that in this case is an approximation of the true objective with a reweighting of the contribution of each . Following this consideration, can be interpreted as a trade-off of the contribution to the gradient of the local and global cost. In practice, we set , so that the two terms are roughly on the same scale. Empirically, we observed that using the modified objective consistently leads to faster convergence; see Sec. 5.

5 Experiments

To validate the effectiveness of the proposed framework, we perform two sets of experiments. The first one focuses on the task of graph identification where we address the problem of retrieving graphs that better explain a set of observations. The second one shows how our approach can be used as a graph-learning module in an end-to-end forecasting architecture. We consider one synthetic dataset and , openly available, real-world ones. The GPVAR synthetic dataset consists of signals generated by recursively applying a polynomial Graph VAR filter [isufi2019forecasting] and adding Gaussian noise at each time step: this results in complex, yet known and controllable, spatiotemporal dynamics. In particular, we consider the same filter and graph used in [zambon2022az] and a data split for training, validation, and testing, respectively. As real-world datasets, we consider two popular benchmarks in the traffic forecasting literature, i.e., the METR-LA and PEMS-BAY datasets from [li2018diffusion], which consist of traffic speed measurements taken at cross-roads in Los Angeles and San Francisco, respectively. We use the same preprocessing and data splits of [wu2019graph]. Finally, we use a dataset of hourly readings from air quality monitoring stations scattered over different Chinese cities (AQI [zheng2015forecasting]). For AQI we use the same splits used in previous works [yi2016st]. More details are given in the supplemental material.

5.1 Synthetic data

Figure 2:

Experiments on GPVAR. All the curves show the validation MAE after each training epoch.

To gather insights on the impact of each aspect of the methods introduced so far, we start by using the controlled environment provided by the GPVAR dataset. In the first setup, we consider a GPVAR filter as the predictor and assume known the true model parameters, i.e., the coefficients of the filter, to decouple the graph-learning module from the forecasting module. Then, in a second scenario, we learn the graph while, at the same time, fitting the filter’s parameters. Fig. 2 shows the validation mean absolute error (MAE) after each training epoch by using BES and SNS samplers, with and without baseline for variance reduction, and when SNS is run with dummy nodes for adaptive node degrees. In particular, Fig. 2 and Fig. 2 show results in the graph identification task for the vanilla gradient estimator derived from Eq. 3 and for the surrogate objective from Eq. 8, respectively. Note that, to match the optimal prediction, models have to perfectly retrieve the underlying graph. (Impact of the Baseline) The first striking outcome is the effect of baseline in both the considered configurations which dramatically accelerates the learning process. (Graph distribution) The second notable result is that, although both SNS and BES are able to retrieve the underlying graph exactly, the sparsity prior in SNS yields a more sample efficient training procedure, as the validation curves are steeper for SNS; note that the approximation error induced by having a fixed number of neighbors is effectively removed with the dummy nodes. However, we mention that we observed some numerical instability in computing the likelihood of the unordered subsets which, in some cases, prevented SNS to converge to the optimum; we think this could be avoided entirely by reducing the learning rate. (Surrogate objective) Fig. 2 shows that the surrogate objective contributes to accelerating learning even further for all considered methods. (Joint training) Finally, Fig. 2 reports the results for the joint training of the predictor and graph module with the approximate objective. The curves, in this case, were obtained by initializing the parameters of the filter randomly and specifying an order of the filter higher than the real one; nonetheless, the learning procedure was able to converge to the optimum when using the baseline from Sec. 4.3. Note that to obtain the validation scores we simply evaluated the model by taking a single Monte Carlo sample at each step.

5.2 Real-world datasets

Graph Identification in AQI

Tested on
Trained on Beijing Tianjin
Beijing 9.43 0.03 10.62 0.05
Tianjin 9.55 0.06 10.56 0.03
Baseline 10.21 0.01 11.25 0.04
Table 1: AQI experiment.

For graph identification we set up the following scenario. From the AQI dataset, we extract subsets of sensors that correspond to monitoring stations in the cities of Beijing and Tianjin, respectively. We build a graph for both subsets of data by constructing a K-NN graph of the stations based on their distance; we refer to these as ground-truth graphs. Then, we train a different predictor for each of the two cities, based on the ground-truth graph. In particular, we use a TTS STGNN with a simple architecture consisting of a GRU [chung2014empirical] encoder followed by MP layers. As a reference value, we also report the performance achieved by a GRU trained on all sensors, without using any spatial information. Performance is measured in terms of 1-step-ahead MAE. Results for the two models, trained with early stopping on the validation set and tested on the hold-out test set for the same city are shown in the main diagonal of Tab. 1. In the second stage of the experiment, we consider the model above trained on one city, freeze its parameters (not trained further), discard the ground-truth graph, and train our graph learning module (with the SNS parametrization) on the other city. Results, reported in the off-diagonal elements of Tab. 1, show that our approach is able to almost match the performance that would have been possible to achieve by fitting the model directly on the target dataset with the ground-truth adjacency matrix; moreover, the performance is significantly better than that of the reference GRU.

Model MAE @ 15 MAE @ 30 MAE @ 60 MAE @ 15 MAE @ 30 MAE @ 60
Full attention 2.727 0.005 3.049 0.009 3.411 0.007 1.335 0.003 1.655 0.007 1.929 0.007
GTS 2.750 0.005 3.174 0.013 3.653 0.048 1.360 0.011 1.715 0.032 2.054 0.061
MTGNN 2.690 _.___ 3.050 _.___ 3.490 _.___ 1.320 _.___ 1.650 _.___ 1.940 _.___
Our (SNS) 2.725 0.005 3.051 0.009 3.412 0.013 1.317 0.002 1.620 0.003 1.873 0.005
–Truth 2.720 0.004 3.106 0.008 3.556 0.011 1.335 0.001 1.676 0.004 1.993 0.008
–Random 2.801 0.006 3.160 0.008 3.517 0.009 1.327 0.001 1.636 0.002 1.897 0.003
–Identity 2.842 0.002 3.264 0.002 3.740 0.004 1.341 0.001 1.684 0.001 2.013 0.003
Table 2: Results on the traffic datasets.

Joint training and forecasting in traffic datasets

Finally we test our approach on widely used traffic forecasting benchmarks. Here we took the full-graph attention architecture proposed in [satorras2022multivariate], removed the attention gating mechanism, and used the graph learned by our module to constrain the message exchange only between connected nodes; in particular, we considered the SNS sampler with , dummy nodes and surrogate objective. We used the same hyperparameters of [satorras2022multivariate], except for the learning rate schedule and batch size (see supplemental material). As a reference, we also tested results using the ground-truth graph, a graph with only self-loops (i.e., with

the identity matrix), as well as a random graph sampled from the Erdös-Rényi model with

. For MTGNN [wu2020connecting], we report the results presented in the original paper, while for GTS we report results obtained by running the authors’ code. More details are in the supplemental material. Note that GTS is considered the state of the art for methods based on path-wise estimators [zugner2021study]. Results in Tab. 2 show the MAE performance for , and min. time horizons achieved over multiple independent runs. Interestingly, the results with the random and ground-truth graphs are comparable, showing that the optimal graph might not correspond to the most obvious one. Our approach is always comparable to or better than the state-of-the-art alternatives, and it is statistically better than all the baselines using the reference adjacency matrices.

6 Conclusions

In this paper, we propose a methodological framework for learning graph structures underlying spatiotemporal data. We designed our novel probabilistic framework upon score-function gradient estimators that allow us for keeping the computation sparse throughout both training and inference phases, and developed variance-reduction methods to yield accurate estimates for the training gradient; both advantages differentiate us from the current literature. The proposed graph learning modules are trained end-to-end on a time series forecasting task and can be used for both graph identification or as a component in an end-to-end architecture. Empirical results support our claims, showing the effectiveness of our framework. Notably, we achieve forecasting performance on par with the state-of-the-art alternatives, while maintaining the benefits of graph-based processing.


This work was supported by the Swiss National Science Foundation project FNS 204061: HigherOrder Relations and Dynamics in Graph Neural Networks. The authors wish to thank the Institute of Computational Science at USI for granting access to computational resources.



Appendix A Parametrizing the predictor

In this Section we provide a brief overview of the architectures we considered for theoretical and empirical results discussed in the main paper. We start by looking at a general class of message-passing layers and with the different STGNN architectures.

a.1 Message-passing neural networks

We consider the family of message-passing operators where representations are updated at each layer such as


where indicates

-th node-feature vector at layer

, the set of its neighboring nodes, and features associated with the edge connecting the -th to the -th node. Update and message functions, and

, respectively, can be implemented by any differentiable function – e.g., a multilayer perceptron – and

indicates a generic permutation invariant aggregation function.

a.2 Spatiotemporal Graph Neural Networks

STGNNs process input spatiotemporal data by considering operators that use the underlying graph to impose inductive biases in the representation learning process. By adopting the terminology introduced in [gao2021equivalence], we distinguish between time-then-space (TTS) and time-and-space (T&S) based on whether message-passing happens after or in-between a temporal encoding step.

Time-then-space models

TTS models are based on an encoder-decoder architecture where the encoder embeds each input time series , for , in a vector representation, while the decoder, implemented as a multilayer GNN, propagates information across the spatial dimension. In particular, we consider the family of models s.t.

where the notation is consistent with that of Eq. 10 with the addition of subscripts for indexing the temporal dimension. Examples of spatiotemporal graph processing models that fall into the time-then-space category are NRI [kipf2018neural] and the encoder-decoder architecture used in [satorras2022multivariate].

Time-and-space models

Time-and-space models are a general class of STGNNs where space and time are processed in a cohesive way by extending the message-passing framework [gilmer2017neural] to the spatiotemporal settings, i.e., by implementing the gates of a message-passing GNN with neural networks for sequential data processing. A large subset of these family models can be seen as performing the following spatiotemporal message-passing operations

Note that predictions, here, are obtained by pooling representations along the temporal dimension and then using a linear readout as in the previous case. While this writing of the architecture assumes a static adjacency matrix, it could easily be extended to account for the changing topology, e.g., by changing the way messages are aggregated at each time step. Other architectures are possible, e.g., by exploiting recurrent models [seo2018structured, li2018diffusion] or alternating spatial and temporal convolutional layers [yu2018spatio, wu2019graph].

Appendix B Parametrizing the graph distribution

As discussed in the paper, we can parametrize by associating a score to each edge ; i.e., by setting . Similarly, one could reduce the number of parameters to estimate by using ammortized inference and learning some factorization of , e.g., where . Modeling dynamic graphs instead requires accounting for observations at each considered time step . For example, one can consider models s.t.

where indicates a generic encoding function for in the input window (e.g., an MLP or an RNN),

a nonlinear activation function,

is a learnable weight matrix, a learnable bias and

the learnable parameters of the output linear transformation.

Appendix C Details on variance reduction methods

Equation 4 provides a rewriting of gradient we intend to estimate. With as control variate, the value of leading to the smallest variance is that of Eq. 6, i.e.,

Note that , as depends on the observations and estimating is, essentially, as hard as estimating . Therefore, we propose the following approximation from a Taylor’s expansion of

centered in :

Approximating with in Eq. 7, namely,


Although the denominator depends on the current value of , it does not depend on , so it can be estimated incrementally from each training batch.

Appendix D Proof of Proposition 1

We restate the proposition and divide the proof into two sections regarding the expected value and the Fréchet mean , respectively.

Proposition 1.

The expected adjacency matrix and the Fréchet mean adjacency matrix

w.r.t. the Hamming distance and the support of are the following:


where if is true and otherwise, and , with being the indices of the smallest elements s.t. .

d.1 Expected value of

The proof for BES is straightforward because each component of is independent from the others, so , for all . For SNS, the proof is not as simple, because edges and are not independent if . The reminder of this section proves the thesis for an SNS with a given number of neighbors .

For convenience, we consider a generic node , and adopt the following notation.


Consider parameter vector , associated with Gumbel random variables , and one realization thereof . Denote the node set as , and the sampled neighborhood with the nodes associated with top elements of and the node set associated with the smallest values of . Define also and ; see Figure 3 for a visual representation. We assume all values are distinct so that no conflict arise when choosing the top-k or bottom-k values.

Figure 3: Different configurations of and . From left to right: , , , and .

Component is equal to . Node is in if and only if is larger than other realizations.

  • If , then .

  • Conversely, if , then . However, cannot be the maximum in otherwise .

We conclude that By the properties of the maximum of a set of Gumbel random variables is a Gumbel distribution222For any set of Gumbel distributions with parameters , , so . with parameters , therefore,

where random variable does not depend on . By following Lemma 1, we conclude the proof of the first part of Proposition 1.

Lemma 1.

For any two Gumbel random variables of parameters ,

In particular, for and ,

When dealing with dummy nodes, we add extra nodes, we apply the standard k-NN sampling, and then drop all dummy nodes with the associated connections. So is computed from all original and dummy nodes, but dropping the dummy nodes afterwards does not impact on its value.

Proof of Lemma 1.

where and

are the probability density and cumulative distribution functions of the Gumbel(

), respectively. The above rewriting is analogous to equation 31 in kool2020estimating, except that we are only looking at and not all elements of . Similarly to Kool et al., we consider the change of variables , and obtain the expression

The cumulative distribution of a Gumbel (which is absolutely continuous) is

and its inverse is therefore

Note that

and that

So the integral above becomes

where we denoted with the exponent , for brevity. We conclude that

d.2 Fréchet mean of

The proof of the second part of Proposition 1 related to analytic expressions of follows from Lemma 2, where we show how to construct from .

Lemma 2.

If , then the Fréchet mean graph , with respect to the Hamming distance among the adjacency matrices in is, for all ,


For BES, we can combine above lemma with the fact that to conclude that , for all . For SNS, the proof is more involving and rely on following Lemma 3.

Lemma 3.

Consider SNS and three nodes , then . In particular, .

According to Lemma 2, the Fréchet mean w.r.t. the set of binary matrices representing a -NN graph is given by . By Lemma 3, the edges associated with top elements in are the same as the top elements in , therefore,

In the reminder of the section, we prove Lemmas 2 and 3.

Proof of Lemma 2.

Let us define the Fréchet function

w.r.t. to the Hamming distance. Note that for all

w.r.t. the Frobenius norm. Therefore,

Note now that