1. Introduction
In machine learning, gradient descent is a Markov chain of the parameters of the model being optimized, often augmented with some auxiliary quantities for methods such as Adam (Kingma & Ba, 2015)
. In probabilistic inference, Markov chain Monte Carlo (MCMC) is a Markov chain of the values of random variables, also often augmented with auxiliary variables
(e.g., Neal, 2011). Such algorithms are pervasive in the practice of machine learning, probabilistic and otherwise. As such, they have received a great deal of welldeserved attention among library and API developers (e.g., Salvatier et al., 2016; Bingham et al., 2018; Lao et al., 2020).We present FunMC, a Markov chain library designed specifically to enable and accelerate research into new constantmemory algorithms for machine learning and probabilistic inference. The design of FunMC follows these principles:

All API elements are completely stateless, simplifying reasoning, composition, and interoperation with functionoriented platforms like JAX (Bradbury et al., 2018) and subsystems like automatic differentation (through the Markov chain).

Pervasive support for returning and propagating side information.

Composition over configuration. All API elements aim to do one thing well, eschewing long argument or flag lists where possible.

A unified API for MCMC, optimization, and running statistics, for construction of hybrid algorithms.
We have implemented FunMC as a multibackend library of Python functions, executable with TensorFlow
(Abadi et al., 2015) and JAX. It should be relatively easy to port to any other Python machine learning framework. The code is available at https://github.com/tensorflow/probability/tree/master/discussion/fun_mcmc.The contributions of this work are:

We believe FunMC is the first framework for Markovian algorithms whose components are small enough and simple enough to be reused often, while also composing smoothly enough to build sophisticated algorithms out of.

