1 Introduction
In contrast to optimization approaches in machine learning that derive a single estimate for the weights of a neural network, Bayesian inference aims at deriving a posterior distribution over the weights of the network. This makes it possible to sample model instances from the distribution over the weights and offers unique advantages. Multiple model instances can be aggregated to obtain robust uncertainty estimates over the network’s predictions; uncertainty estimates are crucial in domains such as medical diagnosis and autonomous driving where following a model’s incorrect predictions can result in catastrophe
kendall2017uncertainties . Sampling a distribution, as opposed to optimizing a loss, is less prone to overfitting and more training doesn’t decrease test performance. Bayesian inference can also be applied to differential privacy, where each individual sample has increased privacy guarantees privacy, and to reinforcement learning, where one can leverage model uncertainty to balance between exploration and exploitation
thompsonsampling .Markov Chain Monte Carlo (MCMC) methods are a standard class of methods for generating samples from the posterior distribution over model parameters. These methods are seldom applied in deep learning because they have traditionally failed to scale well with large datasets and many parameters bigmcmc . Stochastic Gradient MCMC (SGMCMC) methods have fared somewhat better in scaling to large datasets due to their close relationship to stochastic optimization methods. For example the SGLD sampler sgld
amounts to performing stochastic gradient descent while adding Gaussian noise to each parameter update. Despite these improvements, samplers like SGLD are only guaranteed to converge to the correct distribution when the step size is annealed to zero; additional control variates have been developed to mitigate this to some extent
sgfs ; bayesthermostat .The objective of this work is to make Bayesian inference practical for deep learning by making SGMCMC methods scale to large models and datasets. The contributions described in this work fall in three categories. We first propose the Adaptive Thermostat Monte Carlo (ATMC) sampler that offers improved convergence and stability. ATMC dynamically adjusts the amount of momentum and noise applied to each model parameter. Secondly, we improve an existing second order numerical integration method that is needed for the ATMC sampler. Third, since ATMC, like other SGMCMC samplers, is not directly compatible with stochastic regularization methods as batch normalization (BatchNorm) and Dropout (see Sect. 4), we construct the ResNet++ network by taking the original ResNet architecture resnet , removing BatchNorm and introducing SELUs selu , Fixup initialization fixup and weight normalization weightnorm . We design ResNet++ so that its parameters are easy to sample from and the gradients are wellbehaved even in the absence of BatchNorm.
We show that the ATMC sampler is able to outperform optimization methods in terms of accuracy, loglikelihood and uncertainty calibration in the following settings. First, when using the ResNet++ architecture for both the ATMC sampler and the optimization baseline, the ATMC sampler significantly outperforms the optimization baseline on both Cifar10 and ImageNet. Secondly, when using the standard ResNet for the optimization baseline and the ResNet++ for the ATMC sampler, multiple samples of the ATMC that approximate the predictive posterior of the model are still able to outperform the optimization baseline on ImageNet. Using the ResNet++ architecture, the ATMC sampler reduces the need for hyperparameter tuning since it does not require early stopping, does not use stochastic regularization, is not prone to overfitting on the training data and avoids a carefully tuned learning rate decay schedule.
2 ATMC Sampler
In this section we define the Stochastic Differential Equation (SDE) that gives rise to the ATMC sampler described in Algorithm 1. A detailed background and framework for constructing SDEs that converge to a target distribution can be found in mcmcrecipe .
2.1 General form of the SDE
Our starting point for constructing the ATMC sampler is the framework of Stochastic Differential Equations. We are interested in SDEs that converge to a distribution for which we can evaluate . Because only the gradient of is required, it is sufficient to define an energy function up to a constant . As a consequence, we can sample from the posterior distribution by only evaluating the energy function gradient . The general form of SDEs converging to for which only the gradient of is required is as follows mcmcrecipe :
(1) 
where is a positivedefinite matrix that determines the amount of noise,
is a skewsymmetric matrix that mixes energy between variables, and
is a correction factor that compensates for dynamics that depend on the current state . The ATMC sampler that we propose is an instance of (1) for specific definitions of , , and .2.2 Energy Function
We start by defining the energy function . The energy function for the model posterior
is defined by the loss function
. Because the dataset is generally large, we would like to only evaluate a minibatch loss . However, naively using a stochastic gradient in (1) will result in significant bias sghmc. Motivated by the Central Limit Theorem, the stochastic gradient is assumed to follow a Gaussian distribution
where the covariance is additionally assumed to be diagonal and constant w.r.t. . The energy function for the ATMC sampler is defined as:(2) 
where is the momentum, defines the momentum distribution, and is a control variate referred to as the temperature. Both and have the same dimensionality as . The distribution of the control variate depends on the amount of noise in the stochastic gradient estimate .
2.3 Noise robust dynamics
Next we define the dynamics and such that the SDE that results from (1) can be simulated without the need to evaluate :
(3) 
where is a nonnegative function that determines how the temperature affects the amount of noise added to the momentum update.
We first illustrate the resulting SDE by using a simpler Gaussian momentum distribution (where is a hyperparameter). We substitute the dynamics and defined in (3) and energy function defined in (2) into (1):
(4) 
where we use to replace the gradient of the loss with the minibatch estimate. The momentum is dampened by a friction term that depends on the choice of . The stochastic gradient noise does not show up in (4) due to the particular choice of energy function and dynamics . Note however this analysis relies on the assumption that the covariance of the stochastic gradient noise is constant in and a single temperature variable per parameter can only correct for a diagonal covariance . We do not expect that this assumption will hold in practise and the approximation will therefore lead to bias in the samples. However, annealing the step size will reduce the error due to minibatching together with other sources of discretization error sgld .
2.4 Adaptive Noise Thermostat
Finally, we must choose a function which controls the amount of noise and momentum damping . Previous work uses the NoséHoover thermostat that is defined by where is a constant determining the amount of noise added to the momentum update bayesthermostat . Although the NoséHoover thermostat is able to correct the stochastic gradient noise , the correction comes at the cost of slower convergence because additional friction is applied as increases. Another drawback of the NoséHoover thermostat is that it causes negative friction when . In the negative friction phase , previous gradient terms are amplified rather than dampened. Although this behavior is mathematically sound we find that it can cause exploding momentum variables.
Our choice of is based on the idea that negative friction should not occur and convergence speed should not be reduced by the stochastic gradient noise. Based on this intuition, we define the ATMC sampler by . The ATMC sampler is best characterized by the various temperature stages. For the total amount of noise added to the momentum is and the friction coefficient . At this stage, the stochastic gradient noise is compensated for by adding less noise to the momentum update. If the dominant stage will be resulting in and zero noise being added to the momentum. Finally, when the friction coefficient and the noise added to the momentum is proportional to . Thus, the momentum always experiences a minimum amount of friction determined by the hyperparameter and the noise added to the momentum update is automatically adjusted based on the amount of noise present in the stochastic gradients.
2.5 Momentum energy function
Following rhmc , we generalize the momentum energy function to the symmetric hyperbolic distribution which is defined as follows rhmc :
(5) 
where and are hyperparameters. The Gaussian kinetic energy is a special case obtained by taking the limit . The magnitude of parameter updates is determined by the gradient of the momentum:
(6) 
Hence, the hyperbolic distribution results in relativistic momentum dynamics where the parameter updates are upper bounded by and the preconditioner depends on . The average update magnitude for . Consequently, the parameters and are interpretable hyperparameters controlling the average and maximum parameter update per step together with the step size .
The SDE we derive in (4) and integrate in Sec. 3 uses a Gaussian momentum energy function for clarity. Deriving ATMC with a different momentum distribution like the hyperbolic distribution amounts to substituting (2), (3), and the alternative momentum distribution into (1). For the hyperbolic distribution, the dynamic friction coefficient will also depend on . For the numerical integration of (4) with a hyperbolic momentum distribution we assume to be constant in .
3 Improved numerical integrator for MCMC samplers
In this section we construct the numerical integrator required to numerically approximate the ATMC sampler defined in (4). An efficient numerical integrator can be constructed by splitting the SDE into two terms:
(7) 
Hence, we obtain a linear ODE in part (A) that updates the parameters and the thermostats and a linear SDE in part (B) that updates the momentum . The operators that simulate these dynamics exactly for a time step are denoted and , respectively. Using the Strang splitting scheme yields a second order method splitintegrator :
(8) 
The first operator is given by
(9) 
The second operator is an instance of the Ornstein–Uhlenbeck process which can also be computed exactly as follows:
(10)  
(11) 
Previous work splitintegrator on higher order integrators for samplers splits the SDE into three parts where the third term is obtained from separating the friction term from the other terms in the momentum update . By integrating (10) exactly the gradient step and the noise and gradient term are directly affected by the friction. An exact momentum update provides additional robustness to large gradients because the temperature will increase in order to compensate for momentum updates that would lead to excessively large steps. Another advantage of a twoway split integrator is that the first and last steps in (8) can be fused together such that only a momentum update is performed per iteration. Algorithm 1 shows the pseudocode for the ATMC sampler with the split integrator defined in (9) and (10).
4 The ResNet++ Architecture
The generalization performance of large neural nets trained using optimization depend on stochastic regularization methods like Dropout dropout and BatchNorm batchnorm . These methods implicitly add noise into the model parameters dropoutvi ; batchnormvi
and significantly boost training performance and generalization for image classifiers. These methods can be interpreted as a coarse approximation of Bayesian Inference
dropoutvi ; batchnormvi . But a stochastic gradient sampler like ATMC already adds the necessary amount of noise and combined with BatchNorm or Dropout it leads to underfitting. We thus define a BatchNorm free version of ResNet called ResNet++ that includes SELUs selu , Fixup initialization fixup and weight normalization weightnorm (see Fig. 1). We use ATMC to fill the significant gap in performance due to the absence of BatchNorm in ResNet++.4.1 Selu
We find the SELU activation to work well in BatchNorm free networks. SELU forces the statistics of the activations towards zero mean and unit variance
selu. The SELU activation function additionally has a nonzero gradient everywhere which could improve the mixing of the sampler by providing a more informative gradient.
4.2 Fixup initialization
ResNets are known to scale well with depth resnet
. However, the additive effect of the residual branch causes the magnitudes of the activations to increase with the number of residual connections. Fixup is a recently proposed initialization method that mitigates the exploding residual branch problem without using BatchNorm
fixup . We use a simplified version of Fixup by initializing the scales of the final layer in each residual branch to a small constant.4.3 Weight normalization
We use weight normalization weightnorm
to separate the direction and scale of each linear feature vector
(12) 
where is the direction vector and is the magnitude of a feature vector . Weight normalization does not depend on batch statistics and is compatible with MCMC methods .
The scale of the direction vector does not affect the outputs of the model. It does however affect the effective step size wngrad . Therefore the prior on the direction vector is chosen such that it is forced to unit length
(13) 
The prior on the scales is problem specific and can for example be chosen to encode a preference for structurally sparse models.
5 Experiments
The experiments presented here aim to demonstrate that the ATMC sampler is competitive with a welltuned optimization baseline for largescale datasets and models. We use the TensorFlow official implementation of ResNet56 and ResNet50 on Cifar10 and ImageNet, respectively. We compare our ATMC sampler to an optimization baseline with and without BatchNorm. For the optimization baseline without BatchNorm we use the ResNet++ architecture as described in Sec.
4. For the baseline with BatchNorm we found standard ResNet with Xavier initialization and the ReLU nonlinearity to work better.
For the ATMC sampler we report both the performance of a single sample and the estimated posterior predictive based on a finite number of samples. Similar to earlier work
mcmccycle we found that many fewer samples are needed when a cyclic step size with cycle length is used. The final sample in each cycle is used to estimate the posterior predictive.For ResNet++ we further use a group Laplace prior with to regularize the scales of each linear feature in ResNet++. The momentum noise is chosen as such that the friction applied to the momentum is at least .
5.1 Cifar 10
Setup  Top 1 acc. [%]  NLL [Nats] 

