FunMC: A functional API for building Markov Chains

01/14/2020 ∙ by Pavel Sountsov, et al. ∙ Google 0

Constant-memory algorithms, also loosely called Markov chains, power the vast majority of probabilistic inference and machine learning applications today. A lot of progress has been made in constructing user-friendly APIs around these algorithms. Such APIs, however, rarely make it easy to research new algorithms of this type. In this work we present FunMC, a minimal Python library for doing methodological research into algorithms based on Markov chains. FunMC is not targeted toward data scientists or others who wish to use MCMC or optimization as a black box, but rather towards researchers implementing new Markovian algorithms from scratch.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1. Introduction

In machine learning, gradient descent is a Markov chain of the parameters of the model being optimized, often augmented with some auxiliary quantities for methods such as Adam (Kingma & Ba, 2015)

. In probabilistic inference, Markov chain Monte Carlo (MCMC) is a Markov chain of the values of random variables, also often augmented with auxiliary variables

(e.g., Neal, 2011). Such algorithms are pervasive in the practice of machine learning, probabilistic and otherwise. As such, they have received a great deal of well-deserved attention among library and API developers (e.g., Salvatier et al., 2016; Bingham et al., 2018; Lao et al., 2020).

We present FunMC, a Markov chain library designed specifically to enable and accelerate research into new constant-memory algorithms for machine learning and probabilistic inference. The design of FunMC follows these principles:

  1. All API elements are completely stateless, simplifying reasoning, composition, and interoperation with function-oriented platforms like JAX (Bradbury et al., 2018) and subsystems like automatic differentation (through the Markov chain).

  2. Pervasive support for returning and propagating side information.

  3. Composition over configuration. All API elements aim to do one thing well, eschewing long argument or flag lists where possible.

  4. A unified API for MCMC, optimization, and running statistics, for construction of hybrid algorithms.

We have implemented FunMC as a multi-backend library of Python functions, executable with TensorFlow

(Abadi et al., 2015) and JAX. It should be relatively easy to port to any other Python machine learning framework. The code is available at https://github.com/tensorflow/probability/tree/master/discussion/fun_mcmc.

The contributions of this work are:

  1. We believe FunMC is the first framework for Markovian algorithms whose components are small enough and simple enough to be reused often, while also composing smoothly enough to build sophisticated algorithms out of.

  2. We support this claim by demonstrating how to use FunMC for MCMC sampling using canned (Section 3) or custom (Section 4) methods, including thinning (Section 5); how to reparameterize the target (Section 6); how to do optimization (Section 7), including adapting parameters like step sizes (Figure 4); and how to compute streaming statistics (Section 8); all as independent composable components.

2. Related work

The FunMC library was designed in response to a number of pre-existing MCMC and optimization frameworks. Pyro (Bingham et al., 2018) and PyMC3 (Salvatier et al., 2016)

separate the choice of MCMC transition kernel and the outer sampling loop. The optimization packages of PyTorch

(Steiner et al., 2019)

and TensorFlow make a similar decomposition between the loss on one side and the optimization algorithms or outer training loop on the other. All define a large number of standard MCMC and optimization transition kernels with arguments to specify the hyperparameters and how to adapt them. In the case of Pyro and PyMC3, the choice of the adaptation algorithms is limited, and experimentation with them would require writing new transition kernels with little or no reuse of existing code. In the case of PyTorch and TensorFlow, there is support for flexible learning rate adaptation, but other hyperparameters do not have the same support.

The TensorFlow Probability MCMC (

tfp.mcmc) library (Lao et al., 2020) makes the same factorization of outer sampling loop and transition kernel, but uses a flexible transition kernel DSL, where hyperparameter adaptation is effected by nesting transition kernels. For example, to adapt the step size for Hamiltonian Monte Carlo (Neal, 2011), one can wrap a step size adaptation kernel around an HMC kernel. A shortcoming of this approach is that it requires complex message-passing logic between the kernels, forces a DAG structure on the compound transition kernel computation, and requires learning a new DSL. FunMC does not distinguish between outer loops and transition kernels, and uses a more decoupled approach for hyperparameter adaptation, described below.

