Graph Message Passing with Cross-location Attentions for Long-term ILI Prediction

12/21/2019 ∙ by Songgaojun Deng, et al. ∙ University of Virginia George Mason University Stevens Institute of Technology 13

Forecasting influenza-like illness (ILI) is of prime importance to epidemiologists and health-care providers. Early prediction of epidemic outbreaks plays a pivotal role in disease intervention and control. Most existing work has either limited long-term prediction performance or lacks a comprehensive ability to capture spatio-temporal dependencies in data. Accurate and early disease forecasting models would markedly improve both epidemic prevention and managing the onset of an epidemic. In this paper, we design a cross-location attention based graph neural network (Cola-GNN) for learning time series embeddings and location aware attentions. We propose a graph message passing framework to combine learned feature embeddings and an attention matrix to model disease propagation over time. We compare the proposed method with state-of-the-art statistical approaches and deep learning models on real-world epidemic-related datasets from United States and Japan. The proposed method shows strong predictive performance and leads to interpretable results for long-term epidemic predictions.



There are no comments yet.


page 9

page 13

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Epidemic disease propagation that involves large populations and wide areas can have a significant impact on society. The Center for Disease Control and Prevention (CDC) estimates 79,400 deaths from influenza occurred during the 2017-2018 season in the United States 

111 Early forecasting of infectious diseases such as influenza-like illness (ILI) provides optimal opportunities for timely intervention and resource allocation. It helps with the timely preparation of corresponding vaccines in health care departments which leads to reduced financial burden. For instance, the World Health Organization (WHO) reports that Australia spent over 352 million dollars on routine immunization in the 2017 fiscal year 222 We focus on the problem of long term ILI forecasting with lead time from 1 to 20 weeks based on the influenza surveillance data collected for multiple locations (states and regions). Given the process of data collection and surveillance lag, accurate statistics for influenza warning systems are often delayed by a few weeks, making early prediction imperative. However, there are a few challenges in long-term epidemic forecasting. First, the temporal dependency is hard to capture with short-term input data. Without manually added seasonal trends, most statistical models fail to provide high accuracy. Second, the influence from other locations has not been exhaustively explored with limited data input. Spatio-temporal effects have been studied but they usually require adequate data sources to achieve good performance [22].

Existing work on epidemic prediction has been focused on various aspects: 1) Traditional causal models [15, 8, 3]

, including compartmental models and agent-based models, employ disease progression mechanisms such as Susceptible-Infectious-Recovered (SIR) to capture the dynamics of ILI diseases. Compartmental models focus on mathematical modeling of population-level dynamics. Agent-based models simulate the propagation process at the individual level with contact networks. Calibrating these models is challenging due to the high dimensionality of the parameter space. 2) Time series prediction with statistical models such as Autoregressive (AR) and its variants (e.g., VAR) are not suitable for long term ILI trend forecasting given that the disease activities and human environments evolve over time. 3) Machine learning and deep learning methods 

[21, 26, 31, 29]

such as recurrent neural networks have been explored in recent years but they barely consider cross-spatial effects in long term disease propagation.

In this paper, we focus on long term (10-20 weeks) prediction of the count of ILI patients using data from a limited time range (20 weeks). To tackle this problem, we explore a graph propagation model with deep spatial representations to compensate the loss of temporal information. Assuming each location is a node, we design a graph neural network framework to model epidemic propagation at the population level. Meanwhile, we investigate recurrent neural networks for capturing sequential dependencies in local time series data and temporal convolutions for identifying short-window patterns. Our key contributions are summarized as follows:

  • We propose a novel graph-based deep learning framework for long-term epidemic prediction from a time-series forecasting perspective. This is one of the first works of graph neural networks adapted to epidemic forecasting.

  • We investigate a location-aware attention mechanism to capture location correlations. The influence of locations can be directed and automatically optimized in the model learning process. The attention matrix is further evaluated as an adjacency matrix in the graph neural network for modeling disease propagation.

  • We design a temporal convolution module to automatically extract temporal dependencies and hidden features for time series data of multiple locations. The learned temporal features for each location are utilized as node attributes for the graph neural network.

  • The proposed method, Cola-GNN, outperforms a broad range of state-of-the-art models on three real-word datasets with different long-term prediction settings. We also demonstrate the effectiveness of its learned attention matrix compared to a geographical adjacency matrix in an ablation study.

2 Related Work

2.1 Influenza Prediction

In many studies, forecasting influenza or influenza-like illnesses (ILI) case counts is formulated as time series regression problems, where autoregressive models are widely used 

[27, 1, 11, 30]. Instead of focusing on seasonal effects, Wang et al. [30] propose a dynamic poisson autoregressive model to improve short-term prediction accuracy (e.g. 1-4 weeks). Furthermore, variations of particle filters and ensemble filters have been used to predict influenza activities. Yang et al. [32] evaluate the performance of six state-of-the-art filters to forecast influenza activity and concluded that the models have comparable performance. Ensemble methods such as matrix factorization based regression and nearest neighbor based regression have been studied [5]. While autoregressive, filter-based, and ensemble models are simple and straightforward, they often neglect the geographical dependence in disease propagation.

