Learning Nonlinear Brain Dynamics: van der Pol Meets LSTM

05/24/2018 ∙ by German Abrevaya, et al. ∙ ibm University of Washington University of Buenos Aires 0

Many real-world data sets, especially in biology, are produced by highly multivariate and nonlinear complex dynamical systems. In this paper, we focus on brain imaging data, including both calcium imaging and functional MRI data. Standard vector-autoregressive models are limited by their linearity assumptions, while nonlinear general-purpose, large-scale temporal models, such as LSTM networks, typically require large amounts of training data, not always readily available in biological applications; furthermore, such models have limited interpretability. We introduce here a novel approach for learning a nonlinear differential equation model aimed at capturing brain dynamics. Specifically, we propose a variable-projection optimization approach to estimate the parameters of the multivariate (coupled) van der Pol oscillator, and demonstrate that such a model can accurately represent nonlinear dynamics of the brain data. Furthermore, in order to improve the predictive accuracy when forecasting future brain-activity time series, we use this analytical model as an unlimited source of simulated data for pretraining LSTM; such model-specific data augmentation approach consistently improves LSTM performance on both calcium and fMRI imaging data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 3

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Complex multi-variate nonlinear dynamical systems are abundant in nature and in society, ranging from weather to brain activity and stock market behavior. Building accurate models of such systems is highly nontrivial, and considerably more difficult than modeling linear dynamics. While nonlinear dynamical systems are extensively studied in physics, control theory and related disciplines, learning such systems from data in high-dimensional settings is difficult, and traditional machine learning approaches tend to focus on generic dynamical models, such as recurrent neural networks, rather than on domain-specific types of nonlinear dynamical models, such as, for example, van der Pol (VDP) model considered in this paper.

Our goal is to propose a model that can capture the most relevant features of a complex nonlinear dynamical system, such as brain activity in neuroimaging data. Brain activity exhibits a highly nonlinear behavior that can be oscillatory or even chaotic [korn2003there]

, with sharp phase transitions between different states. The simplest models that can capture these behaviors are

relaxation oscillators. One of the most famous examples is the VDP oscillator [guckenheimer2013nonlinear]

, used to model a variety of problems in physics. It has also played a relevant role in neuroscience given its equivalence to the FitzHugh-Nagumo equations that were introduced as a simplified model of action potential in neurons 

[izhikevich2007dynamical, fitzhugh1961impulses].

In this paper, we address two main questions. First: can we actually learn both

hidden variables and structure parameters of the van der Pol oscillator from data, when we only observe some of the variables? There has been a lot of interest in the physics and inverse problems community in simultaneously estimating states and parameters. Most approaches for learning nonlinear dynamics and parameters avoid optimization entirely by using the unscented kalman filter  

[quach2007estimating, voss2004nonlinear, sitz2002estimation] or other derivative-free dynamic inference methods [havlicek2011dynamic]. Derivative-free methods have limitations — there is no convergence criteria or disciplined way to iterate them to improve estimates. Optimization-based approaches for fitting parameters and dynamics are discussed by [gabor2015robust], who formulate parameter identification under dynamic constraints as an ODE-constrained optimization problem. We take a similar view, and use recent insights into variable projection to develop an efficient optimization algorithm for learning the hidden states and parameters of the van der Pol (VDP). The work of [gabor2015robust] is focused on global strategies (e.g. multiple re-starts of well-known methods); our contribution is to develop an efficient local technique for an inexact VDP formulation.

Our main scientific question is: are such models useful to neuroscience? How well can we capture dynamics and predict temporal evolution of neural activity, and are these results interpretable in the context of prior neuroscientific knowledge? We show that the answer to those questions can be positive, but require a combination of multiple approaches, such as: (1) using both optimization and stochastic search in order to get out of potential local minima and "jump" to more promising parts of an enormous search space; and (2) using our analytical oscillatory model to pre-train generic statistical approaches, such as LSTM.

We show that the best predictive accuracy is achieved by first estimating the van der Pol model (with a relatively small number of parameters) from limited training data, and then using this model to simulate large amounts of data to pre-train a general-purpose LSTM network, pulling it to specific nonlinear dynamics, and then fine-tuning it on limited-size real data. We demonstrate that this hybrid approach consistently improves LSTM performance on both calcium and fMRI imaging data.

2 Calcium Imaging Data

Figure 1: The first 5 SVD components (left column) and the corresponding space components (right column) of the zebrafish data.

