Rethinking the limiting dynamics of SGD: modified loss, phase space oscillations, and anomalous diffusion

07/19/2021 ∙ by Daniel Kunin, et al. ∙ Stanford University 0

In this work we explore the limiting dynamics of deep neural networks trained with stochastic gradient descent (SGD). We find empirically that long after performance has converged, networks continue to move through parameter space by a process of anomalous diffusion in which distance travelled grows as a power law in the number of gradient updates with a nontrivial exponent. We reveal an intricate interaction between the hyperparameters of optimization, the structure in the gradient noise, and the Hessian matrix at the end of training that explains this anomalous diffusion. To build this understanding, we first derive a continuous-time model for SGD with finite learning rates and batch sizes as an underdamped Langevin equation. We study this equation in the setting of linear regression, where we can derive exact, analytic expressions for the phase space dynamics of the parameters and their instantaneous velocities from initialization to stationarity. Using the Fokker-Planck equation, we show that the key ingredient driving these dynamics is not the original training loss, but rather the combination of a modified loss, which implicitly regularizes the velocity, and probability currents, which cause oscillations in phase space. We identify qualitative and quantitative predictions of this theory in the dynamics of a ResNet-18 model trained on ImageNet. Through the lens of statistical physics, we uncover a mechanistic origin for the anomalous limiting dynamics of deep neural networks trained with SGD.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have demonstrated incredible generalization abilities across a variety of datasets and tasks. Essential to their success has been a collection of good practices and procedures on how to train these models with stochastic gradient descent (SGD). Yet, despite their importance, these procedures are mainly based on heuristic arguments and trial and error search. Without a general theory connecting the hyperparameters of optimization, the architecture of the network, and the geometry of the dataset, theory-driven design of deep learning systems is impossible. Existing theoretical works studying this interaction have leveraged the random structure of neural networks at initialization

[26, 32, 20] and in their infinite width limits in order to study their dynamics [10, 21, 36]. Here we take a different approach and study the training dynamics of pre-trained networks that are ready to be used for inference. By leveraging the mathematical structures found at the end of training, we uncover an intricate interaction between the hyperparameters of optimization, the structure in the gradient noise, and the Hessian matrix at the end of training. Understanding the limiting dynamics of SGD is a critical stepping stone to building a complete theory for the learning dynamics of neural networks. Combining empirical exploration and theoretical tools from statistical physics, we identify and uncover a mechanistic origin for the limiting dynamics of neural networks trained with SGD. Our main contributions are:

  1. We demonstrate empirically that long after performance has converged, networks continue to move through parameter space by a process of anomalous diffusion where distance travelled grows as a power law in the number of steps with a nontrivial exponent (section  2).

  2. To understand this empirical behavior, we derive a continuous-time model for SGD as an underdamped Langevin equation, accounting for the discretization error due to finite learning rates and gradient noise introduced by stochastic batches (section  3).

  3. We show that for linear regression, these dynamics give rise to an Ornstein-Uhlenbeck process whose moments can be derived analytically as the sum of damped harmonic oscillators in the eigenbasis of the data (section  

    4).

  4. We prove via the Fokker-Planck equation that the stationary distribution for this process is a Gibbs distribution on a modified (not the original) loss, which breaks detailed balance and gives rise to non-zero probability currents in phase space (section  5).

  5. We demonstrate empirically that the limiting dynamics of a ResNet-18 model trained on ImageNet display these qualitative characteristics – no matter how anisotropic the original training loss, the limiting trajectory of the network will behave isotropically (section  6).

  6. We derive theoretical expressions for the influence of the learning rate, batch size, and momentum coefficient on the limiting instantaneous speed of the network and the anomalous diffusion exponent, which quantitatively match empirics exactly (section  7).

2 Diffusive Behavior in the Limiting Dynamics of SGD

Figure 1: Despite performance convergence, the network continues to move through parameter space. We plot the squared Euclidean norm for the local and global displacement ( and

) of five classic convolutional neural network architectures. The networks are standard Pytorch models pre-trained on ImageNet

[28]

. Their training is resumed for 10 additional epochs. We show the global displacement on a log-log plot where the slope of the least squares line

is the exponent of the power law . See appendix H for experimental details.

Contrary to common intuition, a network that has converged in performance will still continue to move through parameter space. To demonstrate this behavior, we resume training of pre-trained convolutional networks while tracking the network trajectory through parameter space. Let

be the parameter vector for a pre-trained network and

be the parameter vector after steps of resumed training. We track two metrics of the training trajectory, namely the local parameter displacement between consecutive steps, and the global displacement after steps from the pre-trained initialization:

(1)

Surprisingly, as shown in Fig. 1, neither of these differences converge to zero across a variety of architectures, indicating that despite performance convergence, the networks continue to move through parameter space, both locally and globally. The squared norm of the local displacement remains near a constant value, indicating the network is moving at a fixed instantaneous speed. However, the squared norm of the global displacement is monotonically growing for all networks, implying even once trained, the network continues to diverge from where it has been. Indeed Fig. 1 indicates a power law relationship between global displacement and number of steps, given by . As we’ll see in section 7, this relationship is indicative of a diffusive process, where the particular value of corresponds to the diffusion exponent. Standard Brownian motion corresponds to . These empirical results demonstrate that in the terminal stage of training, neural networks continue to move through parameter space away from their earlier high performance states. These observations raise the natural questions, where is the network moving to and why? To answer these questions we will build a diffusion based theory of SGD, study these dynamics in the setting of linear regression, and use lessons learned in this fundamental setting to understand the limiting dynamics of neural networks.

3 Modeling SGD as an Underdamped Langevin Equation

Following the route of previous works [25, 11, 4] studying the limiting dynamics of neural networks, we first seek to model SGD as a continuous stochastic process. We consider a network parameterized by , a training dataset of size , and a training loss with corresponding gradient . The state of the network at the step of training is defined by the position vector and velocity vector of the same dimension. The gradient descent update with learning rate , momentum constant , and weight decay rate is given by the system of equations

(2)

where we initialize the network such that and is the parameter initialization. In order to understand the dynamics of the network through position and velocity space, which we will refer to as phase space

, we express these discrete recursive equations as the discretization of some unknown ordinary differential equation (ODE). By incorporating a previous time step

, we can rearrange the two update equations into the finite difference discretization,

(3)