Attempts to study spatio-temporal effects in influenza disease modeling are not rare. Waller et al.

propose a hierarchical Bayesian parametric model for the spatio-temporal interaction of generic disease mapping 

[28]. A non-parametric Bayesian method [22] is proposed for predicting spatial and temporal variation of influenza cases. Venna et al. develop data-driven approaches involving climatic and geographical factors for real-time influenza forecasting [26]. Wu et al. use deep learning for modeling spatio-temporal patterns in epidemiological prediction problems [31]. Despite their impressive performance, these methods have limitations such as the requirement of additional data which are not readily available and long-term prediction performance is not satisfactory. Improving the long-term epidemiological prediction with restricted training data is an open research problem.

2.2 Long-term Epidemic Prediction

Long-term prediction (aka multi-step prediction), that is, predicting several steps ahead, is a challenge in time series prediction. Long-term prediction has to face growing uncertainties arising from various problems such as accumulation of errors and lack of information. Long-term prediction methods can be categorized into two types: (i) direct methods and (ii) iterative methods. Direct methods predict a future value using the past values in one shot. Iterative methods recursively invoke short-term predictors to make long-term predictions. Specifically, they use the observed data to predict the next step , then use to predict , and so on.

For long-term predictions using time series data, Sorjamaa et al. combine a direct prediction strategy and sophisticated input selection criteria [23]; Qian-Li et al. and Du et al. develop neural network based methods to improve the performance of long-term prediction [20, 10]. Recent works [26, 31] explore deep learning models for direct long-term epidemiological predictions. DEFSI [29] combines deep neural network methods with causal models to address high-resolution ILI incidence forecasting. Yet most of these models rely heavily on extrinsic data to improve accuracy.

3 The Proposed Method

3.1 Problem Formulation

We formulate the epidemic prediction problem as a regression task with multiple time series as input. Throughout the paper, we denote the number of locations by and the time span for one input example as . We use the terms region and location interchangeably.

At each time step , the multi-location epidemiology profile is denoted by whose elements are the observations from sources/locations, e.g. the influenza patient counts per week in locations. We further denote the training data in a time-span of size as . The objective is to predict an epidemiology profile at a future time point where refers to the horizon/lead time of the prediction.

The proposed framework as shown in Figure 1 consists of three modules: 1) location-aware attention to capture location wise interactions, 2) temporal convolutional layer to capture local temporal features, 3) global graph message passing to combine the temporal features and the location-aware attentions to generate further hidden features and make predictions. Each module is described as below. The pseudocode is described in Algorithm 1 and each module is described as below.

Input: Time series data from multiple locations, geographical adjacency matrix
Output: Model parameters
1 for 

each epoch

2       Randomly sample a mini batch for each region  do
4      for each region pair  do
         Simultaneous calculations for all regions
6       for each region  do
          SGD step
Algorithm 1 Cola-GNN
Figure 1: The overview of the proposed framework. Eq. 4-6 are skipped for brevity.

3.2 Location-aware Attention

In this study, without precise population movement data, we dynamically model the impact of one area on other areas during the epidemics of infectious disease. We first learn hidden states for each location given a time period using a Recurrent Neural Network (RNN) given its great success in sequential (temporal) data prediction. Specifically, we use a simple and classic vanilla RNN in this module. The RNN module can be replaced by Gated Recurrent Unit (GRU) 


or Long short-term memory (LSTM) 

[14]; however, in this application, RNN achieves the best performance compared to GRU and LSTM.

Given the multi-location time series data , we employ a global RNN model to capture the temporal dependencies of all locations. For location , an instance of a time series is represented by . Let be the dimension of the hidden state. For each element in the input sequence, the RNN updates its hidden state according to



is the hidden state vector at time

and is the hidden state vector at time ;

is the non-linear activation function;

, , and

determine the adaptive weight and bias vectors of the RNN. Let

be the last hidden state and we will use it to represent location .

Next, we define an attention coefficient for measuring the impact of location on location .

Additive attention (or multi-layer perceptron attention) 

[2] and multiplicative attention (or dot-product attention) [25, 24] are the two most commonly used attention mechanisms. They share the same idea of computing the alignment score between elements from two sources, but with different compatibility functions. We utilize the compatibility function of additive attention due to its better predictive quality, which is defined as:


where is an activation function that is applied element-wise; , , , and are trainable parameters.

is a hyperparameter that controls the dimensions of the parameters in Eq. 

2. Assuming that the impact of location on location is different than vice versa, we obtain an asymmetric attention coefficient matrix

where each row indicates the degree of influence by other locations on the current location. Usually, a softmax function is used to transform the attention scores to a probability distribution. In our problem, the overall impact of other locations vary for different places. For instance, compared to New York, Hawaii may be less affected overall by other states. Instead, we perform normalization over the rows of

to normalize the impact of other locations on one location:


where is a small value to avoid division by zero, and denotes the -norm.

