Meta-Learning for Stochastic Gradient MCMC

06/12/2018 ∙ by Wenbo Gong, et al. ∙ University of Cambridge 0

Stochastic gradient Markov chain Monte Carlo (SG-MCMC) has become increasingly popular for simulating posterior samples in large-scale Bayesian modeling. However, existing SG-MCMC schemes are not tailored to any specific probabilistic model, even a simple modification of the underlying dynamical system requires significant physical intuition. This paper presents the first meta-learning algorithm that allows automated design for the underlying continuous dynamics of an SG-MCMC sampler. The learned sampler generalizes Hamiltonian dynamics with state-dependent drift and diffusion, enabling fast traversal and efficient exploration of neural network energy landscapes. Experiments validate the proposed approach on both Bayesian fully connected neural network and Bayesian recurrent neural network tasks, showing that the learned sampler out-performs generic, hand-designed SG-MCMC algorithms, and generalizes to different datasets and larger architectures.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

page 17

Code Repositories

MetaSGMCMC

The code for Meta Learning for SGMCMC


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

There is a resurgence of research interests in Bayesian deep learning

(graves2011practical; blundell2015weight; hernandez2015probabilistic; hernandez2016black; gal2016dropout; ritter2018a)

, which applies Bayesian inference to neural networks for better uncertainty estimation that is crucial for e.g. better exploration in reinforcement learning

deisenroth2011pilco; depeweg2017learning, resisting adversarial attacks (feinman2017detecting; li2017dropout; louizos2017multiplicative) and continual learning (nguyen2018variational). A popular approach to performing Bayesian inference on neural networks is stochastic gradient Markov chain Monte Carlo (SG-MCMC), which adds properly scaled Gaussian noise to a stochastic gradient ascent procedure (welling2011bayesian). Recent advances in this area further introduced optimization techniques such as pre-conditioning (ahn2012bayesian; patterson2013stochastic), annealing (ding2014bayesian) and adaptive learning rates (li2016preconditioned; chen2016bridging)

. All these efforts have made SG-MCMC highly scalable to many deep learning tasks, including shape and texture modeling in computer vision

(li2016cvpr) and language modeling with recurrent neural networks (gan2017scalable). However, inventing novel dynamics for SG-MCMC requires significant mathematical work to ensure the stationary distribution is the target distribution, which is less friendly to practitioners. Furthermore, many of these algorithms are designed as a generic sampling procedure, and the associated physical mechanism might not be best suited for sampling neural network weights.

Can we automate the design of SG-MCMC algorithms which are tailored to the problem of sampling from certain types of densities, e.g. Bayesian neural network posterior distributions? This paper aims to answer this question by introducing meta-learning, or learning to learn techniques (schmidhuber1987evolutionary; bengio1992optimization; naik1992meta; thrun1998learning). The scope of meta-learning research is very broad, but the general idea is to train a learner on one or multiple tasks in order to acquire common knowledge that generalizes to future tasks. Recent applications of meta-learning include learning to transfer knowledge to unseen few-shot learning tasks (santoro2016meta; ravi2017fewshot; finn2017model), and learning algorithms such as gradient descent (andrychowicz2016learning; li2016learning; wichrowska2017learned), Bayesian optimization (chen2017learning) and reinforcement learning (duan2016rl; wang2016learning). Unfortunately these recent advances cannot be directly transfered to the world of MCMC samplers, since a naive neural network parameterization of the transition kernel does not guarantee the posterior distribution as a stationary distribution.

We present to the best of our knowledge the first attempt towards meta-learning an SG-MCMC algorithm. Concretely, our contribution include:

  • An SG-MCMC sampler that extends Hamiltonian dynamics with learnable diffusion and curl matrices. Once trained, the sampler can generalize to different datasets and architectures.

  • Extensive evaluation of the proposed sampler on Bayesian fully connected neural networks and Bayesian recurrent neural networks, with comparisons to popular SG-MCMC schemes based on e.g. Hamiltonian Monte Carlo (chen2014stochastic) and pre-conditioned Langevin dynamics (li2016preconditioned).

2 Background: a complete framework for SG-MCMC

Consider sampling from a target density that is defined by an energy function: , , . In this paper we focus on this sampling task in a Bayesian modeling set-up, i.e. given observed data , we define a probabilistic model , and then the target density is the posterior distribution . Using Bayesian neural networks as an illustrating example, here , and the model typically uses a Gaussian prior , and the energy function is defined as

