Understanding how stochastic gradient descent (SGD) manages to train artificial neural networks with good generalization capabilities by exploring the high-dimensional non-convex loss landscape is one of the central problems in theory of machine learning. A popular attempt to explain this behavior is by showing that the loss landscape itself is simple, with no spurious (i.e. leading to bad test error) local minima. Some empirical evidence instead leads to the conclusion that the loss landscape of state-of-the-art deep neural networks actually has spurious local (or even global) minima and stochastic gradient descent is able to find themSafran and Shamir (2017); Liu et al. (2019). Still, the stochastic gradient descent algorithm, initialized at random, leads to good generalization properties in practice. It became clear that a theory that would explain this success needs to account for the whole trajectory of the algorithm. Yet this remains a challenging task, certainly for the state-of-the art deep networks trained on real datasets.
Related work —
A detailed description of the whole trajectory taken by the (stochastic) gradient descent was so far obtained only in several special cases. First such case are deep linear networks where the dynamics of gradient descent has been analyzed Bös and Opper (1997); Saxe et al. (2013). While this line of works have led to very interesting insights about the dynamics, linear networks lack the expressivity of the non-linear ones and the large time behavior of the algorithm can be obtained with a simple spectral algorithm. Moreover, the analysis of dynamics in deep linear networks was not extended to the case of stochastic gradient descent. Second case where the trajectory of the algorithm was understood in detail is the one-pass (online) stochastic gradient descent for two-layer neural networks with a small hidden layer in the teacher-student setting Saad and Solla (1995a, b); Saad (2009); Goldt et al. (2019a, b)
. However, the one-pass assumption done in those analyses is far from what is done in practice and is unable to access the subtle difference between the training and test error that leads to many of the empirical mysteries observed in deep learning. A third very interesting line of research that recently provided insight about the behavior of stochastic gradient descent concerns two layer networks with divergingly wide hidden layer. This mean-field limitRotskoff and Vanden-Eijnden (2018); Mei et al. (2018); Chizat and Bach (2018) maps the dynamics into the space of functions where its description is simpler and the dynamics can be written in terms of a closed set of differential equations. It is not clear yet whether this analysis can be extended in a sufficiently explicit way to deeper or finite width neural networks.
Our present work inscribes in the above line of research offering the dynamical mean-field111Note that the word mean-field in the name of the method is taken from its usage in physics, and has in this case nothing to do with the width of an eventual hidden layer. theory (DMFT) formalism Mézard et al. (1987); Georges et al. (1996); Parisi et al. (2020)
leading to a closed set of integro-differential equations to track the full trajectory of the gradient descent (stochastic or not) from random initial condition in the high-dimensional limit for in-general non-convex losses. While in general the DMFT is a heuristic statistical physics method, it has been amenable to rigorous proof in some casesArous et al. (1997). This is hence an important future direction for the case considered in the present paper. The DMFT has been applied recently to a high-dimensional inference problem in Mannelli et al. (2020, 2019)
studying the spiked matrix-tensor model. However, this problem does not allow a natural way to study the stochastic gradient descent or to explore the difference between training and test errors. In particular, the spiked matrix-tensor model does not allow for the study of the so-called interpolating regime where the loss function is optimized to zero while the test error remains positive. As such, its landscape is intrinsically different from supervised learning problems since in the former the spurious minima proliferate at high values of the loss while the good ones lie at the bottom of the landscape. Instead, deep networks have both spurious and good minima at 100% training accuracy and their landscape resembles much more the one of continuous constraint satisfaction problemsFranz et al. (2017, 2019).
Main contributions —
We study a natural problem of supervised classification where the input data come from a high-dimensional Gaussian mixture of several clusters, and all samples in one cluster are assigned to one of two possible output labels. We then consider a single-layer neural network classifier with a general non-convex loss function. We analyze a stochastic gradient descent algorithm in which, at each iteration, the batch used to compute the gradient of the loss is extracted at random, and we define a particular stochastic process for which SGD can be extended to a continuous time limit that we call stochastic gradient flow (SGF). In the full-batch limit we recover the standard Gradient Flow (GF). We describe the high-dimensional limit of the randomly initialized SGF with the DMFT that leads to a description of the dynamics in terms of a self-consistent stochastic process that we compare with numerical simulations. In particular, we show that the finite batch size can have a beneficial effect in the test error and acts as an effective regularization that prevents overfitting.
2 Setting and definitions
In all what follows, we will consider the high-dimensional setting where the dimension of each point in the dataset is and the size of the training set , being a control parameter that we keep of order one. We consider a training set made of points
The patterns are given by
Without loss of generality, we choose a basis where .
We will illustrate our results on a two-cluster example where the coefficients are taken at randomand . The labels of the data points are fixed by . If the noise level
of the number of samples is small enough, the two Gaussian clouds are linearly separable by an hyperplane, as specified in detail inMignacco et al. (2020), and therefore a single layer neural network is enough to perform the classification task in this case. We hence consider learning with the simplest neural network that classifies the data according to .
We consider also an example of three clusters where a good generalization error cannot be obtained by separating the points linearly. In this case we define with probability , and with probability . The labels are then assigned as
One has hence three clouds of Gaussian points, two external and one centered in zero. In order to fit the data we consider a single layer-neural network with the door activation function, defined as
The onset parameter could be learned, but we will instead fix it to a constant.
We study the dynamics of learning by the empirical risk minimization of the loss
where we added a ridge regularization term. The activation function is given by
The DMFT analysis is valid for a generic loss function . However, for concreteness in the result section we will focus on the logistic loss . Note that in this setting the two-cluster dataset leads to convex optimization, with a unique minimum for finite , and implicit regularization for Rosset et al. (2004), and was analyzed in detail in Deng et al. (2019); Mignacco et al. (2020). Still the performance of stochastic gradient descent with finite batch size cannot be obtained in static ways. The three-cluster dataset, instead, leads to a generically non-convex optimization problem which can present many spurious minima with different generalization abilities when the control parameters such as and are changed. We note that our analysis can be extended to neural networks with a small hidden layer Seung et al. (1992), this would allow to study the role of overparametrization, but it is left for future work.
3 Stochastic gradient-descent training dynamics
Discrete SGD dynamics —
We consider the discrete gradient-descent dynamics for which the weight update is given by
where we have introduced the function and we have indicated with a prime the derivative with respect to , . We consider the following initialization of the weight vector , where is a parameter that tunes the average length of the weight vector at the beginning of the dynamics222The DMFT equations we derive can be easily generalized to the case in which the initial distribution over is different. We only need it to be separable and independent of the dataset. . The variables are i.i.d.
binary random variables. Their discrete-time dynamics can be chosen in two ways:
In classical SGD at iteration
one extracts the samples with the following probability distribution
and . In this way for each time iteration one extracts on average patterns at random on which the gradient is computed and therefore the batch size is given by . Note that if one gets full-batch gradient descent.
Persistent-SGD is defined by a stochastic process for given by the following probability rules
where is drawn from the probability distribution(8). In this case, for each time slice one has on average patterns that are active and enter in the computation of the gradient. The main difference with respect to the usual SGD is that one keeps the same patterns and the same minibatch for a characteristic time . Again, setting one gets full-batch gradient descent.
Stochastic gradient flow —
To write the DMFT we consider a continuous-time dynamics defined by the limit. This limit is not well defined for the usual SGD dynamics described by the rule (8) and we consider instead its persistent version described by eq. (9). In this case the stochastic process for is well defined for and one can write a continuous time equation as
Again, for one recovers the gradient flow. We call Eq. (10) stochastic gradient flow (SGF).
4 Dynamical mean-field theory for SGF
We will now analyze the SGF in the infinite size limit , with and and fixed and of order one. In order to do that we use dynamical mean-field theory (DMFT). The derivation of the DMFT equations is given in section A of the appendix, here we will just present the main steps. The derivation extends the one reported in Agoritsas et al. (2018)
for the non-convex perceptron modelFranz et al. (2017) (motivated there as a model of glassy phases of hard spheres). The main differences of the present work with respect to Agoritsas et al. (2018) are that here we consider a finite-batch gradient descent and that our dataset is structured while in Agoritsas et al. (2018) the derivation was done for full-batch gradient descent and random i.i.d. inputs and i.i.d. labels, i.e. a case where one cannot investigate generalization error and its properties. The starting point of the DMFT is the dynamical partition function
where stands for the measure over the dynamical trajectories starting from . Since (it is just an integral of a Dirac delta function) Dominicis (1976) one can average directly over the training set, the initial condition and the stochastic processes of . We indicate this average with the brackets . Hence we can write
where we have defined
and we have introduced a set of fields to produce the integral representation of the Dirac delta function. The average over the training set can be then performed explicitly and the dynamics in the limit satisfies a large deviation principle
where and are two dynamical order parameters defined in section A.1 of the appendix. The limit is therefore controlled by a saddle point. In particular, one can show that the saddle point equations can be recast into a self consistent stochastic process for a variable representing the typical behavior of , which evolves according to the stochastic equation:
where we have denoted by and is the magnetization, namely . The details of the computation are provided in section A.2 of the appendix. There are several sources of stochasticity in Eq. (15). First, one has a dynamical noise
that is Gaussian distributed and characterized by the correlations
Furthermore, the starting point of the stochastic process is random and distributed according to
Moreover, one has to introduce a quenched Gaussian random variable with mean zero and average one. We recall that the random variable with equal probability in the two-cluster model, while in the three-cluster one. The variable is therefore in the two-cluster case, and is given by Eq. (3) in the three-cluster one. Finally, one has a dynamical stochastic process whose statistical properties are specified in Eq. (9). The magnetization is obtained from the following deterministic differential equation
The stochastic process for , the evolution of , as well as the statistical properties of the dynamical noise depend on a series of kernels that must be computed self consistently and are given by
In Eq. (19) the brackets denote the average over all the sources of stochasticity in the self-consistent stochastic process. Therefore one needs to solve the stochastic process in a self-consistent way. Note that in Eq. (15) is set to zero and we need it only to define the kernel . The set of Eqs. (15), (18) and (19) can be solved by a simple straightforward iterative algorithm. One starts with a guess for the kernels and then runs the stochastic process for several times to update the kernels. The iteration is stopped when a desired precision on the kernels is reached Eissfeller and Opper (1992).
Note that, in order to solve Eqs. (15), (18) and (19), one needs to discretize time. In the result section 5, in order to compare with numerical simulations, we will take the time-discretization of DMFT equal to the learning rate in the simulations. In the time-discretized DMFT, this allows us to extract the variables either from (8) (SGD) or (9) (Persistent-SGD). In the former case this gives us a SGD-inspired discretization of the DMFT equations.
Finally, once the self-consistent stochastic process is solved, one has access also to the dynamical correlation functions , encoded in the dynamical order parameter that appears in the large deviation principle of Eq. (14). concentrates for and therefore is controlled by the equations
where we used the shorthand notation and is a response function that controls the variations of the weights when their dynamical evolution is affected by an infinitesimal local field . It is interesting to note that the second of Eqs. (20) controls the evolution of the norm of the weight vector and even if we set we get that it contains an effective regularization that is dynamically self-generated Soudry et al. (2018).
Dynamics of the loss and the generalization error —
Once the solution for the self-consistent stochastic process is found, one can get several interesting quantities. First, one can look at the training loss, which can be obtained as
where again the brackets denote the average over the realization of the stochastic process in Eq. (15). The training accuracy is given by
and, by definition, it is equal to one as soon as all vectors in the training set are correctly classified. Finally, one can compute the generalization error. At any time step, it is defined as the fraction of mislabeled instances:
where is the training set, is an unseen data point and
is the estimator for the new label. The dependence on the training set here is hidden in the weight vector . In the two-cluster case one can easily show that
Conversely, for the door activation trained on the three-cluster dataset we get
In this section, we compare the theoretical curves resulting from the solution of the DMFT equations derived in section 4 to numerical simulations. This analysis allows to gain insight into the learning dynamics of stochastic gradient descent and its dependence on the various control parameters in the two models under consideration.
The left panel of Fig. 1 shows the learning dynamics of the Persistent-SGD in the two-cluster model without regularization . We clearly see a good match between the numerical simulations and the theoretical curves obtained from DMFT, notably also for small values of batchsize and dimension . The figure shows that there exist regions in control parameter space where Persistent-SGD is able to reach 100% training accuracy, while the generalization error is bounded away from zero. Remarkably, we observe that the additional noise introduced by decreasing the batch size results in a shift of the early-stopping minimum of the generalization error at larger times and that, on the time window we show, a batch size smaller than one has a beneficial effect on the generalization error at long times. The right panel illustrates the role of regularization in the same model trained with full-batch gradient descent, presenting that regularization has a similar influence on the learning curve as small batch-size but without the slow-down incurred by Persistent-SGD.
The influence of the batch size and the regularization for the three-cluster model is shown in Fig. 2. We see an analogous effect as for the two-clusters in Fig. 1. In the inset of Fig. 2, we show the norm of the weights as a function of the training time. Both with the smaller mini-batch size and larger regularization the norm is small, testifying further that the two play a similar role in this case.
One difference between the two-cluster an the three-cluster models we observe concerns the behavior of the generalization error at small times. Actually, for the three-cluster model, good generalization is reached because of finite-size effects. Indeed, the corresponding loss function displays a symmetry according to which for each local minimum there is another one with exactly the same properties. Note that this symmetry is inherited from the activation function (6), which is even. This implies that if , the generalization error would not move away from in finite time. However, when is large but finite, at time the weight vector has a finite projection on which is responsible for the dynamical symmetry breaking and eventually for a low generalization error at long times. In order to obtain an agreement between the theory and simulations, we initialize in the DMFT equations with its corresponding finite- average value at . In the left panel of Fig. 3, we show that while this produces a small discrepancy at intermediate times that diminishes with growing size, at long times the DMFT tracks perfectly the evolution of the algorithm.
The right panel of Fig. 3 summarizes the effect of the characteristic time in the Persistent-SGD, i.e. the typical persistence time of each pattern in the training mini-batch. When decreases, the Persistent-SGD algorithm is observed to be getting a better early-stopping generalization error and the dynamics gets closer to the usual SGD dynamics. As expected, the limit of the Persistent-SGD converges to the SGD. It is remarkable that the SGD-inspired discretization of the DMTF equations, that is in principle an ad-hoc construction as the corresponding flow-limit in which the derivation holds does not exist, shows a perfect agreement with the numerics.
Fig. 4 presents the influence of the weight norm at initialization on the dynamics, for the two-cluster (left) and three-cluster (right) model. For the two-cluster case, the gradient descent algorithm with all-zeros initialization “jumps” on the Bayes-optimal error at the first iteration as derived in Mignacco et al. (2020), and in this particular setting the generalization error is monotonically increasing in time. As increases the early stopping error gets worse. At large times all the initializations converge to the same value of the error, as they must, since this is a full-batch gradient descent without regularization that at large times converges to the max-margin estimator according to Rosset et al. (2004). For the three-cluster model we observe a qualitatively similar behavior.
We acknowledge funding from the ERC under the European Union’s Horizon 2020 Research and Innovation Programme Grant Agreement 714608-SMiLe, from the Fondation CFM pour la Recherche-ENS, as well as from the French Agence Nationale de la Recherche under grant ANR-17-CE23-0023-01 PAIL and ANR-19-P3IA-0001 PRAIRIE. This work was supported by ”Investissements d’Avenir” LabExPALM (ANR-10-LABX-0039-PALM).
-  (2018) Out-of-equilibrium dynamical mean-field equations for the perceptron model. Journal of Physics A: Mathematical and Theoretical 51 (8), pp. 085002. Cited by: Appendix A, §4.
-  (1997) Symmetric langevin spin glass dynamics. The Annals of Probability 25 (3), pp. 1367–1422. Cited by: §1.
-  (1997) Dynamics of training. In Advances in Neural Information Processing Systems, pp. 141–147. Cited by: §1.
-  (2018) On the global convergence of gradient descent for over-parameterized models using optimal transport. In Advances in neural information processing systems, pp. 3036–3046. Cited by: §1.
-  (2019) A model of double descent for high-dimensional binary linear classification. arXiv preprint arXiv:1911.05822. Cited by: §2.
-  (1976) Technics of field renormalization and dynamics of critical phenomena. In J. Phys.(Paris), Colloq, pp. C1–247. Cited by: §4.
-  (1992) New method for studying the dynamics of disordered spin systems without finite-size effects. Physical review letters 68 (13), pp. 2094. Cited by: §A.3, §4.
-  (1994) Mean-field monte carlo approach to the sherrington-kirkpatrick model with asymmetric couplings. Physical Review E 50 (2), pp. 709. Cited by: §A.3.
-  (2019) Jamming in multilayer supervised learning models. Physical review letters 123 (16), pp. 160602. Cited by: §1.
-  (2017) Universality of the sat-unsat (jamming) threshold in non-convex continuous constraint satisfaction problems. SciPost Physics 2 (3), pp. 019. Cited by: §1, §4.
-  (1996) Dynamical mean-field theory of strongly correlated fermion systems and the limit of infinite dimensions. Reviews of Modern Physics 68 (1), pp. 13. Cited by: §A.3, §1.
-  (2019) Dynamics of stochastic gradient descent for two-layer neural networks in the teacher-student setup. In Advances in Neural Information Processing Systems, pp. 6979–6989. Cited by: §1.
-  (2019) Modelling the influence of data structure on learning in neural networks. arXiv preprint arXiv:1909.11500. Cited by: §1.
-  (1992) Supersymmetry in spin glass dynamics. Journal de Physique I 2 (7), pp. 1333–1352. Cited by: Appendix A.
-  (2002) Supersymmetry, replica and dynamic treatments of disordered systems: a parallel presentation. arXiv preprint cond-mat/0209399. Cited by: Appendix A.
-  (2019) Bad global minima exist and sgd can reach them. arXiv preprint arXiv:1906.02613. Cited by: §1.
-  (2020) Numerical solution of the dynamical mean field theory of infinite-dimensional equilibrium liquids. The Journal of Chemical Physics 152 (16), pp. 164506. Cited by: §A.3.
-  (2020) Marvels and pitfalls of the langevin algorithm in noisy high-dimensional inference. Physical Review X 10 (1), pp. 011057. Cited by: §1.
-  (2019) Passed & spurious: descent algorithms and local minima in spiked matrix-tensor models. In international conference on machine learning, pp. 4333–4342. Cited by: §1.
-  (2018) A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences 115 (33), pp. E7665–E7671. Cited by: §1.
-  (1987) Spin glass theory and beyond. World Scientific, Singapore. Cited by: Appendix A, §1.
-  (2020) The role of regularization in classification of high-dimensional noisy gaussian mixture. arXiv preprint arXiv:2002.11544. Cited by: §B.1, §2, §2, Figure 1, Figure 3, Figure 4, §5.
-  (2020) Theory of simple glasses: exact solutions in infinite dimensions. Cambridge University Press. Cited by: §1.
-  (2004) Margin maximizing loss functions. In Advances in neural information processing systems, pp. 1237–1244. Cited by: §2, §5.
-  (2018) Neural networks as interacting particle systems: asymptotic convexity of the loss landscape and universal scaling of the approximation error. arXiv preprint arXiv:1805.00915. Cited by: §1.
-  (2019) Numerical implementation of dynamical mean field theory for disordered systems: application to the lotka–volterra model of ecosystems. Journal of Physics A: Mathematical and Theoretical 52 (48), pp. 484001. Cited by: §A.3.
-  (1995) Exact solution for on-line learning in multilayer neural networks. Physical Review Letters 74 (21), pp. 4337. Cited by: §1.
-  (1995) On-line learning in soft committee machines. Physical Review E 52 (4), pp. 4225. Cited by: §1.
-  (2009) On-line learning in neural networks. Vol. 17, Cambridge University Press. Cited by: §1.
Spurious local minima are common in two-layer relu neural networks. arXiv preprint arXiv:1712.08968. Cited by: §1.
-  (2013) Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120. Cited by: §1.
Query by committee.
Proceedings of the fifth annual workshop on Computational learning theory, pp. 287–294. Cited by: §2.
-  (2018) The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research 19 (1), pp. 2822–2878. Cited by: §4.
-  (1996) Quantum field theory and critical phenomena. Clarendon Press. Cited by: §A.1.
Appendix A Derivation of the dynamical mean-field equations
The derivation of the self-consistent stochastic process discussed in the main text can be obtained using tools of statistical physics of disordered systems. In particular, it has been done very recently for a related model, the spherical perceptron with random labels, in . Our derivation extends the known DMFT equations by including
structure in the data;
a stochastic version of gradient descent as discussed in the main text;
the relaxation of the spherical constraint over the weights and the introduction of a Ridge regularization term.
There are at least two ways to write the DMFT equations. One is by using field theoretical techniques; otherwise one can employ a dynamical version of the so-called cavity method . Here we opt for the first option that is generically very compact and immediate and it has a form that resembles very much a static treatment of the Gibbs measure of the problem . We use a supersymmetric (SUSY) representation to derive the dynamical mean-field (DMFT) equations [14, 1]. We do not report all the details, that can be found in  along with an alternative derivation based on the cavity method, but we limit ourselves to provide the main points. We first consider the dynamical partition function, corresponding to Eq. (11) in the main text
where the brackets stand for the average over , and the realization of the noise in the training set. The average over the initial condition is written explicitly. Note that we chose an initial condition that is Gaussian but we could have chosen a different probability measure over the initial configuration of the weights. The equations can be generalized to other initial conditions as soon as they do not depend on quenched random variables that enter in the stochastic gradient descent (SGD) dynamics and their distribution is separable. As observed in the main text, we have that . We can write the integral representation of the Dirac delta function in Eq. A.1 by introducing a set of fields
where the dynamical action is defined as in Eq. (13) of the main text
a.1 SUSY formulation
The dynamical action (A.3) can be rewritten in a supersymmetric form, by extending the time coordinate to include two Grassman coordinates and , i.e. . The dynamic variable and the auxiliary variable are encoded in a super-field
From the properties of Grassman variables 
it follows that
We can use Eq. (A.6) to rewrite . We obtain
where we have defined and we have implicitly defined the kernel such that
By inserting the definition of in the partition function, we have
Let us consider the last factor in the integral in (A.9). We can perform the average over the random vectors , denoted by an overline as
where we have defined
By inserting the definitions of and in the partition function, we obtain
We have used that the samples are i.i.d. and removed the index . The brackets denote the average over the random variable , that has the same distribution as the , over , distributed as , and over the random process of , defined by Eq. (9) in the main text. If we perform the change of variable , we obtain
where the effective local action is given by
Performing a Hubbard-Stratonovich transformation on and a set of transformations on the fields , we obtain that we can rewrite as
a.2 Saddle-point equations
We are interested in the large limit of , in which, according to Eq. (A.12), the partition function is dominated by the saddle-point value of :
The saddle-point equation for gives
The saddle-point equation for is instead
The brackets in the previous equations denote, at the same time, the average over the label , the process , as well as the average over the noise and both and , whose probability distributions are given by and respectively. In other words, one has a set of kernels, such as and , that can be obtained as average over the stochastic process for and therefore must be computed self-consistently.
a.3 Numerical solution of DMFT equations
The algorithm to solve the DMFT equations that are summed up in Eq. (A.20) is the most natural one. It can be understood in this way. The outcome of the DMFT is the computation of the kernels and functions appearing in it, namely , and so on. They are determined as averages over the stochastic process that is defined through them. Therefore, one needs to solve the system of equations in a self-consistent way. The straightforward way to do that is to proceed by iterations. One starts with a random guess of these kernels and then samples the stochastic process several times. From this sampling one then constructs a new guess for the kernels from which a new sampling will be done. The algorithm proceeds in this way until the kernels reach a fixed point. As in all iterative solutions of fixed point equations, it is natural to introduce some damping in the update of the kernels to avoid wild oscillations. This procedure has been first implemented in [7, 8] and recently developed further in other applications [26, 17]. However, DMFT has a long tradition in condensed matter physics  where more involved algorithms have been developed.
Appendix B Generalization error
The generalization error at any time step is defined as the fraction of mislabeled instances:
where is the training set, is an unseen data point and is the estimator for the new label . The dependence on the training set here is hidden in the weight vector .
b.1 Perceptron with linear activation function
In this case, the estimator for a new label is . The generalization error has been computed in  and reads