Given the geographic nature of this task, we also consider the spatial distance between two locations. We use to indicate the connectivity of locations: means locations and are neighbors 333By default, each location is adjacent to itself.. The correlation of the two locations may be affected by their geographic distance, i.e. nearby areas may have similar topographic or climatic characteristics that make them have similar flu outbreaks. Non-adjacent areas may also have potential dependencies due to population movements and similar geographical features. Simulating all the factors related to a flu outbreak is difficult. Therefore, we consider both the attention derived from historical data and the geographical distances of the locations. The final location-aware attention matrix is obtained by combining the geographical adjacency matrix and the attention matrix . The combination is accomplished by an element-wise gate , learned from the attention matrix which evolves over time. We consider the attention matrix to be a feature matrix with gate being adapted from the feature fusion gate [13]:


where Eq. 4 is for normalization, is the degree matrix defined as . and are trainable parameters.

3.3 Temporal Convolution Layer

Besides the spatial dependencies, the outbreak of influenza also has its own unique characteristics over time. For instance, the United States experiences annual epidemics of seasonal flu. Most of the time flu activity peaks between December and February, and it can last as late as May 444

. Convolutional Neural Networks (CNN) have shown successful results in capturing various important local patterns from grid data and sequence data. We apply 1D CNN filters to every row of

to capture the temporal dependency; note that the row is the observed sequential data at location . Specifically, we define filters where each filter and is chosen to be the maximum window length in our experiments. Convolutional operations yield , where represents the convolutional value of the -th row vector and the -th filter. Formally, this convolution operation is given by


Max pooling is needed when as in Kim [16]. To constrain the data, we also apply a nonlinearity to the convolution results. Then the new detected temporal feature of each row/location is .

3.4 Graph Message Passing (the Propagation Model)

After learning the cross-location attentions (Section 3.2) and the local hidden features (Section 3.3), we design a flu propagation model using graph neural networks. Graph neural networks iteratively update the node features from their neighbors. When generalized to irregular domains, this operation is often referred to as message passing or neighbor aggregation. Epidemic disease propagation at the population level is usually affected by human connectivity and transmission. Considering each location as a node in a graph, we take advantage of graph neural networks to model the epidemic disease propagation among different locations. We model the adjacency matrix using the cross-location attention matrix and the nodes’ initial features using the the temporal convolutional features. With denoting node features of node in layer and denoting the location-aware attention coefficient from node to node , the message passing graph neural network can be described as


where denotes a nonlinear activation function, is the weight matrix for hidden layer with with feature maps, and is a bias. is the set of locations. is initialized with at the first layer.

3.5 Output Layer (Prediction)

For each location, we learn the RNN hidden states () from its own historical sequence data, as well as the graph features () learned from other locations’ data in our propagation model. We combine these two features and feed them to the output layer for prediction, which is defined as:


where is the activation function (identity or nonlinear) and are model parameters.

3.6 Optimization

We compare the prediction value of each location with the corresponding ground truth and then optimize a regularized -norm loss:


where is the number of samples in location obtained by a moving window, shared by all locations, is the true value of location in sample , and is the model prediction. stands for all training parameters and is the regularization term (e.g. -norm). All model parameters can be trained via back-propagation and optimized by the Adam algorithm [18] given its efficiency and ability to avoid overfitting.

4 Experiment Setup

4.1 Datasets

We prepare three real-world datasets for experiments: Japan-Prefectures, US-States and US-Regions and their data statistics are shown in Table 1.

  • Japan-Prefectures We collect this data from the Infectious Diseases Weekly Report (IDWR) 555 in Japan. This dataset contains weekly influenza-like-illness statistics (patient counts) from 47 prefectures in Japan, ranging from August 2012 to March 2019.

  • US-States We collect the influenza disease data from the Center for Disease Control (CDC) 666 It contains the count of patient visits for ILI for each week and each state in United States from 2010 to 2017. After removing a state with missing data we kept 49 states remaining in this dataset.

  • US-Regions This dataset is the ILINet portion of the US-HHS (Department of Health and Human Services) dataset  66footnotemark: 6

    , consisting of weekly influenza activity levels for 10 HHS regions of U.S. mainland for the period of 2002 to 2017. Each HHS region represents some collection of associated states. We use flu patient counts for each region, which is calculated by combining state-specific data.

Data is normalized to 0-1 range for each region. The maximum value of the region is set to 1, and the minimum value of the region is set to 0. After ordering the data by time, the first 50% is used for training, next 20% for validation, and the last 30% for testing. Validation data is used to determine the number of epochs that should be run to avoid overfitting. We fixed the validation and test sets by dates for different lead time values. In this case, the test data covers 2.1, 4.5, and 2.1 flu seasons in Japan-Prefectures, US-States and US-Regions respectively. Accordingly, there are at least 3, 7.2 and 3 flu seasons in the three training sets. All data is normalized based on the maximum and minimum values of the training data.

Data set Size Min Max Mean SD
Japan-Prefectures 47348 0 26635 655 1711
US-Regions 10785 0 16526 1009 1351
US-States 49360 0 9716 223 428
Table 1:

Dataset statistics: min, max, mean, and standard deviation (SD) of patient counts; dataset size means number of locations multiplied by # of weeks.