We support this claim by demonstrating how to use FunMC for MCMC sampling using canned (Section 3) or custom (Section 4) methods, including thinning (Section 5); how to reparameterize the target (Section 6); how to do optimization (Section 7), including adapting parameters like step sizes (Figure 4); and how to compute streaming statistics (Section 8); all as independent composable components.
2. Related work
The FunMC library was designed in response to a number of preexisting MCMC and optimization frameworks. Pyro (Bingham et al., 2018) and PyMC3 (Salvatier et al., 2016)
separate the choice of MCMC transition kernel and the outer sampling loop. The optimization packages of PyTorch
(Steiner et al., 2019)and TensorFlow make a similar decomposition between the loss on one side and the optimization algorithms or outer training loop on the other. All define a large number of standard MCMC and optimization transition kernels with arguments to specify the hyperparameters and how to adapt them. In the case of Pyro and PyMC3, the choice of the adaptation algorithms is limited, and experimentation with them would require writing new transition kernels with little or no reuse of existing code. In the case of PyTorch and TensorFlow, there is support for flexible learning rate adaptation, but other hyperparameters do not have the same support.
The TensorFlow Probability MCMC (
tfp.mcmc) library (Lao et al., 2020) makes the same factorization of outer sampling loop and transition kernel, but uses a flexible transition kernel DSL, where hyperparameter adaptation is effected by nesting transition kernels. For example, to adapt the step size for Hamiltonian Monte Carlo (Neal, 2011), one can wrap a step size adaptation kernel around an HMC kernel. A shortcoming of this approach is that it requires complex messagepassing logic between the kernels, forces a DAG structure on the compound transition kernel computation, and requires learning a new DSL. FunMC does not distinguish between outer loops and transition kernels, and uses a more decoupled approach for hyperparameter adaptation, described below.Modern optimization packages and tfp.mcmc support definining losses and target densities as simple callables, using automatic differentiation to compute any necessary gradients. These callables return the quantity of interest, (loss or log density), but have limited to no support of returning side information (e.g., intermediate neural net layer activations). PyMC3 has a concept of deterministic variables which enables this side channel, but it requires a DSL for defining probabilistic models. FunMC retains the flexibility of accepting a simple callable, and makes a simple modification to its calling convention to support returning side information as per PyMC3.
Pyro and PyMC3 implement reparameterization by taking transport maps as arguments to their outer sampling loops. TensorFlow Probability implements this via its transition kernel DSL, where a reparameterization kernel alters the space on which the wrapped kernels operate on. All of thse approaches limit the opportunity to adapt the parameters of the transport maps as part of the sampling procedure. FunMC takes reparameterization completely out of the transition kernels, using a function transformation instead.
Accelerators based on SIMD operations are at the forefront of highperformance computing in machine learning today. The tfp.mcmc library has pioneered using this capability to run multiple independent chains. Relatedly, NumPyro (Phan et al., 2019) achieved a similar capability using function transformations. FunMC takes a hybrid approach, mixing both explicit and automatic batching to run independent operations in parallel.
3. A running example: generating samples using MCMC
As a running example of the FunMC API, consider using Hamiltonian Monte Carlo (HMC) to sample from the posterior of a simple Bayesian logistic regression model (Figure 1). We first define the target density, which is a function with type , on Figure 1. It is a function from the value of the random variable to two return values. The first is the unnormalized log density evaluated at . The second is arbitrary, and can be used to return auxiliary information for debugging and other purposes. In this example we return logits, which is a function of .
The HMC transition kernel is a function of type , on Figure 1. It conforms to the concept of a Markov transition kernel by taking and returning the state of the HMC sampler, but atypically it also has a second output. This second output is used for returning auxiliary information that is not a part of the state of the Markov chain. The kernel is a wrapper around the fun_mc.hamiltonian_monte_carlo transition kernel provided by FunMC with type . The wrapper’s purpose is to partially apply fun_mc.hamiltonian_monte_carlo, and to manipulate the side information.
Finally, we iterate our transition kernel on Figure 1 using the fun_mc.trace transition kernel, which has type . This kernel iterates the function passed as its second argument and stacks its corresponding auxiliary outputs. We defined kernel to return the value of , , and whether the state proposed by HMC was accepted, which last is available from the auxiliary return value of fun_mc.hamiltonian_monte_carlo. The entire history of this side information is thus available for inspection after the completed HMC run.
4. Custom MCMC
FunMC provides the precanned fun_mc.hamiltonian_monte_carlo kernel for the sake of convenience. Following the philosphy of composition over configuration, however, that kernel just consists (Appendix A, Figure 6) of sampling the auxiliary momentum, running the integrator, and then performing an acceptreject step, all of which are public functions of FunMC. This makes it straightforward for the user to write their own version if necessary, with maximal code reuse.
5. fun_mc.trace transition kernel
The primary usage of the fun_mc.trace operator is as shown in Figure 1—computing a (toplevel) trace of Markov chain states. However, the type of fun_mc.trace intentionally admits partial application to produce a transition kernel. This can be used, for example, for thinning, as in Figure 2. For this purpose, fun_mc.trace accepts a trace_mask argument, which is prefixtreezipped with the auxiliary return. The boolean leaves in the mask signify whether the corresponding subtrees in the auxiliary return get traced or not. A scalar False matches the entire tree, causing it to be propagated without history, which is what we want for thinning.
The concept of the kernel returning auxiliary information shows its value here, by separating the true state of the Markov chain from information derived from it, but not propagated. Without framework provision for side information, this effect would require either statefully mutated side variables, or augmenting the Markov chain state with the additional quantities that are simply discarded before each transition. Such a scheme would conceal the Markovian structure as well as force the user to come up with some initial state for these auxiliary quntities.
6. Reparameterization of potential functions
HMC is well known to be sensitive to the geometry of the target density. One way to ameliorate this issue is through reparameterization (Papaspiliopoulos et al., 2007; Parno & Marzouk, 2014). In FunMC, this is effected via function composition, taking care to remap arguments appropriately using the fun_mc.reparam_potential_fn utility function. It has type
It accepts a density function (operating in some “original” space), a diffeomorphism from the reparameterized space to the original space, and a point in the original space. It returns the corresponding density function in the reparameterized space and the corresponding point in the reparameterized space. This can be used to initialize an inference algorithm that then operates in the reparameterized space.
While reparameterization can be advantageous for running the Markov chain, the user may be interested in inspecting the chain’s states in the original parameterization. FunMC’s pervasive support for side returns makes this easy to arrange: the reparameterized density function just returns the point in the original space on the side. The fun_mc.hamiltonian_monte_carlo kernel propagates it, and the kernel we wrote in Figure 3 extracts it on Figure 3 and exposes just the original parameterization to tracing by fun_mc.trace (Figure 3).
The diffeomorphism must be an invertible function with tractable Jacobian logdeterminant. FunMC relies on its backend to compute the inverse and the logdeterminant of the Jacobian. One practical way to do this is to code the diffeomorphism using the tfp.bijectors library (Dillon et al., 2017) (FunMC knows the inversion and Jacobian computation API defined thereby). The automatic inversion mechanism being developed for JAX (Vikram et al., 2020) looks like a promising future alternative.
In summary, to reparameterize the target in our running example from Figure 1, we replace the kernel definition (Figure 1) with Figure 3. Since fun_mc.reparameterize_potential_fn is just a function, the user can call it inside their kernel, using a different diffeomorphism at each point in the chain. This supports adapting or inferring the diffeomorphism’s parameters.
7. Optimization
Optimization can often be cast as a Markov chain, so FunMC provides a number of transition kernels that implement common optimization algorithms such as gradient descent and Adam. Besides their standard uses, these kernels are reusable for hyperparameter adaptation in MCMC. For example, it is conventional to tune the step size parameter of HMC to hit an acceptance rate between 0.6 and 0.8 (Betancourt et al., 2014). Given , the acceptance rate as step , the statistic can be used as the gradient of a surrogate loss function of the step size. The step size is then adapted during the run using some gradientbased optimization scheme such as Nesterov dual averaging (Hoffman & Gelman, 2011), a choice usually hardcoded into a MCMC library.
The composability of FunMC’s components lets us plug in other methods instead. For example, to use Adam, we just replace the kernel definition in Figure 1, Figure 1 with Figure 4. Since Adam expects a differentiable loss function rather than a gradient, we need a little jiu jitsu on Figure 4, Figure 4, in the form of fun_mc.make_surrogate_loss_fn.
8. Streaming statistics
In many cases, MCMC is used to generate samples that are then summarized and the samples discarded. Materializing all the samples is an inefficient use of memory, so FunMC provides a number of streaming statistics to compute simple and exponentially weighted moving means, variances and covariances. As FunMC is explicitly batchaware, it can compute independent statistics for each chain, but also perform preaggregation across the chains.
One perennial defect of MCMC chains is poor mixing. Diagnostics to detect that have been proposed, but standard implementations of such diagnostics require access to the original chain history, which negates the benefit of streaming statistics. FunMC provides a streaming version of the Gelman potential scale reduction (Gelman & Rubin, 1992)
, a simple wrapper around FunMC’s streaming variance estimator. A streaming estimator for the autocovariance function is also provided which can be used to compute the effective sample size.
We can code a chain tracking a streaming mean, covariance, and by replacing the kernel definition Figure 1, Figure 1 with Figure 5. Note that requires multiple chains. In this example, we get them by adding a leading dimension to the chain state which indexes independent chains. Following the convention established by the tfp.mcmc library, FunMC interprets this encoding by running the HMC leapfrog and MetropolisHastings steps on all the chains in parallel.
9. Discussion
We have presented FunMC and illustrated multiple ways its components can be combined to create custom MCMC and optimization algorithms. This composition is a consequence of the core principles of statelessness, returning side information from functions, and writing higherorder functions to propagate this side information. We hope this library will accelerate methodological research, either through direct use or as an example of how to structure a composable API for Markovian algorithms.
Acknowledgements.
Authors would like to thank Srinivas Vasudevan, Matthew D. Hoffman, Jacob Burnim, Rif A. Saurous, Dave Moore and the rest of the TensorFlow Probability team for helpful comments.References
 Abadi et al. (2015) Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard, M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Levenberg, J., Mané, D., Monga, R., Moore, S., Murray, D., Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever, I., Talwar, K., Tucker, P., Vanhoucke, V., Vasudevan, V., Viégas, F., Vinyals, O., Warden, P., Wattenberg, M., Wicke, M., Yu, Y., and Zheng, X. TensorFlow: Largescale machine learning on heterogeneous systems, 2015. URL http://tensorflow.org/. Software available from tensorflow.org.
 Betancourt et al. (2014) Betancourt, M. J., Byrne, S., and Girolami, M. Optimizing The Integrator Step Size for Hamiltonian Monte Carlo. 2014. URL http://arxiv.org/abs/1411.6669.
 Bingham et al. (2018) Bingham, E., Chen, J. P., Jankowiak, M., Obermeyer, F., Pradhan, N., Karaletsos, T., Singh, R., Szerlip, P., Horsfall, P., and Goodman, N. D. Pyro: Deep Universal Probabilistic Programming. Journal of Machine Learning Research, 2018.
 Bradbury et al. (2018) Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., and WandermanMilne, S. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
 Dillon et al. (2017) Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., Patton, B., Alemi, A., Hoffman, M., and Saurous, R. A. TensorFlow Distributions. 2017. URL http://arxiv.org/abs/1711.10604.
 Gelman & Rubin (1992) Gelman, A. and Rubin, D. B. Inference from iterative simulation using multiple sequences. Statistical Science, 1992.
 Hoffman & Gelman (2011) Hoffman, M. D. and Gelman, A. The nouturn sampler: Adaptively setting path lengths in hamiltonian monte carlo. The Journal of Machine Learning Research, 2011.
 Kingma & Ba (2015) Kingma, D. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
 Lao et al. (2020) Lao, J., Suter, C., Chimisov, C., Langmore, I., Saxena, A., Sountsov, P., Moore, D., Hoffman, M. D., and Dillon, J. V. tfp.mcmc: Modern markov chain monte carlo tools built for modern hardware. 2020. URL https://www.tensorflow.org/probability. In preparation.
 Neal (2011) Neal, R. M. MCMC using Hamiltonian dynamics. In Handbook of Markov Chain Monte Carlo. CRC Press New York, NY, 2011.
 Papaspiliopoulos et al. (2007) Papaspiliopoulos, O., Roberts, G. O., and Sköld, M. A general framework for the parametrization of hierarchical models. Statistical Science, 22(1):59–73, 2007.
 Parno & Marzouk (2014) Parno, M. and Marzouk, Y. Transport map accelerated markov chain monte carlo. arXiv preprint arXiv:1412.5492, 2014.
 Phan et al. (2019) Phan, D., Pradhan, N., and Jankowiak, M. Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro. 2019. URL http://arxiv.org/abs/1912.11554.
 Salvatier et al. (2016) Salvatier, J., Wiecki, T. V., and Fonnesbeck, C. Probabilistic programming in Python using PyMC3. PeerJ Computer Science, 2016(4):1–24, 2016.

Steiner et al. (2019)
Steiner, B., Devito, Z., Chintala, S., Gross, S., Paszke, A., Massa, F., Lerer,
A., Chanan, G., Lin, Z., Yang, E., Desmaison, A., Tejani, A., Kopf, A.,
Bradbury, J., Antiga, L., Raison, M., Gimelshein, N., Chilamkurthy, S.,
Killeen, T., Fang, L., and Bai, J.
PyTorch: An Imperative Style, HighPerformance Deep Learning Library.
NeuroIPS, (NeurIPS), 2019.  Vikram et al. (2020) Vikram, S., Radul, A., and Hoffman, M. D. Probabilistic Programming Transformations for JAX. 2020. In preparation.
Appendix A FunMC’s Hamiltonian Monte Carlo
Figure 6 lists the precanned Hamiltonian Monte Carlo in FunMC. This is a simple composition of reusable components, which a user can recombine to implement many HMC variants.