SGD  
SGD + BatchNorm  94.4  
ATMC (single sample)  
ATMC (Posterior predictive)  0.194  
SGNHT (single sample)  
SGNHT (Posterior predictive) 
For Cifar10 we choose the step size
and the cycle length is set to 50 epochs. The momentum hyperparameters are
and such that the average speed and maximum speed per step are and , respectively. The number of convolution filters is doubled to compared to the original ResNet56 implementation. We use a single V100 GPU with a batch size of . The sampler runs for epochs and we start collecting samples for the posterior predictive after epochs. The optimization baseline converges in epochs. We also report the results of sampling with a sampler based NoséHoover thermostats (SGNHT) bayesthermostat ; rhmc applied to the ResNet++ architecture.Table 1 lists the test set performance for Cifar10. A single sample from the posterior already outperforms the baseline without BatchNorm by a significant margin in both test accuracy and loglikelihood. Using BatchNorm significantly improves the generalization of the optimization baseline. It outperforms the estimate of the posterior predictive in accuracy yet it does not have a better test loglikelihood.
To further analyze the quality of the uncertainty estimates, we group each model’s prediction in 8 equally sized bins based on the confidence where
is the maximum probability class for example
. If the probabilities are wellcalibrated, the average confidence should be close to the average accuracy. Figure 2 shows the calibration of the uncertainty estimates for the posterior predictive and optimization baselines. The posterior predictive is calibrated for the least confident predictions and shows less bias towards overconfidence compared to the models trained with SGD.5.2 ImageNet
Setup  Top 1 acc. [%]  NLL [Nats] 