(1)

with usually defined as the loss for regression or the cross-entropy loss for classification. A typical MCMC sampler constructs a Markov chain with a transition kernel, and corrects the proposed samples with Metropolis-Hastings (MH) rejection steps. Some of these methods, e.g. Hamiltonian Monte Carlo (HMC) (duane1987hybrid; neal2011mcmc) and slice sampling (neal2003slice), further augment the state space with auxiliary variables and sample from the augmented distribution , where and the corresponding Hamiltonian is such that . Thus, marginalizing the auxiliary variable will not affect the stationary distribution .

For deep learning tasks, the observed dataset often contains thousands, if not millions, of instances, making MH rejection steps computationally prohibitive. Fortunately this is mitigated by SG-MCMC, whose transition kernel is implicitly defined by a stochastic differential equation (SDEs) that leaves the target density invariant (welling2011bayesian; ahn2012bayesian; patterson2013stochastic; chen2014stochastic; ding2014bayesian). With carefully selected discretization step-size (like learning rates in optimization) the MH rejection steps are usually dropped. Also simulating one step of SG-MCMC only requires evaluating the gradient on a small mini-batch of data, which exhibits the same cost as many stochastic optimization algorithms. These two distinctive features make SG-MCMC highly scalable for sampling posterior distributions of neural network weights conditioned on big datasets.

Generally speaking, the continuous-time dynamics of an SG-MCMC method is governed by the following SDE (and the corresponding Markov process is called Itô diffusion):

(2)

with the deterministic drift, the Wiener process, and the diffusion matrix. ma2015complete derived an extensive framework of SG-MCMC samplers using advanced statistical mechanics (yin2006existence; shi2012relation), which directly parameterizes the drift term with the target density:

(3)

with the curl matrix and a correction term. Remarkably ma2015complete showed the completeness of their framework:

  • is a stationary distribution of the SDE (2) for any pair of positive semi-definite matrix

    and skew-symmetric matrix

    ;

  • for any Itô diffusion process that has the unique stationary distribution , under mild conditions there exist and matrices such that the process is governed by (2).

As a consequence, the construction of an SG-MCMC algorithm reduces to defining its and matrices. Indeed ma2015complete also casted existing SG-MCMC samplers within the framework, and proposed an improved version of SG-Riemannian-HMC. In general, an appropriate design of these two matrices leads to significant improvement on mixing as well as reduction of sample bias (li2016preconditioned; ma2015complete). However, historically this design has been based on strong physical intuitions from e.g. Hamiltonian mechanics (duane1987hybrid; neal2011mcmc) and thermodynamics (ding2014bayesian)

. Therefore it can still be difficult for practitioners to understand and engineer the sampling method that best suited to their machine learning tasks.

3 Meta-learning for SG-MCMC

This section presents a meta-learning approach to learn an SG-MCMC sampler from data. Our aim is to design an appropriate parameterization of and , so that the sampler can be trained on small tasks with a meta-learning procedure, and generalize to more complicated densities in high dimensions. For simplicity, in this paper, we only augment the state-space by introducing one auxiliary variable called momentum (duane1987hybrid; neal2011mcmc), although generalization to thermostat variable augmentation (ding2014bayesian) is fairly straight-forward. Thus, the augmented state-space is (i.e. ), and the Hamiltonian is defined as (i.e. with identity mass matrix for the momentum).

3.1 Efficient parameterization of diffusion and curl matrices

For neural networks, the dimensionality of can be at least tens of thousands. Thus, training and applying full and matrices can cause huge computational burden, let alone gradient computations required by . To address this, we define the preconditioning matrix as follows:

(4)

where and are neural network parameterized functions that will be detailed in section 3.2, and is a small positive constant. We choose and to be diagonal for fast computation, although future work can explore low-rank matrix solutions. From ma2015complete, our design has the unique stationary distribution if is non-negative for all .

We discuss the role of each precondition matrix for better intuition. The curl matrix in (2) mainly controls the deterministic drift forces introduced by the energy gradient (as seen in many HMC-like procedures). Usually we only have the access to stochastic gradient through data sub-sampling, so an additional friction term is needed to counter for the associated noise that mainly affects the dynamics of the momentum . This explains the design of the diffusion matrix that uses to control the amount of friction and injected noise to the momentum. Furthermore should also account for the pre-conditioning effect introduced by , e.g, when the magnitude of is large, without MH step we can only prevent divergence by increasing momentum friction. This explains the squared term in design. The scaling positive constant

