Stochastic natural gradient descent draws posterior samples in function space

06/25/2018 ∙ by Samuel L. Smith, et al. ∙ Google 0

Natural gradient descent (NGD) minimises the cost function on a Riemannian manifold whose metric is defined by the Fisher information. In this work, we prove that if the model predictions on the training set approach the true conditional distribution of labels given inputs, then the noise inherent in minibatch gradients causes the stationary distribution of NGD to approach a Bayesian posterior, whose temperature T ≈ϵ N/(2B) is controlled by the learning rate ϵ, training set size N and batch size B. The parameter-dependence of the Fisher metric introduces an implicit prior over the parameters, which we identify as the well-known Jeffreys prior. To support our claims, we show that the distribution of samples from NGD is close to the Laplace approximation to the posterior when T = 1. Furthermore, the test loss of ensembles drawn using NGD falls rapidly as we increase the batch size until B ≈ϵ N/2, while above this point the test loss is constant or rises slowly.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Stochastic gradient descent remains the dominant optimization algorithm for supervised learning, but it performs poorly when the loss landscape is ill-conditioned. Natural gradient descent (NGD) provides a robust alternative based on the functional view of learning

(Amari, 1998). This view observes that the model parameters are not important, only the function they represent. This function relates randomly sampled inputs to a distribution over labels , and the goal of learning is to uncover the true conditional distribution . We will use to refer to the training set, composed of individual training examples . A single NGD update has the form,

(1)

where is the summed cost function, is the training set size, and is the parameter update. We pre-condition the gradient by the inverse Fisher information,

(defined in Section 4), which we estimate over the training inputs

(Ly et al., 2017). As the learning rate , the change in function during an update becomes independent of the model parameterization (Martens, 2014). NGD also ensures stability, by bounding the KL-divergence between functions before and after an update, such that , subject to weak requirements on the smoothness and support of (Ollivier et al., 2011).

However these properties require full-batch gradients, while in practice we estimate the gradient over a minibatch (Bottou, 2010; Robbins & Monro, 1951). Therefore in this work we analyse the influence of stochastic gradients on the stationary distribution of NGD as 111Note however that throughout this work we assume the uncertainty in the estimate of the Fisher is negligible.. Remarkably, we find that if the model predictions on the training set approach the true conditional distribution , then this stationary distribution approaches a Bayesian posterior near local minima. The temperature is controlled by the ratio of the learning rate, , to the batch size, .

However unlike full-batch NGD, minibatch NGD is not parameterisation invariant. To recover parameterisation invarance, we introduce additional terms which arise from the parameter dependence of the Fisher metric. We propose a novel optimiser, “stochastic natural gradient descent" (SNGD),

(2)

SNGD reduces to a conventional NGD step as or when . Meanwhile in the limit and , Equation 2 will preserve parameterisation invariance and draw samples from a valid Bayesian posterior throughout parameter space at temperature . This novel optimiser introduces a multiplicative bias to the stationary distribution, which we identify as the Jeffreys prior (Jeffreys, 1946). This popular default prior is both uninformative and invariant to the model parameterization (Firth, 1993; Gelman, 2009), enabling us to sample from “function space”, rather than parameter space. We provide a range of experiments to support our claims:

  • We show that when the batch size and , the distribution of samples from NGD is close to the Laplace approximation to the true Bayesian posterior at .

  • If the prior over parameters is chosen appropriately, then Bayesians recommend drawing inferences using ensembles sampled at . We observe empirically that the test loss of ensembles drawn using NGD exhibit minima close to . This suggests we can analytically predict the optimal ratio between the learning rate and the batch size.

Our work builds on a number of recent works in the Bayesian community, which we discuss in Section 2. In Section 3 we derive the update rules for Langevin posterior sampling on arbitrary Riemannian manifolds. In Section 4 we consider Langevin dynamics on the Fisher manifold, noting that the parameter dependence of the Fisher information introduces a data-dependent Jeffreys prior over the parameters. In section 5 we analyse the behaviour of minibatch NGD near local minima, and we introduce our novel stochastic NGD which preserves parameterisation invariance throughout parameter space. We present our experimental results in Section 6, applying minibatch NGD to logistic regression, shallow fully connected neural networks and deep convolutional networks.