Modern optimization packages and tfp.mcmc support definining losses and target densities as simple callables, using automatic differentiation to compute any necessary gradients. These callables return the quantity of interest, (loss or log density), but have limited to no support of returning side information (e.g., intermediate neural net layer activations). PyMC3 has a concept of deterministic variables which enables this side channel, but it requires a DSL for defining probabilistic models. FunMC retains the flexibility of accepting a simple callable, and makes a simple modification to its calling convention to support returning side information as per PyMC3.

Pyro and PyMC3 implement reparameterization by taking transport maps as arguments to their outer sampling loops. TensorFlow Probability implements this via its transition kernel DSL, where a reparameterization kernel alters the space on which the wrapped kernels operate on. All of thse approaches limit the opportunity to adapt the parameters of the transport maps as part of the sampling procedure. FunMC takes reparameterization completely out of the transition kernels, using a function transformation instead.

Accelerators based on SIMD operations are at the forefront of high-performance computing in machine learning today. The tfp.mcmc library has pioneered using this capability to run multiple independent chains. Relatedly, NumPyro (Phan et al., 2019) achieved a similar capability using function transformations. FunMC takes a hybrid approach, mixing both explicit and automatic batching to run independent operations in parallel.

3. A running example: generating samples using MCMC

1def target(w): 
2  logits = input_features @ w
3  log_prob = normal_log_prob(w).sum(-1)
4  log_prob += bernoulli_log_prob(outputs, logits[..., 0]).sum(-1)
5  return log_prob, logits
6
7def kernel(hmc_state, key): 
8  hmc_key, key = jax.random.split(key)
9  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo(hmc_state,
10    target, step_size, num_integrator_steps, seed=hmc_key)
11  w, logits = hmc_state.state, hmc_state.state_extra
12  return (hmc_state, key), (w, logits, hmc_extra.is_accepted)
13
14(fin_hmc_state, _), (w_chain, logits_chain, is_accepted_chain) = fun_mc.trace(  
15  (fun_mc.hamiltonian_monte_carlo_init(w_init, target), jax.random.PRNGKey(0)),
16  kernel, num_steps)
Figure 1.

Example: HMC on a Bayesian Logistic Regression Model. Note the handling of side information:

logits are returned from the target density, combined with the M-H acceptance bit from HMC, and their history returned to the user.

As a running example of the FunMC API, consider using Hamiltonian Monte Carlo (HMC) to sample from the posterior of a simple Bayesian logistic regression model (Figure 1). We first define the target density, which is a function with type , on Figure 1. It is a function from the value of the random variable to two return values. The first is the un-normalized log density evaluated at . The second is arbitrary, and can be used to return auxiliary information for debugging and other purposes. In this example we return logits, which is a function of .

The HMC transition kernel is a function of type , on Figure 1. It conforms to the concept of a Markov transition kernel by taking and returning the state of the HMC sampler, but atypically it also has a second output. This second output is used for returning auxiliary information that is not a part of the state of the Markov chain. The kernel is a wrapper around the fun_mc.hamiltonian_monte_carlo transition kernel provided by FunMC with type . The wrapper’s purpose is to partially apply fun_mc.hamiltonian_monte_carlo, and to manipulate the side information.

Finally, we iterate our transition kernel on Figure 1 using the fun_mc.trace transition kernel, which has type . This kernel iterates the function passed as its second argument and stacks its corresponding auxiliary outputs. We defined kernel to return the value of , , and whether the state proposed by HMC was accepted, which last is available from the auxiliary return value of fun_mc.hamiltonian_monte_carlo. The entire history of this side information is thus available for inspection after the completed HMC run.

4. Custom MCMC

FunMC provides the pre-canned fun_mc.hamiltonian_monte_carlo kernel for the sake of convenience. Following the philosphy of composition over configuration, however, that kernel just consists (Appendix A, Figure 6) of sampling the auxiliary momentum, running the integrator, and then performing an accept-reject step, all of which are public functions of FunMC. This makes it straightforward for the user to write their own version if necessary, with maximal code reuse.