is heuristically selected following

(chen2014stochastic; ma2015complete) (see appendix). Finally the extra term is responsible for compensating the changes introduced by preconditioning matrices and .

The discretized dynamics of the state with step-size and stochastic gradient is

(5)

Again we notice that is responsible for the acceleration of , and from the term in the update equation of , we see that

controls the friction introduced to the momentum. Observing that the noisy gradient is approximately Gaussian distributed in the big data setting,

ma2015complete further suggested a correction scheme to counter for stochastic gradient noise, which samples the Gaussian noise

with an empirical estimate of the gradient variance

instead. These corrections can be dropped when the discretization step-size is small, therefore we do not consider them in our experiments.

3.2 Choices of inputs to the neural networks

We now present detailed functional forms for and . When designing these, our goal was to achieve a good balance between generalization power and computational efficiency. Recall that the curl matrix mainly controls the drift of the dynamics, and the desired behavior is that it should produce accelerations for fast traverse through low density regions. One useful source of information to identify these regions is the energy function which can be used to determine if the particles have reached high density regions.111The energy gradient is also informative here, however, it requires computing the diagonal Hessian for which is costly for high dimension problems. For similar reasons we do not consider (diagonal) Fisher information matrix or Hessian as an input of . We also include the momentum to the inputs of , allowing the matrix to observe the velocity information of the particles. We further add an offset to to prevent the vanishing of this matrix. Putting all of them together, we define the element of as

(6)

The corresponding term requires both and . The energy gradient also appears in (5) so it remains to compute , which, along with , can be obtained by automatic differentiation (abadi2016tensorflow).

Matrix is responsible for the friction term and the stochastic gradient noise, which are crucial for better exploration around high density regions. Therefore we also add energy gradient to the inputs, meaning that the element of is

(7)

By construction of the matrix, the vector only requires , therefore the Hessian of the energy function is not required.

In practice both and are replaced by their stochastic estimates and , respectively. To keep the scale of the inputs roughly the same across tasks, we rescale all the inputs using statistics computed by simulating the sampler with random initialized and . When the computational budget is limited, we replace the exact gradient computation required by with finite difference approximations. We refer the reader to the appendix for details.

3.3 Loss function design for meta-learning

Another challenge is to design a meta-learning procedure for the sampler to encourage faster convergence and low bias on test tasks. To achieve these goals we propose two loss functions that we named as

cross-chain loss and in-chain loss. From now on we consider the discretized dynamics and define

as the marginal distribution of the random variable

at time .

Cross-chain loss

We introduce cross-chain loss that encourages the sampler to exhibit fast convergence. Since the framework guarantees the sampler to have the target density as the unique stationary distribution, fast convergence means that is close to zero when is small. Therefore this KL-divergence becomes a sensible objective to minimize, which is equivalent to maximizing the variational lower-bound (or ELBO): (jordan1999introduction; beal2003variational). We further make the objective doubly stochastic: (1) the energy term is further approximated by its stochastic estimates ; (2) we use Monte Carlo variational inference (MCVI) (ranganath2014black; blundell2015weight) which estimates the lower-bound with samples . These particles are obtained by simulating parallel Markov chains with the sampler, and the cross-chain loss is defined by accumulating the lower-bounds through time:

(8)

By minimizing this objective, we can improve the convergence of the sampler, especially at the early times of the Markov chain. The objective also takes the sampler bias into account because the two distributions will match when the KL-divergence is minimized.

In-chain loss

For very big neural networks, simulating multiple Markov chains is prohibitively expensive. The issue is mitigated by thinning, which collects samples for every step (after burn-in). Effectively, after thinning the samples are drawn from the averaged distribution . The in-chain loss is therefore defined as the ELBO evaluated at the averaged distribution , which is then approximated by Monte Carlo with samples obtained by thinning:

(9)
Gradient approximation

Gradient-based optimization requires the gradient for cross-chain loss and for in-chain loss. Since the density of is intractable, we leverage the recently proposed Stein gradient estimator (li2017stein)

for gradient approximation. Precisely, by the chain rule we have

. Denote as the matrix collecting the gradients of at the sampled locations . The recipe first constructs a kernel matrix with , then compute an estimate of the matrix by , where . In our experiments, we use RBF kernels, and the corresponding gradient estimator has simple analytic form that can be computed efficiently in time (usually ).

