1 Introduction
Epidemic disease propagation that involves large populations and wide areas can have a significant impact on society. The Center for Disease Control and Prevention (CDC) estimates 79,400 deaths from influenza occurred during the 20172018 season in the United States
^{1}^{1}1https://tinyurl.com/y3tf8ebl. Early forecasting of infectious diseases such as influenzalike illness (ILI) provides optimal opportunities for timely intervention and resource allocation. It helps with the timely preparation of corresponding vaccines in health care departments which leads to reduced financial burden. For instance, the World Health Organization (WHO) reports that Australia spent over 352 million dollars on routine immunization in the 2017 fiscal year ^{2}^{2}2https://tinyurl.com/y2duz5p8. We focus on the problem of long term ILI forecasting with lead time from 1 to 20 weeks based on the influenza surveillance data collected for multiple locations (states and regions). Given the process of data collection and surveillance lag, accurate statistics for influenza warning systems are often delayed by a few weeks, making early prediction imperative. However, there are a few challenges in longterm epidemic forecasting. First, the temporal dependency is hard to capture with shortterm input data. Without manually added seasonal trends, most statistical models fail to provide high accuracy. Second, the influence from other locations has not been exhaustively explored with limited data input. Spatiotemporal effects have been studied but they usually require adequate data sources to achieve good performance [22].Existing work on epidemic prediction has been focused on various aspects: 1) Traditional causal models [15, 8, 3]
, including compartmental models and agentbased models, employ disease progression mechanisms such as SusceptibleInfectiousRecovered (SIR) to capture the dynamics of ILI diseases. Compartmental models focus on mathematical modeling of populationlevel dynamics. Agentbased models simulate the propagation process at the individual level with contact networks. Calibrating these models is challenging due to the high dimensionality of the parameter space. 2) Time series prediction with statistical models such as Autoregressive (AR) and its variants (e.g., VAR) are not suitable for long term ILI trend forecasting given that the disease activities and human environments evolve over time. 3) Machine learning and deep learning methods
[21, 26, 31, 29]such as recurrent neural networks have been explored in recent years but they barely consider crossspatial effects in long term disease propagation.
In this paper, we focus on long term (1020 weeks) prediction of the count of ILI patients using data from a limited time range (20 weeks). To tackle this problem, we explore a graph propagation model with deep spatial representations to compensate the loss of temporal information. Assuming each location is a node, we design a graph neural network framework to model epidemic propagation at the population level. Meanwhile, we investigate recurrent neural networks for capturing sequential dependencies in local time series data and temporal convolutions for identifying shortwindow patterns. Our key contributions are summarized as follows:

We propose a novel graphbased deep learning framework for longterm epidemic prediction from a timeseries forecasting perspective. This is one of the first works of graph neural networks adapted to epidemic forecasting.

We investigate a locationaware attention mechanism to capture location correlations. The influence of locations can be directed and automatically optimized in the model learning process. The attention matrix is further evaluated as an adjacency matrix in the graph neural network for modeling disease propagation.

We design a temporal convolution module to automatically extract temporal dependencies and hidden features for time series data of multiple locations. The learned temporal features for each location are utilized as node attributes for the graph neural network.