4.2 Evaluation Metrics

In the experiments, we adopt the following metrics for evaluation. Denote the prediction and true values to be and , respectively. We do not distinguish regions in evaluation.

The Root Mean Squared Error (RMSE) measures the difference between predicted and true values after projecting the normalized values into the real range:

The Mean Absolute Error (MAE) is a measure of difference between two continuous variables:

The Pearson’s Correlation (PCC) is a measure of the linear dependence between two variables:

Leadtime is the number of weeks that the model predicts in advance. For instance, if we use as input and predict the infected patients of the fifth week (leadtime = 5) after current week , the ground truth (expected output) is .

4.3 Comparison Methods

We compare our model with several state-of-the-art methods and their variants listed as below.

  • Autoregressive (AR) Autoregressive models have been widely applied for time series forecasting [4, 30]. Basically, the future state is modeled as a linear combination of past data points. We train an autoregressive model for each location. No data and parameters are shared among locations.

  • Global Autoregression (GAR) This model is mainly used when training data is limited. We train one global model using the data available from each location.

  • Vector Autoregression (VAR) The VAR models cross-signal dependence to address the potential drawback of the AR model, i.e. the signal sources are processed independently of each other. Therefore, it introduces more parameters and is more expensive in training.

  • Autoregressive Moving Average (ARMA) ARMA contains the autoregressive terms and moving-average terms together. A considerable amount of preprocessing has to be performed before such model fitting. The order of the moving average is set to 2 in implementation.

  • Recurrent Neural Network (RNN) RNNs have demonstrated powerful abilities to predict temporal dependencies. We employ a global RNN for our problem, that is, parameters are shared across different regions. RNN can be be replaced by GRU or LSTM. Experimentally, fancy RNN models didn’t achieve better results, so we only consider simple RNN for comparison.

  • RNN+Attn [6] This model considers the self-attention mechanism in a global RNN. In the calculation of rnn units, the hidden state is replaced by a summary vector, which uses the attention mechanism to aggregate all the information of the previous hidden state.

  • CNNRNN-Res [31] A deep learning framework that combines CNN, RNN and residual links to solve epidemiological prediction problems.

  • GCNRNN-Res A variation of CNNRNN-Res. We change the CNN module to a GCN [19] module with two hidden layers, the feature dimensions of which remain unchanged. We utilize the given geographical adjacent matrix.

Hyper-parameter Setting & Implementation Details In our model, we adopt exponential linear unit (ELU) [9] as nonlinearity for function in Eq. 2, and idendity for function in Eq. 9. In the experiment, the input window size is 20 weeks, which spans roughly five months. The hyperparameter in the location-aware attention is set to to reduce the number of parameters compared to standard additive attention. The order of the norm in Eq. 3 is set to 2, and is 1e-12. The number of filters is 10 in Eq. 7. For all methods using the RNN module, we tune the hidden dimensions of the RNN module from {10, 20, 30}, and 20 yields the best performance in most cases. The number of RNN hidden layers and graph layers is optimized to 1 and 2 respectively. In the training process, the best models are selected by early stopping when the validation accuracy does not increase for 200 consecutive epochs, and the maximum epoch is 1500. All the parameters are initialized with Glorot initialization [12] and trained using the Adam [17] optimizer with weight decay 5e-4, and dropout rate 0.2. The initial learning rate of all methods is searched from the set {0.001, 0.005, 0.01}. The batch size is set to 32 across all datasets. All experimental results are the average of 10 randomized trials.

Suppose the dimension of weight matrices in graph message passing is set to , the number of parameters of the proposed model is . In our epidemiological prediction problems, and are limited by relatively small numbers.

5 Results