4 Related work

Since the development of stochastic gradient Langevin dynamics (SGLD, welling2011bayesian), SG-MCMC has been increasingly popular for sampling from posterior densities of big models conditioned on big data. In detail, chen2014stochastic scaled up HMC with stochastic gradients, ding2014bayesian further augmented the state space with a temperature auxiliary variable and performed sampling in the joint space and SpringenbergKFH16 improved robustness through scale adaptation. The SG-MCMC extensions to Riemannian Langevin dynamics and HMC (girolami2011riemann) have also been proposed (patterson2013stochastic; ma2015complete). Our proposed sampler architecture further generalizes SG-Riemannian-HMC as it decouples the design of and matrices, and the detailed functional form of these two matrices are also learned from data.

Our approach is closely related to the recent line of work on learning optimization algorithms. Specifically, andrychowicz2016learning

trained a recurrent neural network (RNN) based optimizer that transfers to similar tasks with supervised learning. Later

chen2017learning generalized this approach to Bayesian optimization (brochu2010tutorial; snoek2012practical) which is gradient-free. We do not use RNN in our approach as it cannot be represented within the framework of ma2015complete. We leave the combination of learnable RNN proposals to future work. Also li2017stein presented an initial attempt to meta-learn an approximate inference algorithm, which simply combined the stochastic gradient and the Gaussian noise with a neural network. Thus the stationary distribution of that sampler (if it exists) is only an approximation to the exact posterior. On the other hand, the proposed sampler (with ) is guaranteed to be correct by complete framework ma2015complete. Very recently wu2018understanding discussed short-horizon meta-objectives for learning optimizers can cause a serious issue for long-time generalization. We found this bias is less severe in our approach, again due to the fact that the learned sampler is provably correct.

Recent research also considered improving HMC with a trainable transition kernel. salimans2015markov improved upon vanilla HMC by introducing a trainable re-sampling distribution for the momentum. song2017nice parameterized the HMC transition kernel with a trainable invertible transformation called non-linear independent components estimation (NICE) (dinh2014nice), and learned this operator with Wasserstein adversarial training (arjovsky2017wasserstein). levy2017generalizing generalized HMC by augmenting the state space with a binary direction variable, and they parameterized the transition kernel with a non-volume preserving invertible transformation that is inspired by real-valued non-volume preserving (RealNVP) flows (dinh2016density). The sampler is then trained with expected squared jump distance (pasarica2010adaptively). We do not explore the adversarial training idea in this paper as for very high dimensional distributions these techniques become less reliable. Also the jump distance does not explicitly take the sampling bias and convergence speed into account. More importantly, the purpose of these approaches is to directly improve the HMC-like sampler on the target distribution, and with NICE/RealNVP parametrization it is difficult to generalize the sampler to densities of different dimensions. In contrast, our goal is to learn an SG-MCMC sampler that can later be transferred to sample from different Bayesian neural network posterior distributions, which will typically have different dimensionality and include tens of thousands of random variables.

5 Experiments

We evaluate the meta-learned SG-MCMC sampler, which is referred to as NNSGHMC or the meta sampler in the following. Detailed test set-ups are reported in the appendix. Code is available at https://github.com/WenboGong/MetaSGMCMC.

5.1 Synthetic example: sampling Gaussian random variables with noisy gradients

We first consider sampling Gaussian variables to demonstrate fast convergence and low bias of the meta sampler. To mimic stochastic gradient settings, we manually inject Gaussian noise with unit variance to the gradient as suggested by (chen2014stochastic). The training density is a 10D Gaussian with randomly generated diagonal covariance, and the test density is a 20D Gaussian. For evaluation, we simulate parallel chains for steps. Then we follow ma2015complete to evaluate the sampler’s bias is measured by the KL divergence from the empirical Gaussian estimate to the ground truth. Results are visualized on the left panel of Figure 1, showing that the meta sampler both converges much faster and achieves lower bias compared to SGHMC. The effective sample size222Implementation follows the ESS function in the BEAST package http://beast.community. for SGHMC and NNSGHMC are 22 and 59, respectively, again indicating better efficiency of the meta sampler. For illustration purposes, we also plot in the other two panels the trajectory of a particle by simulating NNSGHMC (middle) and SGHMC (right) on a 2D Gaussian for fixed amount of time . This confirms that the meta sampler explores more efficiently and is less affected by the injected noise.