The proposed method, ColaGNN, outperforms a broad range of stateoftheart models on three realword datasets with different longterm prediction settings. We also demonstrate the effectiveness of its learned attention matrix compared to a geographical adjacency matrix in an ablation study.
2 Related Work
2.1 Influenza Prediction
In many studies, forecasting influenza or influenzalike illnesses (ILI) case counts is formulated as time series regression problems, where autoregressive models are widely used
[27, 1, 11, 30]. Instead of focusing on seasonal effects, Wang et al. [30] propose a dynamic poisson autoregressive model to improve shortterm prediction accuracy (e.g. 14 weeks). Furthermore, variations of particle filters and ensemble filters have been used to predict influenza activities. Yang et al. [32] evaluate the performance of six stateoftheart filters to forecast influenza activity and concluded that the models have comparable performance. Ensemble methods such as matrix factorization based regression and nearest neighbor based regression have been studied [5]. While autoregressive, filterbased, and ensemble models are simple and straightforward, they often neglect the geographical dependence in disease propagation.Attempts to study spatiotemporal effects in influenza disease modeling are not rare. Waller et al.
propose a hierarchical Bayesian parametric model for the spatiotemporal interaction of generic disease mapping
[28]. A nonparametric Bayesian method [22] is proposed for predicting spatial and temporal variation of influenza cases. Venna et al. develop datadriven approaches involving climatic and geographical factors for realtime influenza forecasting [26]. Wu et al. use deep learning for modeling spatiotemporal patterns in epidemiological prediction problems [31]. Despite their impressive performance, these methods have limitations such as the requirement of additional data which are not readily available and longterm prediction performance is not satisfactory. Improving the longterm epidemiological prediction with restricted training data is an open research problem.2.2 Longterm Epidemic Prediction
Longterm prediction (aka multistep prediction), that is, predicting several steps ahead, is a challenge in time series prediction. Longterm prediction has to face growing uncertainties arising from various problems such as accumulation of errors and lack of information. Longterm prediction methods can be categorized into two types: (i) direct methods and (ii) iterative methods. Direct methods predict a future value using the past values in one shot. Iterative methods recursively invoke shortterm predictors to make longterm predictions. Specifically, they use the observed data to predict the next step , then use to predict , and so on.
For longterm predictions using time series data, Sorjamaa et al. combine a direct prediction strategy and sophisticated input selection criteria [23]; QianLi et al. and Du et al. develop neural network based methods to improve the performance of longterm prediction [20, 10]. Recent works [26, 31] explore deep learning models for direct longterm epidemiological predictions. DEFSI [29] combines deep neural network methods with causal models to address highresolution ILI incidence forecasting. Yet most of these models rely heavily on extrinsic data to improve accuracy.
3 The Proposed Method
3.1 Problem Formulation
We formulate the epidemic prediction problem as a regression task with multiple time series as input. Throughout the paper, we denote the number of locations by and the time span for one input example as . We use the terms region and location interchangeably.
At each time step , the multilocation epidemiology profile is denoted by whose elements are the observations from sources/locations, e.g. the influenza patient counts per week in locations. We further denote the training data in a timespan of size as . The objective is to predict an epidemiology profile at a future time point where refers to the horizon/lead time of the prediction.
The proposed framework as shown in Figure 1 consists of three modules: 1) locationaware attention to capture location wise interactions, 2) temporal convolutional layer to capture local temporal features, 3) global graph message passing to combine the temporal features and the locationaware attentions to generate further hidden features and make predictions. Each module is described as below. The pseudocode is described in Algorithm 1 and each module is described as below.
3.2 Locationaware Attention
In this study, without precise population movement data, we dynamically model the impact of one area on other areas during the epidemics of infectious disease. We first learn hidden states for each location given a time period using a Recurrent Neural Network (RNN) given its great success in sequential (temporal) data prediction. Specifically, we use a simple and classic vanilla RNN in this module. The RNN module can be replaced by Gated Recurrent Unit (GRU)
[7]or Long shortterm memory (LSTM)
[14]; however, in this application, RNN achieves the best performance compared to GRU and LSTM.Given the multilocation time series data , we employ a global RNN model to capture the temporal dependencies of all locations. For location , an instance of a time series is represented by . Let be the dimension of the hidden state. For each element in the input sequence, the RNN updates its hidden state according to
(1) 
where
is the hidden state vector at time
and is the hidden state vector at time ;is the nonlinear activation function;
, , anddetermine the adaptive weight and bias vectors of the RNN. Let
be the last hidden state and we will use it to represent location .Next, we define an attention coefficient for measuring the impact of location on location .
Additive attention (or multilayer perceptron attention)
[2] and multiplicative attention (or dotproduct attention) [25, 24] are the two most commonly used attention mechanisms. They share the same idea of computing the alignment score between elements from two sources, but with different compatibility functions. We utilize the compatibility function of additive attention due to its better predictive quality, which is defined as:(2) 
where is an activation function that is applied elementwise; , , , and are trainable parameters.
is a hyperparameter that controls the dimensions of the parameters in Eq.
2. Assuming that the impact of location on location is different than vice versa, we obtain an asymmetric attention coefficient matrixwhere each row indicates the degree of influence by other locations on the current location. Usually, a softmax function is used to transform the attention scores to a probability distribution. In our problem, the overall impact of other locations vary for different places. For instance, compared to New York, Hawaii may be less affected overall by other states. Instead, we perform normalization over the rows of
to normalize the impact of other locations on one location:(3) 
where is a small value to avoid division by zero, and denotes the norm.
Given the geographic nature of this task, we also consider the spatial distance between two locations. We use to indicate the connectivity of locations: means locations and are neighbors ^{3}^{3}3By default, each location is adjacent to itself.. The correlation of the two locations may be affected by their geographic distance, i.e. nearby areas may have similar topographic or climatic characteristics that make them have similar flu outbreaks. Nonadjacent areas may also have potential dependencies due to population movements and similar geographical features. Simulating all the factors related to a flu outbreak is difficult. Therefore, we consider both the attention derived from historical data and the geographical distances of the locations. The final locationaware attention matrix is obtained by combining the geographical adjacency matrix and the attention matrix . The combination is accomplished by an elementwise gate , learned from the attention matrix which evolves over time. We consider the attention matrix to be a feature matrix with gate being adapted from the feature fusion gate [13]:
(4)  
(5)  
(6) 
where Eq. 4 is for normalization, is the degree matrix defined as . and are trainable parameters.
3.3 Temporal Convolution Layer
Besides the spatial dependencies, the outbreak of influenza also has its own unique characteristics over time. For instance, the United States experiences annual epidemics of seasonal flu. Most of the time flu activity peaks between December and February, and it can last as late as May ^{4}^{4}4https://tinyurl.com/yxevpqs9
. Convolutional Neural Networks (CNN) have shown successful results in capturing various important local patterns from grid data and sequence data. We apply 1D CNN filters to every row of
to capture the temporal dependency; note that the row is the observed sequential data at location . Specifically, we define filters where each filter and is chosen to be the maximum window length in our experiments. Convolutional operations yield , where represents the convolutional value of the th row vector and the th filter. Formally, this convolution operation is given by(7) 
Max pooling is needed when as in Kim [16]. To constrain the data, we also apply a nonlinearity to the convolution results. Then the new detected temporal feature of each row/location is .
3.4 Graph Message Passing (the Propagation Model)
After learning the crosslocation attentions (Section 3.2) and the local hidden features (Section 3.3), we design a flu propagation model using graph neural networks. Graph neural networks iteratively update the node features from their neighbors. When generalized to irregular domains, this operation is often referred to as message passing or neighbor aggregation. Epidemic disease propagation at the population level is usually affected by human connectivity and transmission. Considering each location as a node in a graph, we take advantage of graph neural networks to model the epidemic disease propagation among different locations. We model the adjacency matrix using the crosslocation attention matrix and the nodes’ initial features using the the temporal convolutional features. With denoting node features of node in layer and denoting the locationaware attention coefficient from node to node , the message passing graph neural network can be described as
(8) 
where denotes a nonlinear activation function, is the weight matrix for hidden layer with with feature maps, and is a bias. is the set of locations. is initialized with at the first layer.
3.5 Output Layer (Prediction)
For each location, we learn the RNN hidden states () from its own historical sequence data, as well as the graph features () learned from other locations’ data in our propagation model. We combine these two features and feed them to the output layer for prediction, which is defined as:
(9) 
where is the activation function (identity or nonlinear) and are model parameters.
3.6 Optimization
We compare the prediction value of each location with the corresponding ground truth and then optimize a regularized norm loss:
(10) 
where is the number of samples in location obtained by a moving window, shared by all locations, is the true value of location in sample , and is the model prediction. stands for all training parameters and is the regularization term (e.g. norm). All model parameters can be trained via backpropagation and optimized by the Adam algorithm [18] given its efficiency and ability to avoid overfitting.
4 Experiment Setup
4.1 Datasets
We prepare three realworld datasets for experiments: JapanPrefectures, USStates and USRegions and their data statistics are shown in Table 1.