Forward and backward Euler discretizations are explicit and implicit discretizations respectively of the first order temporal derivative . Naively replacing the discretizations with the derivative would generate an inaccurate first-order model for the discrete process. Like all discretizations, the Euler discretizations introduce higher-order error terms proportional to the step size, which in this case are proportional to . These second-order error terms are commonly referred to as artificial diffusion, as they are not part of the original first-order ODE being discretized, but introduced by the discretization process. Incorporating the artificial diffusion terms into the first-order ODE, we get a second-order ODE, sometimes referred to as a modified equation as in [16, 18], describing the dynamics of gradient descent

While this second-order ODE models the gradient descent process, even at finite learning rates, it fails to account for the stochasticity introduced by choosing a random batch of size drawn uniformly (with or without replacement) from the set of training points. This sampling yields the stochastic gradient . To model this effect, we make the following assumption:

Assumption 1 (Central Limit Theorem).

We assume the batch gradient is a noisy version of the true gradient such that

is a Gaussian random variable with mean

and covariance .

Incorporating this model of stochastic gradients into the previous finite difference equation and applying the stochastic counterparts to Euler discretizations, results in the second-order stochastic differential equation (SDE),

where represents a fluctuating force. An equation of this form is commonly referred to as an underdamped Langevin equation and has a natural physical interpretation as the equation of motion for a particle moving in a potential field with a fluctuating force. In particular, the mass of the particle is , the friction constant is , the potential is the regularized training loss , and the fluctuating force is introduced by the gradient noise. While this form for the dynamics provides useful intuition, we must expand back into phase space in order to write the equation in the standard drift-diffusion form for an SDE,

(4)

where is a standard Wiener process. This is the continuous model we will study in this work:

Assumption 2 (Stochastic Differential Equation).

We assume the underdamped Langevin equation (4) accurately models the trajectory of the network driven by SGD through phase space.

See appendix A for further discussion on the nuances of modeling SGD with an SDE.

4 Linear Regression with SGD leads to an Ornstein-Uhlenbeck Process

Equipped with a model for SGD, we now seek to understand its dynamics in the fundamental setting of linear regression, one of the few cases where we have a complete model for the interaction of the dataset, architecture, and optimizer. Let be the input data, be the output labels, and be our vector of regression coefficients. The least squares loss is the convex quadratic loss with gradient , where and . Plugging this expression for the gradient into the underdamped Langevin equation (4), and rearranging terms, results in the multivariate Ornstein-Uhlenbeck (OU) process,

(5)

where and are the drift and diffusion matrices respectively,

(6)

is a temperature constant, and

is the ridge regression solution. The solution to an OU process is a Gaussian process. By solving for the temporal dynamics of the first and second moments of the process, we can obtain an analytic expression for the trajectory at any time

. In particular, we can decompose the trajectory as the sum of a deterministic and stochastic component defined by the first and second moments respectively. In order to gain analytic expressions for the moments in terms of just the hyperparameters of optimization and the dataset, we introduce the following assumption relating the gradient noise with the Hessian :

Assumption 3 (Covariance Structure).

We assume the covariance of the gradient noise is proportional to the Hessian where is some unknown scalar.

In fact, in the setting of linear regression, this is a very natural assumption. If we assume the classic generative model for the data , where is the true model and , then provably near the minimum. See appendix B for a derivation and further discussion.

Figure 2: Oscillatory dynamics in linear regression.

We train a linear network to perform regression on the CIFAR-10 dataset by using an MSE loss on the one-hot encoding of the labels. We compute the hessian of the loss, as well as its top eigenvectors. The position and velocity trajectories are projected onto the first eigenvector of the hessian and visualized in black. The theoretically derived mean, equation (

7), is shown in red. The top and bottom panels demonstrate the effect of varying momentum on the oscillation mode.

Deterministic component. Using the form of in equation (6) we can decompose the expectation as a sum of harmonic oscillators in the eigenbasis of the Hessian,

(7)

Here the coefficients and depend on the optimization hyperparameters

and the respective eigenvalue of the Hessian

as further explained in appendix F. We verify this expression nearly perfectly matches empirics on complex datasets under various hyperparameter settings as shown in Fig. 2.

Stochastic component. The cross-covariance of the process between two points in time , is

(8)

where solves the Lyapunov equation . We can explicitly solve for as shown in appendix F

in terms of the optimization hyperparameters, eigendecomposition of the Hessian, and variance

of the intrinsic noise in the generative model.

Stationary solution. In the limit as , and assuming , then the process approaches a stationary solution,

(9)

with stationary cross-covariance .

5 Understanding the Stationary Solution via the Fokker-Planck Equation

Figure 3: An anisotropic OU process is driven by a modified loss. We sample from an OU process with anisotropic diffusion and plot the trajectory (same black line on both plots). The left plot shows the original loss generating the drift. The right plot shows the modified loss . Notice the trajectory more closely resembles the curvature of than . The grey lines depict the stationary probability current .

The OU process is unique in that it is one of the few SDEs which we can solve exactly. As shown in section 4, we were able to derive exact expressions for the dynamics of linear regression trained with SGD from initialization to stationarity by simply solving for the first and second moments. While the expression for the first moment provides an understanding of the intricate oscillatory relationship in the deterministic component of the process, the second moment, driving the stochastic component, is much more opaque. An alternative route to solving the OU process that potentially provides more insight is the Fokker-Planck equation.

The Fokker-Planck (FP) equation is a PDE describing the time evolution for the probability distribution of a particle governed by Langevin dynamics. For an arbitrary potential

and diffusion matrix , the Fokker-Planck equation is

(10)

where represents the time-dependent probability distribution, and is a vector field commonly referred to as the probability current. The FP equation is especially useful for explicitly solving for the stationary solution, assuming one exists, of the Langevin dynamics. The stationary solution is the distribution such that or equivalently . From this second definition we see that there are two distinct settings of stationarity: detailed balance when everywhere, or broken detailed balance when and .

For a general OU process, the potential is a convex quadratic function defined by the drift matrix . When the diffusion matrix is isotropic () and spatially independent () the resulting stationary solution is a Gibbs distribution determined by the original loss and is in detailed balance. Lesser known, somewhat surprising, properties of the OU process arise when the diffusion matrix is anisotropic or spatially dependent. In this setting the solution is still a Gaussian process, but the stationary solution, if it exists, is no longer defined by the Gibbs distribution of the original loss , but actually a modified loss . Furthermore, the stationary solution may be in broken detailed balance leading to a non-zero probability current . Depending on the relationship between the drift matrix and the diffusion matrix the resulting dynamics of the OU process can have very nontrivial behavior, as shown in Fig. 3.