Figure 1: (Left) Sampler’s bias measured by KL. (Middle) NNSGHMC trajectory plot on a 2D-Gaussian with manually injected gradient noise. (Right) SGHMC plot for the same settings.

5.2 Bayesian feedforward neural network

Next we consider Bayesian neural network classification on MNIST data with three generalization tests: network architecture generalization (NT), activation function generalization (AF) and dataset generalization

(Data). In all tests the sampler is trained with a 1-hidden layer MLP (20 units, ReLU activation) as the underlying model for the target distribution

. We also report long-time horizon generalization results, meaning that the simulation time steps in test time is much longer than that of training (cf. andrychowicz2016learning). Algorithms in comparison include SGLD welling2011bayesian, SGHMC chen2014stochastic and preconditioned SGLD (PSGLD, li2016preconditioned)

. Note that PSGLD uses RMSprop-like preconditioning techniques

(tieleman2012rmsprop)

that requires moving average estimates of the gradient’s second moments. Therefore the underlying dynamics of PSGLD cannot be represented within our framework (

4). Thus we main focus on comparisons with SGLD and SGHMC, and leave the PSGLD results as reference. The discretization step-sizes for the samplers are tuned on the validation dataset for each task.

Architecture generalization (NT)

In this test we use the trained sampler to draw samples from the posterior distribution of a 2-hidden layer MLP with 40 units and ReLU activations. Figure 2

shows learning curves of test error and negative test log-likelihood (NLL) for 100 epochs, where the final performance is reported in Table

1. Overall NNSGHMC achieves the fastest convergence even when compared with PSGLD. It has the lowest test error compared to SGLD and SGHMC. NNSGHMC’s final test LL is on par with SGLD and slightly worse than PSGLD, but it is still better than SGHMC.

Figure 2: Learning curves on test error (top) and negative test LL (bottom).
Figure 3: (Left) The contour plot of function (Middle) The contour plot for for dimension 1 and 2 with fixed (Right) The same plot for for dimension 2 and 3 with fixed energy.
Sampler NT Err. AF Err Data Err NT NLL AF NLL Data NLL
NNSGHMC 98.36% 97.72% 98.62% 640 875 230
SGHMC 98.21% 97.72% 98.52% 705 929 246
SGLD 98.27% 97.62% 98.54% 631 905 232
PSGLD 98.31% 97.67% 98.60% 610 975 224
Table 1: The final performance for the samplers, averaged over 10 independent runs.
Activation function generalization (AF)

Next we replace the test network’s activation with sigmoid function and re-run the same test as before. Again results in Figure 2 and Table 1 show that NNSGHMC converges faster than others for both test error and NLL. It also achieves the best NLL results among all samplers, and the same test error as SGHMC.

Dataset generalization (Data)

In this test we split MNIST into training task

(classifying digits 0-4) and

test task (digits 5-9). We train the meta sampler on the training task using the small MLP as before, and evaluate the learned sampler on the test task with the larger MLP. From the plots, we see that NNSGHMC, although a bit slower at start, catches up quickly and proceeds to lower error. The difference between these samplers NLL results is marginal, and NNSGHMC is on par with PSGLD.

Learned strategies

For better intuition, we visualize in Figure 3 the contours of and as a function of their inputs. Recall that the curl matrix is determined by . From the left panel, has learned a nearly linear strategy w.r.t. the energy and small variations w.r.t. the momentum. This enables the sampler for fast traversal through low density (high energy) regions and better exploration at high density (low energy) area.

The strategy learned for the diffusion matrix is rather interesting. Recall that is parametrized by both and (Eq.4). Since Figure 3 (left) indicates that is large in high energy regions, the amount of friction is adequate, so tends to reduce its output to maintain momentum (see the middle plot). By contrast, at low energy regions increases to add friction in order to prevent divergence. The right panel visualizes the interactions between the momentum and the mean gradient at a fixed energy level. This indicates that the meta sampler has learned a strategy to prevent overshoot by producing large friction, indeed returns large values when the signs of the momentum and the gradient differ.

5.3 Bayesian recurrent neural networks

Lastly we consider a more challenging setup: sequence modeling with Bayesian RNNs. Here a single datum is a sequence and the log-likelihood is defined as

, with each of the conditional densities produced by a gated recurrent unit (GRU) network