A recently introduced technique, brain-wide calcium imaging (CaI)  [ahrensNature2012], provides for a unique perspective on neural function, recording the concentrations of calcium at sub-cellular spatial resolution across an entire vertebrate brain, and at a temporal resolution that is commensurable with the timescale of calcium dynamics [ahrensNatureMethods2013].

In [ahrensNatureMethods2013], light-sheet microscopy was used to record the neural activity of a whole brain of the larval zebrafish, reported by a genetically-encoded calcium marker, in vivo and at 0.8 Hz sampling rate. From the publicly available data [CaIdata] it is possible to obtain a movie of 500 frames with a 2D collapsed view of 80% of the approximately 40,000–100,000 neurons in the brain, with a resolution of 400 by 250 pixels (approximately 600 by 270 microns).

In order to obtain functionally relevant information, we performed an SVD analysis of these data111

There are multiple alternative approaches to feature extraction/representation learning and dimensionality reduction, which can be explored in this setting, including other component analysis methods (NMF, ICA), sparse coding/dictionary learning, and various autoencoders. However, before diving into more complex feature extraction, we would like to develop an approach to modeling a coupled dynamical system, which is a nontrivial task even with a relatively small number of SVD components.

; the figure 1 shows the first 5 SVD time components (left column) and the corresponding space components (right column). The spatial components show a clear neural substrate, and therefore the time components can be interpreted as traces of neuronal activity from within brain systems identified by each corresponding space components. For example, spatial components 1–5 each show pronounced but non-overlapping forebrain island-like structures, often with lateral symmetry. Moreover, the second and third spatial component include in addition the hindbrain oscillator (seen in the right panels). The corresponding second and third temporal components are dominated by oscillatory activity, consistent with the physiology of the hindbrain oscillator described in [ahrensNatureMethods2013].

3 Van der Pol Model of Neuronal Activity

Because neuronal calcium dynamics are largely driven by transmembrane voltage and voltage-dependent calcium channels, we model the calcium dynamics of a neuron, or small clusters of them, as a 2D differential equation with a voltage-like variable (activity), and a recovery-like variable (excitability), following similar approaches in the literature [Izhikevich2007]. Given that one salient feature of neural systems is their propensity for oscillations, as well as sharp transitions from passive to active states, we consider the following nonlinear oscillator model for each scalar component:

(1)

where is the number of considered neural units (e.g, SVD components), and represent the (observed) activity and the (hidden) excitability variables of the -th neural unit, respectively, and the matrix represents the coupling strength between the observed variables, or neural units. Thus, models the synaptic input to the -th unit provided by other units through their observed variables. The parameters determine the bifurcation diagram of the system, allowing for a rich set of dynamical states including oscillations and spike-like responses [Wiggins2003, Izhikevich2007]. However, imaging techniques only provide information about activity , i.e. the calcium concentration in the case of CaI. In consequence, any model-based analysis requires the inference of the excitability variable represented by hidden (unobserved) variables .

When the parameters and in (1) are known, inferring the hidden components from observations is a nonlinear Kalman smoothing problem. Kalman filtering and smoothing methods are commonly used for inference on noisy dynamical systems. Since their invention [kalman, KalBuc] these algorithms have become a gold standard in a range of applications, including space exploration, missile guidance systems, general tracking and navigation, and weather prediction. Optimization-based approaches with nonlinear and non-Gaussian models require iterative optimization techniques; see for example the survey of [aravkin2017generalized]. Dynamical modeling was applied to nonlinear systems early on by [Anderson:1979, Mortensen1968]. More recently, the optimization perspective on Kalman smoothing has enabled further extensions, including inference for systems with unknown parameters [Bell2000], systems with constraints [Bell2009]

, and systems affected by outlier measurements for both linear 

[Durovic1999, Meinhold1989, Cipra1997] and nonlinear [aravkin2011ell, aravkin2014robust] models.

Building on above perspective, we address the challenging problem of estimating from data both the parameters and the hidden variable . To the best of our knowledge, this work is the first to propose an approach for learning a coupled van der Pol oscillator model from data . We develop a method to find the hidden variables () from the observed ones () for given parameter settings, and to learn unknown parameter settings themselves. Indeed the problems are coupled; however, rather than using alternating optimization (closely related to EM), we use fast optimization techniques available for nonlinear Kalman smoothing to fully minimize over the hidden states for each update of the unknown parameters. The algorithm can be understood in the framework of recent results on variable projection (partial minimization), which is efficient for dealing with nonconvex, possibly ill-conditioned problems. While detailed convergence and sensitivity analysis of this algorithm is a topic of ongoing work, we present here promising results showing that the obtained van der Pol model can accurately capture nonlinear dynamics in the training data, and, furthermore, can be used to predict the future time series (test data); in addition, we show predictive performance is boosted by combining the van der Pol with LSTM networks.