Japan-Prefectures US-Regions US-States
RMSE() 2 3 4 2 3 4 2 3 4
GAR 1232 1628 1865 536 715 859 150 187 213
AR 1377 1705 1901 570 757 888 161 204 231
VAR 1361 1711 1910 741 870 967 290 276 283
ARMA 1371 1703 1902 560 742 874 161 200 228
RNN 1001 1259 1366 513 689 805 149 181 204
RNN+Attn 1166 1572 1706 613 753 962 152 186 210
CNNRNN-Res 1133 1550 1795 571 738 802 205 239 253
GCNRNN-Res 1031 1129 1133 736 847 935 194 210 236
Cola-GNN 919 1060 1072 483 633 765 136 167 191
PCC() 2 3 4 2 3 4 2 3 4
GAR 0.804 0.626 0.461 0.932 0.881 0.835 0.945 0.914 0.893
AR 0.752 0.579 0.428 0.927 0.878 0.834 0.94 0.909 0.885
VAR 0.754 0.585 0.419 0.859 0.797 0.741 0.765 0.79 0.78
ARMA 0.754 0.579 0.428 0.927 0.876 0.833 0.939 0.909 0.886
RNN 0.892 0.833 0.813 0.94 0.895 0.855 0.948 0.922 0.9
RNN+Attn 0.85 0.668 0.604 0.887 0.859 0.774 0.947 0.922 0.903
CNNRNN-Res 0.852 0.673 0.513 0.92 0.862 0.829 0.904 0.86 0.842
GCNRNN-Res 0.893 0.889 0.886 0.871 0.831 0.796 0.903 0.884 0.854
Cola-GNN 0.911 0.893 0.894 0.944 0.905 0.863 0.955 0.933 0.907
Table 2: RMSE and PCC performance of different methods on the three datasets with leadtime = 2, 3, 4. Bold face indicates the best result of each column and underlined the second-best. (Short-term)
Japan-Prefectures US-Regions US-States
RMSE() 5 10 15 5 10 15 5 10 15
GAR 1988 2065 2016 991 1377 1465 236 314 340
AR 2013 2107 2042 997 1330 1404 251 306 327
VAR 2025 1942 1899 1059 1270 1299 295 324 352
ARMA 2013 2105 2041 989 1322 1400 250 306 326
RNN 1376 1696 1629 896 1328 1434 217 274 315
RNN+Attn 1746 1612 1823 1065 1367 1368 234 315 334
CNNRNN-Res 1942 1865 1862 936 1233 1285 267 260 250
GCNRNN-Res 1178 1384 1457 1051 1298 1402 248 275 288
Cola-GNN 1156 1403 1500 871 1126 1218 202 241 232
PCC() 5 10 15 5 10 15 5 10 15
GAR 0.339 0.288 0.47 0.79 0.581 0.485 0.875 0.777 0.742
AR 0.310 0.238 0.483 0.792 0.612 0.527 0.863 0.773 0.723
VAR 0.3 0.426 0.474 0.685 0.508 0.467 0.758 0.709 0.6529
ARMA 0.31 0.253 0.486 0.792 0.614 0.52 0.862 0.773 0.725
RNN 0.821 0.616 0.709 0.821 0.587 0.499 0.886 0.821 0.758
RNN+Attn 0.59 0.741 0.522 0.752 0.554 0.552 0.884 0.78 0.739
CNNRNN-Res 0.38 0.438 0.467 0.782 0.552 0.4851 0.822 0.82 0.847
GCNRNN-Res 0.875 0.823 0.774 0.739 0.554 0.4471 0.844 0.814 0.814
Cola-GNN 0.883 0.818 0.754 0.832 0.719 0.639 0.897 0.822 0.859
Table 3: RMSE and PCC performance of different methods on the three datasets with leadtime = 5, 10, 15. Bold face indicates the best result of each column and underlined the second-best. (Long-term)
Figure 2: PCC of the flu prediction models with different leadtimes on three datasets.
Figure 3: PCC of the flu prediction models with different leadtimes on three datasets.

5.1 Prediction Performance

We evaluate our approach in short-term (leadtime = 2, 3, 4) and long-term (leadtime = 5, 10, 15) lead time settings. We ignore the case of leadtime = 1, because symptom monitoring data is usually delayed by at least one week. Table 2 summarizes the results of all the methods in terms of RMSE and PCC in short-term settings. We can observe that when the lead time is relatively small, our method achieves the most stable and optimal performance on all datasets. In this case, most of the methods can capture relatively good performance in the three datasets, which is due to the small information gap between the history window and the predicted time, thus the models can fit the temporal pattern more easily. The one exception is that in the Japan-Prefectures dataset, the results of most baseline methods deteriorate with a slight increase in lead time. A possible reason for this phenomenon in the Japan-Prefectures dataset is that the seasonal influenza curve in the dataset is more noisy and less predictive. The dataset statistic also shows that Japan-Prefectures dataset has the largest standard deviation.

Table 3

reports the RMSE and PCC results in long-term settings. Overall, the proposed method achieves best performance for most datasets with long lead time windows (leadtime = 5, 10 or 15 weeks). Autoregression models have poor performance, especially VAR which has the largest number of model parameters. This suggests the importance of controlling the model complexity for data insufficiency problems. Recurrent neural network models only achieve good predictive performance when lead time is small, which demonstrates that long-term predictions require a better design to capture spatial and temporal dependencies. CNNRNN-Res uses geographic location information and it only performs well in the US-States dataset. In the Japan-Prefectures and US-Regions datasets, the model performs poorly when having long lead time windows. Its variant GCNRNN-Res contains a graph convolutional module that learns the features from adjacent regions. GCNRNN-Res has achieved good results in Japan-Prefectures and US-States datasets. It proves that the graph convolution module can help capture long-term dependencies. The performances of CNNRNN-Res and GCNRNN-Res are unstable on three datasets and often show large variance in multiple rounds of training. To better visualize the results, we show the mean value and standard deviation of the 10 different runs of some models in Figure 

2 and Figure 3.

If we look at the big picture of the prediction performance, the performance difference of all methods is relatively small when the lead time is 2, but as the lead time increases, the predictive power of simple methods (such as autoregressive) decreases significantly. This suggests that modeling temporal dependence is challenging when a relatively large gap exists between the historical window and the expected prediction time.

5.2 Case Studies