JapanPrefectures We collect this data from the Infectious Diseases Weekly Report (IDWR) ^{5}^{5}5https://tinyurl.com/y5dt7stm in Japan. This dataset contains weekly influenzalikeillness statistics (patient counts) from 47 prefectures in Japan, ranging from August 2012 to March 2019.

USStates We collect the influenza disease data from the Center for Disease Control (CDC) ^{6}^{6}6https://tinyurl.com/y39tog3h. It contains the count of patient visits for ILI for each week and each state in United States from 2010 to 2017. After removing a state with missing data we kept 49 states remaining in this dataset.

USRegions This dataset is the ILINet portion of the USHHS (Department of Health and Human Services) dataset ^{6}^{6}footnotemark: 6
, consisting of weekly influenza activity levels for 10 HHS regions of U.S. mainland for the period of 2002 to 2017. Each HHS region represents some collection of associated states. We use flu patient counts for each region, which is calculated by combining statespecific data.
Data is normalized to 01 range for each region. The maximum value of the region is set to 1, and the minimum value of the region is set to 0. After ordering the data by time, the first 50% is used for training, next 20% for validation, and the last 30% for testing. Validation data is used to determine the number of epochs that should be run to avoid overfitting. We fixed the validation and test sets by dates for different lead time values. In this case, the test data covers 2.1, 4.5, and 2.1 flu seasons in JapanPrefectures, USStates and USRegions respectively. Accordingly, there are at least 3, 7.2 and 3 flu seasons in the three training sets. All data is normalized based on the maximum and minimum values of the training data.
Data set  Size  Min  Max  Mean  SD 
JapanPrefectures  47348  0  26635  655  1711 
USRegions  10785  0  16526  1009  1351 
USStates  49360  0  9716  223  428 
Dataset statistics: min, max, mean, and standard deviation (SD) of patient counts; dataset size means number of locations multiplied by # of weeks.
4.2 Evaluation Metrics
In the experiments, we adopt the following metrics for evaluation. Denote the prediction and true values to be and , respectively. We do not distinguish regions in evaluation.
The Root Mean Squared Error (RMSE) measures the difference between predicted and true values after projecting the normalized values into the real range:
The Mean Absolute Error (MAE) is a measure of difference between two continuous variables:
The Pearson’s Correlation (PCC) is a measure of the linear dependence between two variables:
Leadtime is the number of weeks that the model predicts in advance. For instance, if we use as input and predict the infected patients of the fifth week (leadtime = 5) after current week , the ground truth (expected output) is .
4.3 Comparison Methods
We compare our model with several stateoftheart methods and their variants listed as below.