4 Estimating van der Pol Parameters: ODE-Constrained Inference

We discretize the ODE model in equation  (1), and formulate a joint inference problem for the state space and parameters that is informed by noisy direct observations of some components; and constrained by the discretized dynamics.

Inference for a single component. For time index , let denote the th component of the van der Pol model given earlier in the equation  (1), so , i.e. the state contains both observed and hidden variables. The discretized dynamics governing the evolution can be written

where is a first-order Euler discretization of the nonlinear ODE (1). The inform the evolution of the entire time series . Given an initial and possibly inaccurate state , we form a vector , and describe the dynamics of the entire th component in compact form as , with

(2)

Given noisy observations

we obtain consider ODE-constrained optimization problem for the th component:

(3)

Problem (3) is challenging because (1) the ODE constraint function is nonlinear in , and (2) because it is a joint optimization problem over and . To solve this problem, we use the technique of partial minimization [aravkin2017efficient]222For particular instances, partial minimization is often called variable projection., often used in PDE-constrained optimization [van2015penalty].

Rewriting (3) with a quadratic penalty, we obtain the relaxed problem

(4)

The key idea is to then use partial minimization with respect to at each iteration of and optimize the value function:

The intuitive advantages of this method (find the best state estimate for each regime) are borne out by theory. In particular, for a large class of models, the objective function is well-behaved for large , unlike the joint objective  [aravkin2017efficient]333The Lipschitz constant of the gradient of stays bounded as , which is clearly false for ..

Evaluating requires a minimization routine. We compute gradient and Hessian approximations

where . Evaluating requires obtaining an (approximate) minimizer . With in hand, can be computed using the formula

(5)

The accuracy of the inner solve in can be increased as the optimization over proceeds. Constraints can also be placed on to eliminate non-physical regimes or to incorporate prior information.

Extension to m components In addition to estimating the dynamic parameters , we are also interested in inferring the connectivity matrix . Extending the model to m components, let contain components , so that in particular the -th component contains ; and let contain parameter sets . We can now write down the full nonlinear process model as

(6)

with , and the dynamics in the previous section replicated across the components. Without the matrix, this would be independent models written jointly. The adds linear coupling across the components.

The optimization approach for components is analogous to the single-component case, but includes components simultaneously, and also infers the coupling matrix :

Just as for a single component, we optimize this objective using partial minimization in and working with the value function

For the -parameter case, we optimize over at each iteration using the Gauss-Newton method detailed in the previous section. The outer iteration is a fast projected gradient method for minimizing subject to simple bound constraints.

5 Learning van der Pol: Variable-Projection + Stochastic Search

Optimizing van der Pol model can benefit considerably from a good initialization of its parameters, as we observed in multiple experiments. To improve initialization, we start with a random walk (stochastic search) in the parameter space, aiming at producing a reasonably good starting point for the optimization procedure; given a combination of parameters, we simulate time series using the corresponding van der Pol model, and measure the correlation and the mean-squared error between the simulated and the real (training) data, discarding the parameters whose performance metrics are under some threshold. Once a sufficiently high-performing model is found, we switch to the variable-projection (VP) method described above, initialized with the current parameters, which are now optimized even further (Figure 2). The whole process of alternating between stochastic search and VP optimization is repeated several times, since, consistently with reported works [gabor2015robust, rodriguez2006novel], a hybrid stochastic-deterministic method performs better than a sole local optimization method for complex problems. This combined procedure will be referred to in our subsequent section as simply van der Pol optimization.

Figure 2: Van der Pol optimization procedure: variable-projection augmented with stochastic search.

Stochastic search: implementation details. We start with an initial guess for (same value for all univariate oscillators), zero-connectivity matrix and a random guess for the initial condition of the hidden variables, . At every stochastic search step, these parameters are updated as described below, and the differential equations with the new parameters are integrated; if the resulting time-series solution improves the fit to the training data, the new parameters are accepted, otherwise they are dismissed. As a measure of the goodness of fit we use a linear combination of the Symmetric Mean Absolute Percentage Error and the Pearson correlation. In the first stage of our search, we only update and , while keeping zero weight matrix (i.e., disconnected components). In each step, one of the components is chosen randomly, and its corresponding s and