2 Related work

Langevin dynamics (Gardiner, 1985) can be used to draw posterior samples but often mixes slowly. Girolami & Calderhead (2011) proposed Riemannian Langevin dynamics, which combines Langevin dynamics and NGD to enable faster mixing when the loss landscape is poorly conditioned. However their method still requires exact gradients and remains prohibitively expensive for large training sets. Welling & Teh (2011) proposed SGLD, the first algorithm to draw posterior samples with stochastic gradients. Patterson & Teh (2013) combined SGLD with Riemannian Langevin dynamics while Li et al. (2016)

combined SGLD with RMSprop. These methods scale well to large datasets but they still mix slowly, since the learning rate must satisfy the Robbins-Munro conditions

(Robbins & Monro, 1951). Ahn et al. (2012) observed that the noise inherent in stochastic gradients converges to the empirical Fisher near local minima. By explicitly accounting for this noise, they draw samples from a Gaussian approximation to the posterior with finite learning rates, while converging to the exact posterior as . A similar analysis was published by Mandt et al. (2017), who set the gradient preconditioner equal to the inverse of the empirical gradient covariances. These methods enable fast mixing with stochastic gradients, but the analysis in both papers only holds near local minima.

The link between natural gradients and the Jeffreys prior is mentioned in earlier work by Amari (1998), while Mahony & Williamson (2001) analysed the relationship between Riemannian manifolds and implicit priors in online learning. While we focus on exact NGD in this work, we anticipate that our conclusions could be scaled to larger networks using approximate natural gradient methods like K-FAC (Martens & Grosse, 2015) or other quasi-diagonal approximations to the Fisher (Marceau-Caron & Ollivier, 2017). Zhang et al. (2017) recently proposed applying K-FAC to perform variational inference, while Nado et al. (2018) combined K-FAC with SGLD for posterior sampling.

3 Posterior sampling on Riemannian manifolds

The Langevin equation updates the parameters in continuous time according to (Gardiner, 1985),

(3)

where is the summed cost function and denotes isotropic Gaussian noise. This noise has mean and covariance , where is a Dirac delta function,

is the identity matrix, and

is a scalar known as the “temperature”. In the infinite time limit, Equation 3 will sample parameters from . To draw samples numerically we introduce the learning rate , and integrate over a finite step of length to obtain,

(4)

where

is a Gaussian random variable with mean

and covariance . If we perform parameter updates using Equation 4 and , then we will draw a single sample from the distribution above as . Meanwhile if the posterior is proportional to the exponentiated negative cost function, , then we sample parameters from this posterior when . We now generalise the Langevin equation to a static Riemannian manifold,

(5)
Theorem 1.

We assume that is Lipschitz continuous everywhere, and that the metric is positive definite. In the limit , repeated application of Equation 5 will sample parameters from the stationary distribution , if and .

Proof.

Since is positive definite, . To prove Theorem 1, we simply perform a basis transformation on the original Langevin equation of Equation 3, such that . ∎

Note however that the metric in Equation 5 is constant and independent of the parameters. Meanwhile the metric of NGD is defined by the Fisher information, which is a function the parameters. A detailed discussion is beyond the scope of this paper, but great care must be taken when discretising stochastic differential equations on a non-stationary metric. To build an intuition for this, we note that the magnitude of the noise contribution in Equation 5 is . We therefore cannot expand a stochastic differential equation to first order in without considering second order contributions from the noise source, and these second order contributions introduce additional correction terms which vanish if the metric derivatives are zero. We introduce these correction terms below,

Theorem 2.

We assume both and are Lipschitz continuous and is positive definite. In the limit , repeated application of Equation 3 will sample parameters from , if and .

This theorem is discussed further in the appendix. Notice that the stationary distribution is modified by an implicit bias, , which arises from the parameter dependence of the metric. If one does not wish to introduce this implicit bias, one could instead apply the following update rule,

(7)
Theorem 3.