(cho2014learning). We consider four polyphonic music datasets for this task: Piano-midi (Piano) as training data, and Nottingham (Nott), MuseData (Muse) and JSB chorales (JSB) for evaluation. The meta sampler is trained on a small GRU with 100 hidden states. In test time we follow chen2016bridging and set the step-size to . We found SGLD significantly under-performs, so instead we report the performances of two optimizers Adam (kingma2014adam) and Santa taken from chen2016bridging. Again these two optimizers use moving average schemes which is out of the scope of our framework, so we mainly compare the meta sampler with SGHMC and leave the others as references.

(a)
(b)
(c)
(d)

Method Piano Muse Nott JSB NNSGHMC 7.66 7.27 3.37 8.49 SGHMC 7.65 7.33 3.35 8.40 PSGLD 7.67 7.48 3.28 8.42 Santa 7.6 7.2 3.39 8.46 Adam 8 7.56 3.7 8.51

(e)
Figure 9: Test NLL learning curve (with zoom-in for sampling methods) and the best performance. Santa and Adam results are from chen2016bridging

The meta sampler is tested on the four datasets using 200 unit GRU. For Piano this corresponds to architecture generalization only, and from Figure 9 we see that the meta sampler achieves faster convergence compared to SGHMC and achieves similar speed as Santa at early stages. All the samplers achieve best results close to Santa on Piano. The meta sampler successfully generalizes to the other three datasets, demonstrating faster convergence than SGHMC consistently, and better final performance on Muse. Interestingly, the meta sampler’s final results on Nott and JSB are slightly worse than other samplers. Presumably these two datasets are very different from Muse and Piano, therefore the energy landscape is less similar to the training density (see appendix). Specifically JSB is a dataset with much shorter sequences, in fact SGHMC also exhibits over-fitting but less severe. Therefore, we further test the meta sampler on JSB without the offset in to reduce the acceleration (denoted as NNSGHMC-s). Surprisingly, NNSGHMC-s convergences in similar speeds as the original one, but with less amount of over-fitting and better final test NLL 8.40.

6 Conclusions and future work

We have presented a meta-learning algorithm that can learn an SG-MCMC sampler on simpler tasks and generalizes to more complicated densities in high dimensions. Experiments on both Bayesian MLPs and Bayesian RNNs confirmed strong generalization of the trained sampler to long-time horizon as well as across datasets and network architecture. Future work will focus on better designs for both the sampler and the meta-learning procedure. For the former, temperature variable augmentation as well as moving average estimation will be explored. For the latter, better loss functions will be proposed for faster training, e.g. by reducing the unrolling steps of the sampler during training. Finally, the automated design of generic MCMC algorithms that might not be derived from continuous Markov processes remains an open challenge.

Acknowledgements

We thank Shixiang Gu, Mark Rowland and Cheng Zhang for comments on the manuscript. WG is supported by the CSC-Cambridge Trust Scholarship.

References

Appendix A Finite difference approximation for the Gamma vector

The main computational burden is the gradient computation required by vector. From the parametrization of and matrix in (4), for we have . For the first term , we have

(10)

Due to the two-stage update of Euler integrator, at time t, we have , and . Thus a proper finite difference method requires , which is not exactly the history from the previous time. Therefore we further approximate it using delayed estimate:

(11)

Similarly, the term expands as

(12)

We further approximate by the following

(13)

This only requires the storage of previous matrix. However, requires one further forward pass to obtain , thus, we have

(14)

Therefore the proposed finite difference method only requires one more forward passes to compute and instead, save 3 back-propagations. As back-propagation is typically more expensive than forward pass, our approach reduces running time drastically, especially when the sampler are applied to large neural network.

Time complexity figures

Every SG-MCMC method (including the meta sampler) requires . The main burden is the forward pass and back-propagation through the and matrix, where the latter one has been replaced by the proposed finite difference scheme. The time complexity is for both forward pass and finite difference with the number of hidden units in the neural network of the meta sampler. Parallel computation with GPUs improves real-time speed, indeed in our MNIST experiment the meta sampler spends roughly 1.5x time when compared with SGHMC.

During meta sampler training, the Stein gradient estimator requires the kernel matrix inversion which is for cross-chain training. In practice, we only run a few parallel Markov chains , thus, this will not incur huge computation cost. For in-chain loss the computation can also be reduced with proper thinning schemes.

Appendix B Training details

