Deep universal probabilistic programming with Python and PyTorch
Pyro is a probabilistic programming language built on Python as a platform for developing advanced probabilistic models in AI research. To scale to large datasets and high-dimensional models, Pyro uses stochastic variational inference algorithms and probability distributions built on top of PyTorch, a modern GPU-accelerated deep learning framework. To accommodate complex or model-specific algorithmic behavior, Pyro leverages Poutine, a library of composable building blocks for modifying the behavior of probabilistic programs.READ FULL TEXT VIEW PDF
Probabilistic programming provides the means to represent and reason abo...
We present a new algorithm for approximate inference in probabilistic
The TensorFlow Distributions library implements a vision of probability
It is a significant challenge to design probabilistic programming system...
We consider the problem of Bayesian inference in the family of probabili...
The aim of probabilistic programming is to automatize every aspect of
NumPyro is a lightweight library that provides an alternate NumPy backen...
Deep universal probabilistic programming with Python and PyTorch
In recent years, richly structured probabilistic models have demonstrated promising results on a number of fundamental problems in AI (Ghahramani (2015)). However, most such models are still implemented from scratch as one-off systems, slowing their development and limiting their scope and extensibility. Probabilistic programming languages (PPLs) promise to reduce this burden, but in practice more advanced models often require high-performance inference engines tailored to a specific application. We identify design principles that enable a PPL to scale to advanced research applications while retaining flexibility, and we argue that these principles are fully realized together in the Pyro PPL.
First, a PPL suitable for developing state-of-the-art AI research models should be expressive
: it should be able to concisely describe models with have data-dependent internal control flow or latent variables whose existence depends on the values of other latent variables, or models which only be defined in closed form as unnormalized joint distributions.
For a PPL to be practical, it must be scalable: its approximate inference algorithms must be able to seamlessly handle the large datasets and non-conjugate, high-dimensional models common in AI research, and should exploit compiler acceleration when possible.
A PPL targeting research models should be flexible: in addition to scalability, many advanced models require inference algorithms with complex, model-specific behavior. A PPL should enable researchers to quickly and easily implement such behavior and should enforce a separation of concerns between between model, inference, and runtime implementations.
Finally, a PPL targeting researchers as users should strive to be minimal: in order to minimize cognitive overhead, it should share most of its syntax and semantics with existing languages and and systems and work well with other tools such as libraries for visualization.
As is clear from Table 2, these four principles are often in conflict, with one being achieved at the expense of others. For example, an overly flexible design may be very difficult to implement efficiently and scalably, especially while simultaneously integrating a new language with existing tools. Similarly, enabling development of custom inference algorithms may be difficult without limiting model expressivity. In this section, we describe the design choices we made in Pyro to balance between all four objectives.
Pyro is embedded in Python, and Pyro programs are written as Python functions, or callables, with just two extra language primitives (whose behavior is overridden by inference algorithms): pyro.sample for annotating calls to functions with internal randomness, and pyro.param for registering learnable parameters with inference algorithms that can change them. Pyro models may contain arbitrary Python code and interact with it in arbitrary ways, including expressing unnormalized models through the obs keyword argument to pyro.sample
. Pyro’s language primitives may be used with all of Python’s control flow constructs, including recursion, loops, and conditionals. The existence of random variables in a particular execution may thus depend on any Python control flow construct.
Pyro implements several generic probabilistic inference algorithms, including the No U-turn Sampler (Hoffman and Gelman (2014)), a variant of Hamiltonian Monte Carlo. However, the primary inference algorithm is gradient-based stochastic variational inference (SVI) (Kingma and Welling (2014)
), which uses stochastic gradient descent to optimize Monte Carlo estimates of a divergence measure between approximate and true posterior distributions. Pyro scales to complex, high-dimensional models thanks to GPU-accelerated tensor math and reverse-mode automatic differentiation via PyTorch, and it scales to large datasets thanks to stochastic gradient estimates computed over mini-batches of data in SVI.
Some inference algorithms in Pyro, such as SVI and importance sampling, can use arbitrary Pyro programs (called guides, following webPPL) as approximate posteriors or proposal distributions. A guide for a given model must take the same input arguments as the model and contain a corresponding sample statement for every unconstrained sample statement in the model but is otherwise unrestricted. Users are then free to express complex hypotheses about the posterior distribution, e.g. its conditional independence structure. Unlike webPPL and Anglican, in Pyro guides may not depend on values inside the model.
Finally, to achieve flexibility and separation of concerns, Pyro is built on Poutine, a library of effect handlers (Kammar et al. (2013)) that implement individual control and book-keeping operations used for inspecting and modifying the behavior of Pyro programs, separating inference algorithm implementations from language details.
Pyro’s source code is freely available under an MIT license and developed by the authors and a community of open-source contributors at https://github.com/uber/pyro and documentation, examples, and a discussion forum are hosted online at https://pyro.ai. A comprehensive test suite is run automatically by a continuous integration service before code is merged into the main codebase to maintain a high level of project quality and usability.
We also found that while PyTorch was an invaluable substrate for tensor operations and automatic differentiation, it was lacking a high-performance library of probability distributions. As a result, several of the authors made substantial open-source contributions upstream to PyTorch Distributions,111http://pytorch.org/docs/stable/distributions.html
a new PyTorch core library inspired by TensorFlow Distributions (Dillon et al. (2017)).
|Stan||Static control flow||Some, CPU||Automated||None|
|Edward||Static control flow||Yes, CPU/GPU||Yes||TensorFlow|
|Pyro||Yes||Yes, CPU/GPU||Yes||Python, PyTorch|
Probabilistic programming and approximate inference are areas of active research, so there are many existing probabilistic programming languages and systems. We briefly mention several that were especially influential in Pyro’s development, regretfully omitting (due to space limitations) many systems for which simple direct comparisons are more difficult. We also emphasize that, as in conventional programming language development, our design decisions are not universally applicable or desirable for probabilistic programming, and that other systems purposefully make different tradeoffs to achieve different goals.
To demonstrate that Pyro meets our design goals, we implemented several state-of-the-art models.222http://pyro.ai/examples/ Here we focus on the variational autoencoder (VAE; Kingma and Welling (2014)
) and the Deep Markov Model (DMM;Krishnan et al. (2017)), a non-linear state space model that has been used for several applications including audio generation and causal inference. The VAE is a standard example in deep probabilistic modeling, while the DMM has several characteristics that make it ideal as a point of comparison: it is a high-dimensional, non-conjugate model designed to be fit to large datasets; the number of latent variables in a sequence depends on the input data; and it uses a hand-designed approximate posterior.
The models and inference procedures derived by Pyro replicate the original papers almost exactly, except that we use Monte Carlo estimates rather than exact analytic expressions for KL divergence terms. We use the MNIST dataset and 2-hidden-layer MLP encoder and decoder networks with varying hidden layer size and latent code size for the VAE and the same dataset of digitized music333This is the JSB chorales dataset from http://www-etud.iro.umontreal.ca/\~Eboulanni/icml2012 to train the DMM.
To demonstrate that Pyro’s abstractions do not reduce its scalability by introducing too much overhead, we compared our VAE implementation with an idiomatic PyTorch implementation.444We used the PyTorch example at https://github.com/pytorch/examples/tree/master/vae
After verifying that they converge to the same test ELBO, we compared the wall-clock time taken to compute one gradient update, averaged over 10 epochs of GPU-accelerated mini-batch stochastic gradient variational inference (batch size 128) on a single NVIDIA GTX 1080Ti. Figure4 shows that the relative performance gap between the Pyro and PyTorch versions is moderate, and, more importantly, that it shrinks as the total time spent performing tensor operations increases.
We used the DMM to evaluate Pyro’s flexibility and expressiveness. As Figure 4 shows, we found that we were able to quickly and concisely replicate the exact DMM model and inference configuration and quantitative results reported in the paper after 5000 training epochs. Furthermore, thanks to Pyro’s modular design, we were also able to build DMM variants with more expressive approximate posteriors via autoregressive flows (IAFs) (Kingma et al. (2016)), improving the results with a few lines of code at negligible computational cost.
We would like to acknowledge Du Phan, Adam Scibior, Dustin Tran, Adam Paszke, Soumith Chintala, Robert Hawkins, Andreas Stuhlmueller, our colleagues in Uber AI Labs, and the Pyro and PyTorch open-source communities for helpful contributions and feedback.