We assume both and are Lipschitz continuous and is positive definite. In the limit , repeated application of Equation 7 will sample parameters from , if and .

Proof.

Consider Theorem 2, and let . This modified cost function removes the implicit bias from the stationary distribution, such that , and by directly evaluating the derivative of we arrive at Equation 7. ∎

4 Preconditioned Langevin dynamics and the Jeffreys prior

In this Section, we apply Theorem 2 to draw posterior samples with full batch gradient estimates on the Fisher manifold. For simplicity, we assume the cost function , where is a regularizer and is the cross-entropy of a unique categorical label. We set the metric equal to the Fisher information,

(8)

Notice that the expectation is taken over the predicted label distribution given the current parameter values, , not the empirical labels , and therefore depends only on the training inputs . Following Theorem 2, for sufficiently small repeated application of Equation 3 will draw samples from . The data likelihood . We introduce the temperature adjusted posterior , and recall that we obtain the true posterior when . Notice that the temperature adjusted prior may depend on the training inputs but not the labels . We set to identify,

(9)

The multiplicative bias introduced by the Fisher metric modifies the prior imposed by the regulariser. When , we obtain a temperature independent prior , which we identify as the Jeffreys prior; a common default prior which is both uninformative and invariant to the model parameterization (Jeffreys, 1946; Firth, 1993). Like the uniform prior, the Jeffreys prior is improper. It may therefore be necessary to include additional regularisation to ensure that the posterior is well-defined. Notice that the contribution to the prior from the regulariser grows weaker as the temperature increases, matching the temperature dependence of the likelihood term. Consequently the contribution to the prior from the metric becomes increasingly dominant as the temperature rises.

Note that we can only reinterpret the multiplicative bias as a prior if the gradient pre-conditioner is independent of the training labels. However practitioners often replace the Fisher information by the empirical Fisher information, estimating the inner expectation above over the training set labels ,

(10)

In this case, Equation 3 will draw samples from but we can only interpret this as a Bayesian posterior if the empirical Fisher is constant (for instance near local minima), in which case the correction terms vanish and we recover the simpler update rule of Equation 5. We could draw valid posterior samples under a uniform prior with the empirical fisher and Equation 7.

5 Stochastic natural gradients and minibatch noise

We showed in Section 4 how to draw posterior samples with full batch natural gradients and Gaussian noise. This required us to introduce additional correction terms, as shown in Equation 3. However in practice we usually estimate the gradient over a minibatch. The minibatch NGD update has the form,

(11)

where the gradient noise,

(12)

We assume the training set is randomly sorted between each update, such that Equation 12 samples a minibatch of training examples without replacement. Below we prove Theorem 4,

Theorem 4.

We assume is Lipschitz continuous everywhere, and that the Fisher information matrix is positive definite. Training inputs are drawn from a fixed distribution , while labels are assigned by a conditional distribution . If , , and for all observed , repeated application of Equation 11 will draw samples from near local minima (where the Fisher is stationary), with .

Proof.

Since the gradients of individual training examples are independent,

(13)

denotes the empirical gradient covariances. Applying the central limit theorem over examples in the minibatch, we conclude the gradient noise

is a Gaussian random variable. By direct substitution , while the covariance . To compute , we sum Equation 13 over the indices ,

(14)

As , , the covariance of the underlying data distribution. Meanwhile under mild regularity conditions, , and consequently if for all , we also obtain . Since we assumed this implies , which ensures the gradient noise cannot vanish and prevents the stationary distribution from collapsing to a fixed point. Finally we note that near local minima, the Fisher metric is stationary. Comparing Equations 5 and 11 and following Theorem 1, we obtain Theorem 4. ∎

This remarkable result shows that as the predictions of the model on the training set grow closer to the true conditional distribution of labels given inputs, the stationary distribution of minibatch NGD approaches a Bayesian posterior near local minima at temperature (at least as ). A similar analysis was proposed by Ahn et al. (2012), while Mandt et al. (2017) proposed to directly precondition the gradient by the empirical covariances. Both these works also require that the Fisher is stationary. In order to extend our analysis to non-stationary Fisher matrices, we replace the conventional minibatch NGD step of Equation 11 by our novel SNGD step proposed in Equation 2.