In the setting of linear regression, anisotropy in the data distribution will lead to anisotropy in the gradient noise and thus an anisotropic diffusion matrix. This implies that for most datasets we should expect that the SGD trajectory is not driven by the original least squares loss, but by a modified loss and converges to a stationary solution with broken detailed balance. Using the explicit expressions for the drift and diffusion matrices in equation (6) we can express analytically the modified loss and stationary probability current,

(11)

where

is a skew-symmetric matrix and

is a positive definite matrix defined as,

(12)

These new fundamental matrices, and , relate to the original drift and diffusion matrices through the unique decomposition , first introduced by Kwon et al. [19]. Using this decomposition we can easily show that solves the Lyapunov equation and indeed the stationary solution described in equation (9) is the Gibbs distribution defined by the modified loss in equation (11). Further, the stationary cross-covariance solved in section 4 reflects the oscillatory dynamics introduced by the stationary probability currents in equation (11). Taken together, we gain the intuition that the limiting dynamics of SGD in linear regression are driven by a modified loss subject to oscillatory probability currents.

6 Qualitative Evidence of a Modified Loss and Phase Space Oscillations

Does the theory derived in the linear regression setting (sections 345) help explain the empirical phenomenona observed in the non-linear setting of deep neural networks (section 2)? In order for the theory built in the previous sections to apply to the limiting dynamics of neural networks, we must introduce an additional assumption on the loss landscape at the end of training:

Assumption 4 (Quadratic Loss).

We assume that at the end of training the loss for a neural network can be approximated by the quadratic loss , where is the training loss Hessian and is some unknown mean vector, corresponding to a local minimum.

Incorporating this assumption with the previous three111Again, we assume the covariance of the gradient noise is proportional to the Hessian of the quadratic loss., then the expressions we derived in the linear regression setting would apply directly. Of course, both these assumptions are quite strong, but without mathematically proving their theoretical validity, we can empirically test their qualitative implications: (1) a modified isotropic loss driving the limiting dynamics through parameter space, (2) implicit regularization of the velocity trajectory, and (3) oscillatory dynamics in phase space determined by the eigendecomposition of the Hessian.

Modified loss. As discussed in section 5, due to the anisotropy of the diffusion matrix, the loss landscape driving the dynamics at the end of training is not the original training loss , but a modified loss in phase space. As shown in equation (11), the modified loss decouples into a term that only depends on the parameters and a term that only depends on the velocities . The parameter dependent component is proportional to the convex quadratic,

(13)

This quadratic function has the same mean as the training loss, but a different curvature. Using this expression, notice that when , the modified loss is isotropic in the column space of , regardless of what the nonzero eigenspectrum of is. This striking prediction suggests that no matter how anisotropic the original training loss – as reflected by poor conditioning of the Hessian eigenspectrum – the training trajectory of the network will behave isotropically, since it is driven not by the original anisotropic loss, but a modified isotropic loss.

Figure 4: The training trajectory behaves isotropically, regardless of the training loss. We resume training of a pre-trained ResNet-18 model on ImageNet and project its parameter trajectory (black line) onto the space spanned by the eigenvectors of its pre-trained Hessian (with eigenvalue ratio ). We sample the training and test loss within the same 2D subspace and visualize them as a heatmap in the left and center panels respectively. We visualize the modified loss computed from the eigenvalues (, ) and optimization hyperparameters according to equation (13) in the right plot. Notice how the projected trajectory is isotropic, despite the anisotropy of the training and test loss. As predicted, the limiting dynamics are driven by the isotropic modified loss.

We test this prediction by studying the limiting dynamics of a pre-trained ResNet-18 model with batch normalization that we continue to train on ImageNet according to the last setting of its hyperparameters

[8]. Let represent the initial pre-trained parameters of the network, depicted with the white dot in figures 4 and 5

. We estimate

222To estimate the eigenvectors of we use subspace iteration. See appendix H for details. the top thirty eigenvectors of the Hessian matrix evaluated at and project the limiting trajectory for the parameters onto the plane spanned by the top and bottom eigenvectors. We sample the train and test loss in this subspace for a region around the projected trajectory. Additionally, using the hyperparameters of the optimization, the eigenvalues and , and the estimate for the mean ( is the gradient evaluated at ), we also sample from the modified loss equation (13) in the same region. Figure 4 shows the projected parameter trajectory on the sampled train, test and modified losses. Contour lines of both the train and test loss exhibit anisotropic structure, with sharper curvature along eigenvector compared to eigenvector , as expected. However, as predicted, the trajectory appears to cover both directions equally. This striking isotropy of the trajectory within a highly anisotropic slice of the loss landscape is exactly predicted by our theory, which reveals that the trajectory evolves in a modified isotropic loss landscape.

Figure 5: Implicit velocity regularization defined by the inverse Hessian. The shape of the projected velocity trajectory closely resembles the contours of the modified loss equation (14).

Implicit velocity regularization. A second qualitative prediction of the theory is that the velocity is regulated by the inverse Hessian of the training loss. Of course there are no explicit terms in either the train or test losses that depend on the velocity. Yet, the modified loss contains a component, , that only depends on the velocities,

(14)

This additional term can be understood as a form of implicit regularization on the velocity trajectory. Indeed, when we project the velocity trajectory onto the plane spanned by the and eigenvectors, as shown in Fig. 5, we see that the trajectory closely resembles the curvature of the inverse Hessian . The modified loss is effectively penalizing SGD for moving in eigenvectors of the Hessian with small eigenvalues. A similar qualitative effect was recently proposed by Barrett and Dherin [2] as a consequence of the discretization error due to finite learning rates.

Figure 6: Phase space oscillations are determined by the eigendecomposition of the Hessian. We visualize the projected position and velocity trajectories in phase space. The top and bottom panels show the projections onto and respectively. As predicted, oscillations at different rates are distinguishable for the different eigenvectors.

Phase space oscillations. A final implication of the theory is that at stationarity the network is in broken detailed balance leading to non-zero probability currents flowing through phase space:

(15)