SGD  
SGD + BatchNorm  
ATMC (single sample)  
ATMC (Posterior predictive)  77.5  0.883 
SGNHT (single sample)  
SGNHT (Posterior predictive) 
For the ImageNet experiments we use an initial step size and a cycle length of 20 epochs. The other hyperparameters for the sampler are the same as for the Cifar10 experiments. We use a a single Google Cloud TPUv3 with a batch size of 1024. We did not observe a significant difference in wall clock time per epoch for SGD and ATMC. BatchNorm did result in an overhead of roughly compared to the ResNet++ model. Samples for the posterior predictive are collected after epochs and the sampler runs for epochs. The optimization baseline converges in epochs.
Table 2 lists the results for ImageNet classification. A single sample from the posterior outperforms the optimization baseline without BatchNorm. The posterior predictive based on ATMC outperforms the optimizer with BatchNorm by a wide margin in both accuracy and test loglikelihood. We note that the sampler runs significantly longer (x) compared to the optimization baseline because it takes a long time for the posterior predictive estimate to converge. However, the posterior predictive of ATMC matches the accuracy of the optimization baseline with BatchNorm () after epochs.
Figure 3 shows the quality of the uncertainty for various levels of confidence. Again, the ATMC based posterior predictive produces much better calibrated predictions and is almost perfectly calibrated for low confidence predictions and shows less bias towards overconfidence compared to the optimization baseline.
6 Discussion
The empirical results show it is possible to sample the posterior distribution of neural networks on large scale image classification problems like ImageNet. A major obstacle for sampling the posterior of ResNets in particular is the lack of compatibility with BatchNorm. Using recent advances in initialization and the SELU activation function we are able to stabilize and speed up training of ResNets without resorting to BatchNorm. Nonetheless, we observe that BatchNorm still offers a unique advantage in terms of generalization performance. We hope that future work will allow the implicit inductive bias that BatchNorm has to be transferred into an explicit prior that is compatible with sampling methods.
Multiple posterior samples provide a much more accurate estimate of the posterior predictive, and consequently much better accuracy and uncertainty estimates. For inference, making predictions using a large ensemble of models sampled from the posterior can be costly. Variational Inference methods can be used to quickly characterize a local mode of the posterior bayesbybackprop . More recent work shows that a running estimate of the mean and variance of the parameters during training can also be used to approximate a mode of the posterior swag . Methods like distillation could potentially be used to compress a highquality ensemble into a single network with a limited computational budget distilation .
Although the form in (4) is very general, alternative methods for dealing with stochastic gradients have been proposed in the literature. One approach is to estimate the covariance of the stochastic gradient noise explicitly and use it correct and precondition the sampling dynamics sgfs ; sgfsadam .
Other sampling methods are not based on an SDE that converges to the target distribution. Under some conditions stochastic optimization methods can be interpreted as such a biased sampling method sgdbi . Predictions based on multiple samples from the trajectory of SGD have been used successfully for obtaining uncertainty estimates in large scale Deep Learning swag . However, these methods rely on tuning hyperparameters in such a way that just the right amount of noise is inserted.
7 Conclusion
This work introduces the ATMC sampler, a robust posterior sampling method that scales to large deep learning problems. To the best of our knowledge, we are the first to successfully train neural networks using MCMC on ImageNet. In a BatchNorm free setting, a single sample from the posterior generated by ATMC outperforms the optimization baseline. A posterior predictive estimate outperforms the optimization baseline with BatchNorm on ImageNet. Based on these empirical results we hope the ATMC sampler will enable new applications of Bayesian inference in deep learning.
Acknowledgments
We would like to thank Jascha Sohldickstein and Sebastian Nowozin for helpful feedback. In particular we wish to thank Tim Salimans for his feedback and insightful discussions on MCMC methods.
References