Theorem 5.

We assume both and are Lipschitz continuous everywhere, and that the Fisher information matrix is positive definite. Training inputs are drawn from a fixed distribution , while labels are assigned by a conditional distribution . If , , and for all observed , repeated application of Equation 2 will draw samples from , where .

Proof.

This result follows directly from the proofs of Theorems 2 and 4. Notice that Equation 3 reduces to Equation 2 if one sets and . ∎

As discussed in Section 4, the implicit bias in the stationary distribution of Equation 2 introduces a parameterisation invariant Jeffreys prior over the parameters. Notice that we could also exploit Equation 7 and Theorem 3 to draw (approximate) posterior samples under a uniform prior.

Finally, we note that the metric must be positive definite, while the Fisher is positive semi-definite. To resolve this, it is common practice to introduce the modified metric (Martens, 2014). The “Fisher damping”

imposes a trust region, ensuring the eigenvalues of the pre-conditioner

are bounded by . We therefore expect that we will be able to achieve stable training at larger learning rates by increasing , while we converge to exact SNGD as . The Fisher damping will also modify the implicit prior and break parameterisation invariance.

6 Experiments

In the following, we provide experiments analysing the behaviour of conventional minibatch NGD in the light of Theorem 4. We leave empirical analysis of our novel “SNGD” update rule to future work.

6.1 Comparing preconditioned Langevin dynamics and stochastic NGD

(a)
(b)
Figure 1: Samples from (a) preconditioned Langevin dynamics at and (b) NGD when the batch size

. We plot the iterates in red, the Laplace approximation to the posterior in blue and the covariance of the iterates in green (to 3 standard deviations). As predicted, the stationary distributions of preconditioned Langevin dynamics and stochastic NGD are remarkably similar.

We first consider a binary logistic regression task, with exclusive labels and inputs drawn from an n-dimensional unit Gaussian. Our model infers labels via . In Figure 1a, we show the distribution of samples from preconditioned Langevin dynamics, for which we set the metric equal to the Fisher information , and we evaluate the gradient over the entire training set. To set the temperature, we explicitly add Gaussian noise

to the gradient of variance

, as described in Section 4222However, for simplicity we do not include the additional correction terms in Equation 3, instead following the simpler update rule of Equation 5, despite the parameter dependence of the Fisher metric.. Our training set comprises 1000 examples and we set the input dimensionality

. We sample the training set labels of each input from the Bernoulli distribution of the “true” parameter values

. The learning rate and the Fisher damping . We first run 5000 parameter updates to burn in, before sampling the following 5000 parameter values. The blue point denotes the cost function minimum, obtained by LBFGS, while the blue curve illustrates the width of the Bayesian posterior (to 3 standard deviations), estimated using the Laplace approximation. To confirm that the samples are drawn from a distribution close to the posterior, we plot the covariance of the samples in green. In Figure 1b we replace preconditioned Langevin dynamics by NGD. To set , we simply estimate the gradient over minibatches of size . Batches are sampled randomly without replacement, and we continue to estimate the Fisher information over the full training set. As predicted, when we set the batch size correctly the stationary distribution of NGD is close to the posterior, drawing samples from the same distribution as preconditioned Langevin dynamics.

We now increase the difficulty of the task by setting the input dimensionality and the true parameter values , such that the input comprises one relevant feature and 63 irrelevant features. In Figure 2, we plot the mean test set accuracy and test cross entropy of preconditioned Langevin dynamics as we vary the temperature . Once again, we set the learning rate and the damping , and we run 1000 parameter updates to burn in. In blue we plot the performance of a conventional ensemble average over the final 1000 parameter values, while in green we plot the accuracy of a single sample. The ensemble outperforms single parameter samples, and the test cross entropy of the ensemble is minimised when . Meanwhile in Figure 3, we plot the test set accuracy and test cross entropy of NGD on the same task. We continue to estimate the Fisher over the full training set, but we sample the gradient using minibatches of size . As predicted by our theoretical analysis, the performance of NGD is remarkably similar to preconditioned Langevin dynamics, and the test cross entropy is minimised at .