We visualize on the left panel of Figure 12 the unrolled computation scheme. We apply truncated back-propagate through time (BPTT) to train the sampler. Specifically, we manually stop the gradient flow through the input of and matrix to avoid computing higher order gradients.

(a)
(b)
Figure 12: (Left) The unrolled scheme of the meta sampler updates. Stop gradient operations are applied to the dashed arrows. (Right) A visualization of cross-chain in-chain training. The grey area represents samples across multiple chains, and we compute the cross chain loss for every 5 time steps. The purple area indicates the samples taken across time with sub-sampled chains 1 and 3. In this visualization the initial 15 samples are discarded for burn-in, and the thinning length is (effectively no thinning).

We also illustrate cross-chain in-chain training on the right panel of Figure 12. Cross-chain training encourages both fast convergence and low bias, provided that the samples are taken from parallel chains. On the other hand, in-chain training encourages sample diversity inside a chain. In practice, we might consider thinning the chains when performing in-chain training. Empirically this improves the Stein gradient estimator’s accuracy as the samples are spread out. Computationally, this also prevents inverting big matrices for the Stein gradient estimator, and reduces the number of back-propagation operations. Another trick we applied is parallel chain sub-sampling: if all the chains are used, then there is less encouragement of singe chain mixing, since the parallel chain samples can be diverse enough already to give reasonable gradient estimate.

Appendix C Input pre-processing

One potential challenge is that for different tasks and problem dimensions, the energy function, momentum and energy gradient can have very different scales and magnitudes. This affects the meta sampler’s generalization, for example, if training and test densities have completely different energy scales, then the meta sampler is likely to produce wrong strategies. This is especially the case when the meta sampler is generalized to much bigger networks or to very different datasets.

To mediate this issue, we propose to pre-process the inputs to both and networks to make it at similar scale as those in training task. Recall that the energy function is where the prior is often isotropic Gaussian distribution. Thus the energy function scale linearly w.r.t  both the dimensionality of and the total number of observations . Often the energy function is further approximated using mini-batches of datapoints. Putting them together, we propose pre-processing the energy as

(15)

where and are the dimensionality of in the training task and the test task, respectively. Importantly, for RNNs represents the total sequence length, namely , where is the total number of sequences and is the sequence length for a datum . We also define accordingly. The momentum and energy gradient magnitudes are estimated by simulating a randomly initialized meta sampler for short iterations. With these statistics we normalize both the momentum and the energy gradient to have roughly zero mean and unit variance.

Appendix D Experiment Setup

d.1 Toy Example

We train our meta sampler on a 10D uncorrelated Gaussian with mean and randomly generated covariance matrix. We do not set any offset and additional frictions, i.e. and . The noise estimation matrix are set to be for both meta sampler and SGHMC. To mimic stochastic gradient, we manually inject Gaussian noise with zero mean and unit variance into . The functions and are represented by 1-hidden-layer MLPs with 40 hidden units. For training task, the meta sampler step size is 0.01. The initial positions are drawn from . We train our sampler for 100 epochs and each epochs consists 4 x 100 steps. For every 100 steps, we updates our and using Adam optimizer with learning rate 0.0005. Then we continue the updated sampler with last position and momentum until 4 sub-epochs are finished. We re-initialize the momentum and position. We use both cross-chain and in-chain losses. The Stein Gradient estimator uses RBF kernel with bandwidth chosen to be 0.5 times the median-heuristic estimated value. We unroll the Markov Chain for 20 steps before we manually stop the gradient. For cross-chain training, we take sampler across chain for each 2 time steps. For in-Chain, we discard initial 50 points for burn-in and sub-sample the chain with batch size 5. We thin the samples for every 3 steps. For both training and evaluation, we run 50 parallel Markov Chains.

The test task is to draw samples from a 20D correlated Gaussian with with mean and randomly generated covariance matrix. The step size is 0.025 for both meta sampler and SGHMC. To stabilize the meta sampler we also clamp the output values of within . The friction matrix for SGHMC is selected as .

d.2 Bayesian MLP MNIST

In MNIST experiment, we apply input pre-processing on energy function as in (15) and scale energy gradient by 70. Also, we scale up by 50 to account for sum of stochastic noise. The offset is selected as as suggested by chen2014stochastic, where with the per-batch learning rate. We also turn off the off-set and noise estimation, i.e. and . We run 20 parallel chains for both training and evaluation. We only adopt the cross chain training with thinning samplers of 5 times step. We also use the finite difference technique during evaluation to speed-up computations.