To evaluate the long-term predictive performance of the proposed model, we plot a sequence of predictions, where lead time is 15, in the test set. Four better baselines were chosen and the comparison on the three datasets is shown in Figure 4. We randomly select three locations from each dataset and observe that even though we are using a relatively small window (20) to predict long-term flu count (leadtime = 15), our model is able to better capture the trend and outbreak time of the epidemic outbreak.

Figure 4: Iterative prediction results when leadtime = 15. We test the models trained in leadtime = 15 by moving the history window.

We fix the input window and plot the prediction curve of leadtime from 1 to 20. Likewise, we also randomly sample three locations from each dataset. From the observation in Figure 5, our model tends to capture the peaks and trends in future time based on given historical data.

Figure 5: Direct prediction curves with fixed input windows. We fix the input window and test the models trained with a lead time from 1 to 20.

5.3 Attention Visualization

Figure 6 shows an example of the location-aware attention mechanism with a lead time of 15 in the US-Regions dataset. In this example, we focus on region 5. We visualize the input data of region 5 and two regions {region 3, region 4} with highest attention values for region 5, as well as two regions that has lowest attention values {region 1, region 8}. We normalize the data by regions to better compare flu outbreaks across regions. The time period of the light yellow shade is the input sequence of window = 20. The vertical line indicates the predicted time. We are using only a small part of the sequence of all regions to predict the epidemic outbreak of region 5 in 15 weeks. The regions with higher attention share same early epidemic outbreak as region 5 while regions with lower attention values have later outbreak times.

We show the normalized geolocation distance matrix in Figure 7, which is calculated according to Eq. 4, and the Pearson correlation coefficient of input time series in Figure 7. The learned attention matrix (Figure 7) utilizes geolocation information as well as additive attentions among regions. From the learned attention matrix, we observe that adjacent regions sometimes get higher attention values. Meanwhile, non-adjacent regions can also receive high attention values given their similar long-term influenza trends. The learned attention reveals hidden dynamics (e.g., epidemic outbreaks and peak time) among regions.

Figure 6: An example of the location-aware attention mechanism with a lead time of 15 in the US-Regions data set. Yellow line indicates prediction time. Shaded area is the input.
Figure 7: Comparison of original geolocation matrix (7), input correlation matrix (7), and learned attention matrix (US region).

5.4 Ablation Tests

RMSE() 2 3 4 5 10 15
Cola-GNN w/o 911 1115 1204 1310 1388 1517
Cola-GNN w/o 942 1154 1164 1195 1473 1576
Cola-GNN 919 1060 1072 1156 1403 1500
Cola-GNN w/o 485 662 772 888 1144 1228
Cola-GNN w/o 499 666 782 890 1179 1292
Cola-GNN 483 633 765 871 1126 1218
Cola-GNN w/o 138 169 188 194 251 251
Cola-GNN w/o 138 169 193 202 246 246
Cola-GNN 136 167 191 202 241 232
PCC() 2 3 4 5 10 15
Cola-GNN w/o 0.91 0.867 0.846 0.818 0.793 0.744
Cola-GNN w/o 0.914 0.881 0.89 0.88 0.781 0.727
Cola-GNN 0.911 0.893 0.894 0.883 0.818 0.754
Cola-GNN w/o 0.944 0.902 0.861 0.824 0.712 0.588
Cola-GNN w/o 0.942 0.898 0.858 0.824 0.682 0.582
Cola-GNN 0.944 0.905 0.863 0.832 0.719 0.639
Cola-GNN w/o 0.953 0.93 0.908 0.908 0.833 0.836
Cola-GNN w/o 0.955 0.931 0.913 0.904 0.856 0.855
Cola-GNN 0.955 0.933 0.907 0.897 0.822 0.859
Table 4: Ablation test results in RMSE(top) and PCC(bottom) when leadtime=2,3,4,5,10,15 for three datasets.

To analyze the effect of each component in our framework, we perform the ablation tests on all the datasets with the follow settings:

  • Cola-GNN w/o : Remove the temporal convolution module from the proposed model, and use the raw time series input as features in graph message passing.

  • Cola-GNN w/o : Remove the location-aware attention module and directly use the geographical adjacent matrix which defines the spatial distance between pairs of locations.

The results of RMSE and PCC are shown in Table 4. We can observe that in most cases, variant versions of the proposed method can achieve very good performance. In the US-states dataset, models without temporal or location-aware attention modules are sometimes slightly better than the full model. The US-states dataset has the lowest number of reported influenza cases compare with two other datasets, and the standard deviation is small. Overall, the full model achieves optimal performance across all datasets. Note that all datasets are relatively small in size, which means that adding more parameters may affect the performance due to overfitting. However, adding temporal and spatial modules does not change the short-term (leadtime = 2,3,4) prediction very much. Instead, for long-term predictions (leadtime = 15), involving these two modules produces better results.

Figure 8: Sensitivity analysis on window size.

5.5 Sensitivity Analysis

In this section, we investigate how the prediction performance varies with some hyperparameters.

Size of History Windows

