Modeling often has two goals: first, to learn a flexible representation of complex high-dimensional data, such as images or speech recordings, and second, to find structure that is interpretable and generalizes to new tasks. Probabilistic graphical models(koller2009probabilistic; murphy2012machine) provide many tools to build structured representations, but often make rigid assumptions and may require significant feature engineering. Alternatively, deep learning methods allow flexible data representations to be learned automatically, but may not directly encode interpretable or tractable probabilistic structure. Here we develop a general modeling and inference framework that combines these complementary strengths.
Consider learning a generative model for video of a mouse. Learning interpretable representations for such data, and comparing them as the animal’s genes are edited or its brain chemistry altered, gives useful behavioral phenotyping tools for neuroscience and for high-throughput drug discovery (wiltschko2015mapping). Even though each image is encoded by hundreds of pixels, the data lie near a low-dimensional nonlinear manifold. A useful generative model must not only learn this manifold but also provide an interpretable representation of the mouse’s behavioral dynamics. A natural representation from ethology (wiltschko2015mapping) is that the mouse’s behavior is divided into brief, reused actions, such as darts, rears, and grooming bouts. Therefore an appropriate model might switch between discrete states, with each state representing the dynamics of a particular action. These two learning tasks — identifying an image manifold and a structured dynamics model — are complementary: we want to learn the image manifold in terms of coordinates in which the structured dynamics fit well. A similar challenge arises in speech (hinton2012deep)
, where high-dimensional spectrographic data lie near a low-dimensional manifold because they are generated by a physical system with relatively few degrees of freedom(deng1999computational) but also include the discrete latent dynamical structure of phonemes, words, and grammar (deng2004switching).
To address these challenges, we propose a new framework to design and learn models that couple nonlinear likelihoods with structured latent variable representations. Our approach uses graphical models for representing structured probability distributions while enabling fast exact inference subroutines, and uses ideas from variational autoencoders(kingma2013autoencoding; rezende2014stochastic) for learning not only the nonlinear feature manifold but also bottom-up recognition networks to improve inference. Thus our method enables the combination of flexible deep learning feature models with structured Bayesian (and even nonparametric (johnson2014svi)) priors. Our approach yields a single variational inference objective in which all components of the model are learned simultaneously. Furthermore, we develop a scalable fitting algorithm that combines several advances in efficient inference, including stochastic variational inference (hoffman2013stochastic), graphical model message passing (koller2009probabilistic)
, and backpropagation with the reparameterization trick(kingma2013autoencoding). Thus our algorithm can leverage conjugate exponential family structure where it exists to efficiently compute natural gradients with respect to some variational parameters, enabling effective second-order optimization (martens2015perspectives), while using backpropagation to compute gradients with respect to all other parameters. We refer to our general approach as the structured variational autoencoder (SVAE).
2 Latent graphical models with neural net observations
In this paper we propose a broad family of models. Here we develop three specific examples.
2.1 Warped mixtures for arbitrary cluster shapes
One particularly natural structure used frequently in graphical models is the discrete mixture model. By fitting a discrete mixture model to data, we can discover natural clusters or units. These discrete structures are difficult to represent directly in neural network models.
Consider the problem of modeling the data shown in Fig. 0(a)
However, the fit GMM does not represent the natural clustering of the data (Fig. 0(b)). Its inflexible Gaussian observation model limits its ability to parsimoniously fit the data and their natural semantics.
Instead of using a GMM, a more flexible alternative would be a neural network density model:
where and depend on
through some smooth parametric function, such as multilayer perceptron (MLP), and whereis a Gaussian prior mackay1999density. This model fits the data density well (Fig. 0(c)) but does not explicitly represent discrete mixture components, which might provide insights into the data or natural units for generalization. See Fig. 1(a) for a graphical model.
By composing a latent GMM with nonlinear observations, we can combine the modeling strengths of both (IwaDuvGha2012warped), learning both discrete clusters along with non-Gaussian cluster shapes:
2.2 Latent linear dynamical systems for modeling video
Now we consider a harder problem: generatively modeling video. Since a video is a sequence of image frames, a natural place to start is with a model for images. kingma2013autoencoding shows that the density network of Eq. (1) can accurately represent a dataset of high-dimensional images in terms of the low-dimensional latent variables
, each with independent Gaussian distributions.
To extend this image model into a model for videos, we can introduce dependence through time between the latent Gaussian samples . For instance, we can make each latent variable depend on the previous latent variable through a Gaussian linear dynamical system, writing
where the matrices and
have a conjugate prior. This model has low-dimensional latent states and dynamics as well as a rich nonlinear generative model of images. In addition, the timescales of the dynamics are represented directly in the eigenvalue spectrum of, providing both interpretability and a natural way to encode prior information. See Fig. 1(c) for a graphical model.
2.3 Latent switching linear dynamical systems for parsing behavior from video
As a final example that combines both time series structure and discrete latent units, consider again the behavioral phenotyping problem described in Section 1. Drawing on graphical modeling tools, we can construct a latent switching linear dynamical system (SLDS) (fox2011slds) to represent the data in terms of continuous latent states that evolve according to a discrete library of linear dynamics, and drawing on deep learning methods we can generate video frames with a neural network image model.
At each time there is a discrete-valued latent state that evolves according to Markovian dynamics. The discrete state indexes a set of linear dynamical parameters, and the continuous-valued latent state evolves according to the corresponding dynamics,
where denotes the Markov transition matrix and is its th row. We use the same neural net observation model as in Section 2.2. This SLDS model combines both continuous and discrete latent variables with rich nonlinear observations. See Fig. 1(d) for a graphical model.
3 Structured mean field inference and recognition networks
Why aren’t such rich hybrid models used more frequently? The main difficulty with combining rich latent variable structure and flexible likelihoods is inference. The most efficient inference algorithms used in graphical models, like structured mean field and message passing, depend on conjugate exponential family likelihoods to preserve tractable structure. When the observations are more general, like neural network models, inference must either fall back to general algorithms that do not exploit the model structure or else rely on bespoke algorithms developed for one model at a time.
In this section, we review inference ideas from conjugate exponential family probabilistic graphical models and variational autoencoders, which we combine and generalize in the next section.
3.1 Inference in graphical models with conjugacy structure
Graphical models and exponential families provide many algorithmic tools for efficient inference (wainwright2008graphical). Given an exponential family latent variable model, when the observation model is a conjugate exponential family, the conditional distributions stay in the same exponential families as in the prior and hence allow for the same efficient inference algorithms.
For example, consider learning a Gaussian linear dynamical system model with linear Gaussian observations. The generative model for latent states and observations is
given parameters with a conjugate prior . To approximate the posterior , consider the mean field family and the variational inference objective
where we can optimize the variational family to approximate the posterior by maximizing Eq. (7). Because the observation model is conjugate to the latent variable model , for any fixed the optimal factor is itself a Gaussian linear dynamical system with parameters that are simple functions of the expected statistics of and the data . As a result, for fixed we can easily compute and use message passing algorithms to perform exact inference in it. However, when the observation model is not conjugate to the latent variable model, these algorithmically exploitable structures break down.
3.2 Recognition networks in variational autoencoders
The variational autoencoder (VAE) (kingma2013autoencoding) handles general non-conjugate observation models by introducing recognition networks. For example, when a Gaussian latent variable model is paired with a general nonlinear observation model , the posterior is non-Gaussian, and it is difficult to compute an optimal Gaussian approximation. The VAE instead learns to directly output a suboptimal Gaussian factor by fitting a parametric map from data to a mean and covariance, and , such as an MLP with parameters . By optimizing over , the VAE effectively learns how to condition on non-conjugate observations and produce a good approximating factor.
4 Structured variational autoencoders
We can combine the tractability of conjugate graphical model inference with the flexibility of variational autoencoders. The main idea is to use a conditional random field (CRF) variational family. We learn recognition networks that output conjugate graphical model potentials instead of outputting the complete variational distribution’s parameters directly. These potentials are then used in graphical model inference algorithms in place of the non-conjugate observation likelihoods.
The SVAE algorithm computes stochastic gradients of a mean field variational inference objective. It can be viewed as a generalization both of the natural gradient SVI algorithm for conditionally conjugate models (hoffman2013stochastic) and of the AEVB algorithm for variational autoencoders (kingma2013autoencoding). Intuitively, it proceeds by sampling a data minibatch, applying the recognition model to compute graphical model potentials, and using graphical model inference algorithms to compute the variational factor, combining the evidence from the potentials with the prior structure in the model. This variational factor is then used to compute gradients of the mean field objective. See Fig. 3 for graphical models of the variational families with recognition networks for the models developed in Section 2.
In this section, we outline the SVAE model class more formally, write the mean field variational inference objective, and show how to efficiently compute unbiased stochastic estimates of its gradients. The resulting algorithm for computing gradients of the mean field objective, shown in Algorithm1, is simple and efficient and can be readily applied to a variety of learning problems and graphical model structures. See the supplementals for details and proofs.
4.1 SVAE model class
To set up notation for a general SVAE, we first define a conjugate pair of exponential family densities on global latent variables and local latent variables . Let be an exponential family and let be its corresponding natural exponential family conjugate prior, writing
where we used exponential family conjugacy to write . The local latent variables could have additional structure, like including both discrete and continuous latent variables or tractable graph structure, but here we keep the notation simple.
Next, we define a general likelihood function. Let be a general family of densities and let be an exponential family prior on its parameters. For example, each observation may depend on the latent value through an MLP, as in the density network model of Section 2. This generic non-conjugate observation model provides modeling flexibility, yet the SVAE can still leverage conjugate exponential family structure in inference, as we show next.
4.2 Stochastic variational inference algorithm
Though the general observation model means that conjugate updates and natural gradient SVI (hoffman2013stochastic) cannot be directly applied, we show that by generalizing the recognition network idea we can still approximately optimize out the local variational factors leveraging conjugacy structure.
For fixed , consider the mean field family and the variational inference objective
Without loss of generality we can take the global factor to be in the same exponential family as the prior , and we denote its natural parameters by . We restrict to be in the same exponential family as with natural parameters . Finally, we restrict to be in the same exponential family as , writing its natural parameter as . Using these explicit variational parameters, we write the mean field variational inference objective in Eq. (10) as .
To perform efficient optimization of the objective , we consider choosing the variational parameter as a function of the other parameters and . One natural choice is to set to be a local partial optimizer of . However, without conjugacy structure finding a local partial optimizer may be computationally expensive for general densities , and in the large data setting this expensive optimization would have to be performed for each stochastic gradient update. Instead, we choose by optimizing over a surrogate objective with conjugacy structure, given by
where is some parameterized class of functions that serves as the recognition model. Note that the potentials have a form conjugate to the exponential family . We define to be a local partial optimizer of along with the corresponding factor ,
As with the variational autoencoder of Section 3.2, the resulting variational factor is suboptimal for the variational objective . However, because the surrogate objective has the same form as a variational inference objective for a conjugate observation model, the factor not only is easy to compute but also inherits exponential family and graphical model structure for tractable inference.
Given this choice of , the SVAE objective is . This objective is a lower bound for the variational inference objective Eq. (10) in the following sense.
Proposition 4.1 (The SVAE objective lower-bounds the mean field objective)
The SVAE objective function lower-bounds the mean field objective in the sense that
for any parameterized function class . Furthermore, if there is some such that , then the bound can be made tight in the sense that
Thus by using gradient-based optimization to maximize we are maximizing a lower bound on the model log evidence . In particular, by optimizing over we are effectively learning how to condition on observations so as to best approximate the posterior while maintaining conjugacy structure. Furthermore, to provide the best lower bound we may choose the recognition model function class to be as rich as possible.
Choosing to be a local partial optimizer of provides two computational advantages. First, it allows and expectations with respect to to be computed efficiently by exploiting exponential family graphical model structure. Second, it provides computationally efficient ways to estimate the natural gradient with respect to the latent model parameters, as we summarize next.
Proposition 4.2 (Natural gradient of the SVAE objective)
The natural gradient of the SVAE objective with respect to can be estimated as
where . When there is only one local variational factor , then we can simplify the estimator to
Note that the first term in Eq. (13) is the same as the expression for the natural gradient in SVI for conjugate models (hoffman2013stochastic), while a stochastic estimate of in the first expression or, alternatively, a stochastic estimate of in the second expression is computed automatically as part of the backward pass for computing the gradients with respect to the other parameters, as described next. Thus we have an expression for the natural gradient with respect to the latent model’s parameters that is almost as simple as the one for conjugate models, differing only by a term involving the neural network likelihood function. Natural gradients are invariant to smooth invertible reparameterizations of the variational family (amari1998natural; amari2007methods) and provide effective second-order optimization updates (martens2015optimizing; martens2015perspectives).
The gradients of the objective with respect to the other variational parameters, namely and , can be computed using the reparameterization trick and standard automatic differentiation techniques. To isolate the terms that require the reparameterization trick, we rearrange the objective as
The KL divergence terms are between members of the same tractable exponential families. An unbiased estimate of the first term can be computed by samplingand and computing with automatic differentiation.
5 Related work
In addition to the papers already referenced, there are several recent papers to which this work is related.
The two papers closest to this work are krishnan2015deep and archer2015black. In krishnan2015deep
the authors consider combining variational autoencoders with continuous state-space models, emphasizing the relationship to linear dynamical systems (also called Kalman filter models). They primarily focus on nonlinear dynamics and an RNN-based variational family, as well as allowing control inputs. However, the approach does not extend to general graphical models or discrete latent variables. It also does not leverage natural gradients or exact inference subroutines.
In archer2015black the authors also consider the problem of variational inference in general continuous state space models but focus on using a structured Gaussian variational family without considering parameter learning. As with krishnan2015deep, this approach does not include discrete latent variables (or any latent variables other than the continuous states). However, the method they develop could be used with an SVAE to handle inference with nonlinear dynamics.
In addition, both gregor2015draw and chung2015recurrent extend the variational autoencoder framework to sequential models, though they focus on RNNs rather than probabilistic graphical models. 11todo: 1cite deep sigmoid 22todo: 2cite embed2control 33todo: 3cite percy 2008 44todo: 4cite sontag dkf
Finally, there is much related work on handling nonconjugate model terms in mean field variational inference. In khan2015kullback and khan2016faster the authors present a general scheme that is able to exploit conjugate exponential family structure while also handling arbitrary nonconjugate model factors, including the nonconjugate observation models we consider here. In particular, they propose using a proximal gradient framework and splitting the variational inference objective into a difficult term to be linearized (with respect to mean parameters) and a tractable concave term, so that the resulting proximal gradient update is easy to compute, just like in a fully conjugate model. In knowles2011ncvmp, the authors propose performing natural gradient descent with respect to natural parameters on each of the variational factors in turn, and they focus on approximating expectations of nonconjugate energy terms in the objective with model-specific lower-bounds (rather than estimating them with generic Monte Carlo). As in conjugate SVI (hoffman2013stochastic), they observe that, on conjugate factors and with an undamped update (i.e. a unit step size), the natural gradient update reduces to the standard conjugate mean field update.
In contrast to the approaches of khan2015kullback, khan2016faster, and knowles2011ncvmp, rather than linearizing intractable terms around the current iterate, in this work we handle intractable terms via recognition networks and amoritized inference (and the remaining tractable objective terms are multi-concave in general, analogous to SVI (hoffman2013stochastic)). That is, we use parametric function approximators to learn to condition on evidence in a conjugate form. We expect these approaches to handling nonconjugate objective terms may be complementary, and the best choice may be situation-dependent. For models with local latent variables and datasets where minibatch-based updating is important, using inference networks to compute local variational parameters in a fixed-depth circuit (as in the VAE (kingma2013autoencoding; rezende2014stochastic)) or optimizing out the local variational factors using fast conjugate updates (as in conjugate SVI (hoffman2013stochastic)) can be advantageous because in both cases local variational parameters for the entire dataset need not be maintained across updates. The SVAE we propose here is a way to combine the inference network and conjugate SVI approaches.
We apply the SVAE to both synthetic and real data and demonstrate its ability to learn feature representations and latent structure. Code is available at github.com/mattjj/svae.
6.1 LDS SVAE for modeling synthetic data
Consider a sequence of 1D images representing a dot bouncing from one side of the image to the other, as shown at the top of Fig. 4. We use an LDS SVAE to find a low-dimensional latent state space representation along with a nonlinear image model. The model is able to represent the image accurately and to make long-term predictions with uncertainty. See supplementals for details.
This experiment also demonstrates the optimization advantages that can be provided by the natural gradient updates. In Fig. 4(a) we compare natural gradient updates with standard gradient updates at three different learning rates. The natural gradient algorithm not only learns much faster but also is less dependent on parameterization details: while the natural gradient update used an untuned stepsize of 0.1, the standard gradient dynamics at step sizes of both 0.1 and 0.05 resulted in some matrix parameters to be updated to indefinite values.
6.2 LDS SVAE for modeling video
We also apply an LDS SVAE to model depth video recordings of mouse behavior. We use the dataset from wiltschko2015mapping in which a mouse is recorded from above using a Microsoft Kinect. We used a subset consisting of 8 recordings, each of a distinct mouse, 20 minutes long at 30 frames per second, for a total of 288000 video fames downsampled to pixels.
We use MLP observation and recognition models with two hidden layers of 200 units each and a 10D latent space. Fig. 4(b) shows images corresponding to a regular grid on a random 2D subspace of the latent space, illustrating that the learned image manifold accurately captures smooth variation in the mouse’s body pose. Fig. 6 shows predictions from the model paired with real data.
6.3 SLDS SVAE for parsing behavior
Finally, because the LDS SVAE can accurately represent the depth video over short timescales, we apply the latent switching linear dynamical system (SLDS) model to discover the natural units of behavior. Fig. 7 and Fig. 8 in the appendix show some of the discrete states that arise from fitting an SLDS SVAE with 30 discrete states to the depth video data. The discrete states that emerge show a natural clustering of short-timescale patterns into behavioral units. See the supplementals for more.
Examples of behavior states inferred from depth video. Each frame sequence is padded on both sides, with a square in the lower-right of a frame depicting when the state is the most probable.
Structured variational autoencoders provide a general framework that combines some of the strengths of probabilistic graphical models and deep learning methods. In particular, they use graphical models both to give models rich latent representations and to enable fast variational inference with CRF-like structured approximating distributions. To complement these structured representations, SVAEs use neural networks to produce not only flexible nonlinear observation models but also fast recognition networks that map observations to conjugate graphical model potentials.
Appendix A Optimization
In this section we fix our notation for gradients and establish some basic definitions and results that we use in the sequel.
a.1 Gradient notation
We follow the notation in bertsekas1999nonlinear. In particular, if is a continuously differentiable function, we define the gradient matrix of , denoted , to be the matrix in which the th column is the gradient of , the th coordinate function of , for . That is,
The transpose of is the Jacobian matrix of , in which the th entry is the function .
If is continuously differentiable with continuously differentiable partial derivatives, then we define the Hessian matrix of , denoted , to be the matrix in which the th entry is the function .
Finally, if is a function of with and , we write
a.2 Local and partial optimizers
In this section we state the definitions of local partial optimizer and necessary conditions for optimality that we use in the sequel.
Definition A.1 (Partial optimizer, local partial optimizer)
Let be an objective function to be maximized. For a fixed , we call a point an unconstrained partial optimizer of given if
and we call an unconstrained local partial optimizer of given if there exists an such that
where is any vector norm.
is any vector norm.
Proposition A.2 (Necessary conditions for optimality, Prop. 3.1.1 of bertsekas1999nonlinear)
Let be continuously differentiable. For fixed if is an unconstrained local partial optimizer for given then
If instead and are subject to the constraints for some continuously differentiable and is a constrained local partial optimizer for given with the regularity condition that is full rank, then there exists a Lagrange multiplier such that
and hence the cost gradient is orthogonal to the first-order feasible variations in given by the null space of .
Note that the regularity condition on the constraints is not needed if the constraints are linear (bertsekas1999nonlinear, Prop. 3.3.7).
For a continuously differentiable function , we say is a stationary point of if . For general unconstrained smooth optimization, the limit points of gradient-based algorithms are guaranteed only to be stationary points of the objective, not necessarily local optima. Block coordinate ascent methods, when available, provide slightly stronger guarantees: not only is every limit point a stationary point of the objective, in addition each coordinate block is a partial optimizer of the objective. Note that the objective functions we consider maximizing in the following are bounded above.
a.3 Partial optimization and the Implicit Function Theorem
Let be a scalar-valued objective function of two unconstrained arguments and , and let be some function that assigns to each a value . Define the composite function as
and using the chain rule write its gradient as
One choice of the function is to partially optimize for any fixed value of . For example, assuming that is nonempty for every , we could choose to satisfy , so that .111For a discussion of differentiability issues when there is more than one optimizer, i.e. when has more than one element, see danskin1967theory, fiacco1984introduction, and bonnans2000perturbation. Here we only consider the sensitivity of local stationary points and assume differentiability almost everywhere. Similarly, if is chosen so that , which is satisfied when is an unconstrained local partial optimizer for given , then the expression in Eq. (25) can be simplified as in the following proposition.
Proposition A.3 (Gradients of locally partially optimized objectives)
Let be continuously differentiable, let be a local partial optimizer of given such that is differentiable, and define . Then
If is an unconstrained local partial optimizer of given then it satisfies , and if is a regularly-constrained local partial optimizer then the feasible variation is orthogonal to the cost gradient . In both cases the second term in the expression for in Eq. (25) is zero. ∎
In general, when is not a stationary point of , to evaluate the gradient we need to evaluate in Eq. (25). However, this term may be difficult to compute directly. The function may arise implicitly from some system of equations of the form for some continuously differentiable function . For example, the value of may be computed from and using a black-box iterative numerical algorithm. However, the Implicit Function Theorem provides another means to compute using only the derivatives of and the value of .
Proposition A.4 (Implicit Function Theorem, Prop. A.25 of bertsekas1999nonlinear)
Let be a function and and be points such that
is continuous and has a continuous nonsingular gradient matrix in an open set containing .
Then there exist open sets and containing and , respectively, and a continuous function such that and for all . The function is unique in the sense that if , , and , then . Furthermore, if for some , is times continuously differentiable, the same is true for , and we have
As a special case, the equations may be the first-order stationary conditions of another unconstrained optimization problem. That is, the value of may be chosen by locally partially optimizing the value of for a function with no constraints on , leading to the following corollary.
Appendix B Exponential families
In this section we set up notation for exponential families and outline some basic results. Throughout this section we take all densities to be absolutely continuous with respect to the appropriate Lebesgue measure (when the underlying set is Euclidean space) or counting measure (when is discrete), and denote the Borel -algebra of a set as (generated by Euclidean and discrete topologies, respectively). We assume measurability of all functions as necessary.
Given a statistic function and a base measure , we can define an exponential family of probability densities on relative to and indexed by natural parameter by
where is the standard inner product on . We also define the partition function as
and define to be the set of all normalizable natural parameters,
We can write the normalized probability density as
We say that an exponential family is regular if is open, and minimal if there is no such that (-a.e.). We assume all families are regular and minimal.222Families that are not minimal, like the density of the categorical distribution, can be treated by restricting all algebraic operations to the subspace spanned by the statistic, i.e. to the smallest with . Finally, when we parameterize the family with some other coordinates , we write the natural parameter as a continuous function and write the density as
and take to be the open set of parameters that correspond to normalizable densities. We summarize this notation in the following definition.
Definition B.1 (Exponential family of densities)
Given a measure space , a statistic function , and a natural parameter function , the corresponding exponential family of densities relative to is
is the log partition function.
When we write exponential families of densities for different random variables, we change the subscripts on the statistic function, natural parameter function, and log partition function to correspond to the symbol used for the random variable. When the corresponding random variable is clear from context, we drop the subscripts to simplify notation.
The next proposition shows that the log partition function of an exponential family generates cumulants of the statistic.
Proposition B.2 (Gradients of and expected statistics)
The gradient of the log partition function of an exponential family gives the expected sufficient statistic,
where the expectation is over the random variable with density .
More generally, the moment generating function of
. More generally, the moment generating function ofcan be written
and so derivatives of give cumulants of , where the first cumulant is the mean and the second and third cumulants are the second and third central moments, respectively.
Given an exponential family of densities on as in Definition B.1, we can define a related exponential family of densities on by defining a statistic function in terms of the functions and .
Definition B.3 (Natural exponential family conjugate prior)
Given the exponential family of Definition B.1, define the statistic function as the concatenation
where the first coordinates of are given by and the last coordinate is given by . We call the exponential family with statistic the natural exponential family conjugate prior to the density and write
where and the density is taken relative to some measure on .
Notice that using we can rewrite the original density as
This relationship is useful in Bayesian inference: when the exponential familyis a likelihood function and the family is used as a prior, the pair enjoy a convenient conjugacy property, as summarized in the next proposition.