5. fun_mc.trace transition kernel

The primary usage of the fun_mc.trace operator is as shown in Figure 1—computing a (top-level) trace of Markov chain states. However, the type of fun_mc.trace intentionally admits partial application to produce a transition kernel. This can be used, for example, for thinning, as in Figure 2. For this purpose, fun_mc.trace accepts a trace_mask argument, which is prefix-tree-zipped with the auxiliary return. The boolean leaves in the mask signify whether the corresponding sub-trees in the auxiliary return get traced or not. A scalar False matches the entire tree, causing it to be propagated without history, which is what we want for thinning.

The concept of the kernel returning auxiliary information shows its value here, by separating the true state of the Markov chain from information derived from it, but not propagated. Without framework provision for side information, this effect would require either statefully mutated side variables, or augmenting the Markov chain state with the additional quantities that are simply discarded before each transition. Such a scheme would conceal the Markovian structure as well as force the user to come up with some initial state for these auxiliary quntities.

1_, (w_chain, logits_chain, is_accepted_chain) = fun_mc.trace(
2  (fun_mc.hamiltonian_monte_carlo_init(w_init, target), jax.random.PRNGKey(0)),
3  lambda *state: fun_mc.trace(state, kernel, num_substeps, trace_mask=False),
4  num_steps // num_substeps)
Figure 2. Thinning an MCMC chain. The difference from Figure 1 is wrapping kernel in another fun_mc.trace without collecting the intermediate states.

6. Reparameterization of potential functions

HMC is well known to be sensitive to the geometry of the target density. One way to ameliorate this issue is through reparameterization (Papaspiliopoulos et al., 2007; Parno & Marzouk, 2014). In FunMC, this is effected via function composition, taking care to remap arguments appropriately using the fun_mc.reparam_potential_fn utility function. It has type

1reparam_potential_fn :: (State -> (Tensor, Extra_1),
2                         ReparamState -> (State, Extra_2),
3                         State) ->
4    (ReparamState -> (Tensor, (State, Extra_1, Extra_2)), ReparamState).

It accepts a density function (operating in some “original” space), a diffeomorphism from the reparameterized space to the original space, and a point in the original space. It returns the corresponding density function in the reparameterized space and the corresponding point in the reparameterized space. This can be used to initialize an inference algorithm that then operates in the reparameterized space.

While reparameterization can be advantageous for running the Markov chain, the user may be interested in inspecting the chain’s states in the original parameterization. FunMC’s pervasive support for side returns makes this easy to arrange: the reparameterized density function just returns the point in the original space on the side. The fun_mc.hamiltonian_monte_carlo kernel propagates it, and the kernel we wrote in Figure 3 extracts it on Figure 3 and exposes just the original parameterization to tracing by fun_mc.trace (Figure 3).

The diffeomorphism must be an invertible function with tractable Jacobian log-determinant. FunMC relies on its backend to compute the inverse and the log-determinant of the Jacobian. One practical way to do this is to code the diffeomorphism using the tfp.bijectors library (Dillon et al., 2017) (FunMC knows the inversion and Jacobian computation API defined thereby). The automatic inversion mechanism being developed for JAX (Vikram et al., 2020) looks like a promising future alternative.

In summary, to reparameterize the target in our running example from Figure 1, we replace the kernel definition (Figure 1) with Figure 3. Since fun_mc.reparameterize_potential_fn is just a function, the user can call it inside their kernel, using a different diffeomorphism at each point in the chain. This supports adapting or inferring the diffeomorphism’s parameters.

1reparam_potential_fn, reparam_w_init = fun_mc.reparameterize_potential_fn(
2    target, diffeomorphism, w_init)
3
4def kernel(hmc_state, key):
5  hmc_key, key = jax.random.split(key)
6  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo(hmc_state,
7    reparam_potential_fn, step_size, num_integrator_steps, seed=hmc_key)
8  w, logits = hmc_state.state_extra[:2] 
9  return (hmc_state, key), (w, logits, hmc_extra.is_accepted) 
10
11_, (w_chain, logits_chain, is_accepted_chain) = fun_mc.trace(
12    (fun_mc.hamiltonian_monte_carlo_init(reparam_w_init, reparam_potential_fn),
13     jax.random.PRNGKey(0)), kernel, num_steps)
Figure 3. Reparameterization is just a matter of wrapping the original target. Values in the original space can be propagated via the side returns instead of being lost or having to be recomputed.

7. Optimization

1def kernel(hmc_state, log_step_size_state, key):
2  hmc_key, key = jax.random.split(key)
3  step_size = np.exp(log_step_size_state.state)
4  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo(hmc_state,
5    target, step_size, num_integrator_steps, seed=hmc_key)
6  p_accept = np.exp(np.minimum(0., hmc_extra.log_accept_ratio))
7  loss_fn = fun_mc.make_surrogate_loss_fn(lambda _: (0.8 - p_accept, ())) 
8  log_step_size_state, _ = fun_mc.adam_step(log_step_size_state, loss_fn,
9    learning_rate=1e-2)
10  w = hmc_state.state; logits = hmc_state.state_extra
11  return (hmc_state, log_step_size_state, key), (w, logits, hmc_extra.is_accepted)
Figure 4.

HMC with step size adaptation, reusing the FunMC Adam optimizer. Note the surrogate loss function trick on

Figure 4.

Optimization can often be cast as a Markov chain, so FunMC provides a number of transition kernels that implement common optimization algorithms such as gradient descent and Adam. Besides their standard uses, these kernels are reusable for hyper-parameter adaptation in MCMC. For example, it is conventional to tune the step size parameter of HMC to hit an acceptance rate between 0.6 and 0.8 (Betancourt et al., 2014). Given , the acceptance rate as step , the statistic can be used as the gradient of a surrogate loss function of the step size. The step size is then adapted during the run using some gradient-based optimization scheme such as Nesterov dual averaging (Hoffman & Gelman, 2011), a choice usually hardcoded into a MCMC library.

The composability of FunMC’s components lets us plug in other methods instead. For example, to use Adam, we just replace the kernel definition in Figure 1, Figure 1 with Figure 4. Since Adam expects a differentiable loss function rather than a gradient, we need a little jiu jitsu on Figure 4, Figure 4, in the form of fun_mc.make_surrogate_loss_fn.

8. Streaming statistics

In many cases, MCMC is used to generate samples that are then summarized and the samples discarded. Materializing all the samples is an inefficient use of memory, so FunMC provides a number of streaming statistics to compute simple and exponentially weighted moving means, variances and covariances. As FunMC is explicitly batch-aware, it can compute independent statistics for each chain, but also perform pre-aggregation across the chains.

One perennial defect of MCMC chains is poor mixing. Diagnostics to detect that have been proposed, but standard implementations of such diagnostics require access to the original chain history, which negates the benefit of streaming statistics. FunMC provides a streaming version of the Gelman potential scale reduction (Gelman & Rubin, 1992)

, a simple wrapper around FunMC’s streaming variance estimator. A streaming estimator for the auto-covariance function is also provided which can be used to compute the effective sample size.

We can code a chain tracking a streaming mean, covariance, and by replacing the kernel definition Figure 1, Figure 1 with Figure 5. Note that requires multiple chains. In this example, we get them by adding a leading dimension to the chain state which indexes independent chains. Following the convention established by the tfp.mcmc library, FunMC interprets this encoding by running the HMC leapfrog and Metropolis-Hastings steps on all the chains in parallel.

1def kernel(hmc_state, cov_state, rhat_state, key):
2  hmc_key, key = jax.random.split(key)
3  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo(hmc_state,
4    target, step_size, num_integrator_steps, seed=hmc_key)
5  w, logits = hmc_state.state, hmc_state.state_extra
6  cov_state, _ = fun_mc.running_covariance_step(cov_state, (w, logits), axis=0)
7  rhat_state, _ = fun_mc.potential_scale_reduction_step(rhat_state, w)
8  return (hmc_state, cov_state, rhat_state, key), ()
9
10(_, fin_cov_state, fin_mean_accept_state, fin_rhat_state, _), _ = fun_mc.trace(
11  (fun_mc.hamiltonian_monte_carlo_init(w_init, target),
12   fun_mc.running_covariance_init((w_init.shape[-1:], y.shape), (np.float32,) * 2,
13   fun_mc.potential_scale_reduction_init(w_init.shape, np.float32),
14   jax.random.PRNGKey(0)), kernel, num_steps)
15w_cov, logits_cov = fin_cov_state.covariance
16r_hat = fun_mc.potential_scale_reduction_extract(fin_rhat_state)
Figure 5. HMC with streaming mean, covariance, and potential scale reduction () estimation.

9. Discussion

We have presented FunMC and illustrated multiple ways its components can be combined to create custom MCMC and optimization algorithms. This composition is a consequence of the core principles of statelessness, returning side information from functions, and writing higher-order functions to propagate this side information. We hope this library will accelerate methodological research, either through direct use or as an example of how to structure a composable API for Markovian algorithms.

Acknowledgements.
Authors would like to thank Srinivas Vasudevan, Matthew D. Hoffman, Jacob Burnim, Rif A. Saurous, Dave Moore and the rest of the TensorFlow Probability team for helpful comments.

References

Appendix A FunMC’s Hamiltonian Monte Carlo

Figure 6 lists the pre-canned Hamiltonian Monte Carlo in FunMC. This is a simple composition of reusable components, which a user can recombine to implement many HMC variants.

1def hamiltonian_monte_carlo(hmc_state, target_log_prob_fn, step_size,
2    num_integrator_steps, seed):
3  # Define the sub-transition kernels.
4  kinetic_energy_fn = fun_mc.make_gaussian_kinetic_energy_fn(
5    len(hmc_state.target_log_prob.shape))
6
7  momentum_sample_fn = lambda key: fun_mc.gaussian_momentum_sample(
8    state=hmc_state.state, seed=key)
9
10  integrator_step_fn = lambda state: fun_mc.leapfrog_step(state, step_size,
11    target_log_prob_fn, kinetic_energy_fn)
12
13  integrator_fn = lambda state: fun_mc.hamiltonian_integrator(state,
14    num_integrator_steps, integrator_step_fn, kinetic_energy_fn)
15
16  # Run the integration.
17  mh_key, sample_key = jax.random.split(key, 2)
18  momentum = momentum_sample_fn(sample_key)
19
20  integrator_state = fun_mc.IntegratorState(hmc_state.state,
21    hmc_state.state_extra, hmc_state.state_grads, hmc_state.target_log_prob,
22    momentum)
23
24  integrator_state, integrator_extra = integrator_fn(integrator_state)
25
26  # Do the MH accept-reject step.
27  proposed_state = fun_mc.HamiltonianMonteCarloState(
28    state=integrator_state.state,
29    state_grads=integrator_state.state_grads,
30    target_log_prob=integrator_state.target_log_prob,
31    state_extra=integrator_state.state_extra)
32
33  hmc_state, mh_extra = fun_mc.metropolis_hastings_step(hmc_state,
34    proposed_state, integrator_extra.energy_change, seed=mh_key)
35
36  return hmc_state, fun_mc.HamiltonianMonteCarloExtra(
37      is_accepted=mh_extra.is_accepted,
38      proposed_hmc_state=proposed_state,
39      log_accept_ratio=-integrator_extra.energy_change,
40      integrator_state=integrator_state, integrator_extra=integrator_extra,
41      initial_momentum=momentum)
Figure 6. Default Hamiltonian Monte Carlo in FunMC. Side quantities of interest are returned in the HamiltonianMonteCarloExtra structure for possible accumulation and/or inspection by the user.