These probability currents generate oscillatory dynamics in the phase space planes characterized by the eigenvectors of the Hessian, at rates proportional to their eigenvalues. We consider the same projected trajectory of the ResNet-18 model visualized in figures 4 and  5, but plot the trajectory in phase space for the two eigenvectors and separately. Shown in Fig. 6, we see that both trajectories look like noisy clockwise rotations. As predicted, the spiral appears more evident and at a faster rate for the top eigenvector than the bottom eigenvector .

The integral curves of the stationary probability current are one-dimensional paths confined to level sets of the modified loss. These paths might cross themselves, in which case they are limit cycles, or they could cover the entire surface of the level sets, in which case they are space-filling curves. This distinction depends on the relative frequencies of the oscillations, as determined by the pairwise ratios of the eigenvalues of the Hessian. For real-world datasets, with a large spectrum of incommensurate frequencies, we should expect to be in the latter setting, which deepens the insight provided by Chaudhari and Soatto [4].

7 Predicting the Diffusive Nature of Neural Networks at the End of Training

Taken together the empirical results shown in section 6 indicate that many of the same qualitative behaviors of SGD identified theoretically for linear regression are evident in the limiting dynamics of neural networks. But what does this alignment of theory and empirics imply for the learning dynamics? Can this theory quantitatively explain the results we identified in section 2?

Constant instantaneous speed. As noted in section 2, we observed that at the end of training, across various architectures, the squared norm of the local displacement remains constant. Assuming the limiting dynamics are described by the stationary solution in equation (9), the expectation of the local displacement is

(16)

as derived in appendix G. We cannot test this prediction directly as we do not know and computing is computationally prohibitive. However, we can estimate by resuming training for a model, measuring the average , and then inverting equation (16). Using this single estimate, we find that for a sweep of models with varying hyperparameters, equation (16) accurately predicts their instantaneous speed. Indeed, Fig. 7 shows a close match between the empirics and theory, which strongly suggests that despite changing hyperparameters at the end of training, the model remains in the same quadratic basin.

Figure 7: Understanding how the hyperparameters of optimization influence the diffusion. We resume training of pre-trained ResNet-18 models on ImageNet using a range of different learning rates, batch sizes, and momentum coefficients, tracking and . Starting from the default hyperparameters, namely , , and , we vary each one while keeping the others fixed. The top row shows the measured in color, with the default hyperparameter setting highlighted in black. The dotted lines depict the predicted value from equation (16). The bottom row shows the estimated exponent found by fitting a line to the trajectories in a log-log plot.

Exponent of anomalous diffusion. The expected value for the global displacement under the stationary solution can also be analytically expressed in terms of the optimization hyperparameters and the eigendecomposition of the Hessian as,

(17)

where is a trigonomentric function describing the velocity of a harmonic oscillator with damping ratio , see appendix G for details. As shown empirically in section 2, the squared norm monotonically increases as a power law in the number of steps, suggesting its expectation is proportional to for some unknown, constant . The exponent determines the regime of diffusion for the process. When , the process corresponds to standard Brownian diffusion. For or the process corresponds to anomalous super-diffusion or sub-diffusion respectively. Unfortunately, it is not immediately clear how to extract the explicit exponent from equation (17). However, by exploring the functional form of and its relationship to the hyperparameters of optimization through the damping ratio , we can predict overall trends in the diffusion exponent .

Akin to how the exponent determines the regime of diffusion, the damping ratio determines the regime for the harmonic oscillator describing the stationary velocity-velocity correlation in the eigenvector of the Hessian. When , the oscillator is critically damped implying the velocity correlations converge to zero as quickly as possible. In the extreme setting of for all , then equation (17) simplifies to standard Brownian diffusion, . When , the oscillator is overdamped implying the velocity correlations dampen slowly and remain positive even over long temporal ranges. In the extreme setting of for all , then equation (17) simplifies to a form of anomalous super-diffusion, . When , the oscillator is underdamped implying the velocity correlations will oscillate between positive and negative values. Indeed, the only way equation (17) could describe anomalous sub-diffusion is if took on negative values for certain .

Using the same sweep of models described above, we can empirically confirm that the optimization hyperparameters each influence the diffusion exponent . As shown in Fig. 7, the learning rate, batch size, and momentum can each independently drive the exponent into different regimes of anomalous diffusion. Notice how the influence of the learning rate and momentum on the diffusion exponent closely resembles their respective influences on the damping ratio . Contrary to intuition, a larger learning rate leads to underdamped oscillations that reduce the exponent of diffusion. The batch size on the other hand, has no influence on the damping ratio, but leads to an interesting, non-monotonic influence on the diffusion exponent.

8 Related Work

Here we contextualize our analysis and experiments among relevant works studying the limiting dynamics of SGD in deep learning.

Continuous models for SGD. Many works have considered how to improve the classic gradient flow model for SGD to more realistically reflect the heavy-ball effect of momentum [30], the discretization effect due to finite learning rates [18, 2], and the stochasticity of random batches [22, 34]. However, recently a series of works have questioned the validity of using an SDE to model SGD. The main argument, as nicely explained in Yaida [38], is that most SDE approximations simultaneously assume that , while maintaining that the learning rate

is finite. Other works have questioned the correctness of the using the central limit theorem to model the gradient noise as Gaussian

[23], arguing that the weak dependence between batches and heavy-tailed structure in the gradient noise leads the CLT to break down. In our work, we maintain that the CLT assumption holds, which we discuss further in appendix A, but importantly avoid the pitfalls of many previous SDE approximations by simultaneously modeling the effect of finite learning rates and stochasticity.

Limiting dynamics. A series of works have applied SDE models of SGD to study the limiting dynamics of neural networks. In the seminal work by Mandt et al. [25], the limiting dynamics were modeled with a multivariate Ornstein-Uhlenbeck process by combining the naive first-order SDE model for SGD with assumptions on the geometry of the loss and covariance matrix for the gradient noise. This analysis was extended by Jastrzębski et al. [11] through additional assumptions on the covariance matrix to gain tractable insights and applied by Ali et al. [1] to the simpler setting of linear regression where the loss is definitionally quadratic. A different approach was taken by Chaudhari and Soatto [4], which did not formulate the dynamics as an OU process, nor assume directly a structure on the loss or gradient noise. Rather, this analysis studied the same first-order SDE via the Fokker-Planck equation to propose the existence of a modified loss and probability currents driving the limiting dynamics, but did not provide explicit expressions. Our analysis deepens and combines ideas from all these works, where our key insight is to lift the dynamics into phase space. By studying the dynamics of the parameters and their velocities, and by applying the analysis first in the setting of linear regression where assumptions are provably true, we are able to identify analytic expressions and explicit insights which lead to concrete predictions and testable hypothesis.