(a)
(b)
Figure 2: Preconditioned Langevin dynamics for logistic regression. The test set accuracy (a) and test cross-entropy (b), as a function of the sampling temperature , which we set by adding Gaussian noise to the gradient. The test cross-entropy of the ensemble is minimised at .
(a)
(b)
Figure 3: NGD for logistic regression. The test set accuracy (a) and test cross-entropy (b), as a function of the sampling temperature , which we control by setting the batch size . As predicted, the performance of the ensemble is similar to preconditioned Langevin dynamics.

6.2 Stochastic NGD and an MLP

To confirm that our conclusions are relevant to non-convex settings we now apply NGD to train a simple MLP on MNIST. Our model has a single hidden layer with 40 hidden units and RELU activations. To reduce the number of parameters, we use a matrix whose elements are drawn randomly from the unit Gaussian to project the input features down to 10 dimensions, and to emphasize the influence of the temperature on training we reduce the training set size to

. We also found it was necessary to include additional regularisation to ensure the samples converge to a stationary distribution. We therefore introduce L2-regularization with regularization coefficient . We increase the Fisher damping to to ensure stability, while the learning rate . In Figure 4a, we plot the test set accuracy of NGD as a function of the batch size used to estimate the gradient. The Fisher is estimated over separate batches of images, using a single sampled label per example. For each batch size, we perform 500 gradient updates to burn in, before sampling the parameters over a further 500 updates. We plot both the performance of the ensemble average across the 500 samples, as well as the mean accuracy of a single sample. When training with an ensemble, the accuracy increases rapidly as we increase the batch size until we reach , at which the temperature , while above this temperature the accuracy drops333We note that the optimum temperature may depend on the damping coefficient or the regulariser, although we would usually expect this dependence to be relatively weak.. In Figure 4b, we exhibit the test set cross-entropy. While the mean test cross-entropy of a single sample always rises as we reduce the batch size, the ensemble test cross-entropy also shows a minimum at .

(a)
(b)
Figure 4: The test set accuracy (a) and test cross entropy (b) of NGD as a function of batch size , when training an MLP on randomly projected input features from MNIST with . We also provide the batch size at which the temperature

. We note that the peak test accuracy is very low, since this task is substantially harder than classifying a complete MNIST image.

(a)
(b)
Figure 5: The test cross entropy of NGD as a function of batch size , when training an MLP on randomly projected input features from MNIST. a) The training set size . The cross-entropy of the ensemble shows a minimum at , for which . b) The training set size . As the training set size increases, the minimum near is increasingly flat.

In Figure 5a we exhibit the test cross entropy for the same model when we increase the training set size to . Now the cross entropy of the ensemble shows a minimum at an increased batch size of . Once again, this matches the Bayesian prediction. Notice that this leads to a linear scaling rule between the batch size and the training set size () as was previously observed for SGD by Smith & Le (2018). In Figure 5b, we further increase the training set size to . The minimum in the test cross entropy becomes increasingly flat as the training set grows.

(a)
(b)
Figure 6: The test cross entropy of NGD as a function of temperature , when training a CNN on MNIST. a) The training set size , and we plot the test cross entropy as a function of temperature , for a range of learning rates . b) The learning rate , and we plot the test cross entropy as a function of temperature for a range of training set sizes .

6.3 Stochastic NGD and a CNN

In Figure 6, we apply NGD to train a CNN on MNIST. Our model is comprised of 3 convolutional layers and 1 fully connected softmax layer. Each convolutional layer has a

spatial window, 10 output channels and a stride of

. This results in a total of trainable parameters ( for the convolutional layers and for the final fully connected layer). We introduce L2 regularization and set . The Fisher information is estimated during training using independent minibatches of 1024 examples, using a single sampled label per example. To reduce variance, we store a moving average of the Fisher over previous updates, and we set the smoothing coefficient of the Fisher moving average, . The Fisher damping . To reduce the burn in time, we initialise the weights at the start of training from the final parameters of a single SGD training run on the full MNIST training set. We average our predictions over ensembles sampled from the final of gradient updates at the end of training, and we perform gradient updates per training run.