To test if our model is sensitive to the length of historical data, we evaluate different window sizes from 10 to 50 with step 5. The experiment was conducted on US-Regions and US-States datasets as shown in Figure 8. The predictive performance in RMSE and MAE with different window sizes are fairly stable. We can avoid training with very long sequences and achieve relatively comparable results.

Size of Graph Features

We learn the RNN hidden states from the historical sequence data and the graph features which involves features of other regions by message passing over location-aware attentions. We vary the dimension of the graph feature from 1 to 15 and evaluate the predictive performance in US-States dataset when leadtime is 15. Figure 9 reports RMSE and MAE results. Features of smaller dimensions result in poor predictive performance due to limited encoding power. The model produces better predictive power when the feature dimension is larger.

Figure 9: Sensitivity analysis of graph feature size.

RNN Modules

The RNN module is used to output a hidden state vector for each location based on given historical data. The hidden state vector is then provided to the location-aware attention module. We replaced the RNN modules with GRU and LSTM to assess their impact on model performance. Figure 10 shows RMSE results for leadtime = 2,5,10,15 in US-Regions and US-States datasets. We found that the performance of GRU and LSTM is not better than a simple RNN. The likely reason is that they involve more model parameters and tend to overfit in the epidemiological datasets.

Figure 10: Sensitivity analysis on RNN modules.

5.6 Model Complexity

Architecture Parameters Runtime(s)
GAR 21 0.01
AR 1,029 0.02
VAR 48,069 0.02
ARMA 1,960 0.03
RNN 481 0.04
RNN+Attn 1,321 0.58
CNNRNN-Res 7,695 0.04
GCNRNN-Res 6,214 0.04
Cola-GNN 3,778 0.21
Table 5: Runtime comparison of models on the US-States dataset. Runtime is the time spent on a single GPU per epoch.

Table 5 shows the comparison of runtimes and numbers of parameters for each model on the US-States dataset, which has the largest number of regions among the three datasets. In this task, all methods can be effectively trained due to the nature of the datasets. Meanwhile, we only utilize flu disease data and geographic location data, while ignoring other external features. Compared with other methods, the proposed method has no significant effect on training efficiency. It can also control the size of the model parameters to prevent overfitting.

6 Conclusion