are changed using a Gaussian random walk. Steps with larger variance are taken infrequently to escape potential local minima. After this initialization, we update all parameters including

. All components of change at every (low-variance) random step.

VP+Stochastic Search: implementation details. We use up to 200000 stochastic steps with a maximum of 50 outer iterations of the VP optimization for every 1000 stochastic steps. The matrix is held to zero for the first ‘burn-in’ 15000 iterations. Large stochastic steps are performed after every 30 small steps; the variance in steps increases by one order of magnitude (from 0.01 to 0.1); for steps, the variances remains the same (0.1), but larger steps involve changing all components at once, rather than one at a time in smaller steps. We use in the VP optimization formulation.

Time series prediction.
Once a van der Pol model is trained on a given time window, we can use it to predict the future time series, by integrating the model with the given parameters and the initial hidden state variable.

Interpretability.
Note that one of the advantages of the analytical van der Pol model is its interpretability, as it learns the interaction matrix among different spatial components, i.e. brain areas.

6 Hybrid approach: vdP-LSTM.

An alternative to learning an analytical model, such as van der Pol, is to use some generic method for time-series prediction, such as, for example, recurrent neural networks, e.g., LSTM. Herein, we used the classical LSTM model proposed by [19-Hochreiter1997]

, a popular extension of Recurrent Neural Network (RNN) models with improved memory. Our LSTM networks, implemented in Keras, contained two layers, 128 units in each layer, followed by the fully-connected layer and linear activation; we used the mean squared error and the optimizer RMSProp; the drop-out rate was set to 0.8. We used LSTM for multivariate time-series prediction, where each time point

is represented by an -dimensional vector (corresponding, in our case, to temporal components of the data at time ). We denote by the model which uses the previous time points to predict the -st time point. The prediction of the time step is performed by shifting the window of length one step forward and using the prediction for the -st data point a new data point, iteratively. In our experiments, several values of were tried and was selected.

vdP-LSTM: LSTM Pretrained with vdP Simulations. Training LSTMs requires a large number of samples, while our data were limited to only 500 time points (including training and test subsets). On the other hand, given our prior knowledge about the data, such as nonlinear dynamic behavior of certain type (e.g., van der Pol with specific parameters), one can hypothesize that providing LSTM with information about such domain-specific dynamics can potentially improve its performance.

Thus, we propose a simple approach for providing general-purpose LSTM with prior information about the domain-specific dynamics, namely, a data-augmentation approach which pretrains LSTM on a large amount of simulated data obtained from a fitted van der Pol model, before fine-tuning LSTM on a relatively small amount of available real data. Such pretraining on the data simulated from our analytical model serves as a regularizer in the absence of large training data sets, biasing LSTM towards the type of dynamics we expect in the data. Training the van der Pol on the same amount of data is easier, since there are far fewer parameters to be estimated than for a typical LSTM.

vdP-LSTM: implementation details. We train van der Pol models on the training data; for each of those models, we simulate noisy versions of time series (each of length ) obtained by integrating each model; we take , , and .

LSTM was pretrained with 100 epochs using the above simulated data, and then trained with 50 epochs on the real training dataset; the number of epochs was selected so that the total number of samples used for training was the same for both simulated and real training data.

7 Experiments

Figure 3: Spyder graph representing the strongest links between the spatial support of the components, as interpreted from . Red represent negative and blue, positive.

We now present our empirical results, including (1) van der Pol model fit on training data; (2) predictive accuracy when forecasting time series using van der Pol, LSTM and hybrid vdP-LSTM, with a linear Vector Auto-Regressive (VAR) model used as a baseline; (3) a brief discussion of interpreting the van der Pol interaction matrix.

Evaluating van der Pol model fit on real data. We evaluated multiple runs of van der Pol estimation procedure described above, combining stochastic search with VP optimization. Figure 4 shows the fit to the training data achieved by one of the best-performing model; the correlations between the actual data and the model predictions are high, ranging from 0.76 to 0.83 for all six components.