d.2.1 Architecture Generalization

We train the meta sampler on a smaller BNN with architecture 784-20-10 and ReLU activation function, then test it on a larger one with architecture 784-40-40-10. In both cases the batch size is 500 following chen2014stochastic. Both and are parameterized by 1-hidden-layer MLPs with 10 units. The per-batch learning rate is 0.007. We train the sampler for 100 epochs and each one consists of 7 sub-epochs. For each sub-epoch, we run the sampler for 100 steps. We re-initialize and momentum after each epoch. To stabilize the meta sampler in evaluation, we first run the meta sampler with small per-batch learning rate for 3 data epochs and clamp the values. After, we increase the per-batch learning rate to with clipped . The learning rate for SGHMC is for all times. For SGLD and PSGLD, they are and respectively. These step-sizes are tuned on MNIST validation data.

d.2.2 Activation function generalization

We modify the test network’s activation function to sigmoid. We use almost the same settings as in network generalization tests, except that the per-batch learning rates are tuned again on validation data. For the meta sampler and SGHMC, they are and . For SGLD and PSGLD, they are and .

d.3 Dataset Generalization

We train the meta sampler on ReLU network with architecture 784-20-5 to classify images 0-4, and test the sampler on ReLU network 784-40-40-5 to classify images 5-9. The settings are mostly the same as in network architecture generalization for both training and evaluation. One exception is again the per-batch learning rate for PSGLD, which is tuned as . Note that even though we use the same per-batch learning rate as before, the discretization step-size is now different due to smaller training dataset, thus, will be automatically adjusted accordingly.

d.4 Bayesian RNN

The Piano data is selected as the training task, which is further split into training, validation and test subsets. We use batch-size 1, meaning that the energy and the gradient are estimated on a single sequence. The meta sampler uses similar neural network architectures as in MNIST tests. The training and evaluation per-batch learning rate for all the samplers is set to be 0.001 following chen2016bridging. We train the meta sampler for 40 epochs with 7 sub-epochs with only cross chain loss. Each sub-epochs consists 70 iterations. We scale the output by 20 and set , where is defined in the same way as before. We use zero offset during training, i.e. . We apply input pre-processing for both and . To prevent divergence of the meta sampler at early training stage. We also set the constant of to the . For dataset generalization, we tune the off-set value based on Piano validation set and transfer the tuned setting to the other three datasets. For Piano architecture generalization, we do not tune any hyper-parameters including and use exactly same settings as training. Exact gradient is used in RNN experiments instead of computing finite differences.

Appendix E RNN dataset description

We list some data statistics in Table 2 which roughly indicates the similarity between datasets.

Piano Muse Nott JSB
Size:train 87 524 694 229
Size:test 25 124 170 77
Avg. Time:train 872 467 254 60
Avg. Time:test 761 518 261 61
Energy scale:train
Table 2: The basic statistics for 4 RNN datasets, bold figure represents large difference compared to others. Size is the number of data point. Avg. Time is the averaged sequence and Energy scale is the rough scale of the train NLL when sampler converges.

Piano dataset is the smallest in terms of data number, however, the averaged sequence length is the largest. Muse dataset is similar to Piano in sequence length and energy scale but much larger in terms of data number. On the other hand, Nott dataset has very different energy scale compared to the other three. This potentially makes the generalization much harder due to inconsistent energy scale fed into and . For JSB, we notice a very short sequence length on average, therefore the GRU model is more likely to over-fit. Indeed, some algorithms exhibits significant over-fitting behavior on JSB dataset compared to other data (Santa is particularly severe).

Appendix F Additional Plots

f.1 Short run comparison

We also run the samplers using the same settings as in MNIST experiments for a short period of time (500 iterations). We also compare to other optimization methods including Momentum SGD (SGD-M) and Adam. We use the same per-batch learning rate for SGD-M and SGHMC as in MNIST experiment. For Adam, we use 0.002 for ReLU and 0.01 for Sigmoid network.

Figure 13: We only test the Network Generalization and Activation function generalization. The upper part indicates the test error plot and lower part are the negative test LL curve

The results are shown in Figure 13. Meta sampler and Adam achieves the fastest convergence speed. This again confirms the faster convergence of the meta sampler especially at initial stages. We also provide additional contour plots (Figure 14) to demonstrate the strategy learned by for reference.

Figure 14: The contour plots of for other input values.