In this work, we propose a graph-based deep learning framework with cross-location attentions to study the spatio-temporal influence of long-term epidemiological predictions. We demonstrate the effectiveness of the proposed model on real-world epidemiological datasets. The proposed method is not flexible enough in the case that different models are trained for different lead time settings. Future work will consider iterative predictions to increase the flexibility of the model. Another research direction is to involve more complex dependencies such as weather, social factors, and population migration. We intend to determine if the prediction accuracy is improved when using external indicators. Furthermore, it is also essential to identify the main factors affecting the epidemic outbreak of one area by learning multiple areas simultaneously.


  • [1] H. Achrekar, A. Gandhe, R. Lazarus, S. Yu, and B. Liu (2011) Predicting flu trends using twitter data. In 2011 IEEE conference on computer communications workshops (INFOCOM WKSHPS), pp. 702–707. Cited by: §2.1.
  • [2] D. Bahdanau, K. Cho, and Y. Bengio (2014) Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. Cited by: §3.2.
  • [3] K. R. Bisset, J. Chen, X. Feng, V. Kumar, and M. V. Marathe (2009) EpiFast: a fast algorithm for large scale realistic epidemic simulations on distributed memory systems. In Proceedings of the 23rd international conference on Supercomputing, pp. 430–439. Cited by: §1.
  • [4] J. S. Brownstein, S. Chu, A. Marathe, M. V. Marathe, A. T. Nguyen, D. Paolotti, N. Perra, D. Perrotta, M. Santillana, S. Swarup, et al. (2017) Combining participatory influenza surveillance with modeling and forecasting: three alternative approaches. JMIR public health and surveillance 3 (4), pp. e83. Cited by: 1st item.
  • [5] P. Chakraborty, P. Khadivi, B. Lewis, A. Mahendiran, J. Chen, P. Butler, E. O. Nsoesie, S. R. Mekaru, J. S. Brownstein, M. V. Marathe, et al. (2014) Forecasting a moving target: ensemble models for ili case count predictions. In Proceedings of the 2014 SIAM international conference on data mining, pp. 262–270. Cited by: §2.1.
  • [6] J. Cheng, L. Dong, and M. Lapata (2016-11) Long short-term memory-networks for machine reading. In

    Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing

    Austin, Texas, pp. 551–561. External Links: Document Cited by: 6th item.
  • [7] K. Cho, B. van Merriënboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio (2014-10) Learning phrase representations using RNN encoder–decoder for statistical machine translation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, pp. 1724–1734. External Links: Document Cited by: §3.2.
  • [8] G. Chowell, M. Miller, and C. Viboud (2008) Seasonal influenza in the united states, france, and australia: transmission and prospects for control. Epidemiology & Infection 136 (6), pp. 852–864. Cited by: §1.
  • [9] D. Clevert, T. Unterthiner, and S. Hochreiter (2015) Fast and accurate deep network learning by exponential linear units (elus). In Proceedings of the 2015 International Conference on Learning Representations, Vol. abs/1511.07289. Cited by: §4.3.
  • [10] B. Du, W. Xu, B. Song, Q. Ding, and S. Chu (2014)

    Prediction of chaotic time series of rbf neural network based on particle swarm optimization

    In Intelligent Data analysis and its Applications, Volume II, pp. 489–497. Cited by: §2.2.
  • [11] A. F. Dugas, M. Jalalpour, Y. Gel, S. Levin, F. Torcaso, T. Igusa, and R. E. Rothman (2013) Influenza forecasting with google flu trends. PloS one 8 (2), pp. e56176. Cited by: §2.1.
  • [12] X. Glorot and Y. Bengio (2010) Understanding the difficulty of training deep feedforward neural networks. In

    Proceedings of the thirteenth international conference on artificial intelligence and statistics

    pp. 249–256. Cited by: §4.3.
  • [13] Y. Gong and S. Bowman (2018-07) Ruminating reader: reasoning with gated multi-hop attention. In Proceedings of the Workshop on Machine Reading for Question Answering, Melbourne, Australia, pp. 1–11. External Links: Document Cited by: §3.2.
  • [14] S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §3.2.
  • [15] W. O. Kermack and A. G. McKendrick (1927) A contribution to the mathematical theory of epidemics. Proceedings of the royal society of london. Series A, Containing papers of a mathematical and physical character 115 (772), pp. 700–721. Cited by: §1.
  • [16] Y. Kim (2014-10) Convolutional neural networks for sentence classification. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, pp. 1746–1751. External Links: Document Cited by: §3.3.
  • [17] D. Kinga and J. B. Adam (2015) A method for stochastic optimization. In International Conference on Learning Representations, Vol. 5. Cited by: §4.3.
  • [18] D. P. Kingma and J. Ba (2015) Adam: a method for stochastic optimization. In Proceedings of the 2015 International Conference on Learning Representations, Vol. abs/1412.6980. Cited by: §3.6.
  • [19] T. N. Kipf and M. Welling (2016) Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907. Cited by: 8th item.
  • [20] M. Qian-Li, Z. Qi-Lun, P. Hong, Z. Tan-Wei, and Q. Jiang-Wei (2008) Multi-step-prediction of chaotic time series based on co-evolutionary recurrent neural network. Chinese Physics B 17 (2), pp. 536. Cited by: §2.2.
  • [21] J. C. Santos and S. Matos (2014) Analysing twitter and web queries for flu trend prediction. Theoretical Biology and Medical Modelling 11 (1), pp. S6. Cited by: §1.
  • [22] R. Senanayake, S. O’Callaghan, and F. Ramos (2016) Predicting spatio-temporal propagation of seasonal influenza using variational gaussian process regression. In Thirtieth AAAI Conference on Artificial Intelligence, Cited by: §1, §2.1.
  • [23] A. Sorjamaa, J. Hao, N. Reyhani, Y. Ji, and A. Lendasse (2007) Methodology for long-term prediction of time series. Neurocomputing 70 (16-18), pp. 2861–2869. Cited by: §2.2.
  • [24] S. Sukhbaatar, J. Weston, R. Fergus, et al. (2015) End-to-end memory networks. In Advances in neural information processing systems, pp. 2440–2448. Cited by: §3.2.
  • [25] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §3.2.
  • [26] S. R. Venna, A. Tavanaei, R. N. Gottumukkala, V. V. Raghavan, A. S. Maida, and S. Nichols (2018) A novel data-driven model for real-time influenza forecasting. IEEE Access 7, pp. 7691–7701. Cited by: §1, §2.1, §2.2.
  • [27] C. Viboud, P. Boëlle, F. Carrat, A. Valleron, and A. Flahault (2003) Prediction of the spread of influenza epidemics by the method of analogues. American Journal of Epidemiology 158 (10), pp. 996–1006. Cited by: §2.1.
  • [28] L. A. Waller, B. P. Carlin, H. Xia, and A. E. Gelfand (1997) Hierarchical spatio-temporal mapping of disease rates. Journal of the American Statistical association 92 (438), pp. 607–617. Cited by: §2.1.
  • [29] L. Wang, J. Chen, and M. Marathe (2019-Jul.) DEFSI: deep learning based epidemic forecasting with synthetic information. Vol. 33, pp. 9607–9612. External Links: Document Cited by: §1, §2.2.
  • [30] Z. Wang, P. Chakraborty, S. R. Mekaru, J. S. Brownstein, J. Ye, and N. Ramakrishnan (2015) Dynamic poisson autoregression for influenza-like-illness case count prediction. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1285–1294. Cited by: §2.1, 1st item.
  • [31] Y. Wu, Y. Yang, H. Nishiura, and M. Saitoh (2018) Deep learning for epidemiological predictions. In The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval, pp. 1085–1088. Cited by: §1, §2.1, §2.2, 7th item.
  • [32] W. Yang, A. Karspeck, and J. Shaman (2014) Comparison of filtering methods for the modeling and retrospective forecasting of influenza epidemics. PLoS computational biology 10 (4), pp. e1003583. Cited by: §2.1.