[1]
Alex Kendall and Yarin Gal.
What uncertainties do we need in bayesian deep learning for computer vision?
In Advances in neural information processing systems, pages 5574–5584, 2017.  [2] YuXiang Wang, Stephen Fienberg, and Alex Smola. Privacy for free: Posterior sampling and stochastic gradient monte carlo. In International Conference on Machine Learning, pages 2493–2502, 2015.
 [3] Ian Osband and Benjamin Van Roy. Why is posterior sampling better than optimism for reinforcement learning? In Proceedings of the 34th International Conference on Machine LearningVolume 70, pages 2701–2710. JMLR. org, 2017.
 [4] Bala Rajaratnam and Doug Sparks. Mcmcbased inference in the era of big data: A fundamental analysis of the convergence complexity of highdimensional chains. arXiv preprint arXiv:1508.00947, 2015.
 [5] Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML11), pages 681–688, 2011.
 [6] Sungjin Ahn, Anoop Korattikara, and Max Welling. Bayesian posterior sampling via stochastic gradient fisher scoring. arXiv preprint arXiv:1206.6380, 2012.
 [7] Nan Ding, Youhan Fang, Ryan Babbush, Changyou Chen, Robert D Skeel, and Hartmut Neven. Bayesian sampling using stochastic gradient thermostats. In Z. Ghahramani, M. Welling, C. Cortes, N. D. Lawrence, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems 27, pages 3203–3211. Curran Associates, Inc., 2014.

