Combining Generative and Discriminative Models for Hybrid Inference

06/06/2019 ∙ by Victor Garcia Satorras, et al. ∙ University of Amsterdam 8

A graphical model is a structured representation of the data generating process. The traditional method to reason over random variables is to perform inference in this graphical model. However, in many cases the generating process is only a poor approximation of the much more complex true data generating process, leading to suboptimal estimation. The subtleties of the generative process are however captured in the data itself and we can `learn to infer', that is, learn a direct mapping from observations to explanatory latent variables. In this work we propose a hybrid model that combines graphical inference with a learned inverse model, which we structure as in a graph neural network, while the iterative algorithm as a whole is formulated as a recurrent neural network. By using cross-validation we can automatically balance the amount of work performed by graphical inference versus learned inference. We apply our ideas to the Kalman filter, a Gaussian hidden Markov model for time sequences, and show, among other things, that our model can estimate the trajectory of a noisy chaotic Lorenz Attractor much more accurately than either the learned or graphical inference run in isolation.



There are no comments yet.


page 1

page 2

page 3

page 4

Code Repositories

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Before deep learning, one of the dominant paradigms in machine learning was graphical models

bishop2006pattern ; murphy2012probabilistic ; koller2009probabilistic

. Graphical models structure the space of (random) variables by organizing them into a dependency graph. For instance, some variables are parents/children (directed models) or neighbors (undirected models) of other variables. These dependencies are encoded by conditional probabilities (directed models) or potentials (undirected models). While these interactions can have learnable parameters, the structure of the graph imposes a strong inductive bias onto the model. Reasoning in graphical models is performed by a process called probabilistic inference where the posterior distribution, or the most probable state of a set of variables, is computed given observations of other variables. Many approximate algorithms have been proposed to solve this problem efficiently, among which are MCMC sampling

neal2011mcmc ; salimans2015markov , variational inference kingma2013auto and belief propagation algorithms crick2002loopy ; koller2009probabilistic .

Graphical models are a kind of generative model where we specify important aspects of the generative process. They excel in the low data regime because we maximally utilize expert knowledge (a.k.a. inductive bias). However, human imagination often falls short of modeling all of the intricate details of the true underlying generative process. In the large data regime there is an alternative strategy which we could call “learning to infer”. Here, we create lots of data pairs with the observed variables and the latent unobserved random variables. These can be generated from the generative model or are available directly in the dataset. Our task is now to learn a flexible mapping to infer the latent variables directly from the observations. This idea is known as “inverse modeling” in some communities. It is also known as “amortized” inference rezende2015variational

or recognition networks in the world of variational autoencoders

kingma2013auto and Helmholtz machines dayan1995helmholtz .

In this paper we consider inference as an iterative message passing scheme over the edges of the graphical model. We know that (approximate) inference in graphical models can be formulated as message passing, known as belief propagation, so this is a reasonable way to structure our computations. When we unroll these messages for

steps we have effectively created a recurrent neural network as our computation graph. We will enrich the traditional messages with a learnable component that has the function to correct the original messages when there is enough data available. In this way we create a hybrid message passing scheme with prior components from the graphical model and learned messages from data. The learned messages may be interpreted as a kind of graph convolutional neural network

bruna2013spectral ; henaff2015deep ; kipf2016semi .

Our Hybrid model neatly trades off the benefit of using inductive bias in the small data regime and the benefit of a much more flexible and learnable inference network when sufficient data is available. In this paper we restrict ourselves to a sequential model known as a hidden Markov process.

Figure 1: Examples of inferred 5K length trajectories for the Lorenz attractor with trained on 30K length trajectory. The mean squared errors from left to right are (Observations: 0.2462, GNN: 0.1777, E-Kalman Smoother: 0.0372, Hybrid: 0.0077).

2 The Hidden Markov Process

In this section we briefly explain the Hidden Markov Process and how we intend to extend it. In a Hidden Markov Model (HMM), a set of unobserved variables define the state of a process at every timestep . The set of observable variables from which we want to infer the process states are denoted by . HMMs are used in diverse applications as localization, tracking, weather forecasting and computational finance among others. (in fact, the Kalman filter was used to land the eagle on the moon.)

We can express

as the probability distribution of the hidden states given the observations. Our goal is to find which states

maximize this probability distribution. More formally:


Under the Markov assumption i) the transition model is described by the transition probability , and ii) the measurement model is described by . Both distributions are stationary for all . The resulting graphical model can be expressed with the following equation:


One of the best known approaches for inference problems in this graphical model is the Kalman Filter kalman1960new and Smoother rauch1965maximum . The Kalman Filter assumes both the transition and measurement distributions are linear and Gaussian. The prior knowledge we have about the process is encoded in linear transition and measurement processes, and the uncertainty of the predictions with respect to the real system is modeled by Gaussian noise:



come from Gaussian distributions

, . , are the linear transition and measurement functions respectively. If the process from which we are inferring is actually Gaussian and linear, a Kalman Filter + Smoother with the right parameters is able to infer the optimal state estimates.

The real world is usually non-linear and complex, assuming that a process is linear may be a strong limitation. Some alternatives like the Extended Kalman Filter ljung1979asymptotic and the Unscented Kalman Filter wan2000unscented are used for non-linear estimation, but even when functions are non-linear, they are still constrained to our knowledge about the dynamics of the process which may differ from real world behavior.

To model the complexities of the real world we intend to learn them from data through flexible models such as neural networks. In this work we present an hybrid inference algorithm that combines the knowledge from a generative model (e.g. physics equations) with a function that is automatically learned from data using a neural network. In our experiments we show that this hybrid method outperforms the graphical inference methods and also the neural network methods for low and high data regimes respectively. In other words, our method benefits from the inductive bias in the limit of small data and also the high capacity of a neural networks in the limit of large data. The model is shown to gracefully interpolate between these regimes.

3 Related Work

The proposed method has interesting relations with meta learning andrychowicz2016learning since it learns more flexible messages on top of an existing algorithm. It is also related to structured prediction energy networks belanger2017end which are discriminative models that exploit the structure of the output. Structured inference in relational outputs has been effective in a variety of tasks like pose estimation wei2016convolutional , activity recognition deng2016structure or image classification nauata2018structured . One of the closest works is Recurrent Inference Machines (RIM) putzky2017recurrent where a generative model is also embedded into a Recurrent Neural Network (RNN). However in that work graphical models played no role.

Another related line of research is the convergence of graphical models with neural networks, mirowski2009dynamic replaced the joint probabilities with trainable factors for time series data. Learning the messages in conditional random fields has been effective in segmentation tasks chen2014semantic ; zheng2015conditional . Relatedly, johnson2016composing runs message passing algorithms on top of a latent representation learned by a deep neural network. More recently yoon2018inference showed the efficacy of using Graph Neural Networks (GNNs) for inference on a variety of graphical models, and compared the performance with classical inference algorithms. This last work is in a similar vein as ours, but in our case, learned messages are used to correct the messages from graphical inference. In the experiments we will show that this hybrid approach really improves over running GNNs in isolation.

The Kalman Filter is a widely used algorithm for inference in Hidden Markov Processes. Some works have explored the direction of coupling them with machine learning techniques. A method to discriminatively learn the noise parameters of a Kalman Filter was introduced by abbeel2005discriminative . In order to input more complex variables, haarnoja2016backprop back-propagates through the Kalman Filter such that an encoder can be trained at its input. Similarly, coskun2017long replaces the dynamics defined in the Kalman Filter with a neural network. In our hybrid model, instead of replacing the already considered dynamics, we simultaneously train a learnable function for the purpose of inference.

4 Model

Figure 2: Graphical illustration of our Hybrid algorithm. The GM-module (blue box) sends messages to the GNN-module (red box) which refines the estimation of .

We cast our inference model as a message passing scheme where the nodes of a probabilistic graphical model can send messages to each other to infer estimates of the states . Our aim is to develop a hybrid scheme where messages derived from the generative graphical model are combined with GNN messages:

Graphical Model Messages (GM-messages): These messages are derived from the generative graphical model (e.g. equations of motion from a physics model).

Graph Neural Network Messages (GNN-messages): These messages are learned by a GNN which is trained to reduce the inference error on labelled data in combination with the GM-messages

In the following two subsections we introduce the two types of messages and the final hybrid inference scheme.

4.1 Graphical Model Messages

In order to define the GM-messages, we interpret inference as an iterative optimization process to estimate the maximum likelihood values of the states . In its more generic form, the recursive update for each consecutive estimate of is given by:


Factorizing equation 5 to the hidden Markov Process from equation 2, we get three input messages for each inferred node :


These messages can be obtained by computing the three derivatives from equations 7, 8, 9. It is often assumed that the transition and measurement distributions , are linear and Gaussian (e.g. Kalman Filter model). Next, we provide the expressions of the GM-messages when assuming the linear and Gaussian functions from equations 3, 4:


4.2 Adding GNN-messages

We call the collection of nodes of the graphical model . We also define an equivalent graph where the GNN operates by propagating the GNN messages. We build the following mappings from the nodes of the graphical model to the nodes of the GNN: , . Analogously, the union of both collections would be . Therefore, each node of the graphical model has a corresponding node in the GNN. The edges for both graphs are also equivalent. Values of that correspond to unobserved variables are randomly initialized. Instead, values are obtained by forwarding through an MLP.

Next we present the equations of the learned mes sages, which consist of a GNN message passing operation. Similarly to li2015gated ; kipf2018neural , a GRU chung2014empirical is added to the message passing operation to make it recursive:

(message from GNN nodes to edge factor) (13)
(message from edge factors to GNN node) (14)
(RNN update) (15)
(computation of correction factor) (16)

Each GNN message is computed by the function , which receives as input two hidden states from the last recurrent iteration, and their corresponding GM-message, this function is different for each type of edge (e.g. transition or measurement for the HMM). takes value if there is an edge between and , otherwise its value is 0. The sum of messages is provided as input to the GRU that updates each hidden state for each node. Finally a correction signal is decoded from each hidden state and it is added to the recursive operation 6, resulting in the final equation:


In summary, equation 17 defines our hybrid model in a simple recursive form where is updated through two contributions: one that relies on the probabilistic graphical model messages , and , that is automatically learned. We note that it is important that the GNN messages model the "residual error" of the GM inference process, which is often simpler than modeling the full signal. A visual representation of the algorithm is shown in Figure 2.

In the experimental section of this work we apply our model to the Hidden Markov Process, however, the above mentioned GNN-messages are not constrained to this particular graphical structure. The GM-messages can also be obtained for other arbitrary graph structures by applying the recursive inference equation 5 to their respective graphical models.

4.3 Training procedure

In order to provide early feedback, the loss function is computed at every iteration with a weighted sum that emphasizes later iterations,

, more formally:


Where function extracts the part of the hidden state contained in the ground truth . In our experiments we use the mean square error for . The training procedure consists of three main steps. First, we initialize at the value that maximizes . For example, in a trajectory estimation problem we set the position values of as the observed positions

. Second, we tune the hyper-parameters of the graphical model as it would be done with a Kalman Filter, which are usually the variance of Gaussian distributions. Finally, we train the model using the above mentioned loss (section


5 Experiments

In this section we compare our Hybrid model with the Kalman Smoother and a GNN. We show that our Hybrid model can leverage the benefits of both methods for different data regimes. Next we define the models used in the experiments:

Kalman Smoother: The Kalman Smoother is the widely known Kalman Filter algorithm kalman1960new + the RTS smoothing step rauch1965maximum . In experiments where with non-linear transition function we use the Extended Kalman Filter + the smoothing step which we will call “E-Kalman Smoother”.
GM-messages: As a special case of our hybrid model we propose to remove the learned signal and base our predictions only on the graphical model messages from eq. 6.
GNN-messages: The GNN model is another special case of our model when all the GM-messages are removed and only GNN messages are propagated. Instead of decoding a refinement for the current estimate, we directly estimate: . The resulting algorithm is equivalent to a Gated Graph Neural Network li2015gated .
Hybrid model: This is our full model explained in section 4.2.

We set and use the Adam optimizer with a learning rate . The number of inference iterations used in the Hybrid model, GNN-messages and GM-messages is N=100. The number of features in the hidden layers of the , and the 2-layers MLPs and is nf=64. In the trajectory estimation experiments, values may take any value from the real numbers . Shifting a trajectory to a non-previously seen position may hurt the generalization performance of the neural network. To make the problem translation invariant we modify before mapping it to , we use the difference between the observed current position with the previous one and with the next one.

5.1 Linear dynamics

The aim of this experiment is to infer the position of every node in trajectories generated by linear and gaussian equations. The advantage of using a synthetic environment is that we know in advance the original equations the motion pattern was generated from, and by providing the right linear and gaussian equations to a Kalman Smoother we can obtain the optimal inferred estimate as a lower bound of the test loss.

Among other tasks, Kalman Filters are used to refine the noisy measurement of GPS systems. A physics model of the dynamics can be provided to the graphical model that, combined with the noisy measurements, gives a more accurate estimation of the position. The real world is usually more complex than the equations we may provide to our graphical model, leading to a gap between the assumed dynamics and the real world dynamics. Our hybrid model is able to fill this gap without the need to learn everything from scratch.

To show that, we generate synthetic trajectories . Each state

is a 6-dimensional vector that encodes position, velocity and acceleration

for two dimensions. Each is a noisy measurement of the position also for two dimensions. The transition dynamic is a non-uniform accelerated motion that also considers drag (air resistance):


Where represents the air resistance falkovich2011fluid , with being a constant that depends on the properties of the fluid and the object dimensions. Finally, the variable is used to non-uniformly accelerate the object.

To generate the dataset, we sample from the Markov process of equation 2 where the transition probability distribution and the measurement probability distribution follow equations (3, 4). Values for these distributions are described in the Appendix, in particular, is analytically obtained from the above mentioned differential equations 19. We sample two different motion trajectories from 50 to 100K time steps each, one for validation and the other for testing. An additional 10K timesteps trajectory is sampled for testing. The sampling timestep is .

Figure 3: MSE comparison with respect to the number of training samples for the linear dynamics dataset.

Alternatively, the graphical model of the algorithm is limited to a uniform motion pattern . Its equivalent differential equations form would be . Notice that the air friction is not considered anymore and velocity and acceleration are assumed to be uniform. Again the parameters for the matrices when considering a uniform motion pattern are analytically obtained and described in the Appendix.


The Mean Square Error with respect to the number of training samples is shown for different algorithms in Figure 3. Note that the MSE of the Kalman Smoother and GM-messages overlap in the plot since both errors were exactly the same.

Our model outperforms both the GNN or Kalman Smoother in isolation in all data regimes, and it has a significant edge over the Kalman Smoother when the number of samples is larger than 1K. This shows that our model is able to ensemble the advantages of prior knowledge and deep learning in a single framework. These results show that our hybrid model benefits from the inductive bias of the graphical model equations when data is scarce, and simultaneously it benefits from the flexibility of the GNN when data is abound.

A clear trade-off can be observed between the Kalman smoother and the GNN. The Kalman Smoother clearly performs better for low data regimes, while the GNN outperforms it for larger amounts of data (>10K). The hybrid model is able to benefit from the strengths of both.

5.2 Lorenz Attractor

The Lorenz equations describe a non-linear chaotic system used for atmospheric convection. Learning the dynamics of this chaotic system in a supervised way is expected to be more challenging than for linear dynamics, making it an interesting evaluation of our Hybrid model. A Lorenz system is modelled by three differential equations that define the convection rate, the horizontal temperature variation and the vertical temperature variation of a fluid:


To generate a trajectory we run the Lorenz equations 20 with a from which we sample with a time step of

resulting in a single trajectory of 104K time steps. Each point is then perturbed with gaussian noise of standard deviation

. From this trajectory, 4K time steps are separated for testing, the remaining trajectory of 100K time steps is equally split between training and validation partitions.

Assuming is a 3-dimensional vector , we can write down the dynamics matrix of the system as from the Lorenz differential eq. 20, and obtain the transition function labbe2014kalman using the Taylor Expansion.



is the identity matrix and

is the number of terms from the Taylor expansion. We only run simulations for J=1 and J=2, because we empirically found that the improvement was minimal for larger . For the measurement model we use the identity matrix. For the noise distributions and we use diagonal matrices. The only hyper-parameter to tune from the graphical model is .

Figure 4: MSE comparison with respect to the the number of training samples on the Lorenz Attractor.

Since the dynamics are non-linear, the matrix depends on the values . The presence of these variables inside the matrix introduces a simple non-linearity that makes the function much harder to learn.


The results in Figure 4 show that the GNN struggles to learn the dynamics of this chaotic system, e.g. it does not reach the MSE of the E-Kalman Smoother even when the MSE of the Hybrid model is . We attribute this difficulty to the fact the matrix is different at every state , becoming harder to approximate.

This behavior is different from the previous experiment (linear dynamics) where both the Hybrid model and the GNN converged to the optimal solution for high data regimes. In this experiment, even when the GNN and the E-Kalman Smoother perform poorly, the Hybrid model gets closer to the optimal solution, significantly outperforming both of them in isolation. This shows that the Hybrid model benefits from the labeled data even in situations where its fully-supervised variant or the E-Kalman Smoother are unable to properly model the process. One reason for this could be that the residual dynamics (i.e. the error of the E-Kalman Smoother) is much more linear than the original dynamics and hence easier to model by the GNN.

Qualitative results of estimated trajectories by the different algorithms on the Lorenz attractor are depicted in Figure 1. The plots correspond to a 5K length test trajectory (with . All trainable algorithms have been trained on 30K length trajectories.

5.3 Real World Dynamics: Michigan NCLT dataset

To demonstrate the generalizability of our Hybrid model to real world datasets, we use the Michigan NCLT carlevaris2016university dataset which is collected using a segway robot moving around the University of Michigan’s North Campus. It comprises different trajectories where the GPS measurements and the ground truth location of the robot are provided. Given these noisy GPS observations, our goal is to infer a more accurate position of the segway at a given time.

Algorithm MSE
Observations (Baseline) 3.4974
Kalman Smoother 3.0514
GM-Messages 3.0048
GNN-Messages 1.7891
Hybrid model 1.4771
Table 1: MSE for different methods on the Michigan NCLT datset.

In our experiments we arbitrarily use the session with date 2012-01-22 which consists of a single trajectory of 6.1 Km on a cloudy day. Sampling at 1Hz results in 4.629 time steps and after removing the parts with a unstable GPS signal, 4.344 time steps remain. Finally, we split the trajectory into three sections: 1.502 time steps for training, 1.469 for validation and 1.373 for testing. The GPS measurements are assumed to be the noisy measurements denoted by .

For the transition and measurement graphical model distributions we assume the same uniform motion model used in section 5.1, specifically the dynamics of a uniform motion pattern. The only parameters to learn from the graphical model will be the variance from the measurement and transition distributions. The detailed equations are presented in the Appendix.


Our results show that our Hybrid model (1.4771 MSE) outperforms the GNN (1.7891 MSE), the Kalman Smoother (3.0514 MSE) and the GM-messages (3.0048 MSE). One of the advantages of the GNN and the Hybrid methods on real world datasets is that both can model the correlations through time from the noise distributions while the GM-messages and the Kalman Smoother assume the noise to be uncorrelated through time as it is defined in the graphical model. In summary, this experiment shows that our hybrid model can generalize with good performance to a real world dataset.

6 Discussion

In this work, we explored the combination of recent advances in neural networks (e.g. graph neural networks) with more traditional methods of graphical inference in hidden Markov models for time series. The result is a hybrid algorithm that benefits from the inductive bias of graphical models and from the high flexibility of neural networks. We demonstrated these benefits in three different tasks for trajectory estimation, a linear dynamics dataset, a non-linear chaotic system (Lorenz attractor) and a real world positioning system. In three experiments, the Hybrid method learns to efficiently combine graphical inference with learned inference, outperforming both when run in isolation.

Possible future directions include exploring hybrid methods for performing probabilistic inference in more general graphical models, as well learning the graphical model itself. In this work we used cross-validation to make sure we did not overfit the GNN part of the model to the data at hand, optimally balancing prior knowledge and data-driven inference. In the future we intend to explore a more principled Bayesian approach to this. Finally, hybrid models like the one presented on this paper can help improve the interpretability of model predictions due to their graphical model backbone.


Appendix A Appendix

a.1 Equation details the for Linear dynamics experiment

Linear and Gaussian dataset matrices:

The differential equations that describe the dynamics are (, ):


Therefore, the dynamics matrix is defined as:


Using the following Taylor expansion we find the transition matrix


The transition matrix for each dimension is:


Then, the transition matrix and noise distributions are:

The measurement matrix and noise distribution are:

Linear and Gaussian matrices used for hybrid and GM-messages models:

The differential equations that describe the dynamics are:


Then the transition matrix given the last equations is:


And the transition matrix and noise distributions are:

Finally the measurement distribution matrices are:

Such that the only parameter to optimize from the graphical model is the variance of the transition noise distribution

a.2 Equation details for the NCLT dataset

For the NCLT dataset we use the uniform velocity motion equations. The differential equations that describe the dynamics are:


Such that the transition matrix for one component is:


And the transition matrix and noise distribution are:

Finally the measurement distribution matrices are:

Such that the only parameters to optimize from the graphical model is the variance of the noise and measurement distributions and .