Empirical exploration. Another set of works analyzing the limiting dynamics of SGD has taken a purely empirical approach. Building on the intuition that flat minima generalize better than sharp minima, Keskar et al. [15] demonstrated empirically that the hyperparameters of optimization influence the eigenvalue spectrum of the Hessian matrix at the end of training. Many subsequent works have thoroughly studied the eigendecomposition of the Hessian during and at the end of training. Jastrzębski et al. [12], Cohen et al. [5] studied the dynamics of the top eigenvalues during training, Papyan [27] demonstrated the spectrum is composed of a bulk of values near zero and a smaller number of important spike values, and Gur-Ari et al. [7] demonstrated that the learning dynamics are constrained to the subspace spanned by the top eigenvectors, but found no special properties of the dynamics within this subspace. In our work we also determine that the top eigensubspace of the Hessian plays a crucial role in the limiting dynamics and by projecting the dynamics into this subspace in phase space, we see that the motion is not random, but rather oscillatory. Furthermore, by measuring simple metrics of motion, we identified that, at the end of training, the network is moving at constant instantaneous speed and anomalously diffusing through parameter space.

Hyperparameter schedules and algorithm development. Lastly, a set of works have used theoretical and empirical insights of the limiting dynamics to construct hyperparameter schedules and algorithms to improve performance. Most famously, the linear scaling rule, derived by Krizhevsky [17] and Goyal et al. [6], relates the influence of the batch size and the learning rate, facilitating stable training with increased parallelism. This relationship was extended by Smith et al. [35] to account for the effect of momentum. Lucas et al. [24] proposed a variant of SGD with multiple velocity buffers to increase stability in the presence of a nonuniform Hessian spectrum. Chaudhari et al. [3] introduced a variant of SGD guided by local entropy to bias the dynamics into wider minima. Izmailov et al. [9] demonstrated how a simple algorithm of stochastically averaging samples from the limiting dynamics of a network can improve generalization performance. While algorithm development is not the focus of our work, we believe that our careful and precise understanding of the deep learning limiting dynamics will similarly provide insight for future work in this direction.

9 Discussion

In this work, through empirical exploration and theoretical tools from statistical physics, we uncovered an intricate interaction between the hyperparameters of optimization, the structure in the gradient noise, and the Hessian matrix at the end of training. In many ways this empirical and theoretical understanding challenges us to rethink the limiting dynamics of deep neural networks. Natural intuitions, such as “the network converges in parameter space” or “the network stays within a local region”, are wrong. The expectation that the training trajectory would explore the underlying anisotropy of the training loss driving the dynamics is also wrong. Even the belief that the limiting dynamics are simply diffusing in detailed balance is wrong. So how should we interpret these theoretical insights? Should we consider the stochasticity in SGD as a form of adaptivity, removing pathological curvature by introducing a modified isotropic loss? Or should we consider this effect a form of implicit regularization, explaining the generalization ability of these networks despite massive overparameterization? Or maybe we should take a Bayesian perspective and consider the existence of probability currents as a mechanism for efficient posterior sampling? Our careful and precise understanding of the limiting dynamics of SGD will provide many directions for future work and analysis. By leveraging the mathematical structures found at the end of training, we challenge us to rethink our understanding of deep learning dynamics.

Acknowledgments

We thank Jing An, Pratik Chaudhari, Ekdeep Singh, Ben Sorscher and Atsushi Yamamura for helpful discussions. This work was funded in part by the IBM-Watson AI Lab. D.K. thanks the Stanford Data Science Scholars program and NTT Research for support. J.S. thanks the Mexican National Council of Science and Technology (CONACYT) for support. S.G. thanks the James S. McDonnell and Simons Foundations, NTT Research, and an NSF CAREER Award for support while at Stanford. D.L.K.Y thanks the McDonnell Foundation (Understanding Human Cognition Award Grant No. 220020469), the Simons Foundation (Collaboration on the Global Brain Grant No. 543061), the Sloan Foundation (Fellowship FG-2018-10963), the National Science Foundation (RI 1703161 and CAREER Award 1844724), and the DARPA Machine Common Sense program for support and the NVIDIA Corporation for hardware donations.

Contributions

D.K. developed the theory and wrote the manuscript. J.S. ran the experiments and edited the manuscript. L.G. and E.M. ran initial experiments and edited the manuscript. H.T., S.G., and D.L.K.Y. advised throughout the work and edited the manuscript.