[8]
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
Deep residual learning for image recognition.
In
Proceedings of the IEEE conference on computer vision and pattern recognition
, pages 770–778, 2016.  [9] Günter Klambauer, Thomas Unterthiner, Andreas Mayr, and Sepp Hochreiter. Selfnormalizing neural networks. In Advances in neural information processing systems, pages 971–980, 2017.
 [10] Hongyi Zhang, Yann N Dauphin, and Tengyu Ma. Fixup initialization: Residual learning without normalization. arXiv preprint arXiv:1901.09321, 2019.
 [11] Tim Salimans and Durk P Kingma. Weight normalization: A simple reparameterization to accelerate training of deep neural networks. In Advances in Neural Information Processing Systems, pages 901–909, 2016.
 [12] YiAn Ma, Tianqi Chen, and Emily Fox. A complete recipe for stochastic gradient mcmc. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 2917–2925. Curran Associates, Inc., 2015.
 [13] Tianqi Chen, Emily Fox, and Carlos Guestrin. Stochastic gradient hamiltonian monte carlo. In International conference on machine learning, pages 1683–1691, 2014.
 [14] Xiaoyu Lu, Valerio Perrone, Leonard Hasenclever, Yee Whye Teh, and Sebastian J Vollmer. Relativistic monte carlo. arXiv preprint arXiv:1609.04388, 2016.
 [15] Changyou Chen, Nan Ding, and Lawrence Carin. On the convergence of stochastic gradient mcmc algorithms with highorder integrators. In Advances in Neural Information Processing Systems, 2015.
 [16] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research, 15(1):1929–1958, 2014.
 [17] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pages 448–456, 2015.
 [18] Durk P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameterization trick. In Advances in Neural Information Processing Systems, pages 2575–2583, 2015.
 [19] Mattias Teye, Hossein Azizpour, and Kevin Smith. Bayesian uncertainty estimation for batch normalized deep networks. arXiv preprint arXiv:1802.06455, 2018.
 [20] Xiaoxia Wu, Rachel Ward, and Léon Bottou. Wngrad: learn the learning rate in gradient descent. arXiv preprint arXiv:1803.02865, 2018.
 [21] Ruqi Zhang, Chunyuan Li, Jianyi Zhang, Changyou Chen, and Andrew Gordon Wilson. Cyclical stochastic gradient mcmc for bayesian deep learning. arXiv preprint arXiv:1902.03932, 2019.
 [22] Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural network. In International Conference on Machine Learning, pages 1613–1622, 2015.
 [23] Wesley Maddox, Timur Garipov, Pavel Izmailov, Dmitry Vetrov, and Andrew Gordon Wilson. A simple baseline for bayesian uncertainty in deep learning. arXiv preprint arXiv:1902.02476, 2019.
 [24] Anoop Korattikara Balan, Vivek Rathod, Kevin P Murphy, and Max Welling. Bayesian dark knowledge. In Advances in Neural Information Processing Systems, pages 3438–3446, 2015.

[25]
Chunyuan Li, Changyou Chen, David Carlson, and Lawrence Carin.
Preconditioned stochastic gradient langevin dynamics for deep neural
networks.
In
Thirtieth AAAI Conference on Artificial Intelligence
, 2016.  [26] Stephan Mandt, Matthew D Hoffman, and David M Blei. Stochastic gradient descent as approximate bayesian inference. The Journal of Machine Learning Research, 18(1):4873–4907, 2017.