Global Autoregression (GAR) This model is mainly used when training data is limited. We train one global model using the data available from each location.

Vector Autoregression (VAR) The VAR models crosssignal dependence to address the potential drawback of the AR model, i.e. the signal sources are processed independently of each other. Therefore, it introduces more parameters and is more expensive in training.

Autoregressive Moving Average (ARMA) ARMA contains the autoregressive terms and movingaverage terms together. A considerable amount of preprocessing has to be performed before such model fitting. The order of the moving average is set to 2 in implementation.

Recurrent Neural Network (RNN) RNNs have demonstrated powerful abilities to predict temporal dependencies. We employ a global RNN for our problem, that is, parameters are shared across different regions. RNN can be be replaced by GRU or LSTM. Experimentally, fancy RNN models didn’t achieve better results, so we only consider simple RNN for comparison.

RNN+Attn [6] This model considers the selfattention mechanism in a global RNN. In the calculation of rnn units, the hidden state is replaced by a summary vector, which uses the attention mechanism to aggregate all the information of the previous hidden state.

CNNRNNRes [31] A deep learning framework that combines CNN, RNN and residual links to solve epidemiological prediction problems.

GCNRNNRes A variation of CNNRNNRes. We change the CNN module to a GCN [19] module with two hidden layers, the feature dimensions of which remain unchanged. We utilize the given geographical adjacent matrix.
Hyperparameter Setting & Implementation Details In our model, we adopt exponential linear unit (ELU) [9] as nonlinearity for function in Eq. 2, and idendity for function in Eq. 9. In the experiment, the input window size is 20 weeks, which spans roughly five months. The hyperparameter in the locationaware attention is set to to reduce the number of parameters compared to standard additive attention. The order of the norm in Eq. 3 is set to 2, and is 1e12. The number of filters is 10 in Eq. 7. For all methods using the RNN module, we tune the hidden dimensions of the RNN module from {10, 20, 30}, and 20 yields the best performance in most cases. The number of RNN hidden layers and graph layers is optimized to 1 and 2 respectively. In the training process, the best models are selected by early stopping when the validation accuracy does not increase for 200 consecutive epochs, and the maximum epoch is 1500. All the parameters are initialized with Glorot initialization [12] and trained using the Adam [17] optimizer with weight decay 5e4, and dropout rate 0.2. The initial learning rate of all methods is searched from the set {0.001, 0.005, 0.01}. The batch size is set to 32 across all datasets. All experimental results are the average of 10 randomized trials.
Suppose the dimension of weight matrices in graph message passing is set to , the number of parameters of the proposed model is . In our epidemiological prediction problems, and are limited by relatively small numbers.
5 Results
JapanPrefectures  USRegions  USStates  
RMSE()  2  3  4  2  3  4  2  3  4 
GAR  1232  1628  1865  536  715  859  150  187  213 
AR  1377  1705  1901  570  757  888  161  204  231 
VAR  1361  1711  1910  741  870  967  290  276  283 
ARMA  1371  1703  1902  560  742  874  161  200  228 
RNN  1001  1259  1366  513  689  805  149  181  204 
RNN+Attn  1166  1572  1706  613  753  962  152  186  210 
CNNRNNRes  1133  1550  1795  571  738  802  205  239  253 
GCNRNNRes  1031  1129  1133  736  847  935  194  210  236 
ColaGNN  919  1060  1072  483  633  765  136  167  191 
PCC()  2  3  4  2  3  4  2  3  4 
GAR  0.804  0.626  0.461  0.932  0.881  0.835  0.945  0.914  0.893 
AR  0.752  0.579  0.428  0.927  0.878  0.834  0.94  0.909  0.885 
VAR  0.754  0.585  0.419  0.859  0.797  0.741  0.765  0.79  0.78 
ARMA  0.754  0.579  0.428  0.927  0.876  0.833  0.939  0.909  0.886 
RNN  0.892  0.833  0.813  0.94  0.895  0.855  0.948  0.922  0.9 
RNN+Attn  0.85  0.668  0.604  0.887  0.859  0.774  0.947  0.922  0.903 
CNNRNNRes  0.852  0.673  0.513  0.92  0.862  0.829  0.904  0.86  0.842 
GCNRNNRes  0.893  0.889  0.886  0.871  0.831  0.796  0.903  0.884  0.854 
ColaGNN  0.911  0.893  0.894  0.944  0.905  0.863  0.955  0.933  0.907 
JapanPrefectures  USRegions  USStates  
RMSE()  5  10  15  5  10  15  5  10  15 
GAR  1988  2065  2016  991  1377  1465  236  314  340 
AR  2013  2107  2042  997  1330  1404  251  306  327 
VAR  2025  1942  1899  1059  1270  1299  295  324  352 
ARMA  2013  2105  2041  989  1322  1400  250  306  326 
RNN  1376  1696  1629  896  1328  1434  217  274  315 
RNN+Attn  1746  1612  1823  1065  1367  1368  234  315  334 
CNNRNNRes  1942  1865  1862  936  1233  1285  267  260  250 
GCNRNNRes  1178  1384  1457  1051  1298  1402  248  275  288 
ColaGNN  1156  1403  1500  871  1126  1218  202  241  232 
PCC()  5  10  15  5  10  15  5  10  15 
GAR  0.339  0.288  0.47  0.79  0.581  0.485  0.875  0.777  0.742 
AR  0.310  0.238  0.483  0.792  0.612  0.527  0.863  0.773  0.723 
VAR  0.3  0.426  0.474  0.685  0.508  0.467  0.758  0.709  0.6529 
ARMA  0.31  0.253  0.486  0.792  0.614  0.52  0.862  0.773  0.725 
RNN  0.821  0.616  0.709  0.821  0.587  0.499  0.886  0.821  0.758 
RNN+Attn  0.59  0.741  0.522  0.752  0.554  0.552  0.884  0.78  0.739 
CNNRNNRes  0.38  0.438  0.467  0.782  0.552  0.4851  0.822  0.82  0.847 
GCNRNNRes  0.875  0.823  0.774  0.739  0.554  0.4471  0.844  0.814  0.814 
ColaGNN  0.883  0.818  0.754  0.832  0.719  0.639  0.897  0.822  0.859 
5.1 Prediction Performance
We evaluate our approach in shortterm (leadtime = 2, 3, 4) and longterm (leadtime = 5, 10, 15) lead time settings. We ignore the case of leadtime = 1, because symptom monitoring data is usually delayed by at least one week. Table 2 summarizes the results of all the methods in terms of RMSE and PCC in shortterm settings. We can observe that when the lead time is relatively small, our method achieves the most stable and optimal performance on all datasets. In this case, most of the methods can capture relatively good performance in the three datasets, which is due to the small information gap between the history window and the predicted time, thus the models can fit the temporal pattern more easily. The one exception is that in the JapanPrefectures dataset, the results of most baseline methods deteriorate with a slight increase in lead time. A possible reason for this phenomenon in the JapanPrefectures dataset is that the seasonal influenza curve in the dataset is more noisy and less predictive. The dataset statistic also shows that JapanPrefectures dataset has the largest standard deviation.
Table 3
reports the RMSE and PCC results in longterm settings. Overall, the proposed method achieves best performance for most datasets with long lead time windows (leadtime = 5, 10 or 15 weeks). Autoregression models have poor performance, especially VAR which has the largest number of model parameters. This suggests the importance of controlling the model complexity for data insufficiency problems. Recurrent neural network models only achieve good predictive performance when lead time is small, which demonstrates that longterm predictions require a better design to capture spatial and temporal dependencies. CNNRNNRes uses geographic location information and it only performs well in the USStates dataset. In the JapanPrefectures and USRegions datasets, the model performs poorly when having long lead time windows. Its variant GCNRNNRes contains a graph convolutional module that learns the features from adjacent regions. GCNRNNRes has achieved good results in JapanPrefectures and USStates datasets. It proves that the graph convolution module can help capture longterm dependencies. The performances of CNNRNNRes and GCNRNNRes are unstable on three datasets and often show large variance in multiple rounds of training. To better visualize the results, we show the mean value and standard deviation of the 10 different runs of some models in Figure
2 and Figure 3.If we look at the big picture of the prediction performance, the performance difference of all methods is relatively small when the lead time is 2, but as the lead time increases, the predictive power of simple methods (such as autoregressive) decreases significantly. This suggests that modeling temporal dependence is challenging when a relatively large gap exists between the historical window and the expected prediction time.
5.2 Case Studies
To evaluate the longterm predictive performance of the proposed model, we plot a sequence of predictions, where lead time is 15, in the test set. Four better baselines were chosen and the comparison on the three datasets is shown in Figure 4. We randomly select three locations from each dataset and observe that even though we are using a relatively small window (20) to predict longterm flu count (leadtime = 15), our model is able to better capture the trend and outbreak time of the epidemic outbreak.
We fix the input window and plot the prediction curve of leadtime from 1 to 20. Likewise, we also randomly sample three locations from each dataset. From the observation in Figure 5, our model tends to capture the peaks and trends in future time based on given historical data.
5.3 Attention Visualization
Figure 6 shows an example of the locationaware attention mechanism with a lead time of 15 in the USRegions dataset. In this example, we focus on region 5. We visualize the input data of region 5 and two regions {region 3, region 4} with highest attention values for region 5, as well as two regions that has lowest attention values {region 1, region 8}. We normalize the data by regions to better compare flu outbreaks across regions. The time period of the light yellow shade is the input sequence of window = 20. The vertical line indicates the predicted time. We are using only a small part of the sequence of all regions to predict the epidemic outbreak of region 5 in 15 weeks. The regions with higher attention share same early epidemic outbreak as region 5 while regions with lower attention values have later outbreak times.
We show the normalized geolocation distance matrix in Figure 7, which is calculated according to Eq. 4, and the Pearson correlation coefficient of input time series in Figure 7. The learned attention matrix (Figure 7) utilizes geolocation information as well as additive attentions among regions. From the learned attention matrix, we observe that adjacent regions sometimes get higher attention values. Meanwhile, nonadjacent regions can also receive high attention values given their similar longterm influenza trends. The learned attention reveals hidden dynamics (e.g., epidemic outbreaks and peak time) among regions.
5.4 Ablation Tests
RMSE()  2  3  4  5  10  15 
JapanPrefectures  
ColaGNN w/o  911  1115  1204  1310  1388  1517 
ColaGNN w/o  942  1154  1164  1195  1473  1576 
ColaGNN  919  1060  1072  1156  1403  1500 
USRegions  
ColaGNN w/o  485  662  772  888  1144  1228 
ColaGNN w/o  499  666  782  890  1179  1292 
ColaGNN  483  633  765  871  1126  1218 
USStates  
ColaGNN w/o  138  169  188  194  251  251 
ColaGNN w/o  138  169  193  202  246  246 
ColaGNN  136  167  191  202  241  232 
PCC()  2  3  4  5  10  15 
JapanPrefectures  
ColaGNN w/o  0.91  0.867  0.846  0.818  0.793  0.744 
ColaGNN w/o  0.914  0.881  0.89  0.88  0.781  0.727 
ColaGNN  0.911  0.893  0.894  0.883  0.818  0.754 
USRegions  
ColaGNN w/o  0.944  0.902  0.861  0.824  0.712  0.588 
ColaGNN w/o  0.942  0.898  0.858  0.824  0.682  0.582 
ColaGNN  0.944  0.905  0.863  0.832  0.719  0.639 
USStates  
ColaGNN w/o  0.953  0.93  0.908  0.908  0.833  0.836 
ColaGNN w/o  0.955  0.931  0.913  0.904  0.856  0.855 
ColaGNN  0.955  0.933  0.907  0.897  0.822  0.859 
To analyze the effect of each component in our framework, we perform the ablation tests on all the datasets with the follow settings:

ColaGNN w/o : Remove the temporal convolution module from the proposed model, and use the raw time series input as features in graph message passing.

ColaGNN w/o : Remove the locationaware attention module and directly use the geographical adjacent matrix which defines the spatial distance between pairs of locations.
The results of RMSE and PCC are shown in Table 4. We can observe that in most cases, variant versions of the proposed method can achieve very good performance. In the USstates dataset, models without temporal or locationaware attention modules are sometimes slightly better than the full model. The USstates dataset has the lowest number of reported influenza cases compare with two other datasets, and the standard deviation is small. Overall, the full model achieves optimal performance across all datasets. Note that all datasets are relatively small in size, which means that adding more parameters may affect the performance due to overfitting. However, adding temporal and spatial modules does not change the shortterm (leadtime = 2,3,4) prediction very much. Instead, for longterm predictions (leadtime = 15), involving these two modules produces better results.
5.5 Sensitivity Analysis
In this section, we investigate how the prediction performance varies with some hyperparameters.
Size of History Windows
To test if our model is sensitive to the length of historical data, we evaluate different window sizes from 10 to 50 with step 5. The experiment was conducted on USRegions and USStates datasets as shown in Figure 8. The predictive performance in RMSE and MAE with different window sizes are fairly stable. We can avoid training with very long sequences and achieve relatively comparable results.
Size of Graph Features
We learn the RNN hidden states from the historical sequence data and the graph features which involves features of other regions by message passing over locationaware attentions. We vary the dimension of the graph feature from 1 to 15 and evaluate the predictive performance in USStates dataset when leadtime is 15. Figure 9 reports RMSE and MAE results. Features of smaller dimensions result in poor predictive performance due to limited encoding power. The model produces better predictive power when the feature dimension is larger.
RNN Modules
The RNN module is used to output a hidden state vector for each location based on given historical data. The hidden state vector is then provided to the locationaware attention module. We replaced the RNN modules with GRU and LSTM to assess their impact on model performance. Figure 10 shows RMSE results for leadtime = 2,5,10,15 in USRegions and USStates datasets. We found that the performance of GRU and LSTM is not better than a simple RNN. The likely reason is that they involve more model parameters and tend to overfit in the epidemiological datasets.
5.6 Model Complexity
Architecture  Parameters  Runtime(s) 
GAR  21  0.01 
AR  1,029  0.02 
VAR  48,069  0.02 
ARMA  1,960  0.03 
RNN  481  0.04 
RNN+Attn  1,321  0.58 
CNNRNNRes  7,695  0.04 
GCNRNNRes  6,214  0.04 
ColaGNN  3,778  0.21 
Table 5 shows the comparison of runtimes and numbers of parameters for each model on the USStates dataset, which has the largest number of regions among the three datasets. In this task, all methods can be effectively trained due to the nature of the datasets. Meanwhile, we only utilize flu disease data and geographic location data, while ignoring other external features. Compared with other methods, the proposed method has no significant effect on training efficiency. It can also control the size of the model parameters to prevent overfitting.
6 Conclusion
In this work, we propose a graphbased deep learning framework with crosslocation attentions to study the spatiotemporal influence of longterm epidemiological predictions. We demonstrate the effectiveness of the proposed model on realworld epidemiological datasets. The proposed method is not flexible enough in the case that different models are trained for different lead time settings. Future work will consider iterative predictions to increase the flexibility of the model. Another research direction is to involve more complex dependencies such as weather, social factors, and population migration. We intend to determine if the prediction accuracy is improved when using external indicators. Furthermore, it is also essential to identify the main factors affecting the epidemic outbreak of one area by learning multiple areas simultaneously.
References
 [1] (2011) Predicting flu trends using twitter data. In 2011 IEEE conference on computer communications workshops (INFOCOM WKSHPS), pp. 702–707. Cited by: §2.1.
 [2] (2014) Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. Cited by: §3.2.
 [3] (2009) EpiFast: a fast algorithm for large scale realistic epidemic simulations on distributed memory systems. In Proceedings of the 23rd international conference on Supercomputing, pp. 430–439. Cited by: §1.
 [4] (2017) Combining participatory influenza surveillance with modeling and forecasting: three alternative approaches. JMIR public health and surveillance 3 (4), pp. e83. Cited by: 1st item.
 [5] (2014) Forecasting a moving target: ensemble models for ili case count predictions. In Proceedings of the 2014 SIAM international conference on data mining, pp. 262–270. Cited by: §2.1.