References

  • [1] A. Ali, E. Dobriban, and R. Tibshirani (2020) The implicit regularization of stochastic gradient flow for least squares. In

    International Conference on Machine Learning

    ,
    pp. 233–244. Cited by: Appendix B, §8.
  • [2] D. G. Barrett and B. Dherin (2020) Implicit gradient regularization. arXiv preprint arXiv:2009.11162. Cited by: §6, §8.
  • [3] P. Chaudhari, A. Choromanska, S. Soatto, Y. LeCun, C. Baldassi, C. Borgs, J. Chayes, L. Sagun, and R. Zecchina (2019) Entropy-sgd: biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment 2019 (12), pp. 124018. Cited by: §8.
  • [4] P. Chaudhari and S. Soatto (2018) Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks. In 2018 Information Theory and Applications Workshop (ITA), pp. 1–10. Cited by: §E.2, §E.2, §E.2, §E.2, §E.3, §3, §6, §8.
  • [5] J. M. Cohen, S. Kaur, Y. Li, J. Z. Kolter, and A. Talwalkar (2021) Gradient descent on neural networks typically occurs at the edge of stability. arXiv preprint arXiv:2103.00065. Cited by: §8.
  • [6] P. Goyal, P. Dollár, R. Girshick, P. Noordhuis, L. Wesolowski, A. Kyrola, A. Tulloch, Y. Jia, and K. He (2017) Accurate, large minibatch sgd: training imagenet in 1 hour. arXiv preprint arXiv:1706.02677. Cited by: §8.
  • [7] G. Gur-Ari, D. A. Roberts, and E. Dyer (2018) Gradient descent happens in a tiny subspace. arXiv preprint arXiv:1812.04754. Cited by: §8.
  • [8] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 770–778. Cited by: §6.
  • [9] P. Izmailov, D. Podoprikhin, T. Garipov, D. Vetrov, and A. G. Wilson (2018) Averaging weights leads to wider optima and better generalization. arXiv preprint arXiv:1803.05407. Cited by: §8.
  • [10] A. Jacot, F. Gabriel, and C. Hongler (2018) Neural tangent kernel: convergence and generalization in neural networks. arXiv preprint arXiv:1806.07572. Cited by: §1.
  • [11] S. Jastrzębski, Z. Kenton, D. Arpit, N. Ballas, A. Fischer, Y. Bengio, and A. Storkey (2017) Three factors influencing minima in sgd. arXiv preprint arXiv:1711.04623. Cited by: Appendix B, Appendix C, §3, §8.
  • [12] S. Jastrzębski, Z. Kenton, N. Ballas, A. Fischer, Y. Bengio, and A. Storkey (2018) On the relation between the sharpest directions of dnn loss and the sgd step length. arXiv preprint arXiv:1807.05031. Cited by: §8.
  • [13] R. Jordan, D. Kinderlehrer, and F. Otto (1997) Free energy and the fokker-planck equation. Physica D: Nonlinear Phenomena 107 (2-4), pp. 265–271. Cited by: §E.1, §E.1.
  • [14] R. Jordan, D. Kinderlehrer, and F. Otto (1998) The variational formulation of the fokker–planck equation. SIAM journal on mathematical analysis 29 (1), pp. 1–17. Cited by: §E.1, §E.1.
  • [15] N. S. Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang (2016) On large-batch training for deep learning: generalization gap and sharp minima. arXiv preprint arXiv:1609.04836. Cited by: §8.
  • [16] N. B. Kovachki and A. M. Stuart (2019) Analysis of momentum methods. arXiv preprint arXiv:1906.04285. Cited by: §3.
  • [17] A. Krizhevsky (2014) One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv:1404.5997. Cited by: §8.
  • [18] D. Kunin, J. Sagastuy-Brena, S. Ganguli, D. L. Yamins, and H. Tanaka (2020) Neural mechanics: symmetry and broken conservation laws in deep learning dynamics. arXiv preprint arXiv:2012.04728. Cited by: Appendix C, §3, §8.
  • [19] C. Kwon, P. Ao, and D. J. Thouless (2005) Structure of stochastic dynamics near fixed points. Proceedings of the National Academy of Sciences 102 (37), pp. 13029–13033. Cited by: §D.2, §D.2, §5.
  • [20] J. Lee, Y. Bahri, R. Novak, S. S. Schoenholz, J. Pennington, and J. Sohl-Dickstein (2017) Deep neural networks as gaussian processes. arXiv preprint arXiv:1711.00165. Cited by: §1.
  • [21] J. Lee, L. Xiao, S. S. Schoenholz, Y. Bahri, R. Novak, J. Sohl-Dickstein, and J. Pennington (2019) Wide neural networks of any depth evolve as linear models under gradient descent. arXiv preprint arXiv:1902.06720. Cited by: §1.
  • [22] Q. Li, C. Tai, and E. Weinan (2017) Stochastic modified equations and adaptive stochastic gradient algorithms. In International Conference on Machine Learning, pp. 2101–2110. Cited by: §A.2, §8.
  • [23] Z. Li, S. Malladi, and S. Arora (2021) On the validity of modeling sgd with stochastic differential equations (sdes). arXiv preprint arXiv:2102.12470. Cited by: §8.
  • [24] J. Lucas, S. Sun, R. Zemel, and R. Grosse (2018) Aggregated momentum: stability through passive damping. arXiv preprint arXiv:1804.00325. Cited by: §8.
  • [25] S. Mandt, M. Hoffman, and D. Blei (2016) A variational analysis of stochastic gradient algorithms. In International conference on machine learning, pp. 354–363. Cited by: Appendix C, §3, §8.
  • [26] R. M. Neal (1996) Priors for infinite networks. In Bayesian Learning for Neural Networks, pp. 29–53. Cited by: §1.
  • [27] V. Papyan (2018) The full spectrum of deepnet hessians at scale: dynamics with sgd training and sample size. arXiv preprint arXiv:1811.07062. Cited by: Appendix C, §8.
  • [28] A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer (2017) Automatic differentiation in pytorch. Neural Information Processing Systems Workshop. Cited by: §H.2, §H.5, §H.5, §H.6, Figure 1.
  • [29] T. Poggio, K. Kawaguchi, Q. Liao, B. Miranda, L. Rosasco, X. Boix, J. Hidary, and H. Mhaskar (2017) Theory of deep learning iii: explaining the non-overfitting puzzle. arXiv preprint arXiv:1801.00173. Cited by: Appendix C.
  • [30] N. Qian (1999) On the momentum term in gradient descent learning algorithms. Neural networks 12 (1), pp. 145–151. Cited by: §A.1, §8.
  • [31] L. Sagun, L. Bottou, and Y. LeCun (2016) Eigenvalues of the hessian in deep learning: singularity and beyond. arXiv preprint arXiv:1611.07476. Cited by: Appendix C.
  • [32] S. S. Schoenholz, J. Gilmer, S. Ganguli, and J. Sohl-Dickstein (2016) Deep information propagation. arXiv preprint arXiv:1611.01232. Cited by: §1.
  • [33] U. Simsekli, L. Sagun, and M. Gurbuzbalaban (2019) A tail-index analysis of stochastic gradient noise in deep neural networks. In International Conference on Machine Learning, pp. 5827–5837. Cited by: §A.2.
  • [34] S. L. Smith, B. Dherin, D. G. Barrett, and S. De (2021) On the origin of implicit regularization in stochastic gradient descent. arXiv preprint arXiv:2101.12176. Cited by: §8.
  • [35] S. L. Smith, P. Kindermans, C. Ying, and Q. V. Le (2017) Don’t decay the learning rate, increase the batch size. arXiv preprint arXiv:1711.00489. Cited by: §8.
  • [36] M. Song, A. Montanari, and P. Nguyen (2018) A mean field view of the landscape of two-layers neural networks. Proceedings of the National Academy of Sciences 115, pp. E7665–E7671. Cited by: §1.
  • [37] V. Thomas, F. Pedregosa, B. Merriënboer, P. Manzagol, Y. Bengio, and N. Le Roux (2020) On the interplay between noise and curvature and its effect on optimization and generalization. In

    International Conference on Artificial Intelligence and Statistics

    ,
    pp. 3503–3513. Cited by: Appendix C.
  • [38] S. Yaida (2018) Fluctuation-dissipation relations for stochastic gradient descent. arXiv preprint arXiv:1810.00004. Cited by: §A.2, §8.