Interpretability. In the bottom right corner of the Figure 4, we plot the coupling matrix . We interpret the entries of W as the effective connectivity between the spatial support of the components. Thus, contains interesting information about interactions (positive and negative) across different brain regions/subnetworks. For example, we observe a strong positive interactions between the components 4 and 5, which correspond to the brain areas where an "flip-flop" oscillating behavior can be clearly observed (e.g., see the 2D version of the temporal data at https://youtu.be/lppAwkek6DI). Figure 3 presents a spyder graph representing the strongest links between the spatial support of the components, obtained from . Using current knowledge of zebrafish and human neuroanatomy, it is possible to validate to what extent this effective connectivity (at least in absolute value) is consistent with real neural tracts.

Figure 4: Van der Pol model fit on training data; correlations between the true and predicted time series for each of the six temporal SVD components. Bottom-right: an interaction matrix .

Prediction on test data. Figures 6 and 6

show the median correlation between the true and predicted values, and the root-means square error, respectively, for several predictive methods: vector autoregressive (VAR) model (red), van der Pol model (green), LSTM (blue), and vdP-LSTM, i.e. LSTM pretrained on the data simulated using the above van der Pol model (orange). Here we estimated parameters of the models on 100 consecutive points of training data, and then predicted the next 30 points (x-axis plots the index of the time points being predicted). Shaded area around each curve represents the standard error. The linear VAR model (red) performs poorly, unable to capture the nonlinear dynamics; van der Pol (green) outperforms LSTM (blue) in the beginning, but then LSTM catches up; the hybrid vdP-LSTM model combines the best of both. Similarly, the hybrid approach performs best in terms of RMSE error (Figure

6).

7.1 Functional MRI Data

We also tested our approach on a functional MRI (fMRI) and obtained promising preliminary results. Though VP optimization was not yet applied on top of the stochastic search (experiments are in progress), we already obtained results similar to the ones seen on calcium data. We used resting-state fMRI data from 10 healthy control subjects, obtained from the Track-On HD dataset [TRACK-ON]. For each subject we had two datasets corresponding to two different visits. We used 15 ICA components, 160 time points each. The datasets from the first visit were used for training, and the ones from the second visit were used for testing. For each training dataset, we ran stochastic search 10 times, and from each run used 50 models which correlated highest with the training data for subsequent simulations and LSTM pre-training; i.e., for each subject, we simulated 15 coupled time series, each of length 160, from 500 different (but related) van der Pol models.

In addition, for comparison with a standard method of data augmentation, 500 noisy datasets (i.e., multivariate time series, with 15 components), also of 160 time steps, were created from each subject’s training dataset by adding Gaussian noise with mean 0 and standard deviation 0.1 to the normalized real data.

The LSTM architecture used was the same as for the calcium imaging experiments. Each subject was trained separately, as a different instance of the experiment. Each dataset contained 15 time series, with 160 time points each. LSTM was trained with 15 epochs. The data-augmented LSTMs were first trained either with the noisy datasets or with the van der Pol-simulated data described above for 15 epochs followed by the training with 15 epochs of real training data (the first visit data for a given subject).

Figures 8 and 8 summarize the correlation and RMSE performance, respectively, of several methods we tried on fMRI data, such as VAR, LSTM, as well as LSTM pretrained with noisy version of real data (standard data-augmentation approach), and vdP-LSTM (LSTM pretrained on van der Pol simulated data). We see much smaller standard error (shaded area around the plots), due to larger number of experiments per point (more fMRI data). Overall, we clearly see that VAR performs poorly, and vdP-LSTM, augmented with simulated data, outperforms both LSTM, and LSTM augmented with noisy data, in terms of correlation and RMSE.

Figure 5: Calcium imaging: predictive performance measured by correlation between the actual and predicted time series.
Figure 6: Calcium imaging: predictive performance measured by the root mean square error (RMSE) between the actual and predicted time series.
Figure 5: Calcium imaging: predictive performance measured by correlation between the actual and predicted time series.
Figure 7: Functional MRI: predictive performance measured by correlation between the actual and predicted time series.
Figure 8: Functional MRI: predictive performance measured by the root mean square error (RMSE) between the actual and predicted time series.
Figure 7: Functional MRI: predictive performance measured by correlation between the actual and predicted time series.

8 Conclusions

Motivated by the challenging problem of modeling nonlinear dynamics of brain activations in calcium imaging, we propose a new approach for learning a nonlinear differential equation model: a variable-projection optimization approach to estimate the parameters of the multivariate coupled van der Pol oscillator. We show how to learn this nonlinear dynamical model, and demonstrate that it can accurately capture nonlinear dynamics of the brain data. Furthermore, in order to improve the predictive accuracy when forecasting future brain activity, we used the learned van der Pol to pretrain LSTM networks, thus imposing an oscillator prior on LSTM; the resulting approach achieves highest predictive accuracy among all methods we evaluated.

References