[6]
(201611)
Long shortterm memorynetworks for machine reading.
In
Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing
, Austin, Texas, pp. 551–561. External Links: Document Cited by: 6th item.  [7] (201410) Learning phrase representations using RNN encoder–decoder for statistical machine translation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, pp. 1724–1734. External Links: Document Cited by: §3.2.
 [8] (2008) Seasonal influenza in the united states, france, and australia: transmission and prospects for control. Epidemiology & Infection 136 (6), pp. 852–864. Cited by: §1.
 [9] (2015) Fast and accurate deep network learning by exponential linear units (elus). In Proceedings of the 2015 International Conference on Learning Representations, Vol. abs/1511.07289. Cited by: §4.3.

[10]
(2014)
Prediction of chaotic time series of rbf neural network based on particle swarm optimization
. In Intelligent Data analysis and its Applications, Volume II, pp. 489–497. Cited by: §2.2.  [11] (2013) Influenza forecasting with google flu trends. PloS one 8 (2), pp. e56176. Cited by: §2.1.

[12]
(2010)
Understanding the difficulty of training deep feedforward neural networks.
In
Proceedings of the thirteenth international conference on artificial intelligence and statistics
, pp. 249–256. Cited by: §4.3.  [13] (201807) Ruminating reader: reasoning with gated multihop attention. In Proceedings of the Workshop on Machine Reading for Question Answering, Melbourne, Australia, pp. 1–11. External Links: Document Cited by: §3.2.
 [14] (1997) Long shortterm memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §3.2.
 [15] (1927) A contribution to the mathematical theory of epidemics. Proceedings of the royal society of london. Series A, Containing papers of a mathematical and physical character 115 (772), pp. 700–721. Cited by: §1.
 [16] (201410) Convolutional neural networks for sentence classification. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, pp. 1746–1751. External Links: Document Cited by: §3.3.
 [17] (2015) A method for stochastic optimization. In International Conference on Learning Representations, Vol. 5. Cited by: §4.3.
 [18] (2015) Adam: a method for stochastic optimization. In Proceedings of the 2015 International Conference on Learning Representations, Vol. abs/1412.6980. Cited by: §3.6.
 [19] (2016) Semisupervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907. Cited by: 8th item.
 [20] (2008) Multistepprediction of chaotic time series based on coevolutionary recurrent neural network. Chinese Physics B 17 (2), pp. 536. Cited by: §2.2.
 [21] (2014) Analysing twitter and web queries for flu trend prediction. Theoretical Biology and Medical Modelling 11 (1), pp. S6. Cited by: §1.
 [22] (2016) Predicting spatiotemporal propagation of seasonal influenza using variational gaussian process regression. In Thirtieth AAAI Conference on Artificial Intelligence, Cited by: §1, §2.1.
 [23] (2007) Methodology for longterm prediction of time series. Neurocomputing 70 (1618), pp. 2861–2869. Cited by: §2.2.
 [24] (2015) Endtoend memory networks. In Advances in neural information processing systems, pp. 2440–2448. Cited by: §3.2.
 [25] (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §3.2.
 [26] (2018) A novel datadriven model for realtime influenza forecasting. IEEE Access 7, pp. 7691–7701. Cited by: §1, §2.1, §2.2.
 [27] (2003) Prediction of the spread of influenza epidemics by the method of analogues. American Journal of Epidemiology 158 (10), pp. 996–1006. Cited by: §2.1.
 [28] (1997) Hierarchical spatiotemporal mapping of disease rates. Journal of the American Statistical association 92 (438), pp. 607–617. Cited by: §2.1.
 [29] (2019Jul.) DEFSI: deep learning based epidemic forecasting with synthetic information. Vol. 33, pp. 9607–9612. External Links: Document Cited by: §1, §2.2.
 [30] (2015) Dynamic poisson autoregression for influenzalikeillness case count prediction. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1285–1294. Cited by: §2.1, 1st item.
 [31] (2018) Deep learning for epidemiological predictions. In The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval, pp. 1085–1088. Cited by: §1, §2.1, §2.2, 7th item.
 [32] (2014) Comparison of filtering methods for the modeling and retrospective forecasting of influenza epidemics. PLoS computational biology 10 (4), pp. e1003583. Cited by: §2.1.
Comments
There are no comments yet.