Appendix A Modeling SGD with an SDE

As explained in section 3, in order to understand the dynamics of stochastic gradient descent we build a continuous Langevin equation in phase space modeling the effect of discrete updates and stochastic batches simultaneously.

a.1 Modeling Discretization

To model the discretization effect we assume that the system of update equations (2) is actually a discretization of some unknown ordinary differential equation. To uncover this ODE, we combine the two update equations in (2), by incorporating a previous time step , and rearrange into the form of a finite difference discretization, as shown in equation (3). Like all discretizations, the Euler discretizations introduce error terms proportional to the step size, which in this case is the learning rate . Taylor expanding and around , its easy to show that both Euler discretizations introduce a second-order error term proportional to .

Notice how the momentum coefficient regulates the amount of backward Euler incorporated into the discretization. When , we remove all backward Euler discretization leaving just the forward Euler discretization. When , we have equal amounts of backward Euler as forward Euler resulting in a central second-order discretization333The difference between a forward Euler and backward Euler discretization is a second-order central discretization, . as noticed in [30].

a.2 Modeling Stochasticity

In order to model the effect of stochastic batches, we first model a batch gradient with the following assumption:

Assumption 1 (Central Limit Theorem). We assume the batch gradient is a noisy version of the true gradient such that is a Gaussian random variable with mean and covariance .

The two conditions needed for the CLT to hold are not exactly met in the setting of SGD. Independent and identically distributed. Generally we perform SGD by making a complete pass through the entire dataset before using a sample again which introduces a weak dependence between samples. While the covariance matrix without replacement more accurately models the dependence between samples within a batch, it fails to account for the dependence between batches. Finite variance. A different line of work has questioned the Gaussian assumption entirely because of the need for finite variance random variables. This work instead suggests using the generalized central limit theorem implying the noise would be a heavy-tailed -stable random variable [33]. Thus, the previous assumption is implicitly assuming the i.i.d. and finite variance conditions apply for large enough datasets and small enough batches.

Under the CLT assumption, we must also replace the Euler discretizations with Euler–Maruyama discretizations. For a general stochastic process, , the Euler–Maruyama method extends the Euler method for ODEs to SDEs, resulting in the update equation , where . Notice, the key difference is that if the temporal step size is , then the noise is scaled by the square root . In fact, the main argument against modeling SGD with an SDE, as nicely explained in Yaida [38], is that most SDE approximations simultaneously assume that , while maintaining that the square root of the learning rate is finite. However, by modeling the discretization and stochastic effect simultaneously we can avoid this argument, bringing us to our second assumption:

Assumption 2 (Stochastic Differential Equation). We assume the underdamped Langevin equation (4) accurately models the trajectory of the network driven by SGD through phase space.

This approach of modeling discretization and stochasticity simultaneously is called stochastic modified equations, as further explained in Li et al. [22].

Appendix B Structure in the Covariance of the Gradient Noise

As we’ve mentioned before, SGD introduces highly structured noise into an optimization process, often assumed to be an essential ingredient for its ability to avoid local minima.

Assumption 3 (Covariance Structure). We assume the covariance of the gradient noise is proportional to the Hessian of the quadratic loss where is some unknown scalar.

In the setting of linear regression, this is a very natural assumption. If we assume the classic generative model for linear regression data where, is the true model and , then provably .

Proof.

We can estimate the covariance as . Near stationarity , and thus,

Under the generative model where and , then the gradient is

and the matrix is

Assuming at stationarity, then . Thus,

Also notice that weight decay is independent of the data or batch and thus simply shifts the gradient distribution, but leaves the covariance of the gradient noise unchanged. ∎

While the above analysis is in the linear regression setting, for deep neural networks it is reasonable to make the same assumption. See the appendix of Jastrzębski et al. [11] for a discussion on this assumption in the non-linear setting.

Recent work by Ali et al. [1] also studies the dynamics of SGD (without momentum) in the setting of linear regression. This work, while studying the classic first-order stochastic differential equation, made a point to not introduce an assumption on the diffusion matrix. In particular, they make the point that even in the setting of linear regression, a constant covariance matrix will fail to capture the actual dynamics. To illustrate this point they consider the univariate responseless least squares problem,

As they explain, the SGD update for this problem would be

from which they conclude for a small enough learning rate , then with probability one . They contrast this with the Ornstein-Uhlenbeck process given by a constant covariance matrix where while the mean for converges to zero its variance converges to a positive constant. So is this discrepancy evidence that an Ornstein-Uhlenbeck process with a constant covariance matrix fails to capture the updates of SGD? In many ways this problem is not a simple example, rather a pathological edge case. Consider the generative model that would give rise to this problem,

In otherwords, the true model

and the standard deviation for the noise

. This would imply by the assumption used in our paper that there would be zero diffusion and the resulting SDE would simplify to a deterministic ODE that exponentially converges to zero.

Appendix C A Quadratic Loss at the End of Training

Assumption 4 (Quadratic Loss). We assume that at the end of training the loss for a neural network can be approximated by the quadratic loss , where is the training loss Hessian and is some unknown mean vector, corresponding to a local minimum.

This assumption has been amply used in previous works such as Mandt et al. [25], Jastrzębski et al. [11], and Poggio et al. [29]. Particularly, Mandt et al. [25]

discuss how this assumption makes sense for smooth loss functions for which the stationary solution to the stochastic process reaches a deep local minimum from which it is difficult to escape.

It is a well-studied fact, both empirically and theoretically, that the Hessian is low-rank near local minima as noted by Sagun et al. [31], and Kunin et al. [18]. This degeneracy results in flat directions of equal loss. Kunin et al. [18] discuss how differentiable symmetries, architectural features that keep the loss constant under certain weight transformations, give rise to these flat directions. Importantly, the Hessian and the covariance matrix share the same null space, and thus we can always restrict ourselves to the image space of the Hessian, where the drift and diffusion matrix will be full rank. Further discussion on the relationship between the Hessian and the covariance matrix can be found in Thomas et al. [37].

It is also a well known empirical fact that even at the end of training the Hessian can have negative eigenvalues [27]

. This empirical observation is at odds with our assumption that the Hessian is positive semi-definite

. Further analysis is needed to alleviate this inconsistency.