In Figure 6a, we plot the test cross entropy when the training set size, , for a range of learning rates . We set the temperature by choosing the batch size . As expected, the test cross entropy at constant temperature is largely independent of the learning rate, indicating that stochastic NGD obeys a linear scaling rule between batch size and learning rate, when , as already observed by many authors for SGD (Goyal et al., 2017; Smith et al., 2018; Balles et al., 2016). The test cross entropy for all learning rates rises rapidly once .

In Figure 6b, we exhibit the test cross entropy when for a range of training set sizes. The test cross entropy rises as the size of the training set falls, but we consistently observe that the test cross entropy begins to increase rapidly once . We note however that the minimum does exhibit a weak shift towards smaller temperatures for smaller training sets. This may reflect the breakdown of the approximations behind Theorem 4, since the samples from minibatch NGD will lie increasingly far from the local minimum as the training set size is reduced, and additionally the model predictions will be increasingly far from the true conditional distribution between inputs and labels.

7 Conclusions

We prove that, if the model predictions on the training set approach the true conditional distribution between inputs and labels, the stationary distribution of NGD approaches a Bayesian posterior near local minima as the learning rate , at a temperature , where denotes the training set size and the batch size. To confirm our claims, we demonstrate that samples from NGD at are close to the Laplace approximation to the Bayesian posterior. We also find that the test cross entropy of ensembles sampled from NGD are minimised when . This confirms minibatch noise can improve the generalization performance of NGD, and suggests Bayesian principles may help predict the optimal batch size during training. Furthermore, we propose a novel algorithm, “stochastic natural gradient descent", which draws parameterization invariant posterior samples with minibatch gradients by introducing the Jeffreys prior to the stationary distribution.

Acknowledgements

We thank Martin Abadi, Matthew Johnson, Matt Hoffman, Roger Grosse, Yasaman Bahri and Alex Botev for helpful feedback, and Yann Ollivier for gently informing us of errors in our original proof.

References

Appendix

In what follows and denotes . The stochastic differential equation (SDE),

(15)

is intended to describe the ordinary differential equation,

(16)

perturbed by an uncorrelated Gaussian noise source with mean and variance . Equivalently, one could choose to define Equation 15 as,

(17)

where is randomly sampled from the unit Gaussian. However it is important to note that, while Equation 16 uniquely defines the evolution of an ordinary differential equation, Equations 15 or 17 do not uniquely define a stochastic differential equation. To understand this discrepancy, recall that there are many different numerical methods to approximate solutions to Equation 16 (“Euler”, “Euler-Heun”, etc.). For an ordinary differential equation, all of these numerical methods converge to the same evolution for as the step-size

. However, different numerical methods may converge to different probability distributions

when applied to Equation 17.

Therefore, in order to describe a stochastic process, we must define a stochastic differential equation, in the form of Equation 17, and also specify an “interpretation” of that SDE. This interpretation tells us which numerical methods can be used. The two most common interpretations are the Ito interpretation, which uses the Euler method, and the Stratonovich interpretation, which uses the Euler-Heun method. We will adopt the Ito interpretation, which can be discretized as follows,

(18)
(19)

For clarity, we have introduced the learning rate . The evolution of the probability distribution under the Ito interpretation is governed by the following Fokker-Planck equation,

(20)

If there is a probability density for which , then this defines the stationary distribution, and under mild conditions we expect .

In the Ito interpretation, Brownian motion on a Riemannian manifold is defined by the SDE,

(21)

where is positive definite. Applying the Euler method, we obtain the update rule,

(22)

This is equivalent to Equation 3 from the main text at temperature , if we redefine . Note that to obtain Equation 22, we applied the identity,

(23)

One can confirm by direct substitution into Equation 20 that the stationary distribution of Equation 22 satisfies . Finally we generalise this result to arbitrary temperatures by setting and , thus proving Theorem 2444We note that the earlier derivation of Equation 22 by Girolami & Calderhead (2011) differs by a factor of 2. They claim their scheme samples the uniform prior, although we have confirmed empirically this is not the case..