Appendix D Solving an Ornstein-Uhlenbeck Process with Anisotropic Noise

We will study the multivariate Ornstein-Uhlenbeck process described by the stochastic differential equation

(18)

where is a positive definite drift matrix, is a mean vector, is some positive constant, and is a positive definite diffusion matrix. This OU process is unique in that it is one of the few SDEs we can solve explicitly. We can derive an expression for as,

(19)
Proof.

Consider the function where is a matrix exponential. Then by Itô’s Lemma444Itô’s Lemma states that for any Itô drift-diffusion process and twice differentiable scalar function , then . we can evaluate the derivative of as

Integrating this expression from to gives

which rearranged gives the expression for . ∎

From this expression it is clear that is a Gaussian process. The mean of the process is

(20)

and the covariance and cross-covariance of the process are

(21)
(22)

These last two expressions are derived by Itô Isometry555Itô Isometry states for any standard Itô process , then ..

d.1 The Lyapunov Equation

We can explicitly solve the integral expressions for the covariance and cross-covariance exactly by solving for the unique matrix that solves the Lyapunov equation,

(23)

If solves the Lyapunov equation, notice

Using this derivative, the integral expressions for the covariance and cross-covariance simplify as,

(24)
(25)

where we implicitly assume .

d.2 Decomposing the Drift Matrix

While the Lyapunov equation simplifies the expressions for the covariance and cross-covariance, it does not explain how to actually solve for the unknown matrix . Following a method proposed by Kwon et al. [19], we will show how to solve for explicitly in terms of the drift and diffusion .

The drift matrix can be uniquely decomposed as,

(26)

where is our symmetric diffusion matrix, is a skew-symmetric matrix (i.e. ), and is a positive definite matrix. Using this decomposition, then , solves the Lyapunov equation.

Proof.

Plug into the left-hand side of equation (23),

Here we used the symmetry of and the skew-symmetry of . ∎

All that is left is to do is solve for the unknown matrices and . First notice the following identity,

(27)
Proof.

Multiplying on the right by gives,

which rearranged and using gives the desired equation. ∎

Let be the eigendecomposition of and define the matrices and . These matrices observe the following relationship,

(28)
Proof.

Replace in the previous equality with its eigendecompsoition,

Multiply this equation on the right by and on the left by ,

Looking at this equality element-wise and using the fact that is diagonal gives the scalar equality for any ,

which rearranged gives the desired expression. ∎

Thus, and are given by,

(29)

This decomposition always holds uniquely when , as exists and is invertible. See [19] for a discussion on the singularities of this decomposition.

d.3 Stationary Solution

Using the Lyapunov equation and the drift decomposition, then , where

(30)

In the limit as , then and where

(31)

Similarly, the cross-covariance converges to the stationary cross-covariance,

(32)

Appendix E A Variational Formulation of the OU Process with Anisotropic Noise

In this section we will describe an alternative, variational, route towards solving the dynamics of the OU process studied in appendix D.

Let be an arbitrary, non-negative potential and consider the stochastic differential equation describing the Langevin dynamics of a particle in this potential field,

(33)

where is an arbitrary, spatially-dependent, diffusion matrix, is a temperature constant, and is the particle’s initial position. The Fokker-Planck equation describes the time evolution for the probability distribution of the particle’s position such that

. The FP equation is the partial differential equation

666This PDE is also known as the Forward Kolmogorov equation.,

(34)

where denotes the divergence and is a dirac delta distribution centered at the initialization . To assist in the exploration of the FP equation we define the vector field,

(35)

which is commonly referred to as the probability current. Notice, that this gives an alternative expression for the FP equation, , demonstrating that defines the flow of probability mass through space and time. This interpretation is especially useful for solving for the stationary solution , which is the unique distribution that satisfies,

(36)

where is the probability current for . The stationary condition can be obtained in two distinct ways:

  1. Detailed balance. This is when for all

    . This is analogous to reversibility for discrete Markov chains, which implies that the probability mass flowing from a state

    to any state is the same as the probability mass flowing from state to state .

  2. Broken detailed balance. This is when but for all . This is analogous to irreversibility for discrete Markov chains, which only implies that the total probability mass flowing out of state equals to the total probability mass flowing into state .

The distinction between these two cases is critical for understanding the limiting dynamics of the process.

e.1 The Variational Formulation of the Fokker-Planck Equation with Isotropic Diffusion

We will now consider the restricted setting of standard, isotropic diffusion (). It is easy enough to check that in this setting the stationary solution is

(37)

where is called a Gibbs distribution and is the partition function. Under this distribution, the stationary probability current is zero () and thus the process is in detailed balance. Interestingly, the Gibbs distribution has another interpretation as the unique minimizer of the the Gibbs free energy functional,

(38)

where is the expectation of the potential under the distribution and is the Shannon entropy of .

Proof.

To prove that indeed is the unique minimizer of the Gibbs free energy functional, consider the following equivalent expression

From this expressions, it is clear that the Kullback–Leibler divergence is uniquely minimized when

. ∎

In other words, with isotropic diffusion the stationary solution can be thought of as the limiting distribution given by the Fokker-Planck equation or the unique minimizer of an energetic-entropic functional.

Seminal work by Jordan et al. [14] deepened this connection between the Fokker-Planck equation and the Gibbs free energy functional. In particular, their work demonstrates that the solution to the Fokker-Planck equation is the Wasserstein gradient flow trajectory on the Gibbs free energy functional.

Steepest descent is always defined with respect to a distance metric. For example, the update equation, , for classic gradient descent on a potential , can be formulated as the solution to the minimization problem where is the Euclidean distance metric. Gradient flow is the continuous-time limit of gradient descent where we take . Similarly, Wasserstein gradient flow is the continuous-time limit of steepest descent optimization defined by the Wasserstein metric. The Wasserstein metric is a distance metric between probability measures defined as,

(39)

where and are two probability measures on with finite second moments and defines the set of joint probability measures with marginals and . Thus, given an initial distribution and learning rate , we can use the Wasserstein metric to derive a sequence of distributions minimizing some functional in the sense of steepest descent. In the continuous-time limit as this sequence defines a continuous trajectory of probability distributions minimizing the functional. Surprisingly, Jordan et al. [13] proved, through the following theorem, that this process applied to the Gibbs free energy functional converges to the solution to the Fokker-Planck equation with the same initialization:

Theorem 1 (Jko).

Given an initial condition