Log In Sign Up

Luck Matters: Understanding Training Dynamics of Deep ReLU Networks

by   Yuandong Tian, et al.

We analyze the dynamics of training deep ReLU networks and their implications on generalization capability. Using a teacher-student setting, we discovered a novel relationship between the gradient received by hidden student nodes and the activations of teacher nodes for deep ReLU networks. With this relationship and the assumption of small overlapping teacher node activations, we prove that (1) student nodes whose weights are initialized to be close to teacher nodes converge to them at a faster rate, and (2) in over-parameterized regimes and 2-layer case, while a small set of lucky nodes do converge to the teacher nodes, the fan-out weights of other nodes converge to zero. This framework provides insight into multiple puzzling phenomena in deep learning like over-parameterization, implicit regularization, lottery tickets, etc. We verify our assumption by showing that the majority of BatchNorm biases of pre-trained VGG11/16 models are negative. Experiments on (1) random deep teacher networks with Gaussian inputs, (2) teacher network pre-trained on CIFAR-10 and (3) extensive ablation studies validate our multiple theoretical predictions.


Over-parameterization as a Catalyst for Better Generalization of Deep ReLU network

To analyze deep ReLU network, we adopt a student-teacher setting in whic...

On Learnability via Gradient Method for Two-Layer ReLU Neural Networks in Teacher-Student Setting

Deep learning empirically achieves high performance in many applications...

A theoretical framework for deep locally connected ReLU network

Understanding theoretical properties of deep and locally connected nonli...

Initializing ReLU networks in an expressive subspace of weights

Using a mean-field theory of signal propagation, we analyze the evolutio...

Learning performance in inverse Ising problems with sparse teacher couplings

We investigate the learning performance of the pseudolikelihood maximiza...

Neural Networks and Polynomial Regression. Demystifying the Overparametrization Phenomena

In the context of neural network models, overparametrization refers to t...

From complex to simple : hierarchical free-energy landscape renormalized in deep neural networks

We develop a statistical mechanical approach based on the replica method...

Code Repositories


Understanding Training Dynamics of Deep ReLU Networks

view repo

1 Introduction

Although neural networks have made strong empirical progress in a diverse set of domains (e.g., computer vision 

alexnet ; vgg ; resnet , speech recognition hinton2012deep ; amodei2016deep

, natural language processing 

word2vec ; bert , and games alphago ; alphagozero ; darkforest ; mnih2013playing

), a number of fundamental questions still remain unsolved. How can Stochastic Gradient Descent (SGD) find good solutions to a complicated non-convex optimization problem? Why do neural networks generalize? How can networks trained with SGD fit both random noise and structured data 

rethinking ; krueger2017deep ; neyshabur2017exploring , but prioritize structured models, even in the presence of massive noise rolnick2017deep ? Why are flat minima related to good generalization? Why does over-parameterization lead to better generalization neyshabur2018towards ; zhang2019identity ; spigler2018jamming ; neyshabur2014search ; li2018measuring ? Why do lottery tickets exist lottery ; lottery-scale ?

In this paper, we propose a theoretical framework for multilayered ReLU networks. Based on this framework, we try to explain these puzzling empirical phenomena with a unified view. We adopt a teacher-student setting where the label provided to an over-parameterized deep student ReLU network is the output of a fixed teacher ReLU network of the same depth and unknown weights (Fig. 1(a)). Here over-parameterization means that at each layer, the number of nodes in student network is more than the number of nodes in the teacher network. In this perspective, hidden student nodes are randomly initialized with different activation regions (Fig. 2(a)). During optimization, student nodes compete with each other to explain teacher nodes. From this setting, Theorem 4 shows that lucky student nodes which have greater overlap with teacher nodes converge to those teacher nodes at a fast rate, resulting in winner-take-all behavior. Furthermore, Theorem 5 shows that in the 2-layer case, if a subset of student nodes are close to the teachers’, they converge to them and the fan-out weights of other irrelevant nodes of the same layer vanishes.

With this framework, we try to intuitively explain various neural network behaviors as follows:

Fitting both structured and random data. Under gradient descent dynamics, some student nodes, which happen to overlap substantially with teacher nodes, will move into the teacher node and cover them. This is true for both structured data that corresponds to small teacher networks with few intermediate nodes, or noisy/random data that correspond to large teachers with many intermediate nodes. This explains why the same network can fit both structured and random data (Fig. 2(a-b)).

Over-parameterization. In over-parameterization, lots of student nodes are initialized randomly at each layer. Any teacher node is more likely to have a substantial overlap with some student nodes, which leads to fast convergence (Fig. 2(a) and (c), Thm. 4), consistent with lottery ; lottery-scale . This also explains that training models whose capacity just fit the data (or teacher) yields worse performance li2018measuring .

Flat minima

. Deep networks often converge to “flat minima” whose Hessian has a lot of small eigenvalues 

sagun2016eigenvalues ; sagun2017empirical ; lipton2016stuck ; baity2018comparing . Furthermore, while controversial dinh2017sharp , flat minima seem to be associated with good generalization, while sharp minima often lead to poor generalization hochreiter1997flat ; keskar2016large ; wu2017towards ; li2018visualizing . In our theory, when fitting with structured data, only a few lucky student nodes converge to the teacher, while for other nodes, their fan-out weights shrink towards zero, making them (and their fan-in weights) irrelevant to the final outcome (Thm. 5), yielding flat minima in which movement along most dimensions (“unlucky nodes”) results in minimal change in output. On the other hand, sharp minima is related to noisy data (Fig. 2(d)), in which more student nodes match with the teacher.

Implicit regularization. On the other hand, the snapping behavior enforces winner-take-all: after optimization, a teacher node is fully covered (explained) by a few student nodes, rather than splitting amongst student nodes due to over-parameterization. This explains why the same network, once trained with structured data, can generalize to the test set.

Lottery Tickets. Lottery Tickets lottery ; lottery-scale ; zhou2019deconstructing is an interesting phenomenon: if we reset “salient weights” (trained weights with large magnitude) back to the values before optimization but after initialization, prune other weights (often of total weights) and retrain the model, the test performance is the same or better; if we reinitialize salient weights, the test performance is much worse. In our theory, the salient weights are those lucky regions ( and in Fig. 3) that happen to overlap with some teacher nodes after initialization and converge to them in optimization. Therefore, if we reset their weights and prune others away, they can still converge to the same set of teacher nodes, and potentially achieve better performance due to less interference with other irrelevant nodes. However, if we reinitialize them, they are likely to fall into unfavorable regions which cannot cover teacher nodes, and therefore lead to poor performance (Fig. 3(c)), just like in the case of under-parameterization. Recently, Supermask zhou2019deconstructing shows that a supermask can be found from winning tickets. If it is applied to initialized weights, the network without training gives much better test performance than chance. This is also consistent with the intuitive picture in Fig. 3(b).

Figure 1: (a) Teacher-Student Setting. For each node , the activation region is . (b) node initialized to overlap substantially with a teacher node converges faster towards (Thm. 4). (c) Student nodes initialized to be close to teacher converges to them, while the fan-out weights of other irrelevant student nodes goes to zero. (Thm. 5).
Figure 2: Explanation of implicit regularization. Blue are activation regions of teacher nodes, while orange are students’. (a) When the data labels are structured, the underlying teacher network is small and each layer has few nodes. Over-parameterization (lots of red regions) covers them all. Moreover, those student nodes that heavily overlap with the teacher nodes converge faster (Thm. 4), yield good generalization performance. (b) If a dataset contains random labels, the underlying teacher network that can fit to it has a lot of nodes. Over-parameterization can still handle them and achieves zero training error.
Figure 3: Explanation of lottery ticket phenomenon. (a) A successful training with over-parameterization (2 filters in the teacher network and 4 filters in the student network). Node and are lucky draws with strong overlap with two teacher node and , and thus converges with high weight magnitude. (b) Lottery ticket phenomenon: initialize node and with the same initial weight, clamp the weight of and to zero, and retrain the model, the test performance becomes better since and still converge to their teacher node, respectively. (c) If we reinitialize node and , it is highly likely that they are not overlapping with teacher node and so the performance is not good.

2 Mathematical Framework

Notation. Consider a student network and its associated teacher network (Fig. 1(a)). Denote the input as . For each node , denote as the activation, as the ReLU gating (for the top-layer, and are always ), and

as the backpropagated gradient, all as functions of

. We use the superscript to represent a teacher node (e.g., ). Therefore, never appears as teacher nodes are not updated. We use to represent weight between node and in the student network. Similarly, represents the weight between node and in the teacher network.

We focus on multi-layered ReLU networks. We use the following equality extensively: . For ReLU node , we use as the activation region of node .

Teacher network versus Dataset. The reason why we formulate the problem using teacher network rather than a dataset is the following: (1) It leads to a nice and symmetric formulation for multi-layered ReLU networks (Thm. 1). (2) A teacher network corresponds to an infinite size dataset, which separates the finite sample issues from induction bias in the dataset, which corresponds to the structure of teacher network. (3) If student weights can be shown to converge to teacher ones, generalization bound can naturally follow for the student. (4) The label complexity of data generated from a teacher is automatically reduced, which could lead to better generalization bound. On the other hand, a bound for arbitrary function class can be hard.


. We assume that both the teacher and the student output probabilities over

classes. We use the output of teacher as the input of the student. At the top layer, each node in the student corresponds to each node in the teacher. Therefore, the objective is:


By the backpropagation rule, we know that for each sample , the (negative) gradient . The gradient gets backpropagated until the first layer is reached.

Note that here, the gradient sent to node is correlated with the activation of the corresponding teacher node and other student nodes at the same layer. Intuitively, this means that the gradient “pushes” the student node to align with class of the teacher. If so, then the student learns the corresponding class well. A natural question arises:

Are student nodes at intermediate layers correlated with teacher nodes at the same layers?

One might wonder this is hard since the student’s intermediate layer receives no direct supervision from the corresponding teacher layer, but relies only on backpropagated gradient. Surprisingly, the following theorem shows that it is possible for every intermediate layer:

Theorem 1 (Recursive Gradient Rule).

If all nodes at layer satisfies Eqn. 2


then all nodes at layer also satisfies Eqn. 2 with and defined recursively from top to bottom-layer as follows:


And for the base case where node and are at the top-most layer, (i.e., when corresponds to , and otherwise).

Note that Theorem 1 applies to arbitrarily deep ReLU networks and allows different number of nodes for the teacher and student. The role played by ReLU activation is to make the expression of concise, otherwise and can take a very complicated (and asymmetric) form.

In particular, we consider the over-parameterization setting: the number of nodes on the student side is much larger (e.g., 5-10x) than the number of nodes on the teacher side. Using Theorem 1, we discover a novel and concise form of gradient update rule:

Assumption 1 (Separation of Expectations).
Theorem 2.

If Assumption 1 holds, the gradient dynamics of deep ReLU networks with objective (Eqn. 1) is:


Here we explain the notations. is teacher weights, , and , , and . We can define similar notations for (which has columns/filters), , , and (Fig. 4(c)). At the lowest layer , , at the highest layer where there is no ReLU, we have due to Eqn. 1. According to network structure, and only depends on weights , while and only depend on .

3 Analysis on the Dynamics

Figure 4: (a) Small overlaps between node activations. Figure drawn in the space spanned by the activations of last layer so all decision boundaries are linear. (b) Lipschitz condition (Assumption 2). (c) Notation in Thm. 2.

In the following, we will use Eqn. 6 to analyze the dynamics of the multi-layer ReLU networks. For convenience, we first define the two functions and ( is the ReLU function):


We assume these two functions have the following property .

Assumption 2 (Lipschitz condition).

There exists and so that:


Using this, we know that , , and so on. For brevity, denote (when notation is heavy) and so on. We impose the following assumption:

Assumption 3 (Small Overlap between teacher nodes).

There exists and so that:


Intuitively, this means that the probability of the simultaneous activation of two teacher nodes and is small. If we have sufficient training data to cover the input space, then a sufficient condition for Assumption 3 to happen is that the teacher has negative bias, which means that they cut corners in the space spanned by the node activations of the lower layer (Fig. 4

a). We have empirically verified that the majority of biases in BatchNorm layers (after the data are whitened) are negative in VGG11/16 trained on ImageNet (Sec. 


3.1 Effects of BatchNorm

Batch Normalization batchnorm has been extensively used to speed up the training, reduce the tuning efforts and improve the test performance of neural networks. Here we use an interesting property of BatchNorm: the total “energy” of the incoming weights of each node is conserved over training iterations:

Theorem 3 (Conserved Quantity in Batch Normalization).

For Linear ReLU BN or Linear BN ReLU configuration, of a filter before BN remains constant in training. (Fig. 15).

See Appendix for the proof. The similar lemma is also in arora2018theoretical . This may partially explain why BN has stabilization effect: energy will not leak from one layer to nearby ones. Due to this property, in the following, for convenience we assume , and the gradient is always orthogonal to the current weight . Note that on the teacher side we can always push the magnitude component to the upper layer; on the student side, random initialization naturally leads to constant magnitude of weights.

3.2 Same number of student nodes as teacher

We start with a simple case first. Consider that we only analyze layer without over-parameterization, i.e., . We also assume that , i.e., the input of that layer is whitened, and the top-layer signal is uniform, i.e., (all entries are 1). Then the following theorem shows that weight recovery could follow (we use as ).

Theorem 4.

For dynamics , where is a projection matrix into the orthogonal complement of . , are corresponding -th column in and . Denote and assume . If , then with the rate ( is learning rate). Here and .

See Appendix for the proof. Here we list a few remarks:

Faster convergence near . we can see that due to the fact that in general becomes larger when (since can be close to ), we expect a super-linear convergence near . This brings about an interesting winner-take-all mechanism: if the initial overlap between a student node and a particular teacher node is large, then the student node will snap to it (Fig. 1(c)).

Importance of projection operator . Intuitively, the projection is needed to remove any ambiguity related to weight scaling, in which the output remains constant if the top-layer weights are multiplied by a constant , while the low-layer weights are divided by . Previous works du2017gradient also uses similar techniques while we justify it with BN. Without , convergence can be harder.

Top-down modulation. Note that here we assume the top-layer signal is uniform, which means that according to , there is no preference on which student node corresponds to which teacher node . If there is a preference (e.g., ), then from the proof, the cross-term will be suppressed due to , making convergence easier. As we will see next, such a top-down modulation plays an important role for 2-layer and over-parameterization case. We believe that it also plays a similar role for deep networks.

3.3 Over-Parameterization and Top-down Modulation in 2-layer Network

In the over-parameterization case (, e.g., 5-10x), we arrange the variables into two parts: , where contains columns (same size as ), while contains columns. We use (or -set) to specify nodes , and (or -set) for the remaining part.

In this case, if we want to show “the main component” converges to , we will meet with one core question: to where will converge, or whether will even converge at all? We need to consider not only the dynamics of the current layer, but also the dynamics of the upper layer. Using a 1-hidden layer over-parameterized ReLU network as an example, Theorem 5 shows that the upper-layer dynamics automatically apply top-down modulation to suppress the influence of , regardless of their convergence. Here , where are the weight components of -set. See Fig. 5.

Figure 5: Over-parameterization and top-down modulation. Thm. 5 shows that under certain conditions, the relevant weights and weights connecting to irrelevant student nodes .
Theorem 5 (Over-Parameterization and Top-down Modulation).

Consider with over-parameterization () and its upper-layer dynamics . Assume that initial value is close to : for . If (1) Assumption 3 holds for all pairwise combination of columns of and , and (2) there exists and so that Eqn. 41 and Eqn. 42 holds, then , and with rate .

See Appendix for the proof (and definition of in Eqn. 45). The intuition is: if is close to and are far away from them due to Assumption 3, the off-diagonal elements of and are smaller than diagonal ones. This causes to move towards and to move towards zero. When becomes small, so does for or . This in turn suppresses the effect of and accelerates the convergence of . exponentially so that stays close to its initial locations, and Assumption 3 holds for all iterations. A few remarks:

Flat minima. Since , can be changed arbitrarily without affecting the outputs of the neural network. This could explain why there are many flat directions in trained networks, and why many eigenvalues of the Hessian are close to zero sagun2016eigenvalues .

Understanding of pruning methods. Theorem 5 naturally relates two different unstructured network pruning approaches: pruning small weights in magnitude han2015learning ; lottery and pruning weights suggested by Hessian lecun1990optimal ; hassibi1993optimal . It also suggests a principled structured pruning method: instead of pruning a filter by checking its weight norm, pruning accordingly to its top-down modulation.

Accelerated convergence and learning rate schedule. For simplicity, the theorem uses a uniform (and conservative) throughout the iterations. In practice, is initially small (due to noise introduced by -set) but will be large after a few iterations when vanishes. Given the same learning rate, this leads to accelerated convergence. At some point, the learning rate becomes too large, leading to fluctuation. In this case, needs to be reduced.

Many-to-one mapping. Theorem 5 shows that under strict conditions, there is one-to-one correspondence between teacher and student nodes. In general this is not the case. Two students nodes can be both in the vicinity of a teacher node and converge towards it, until that node is fully explained. We leave it to the future work for rigid mathematical analysis of many-to-one mappings.

Random initialization. One nice thing about Theorem 5 is that it only requires the initial to be small. In contrast, there is no requirement for small . Therefore, we could expect that with more over-parameterization and random initialization, in each layer , it is more likely to find the -set (of fixed size ), or the lucky weights, so that is quite close to . At the same time, we don’t need to worry about

which grows with more over-parameterization. Moreover, random initialization often gives orthogonal weight vectors, which naturally leads to Assumption 


3.4 Extension to Multi-layer ReLU networks

Using a similar approach, we could extend this analysis to multi-layer cases. We conjecture that similar behaviors happen: for each layer, due to over-parameterization, the weights of some lucky student nodes are close to the teacher ones. While these converge to the teacher, the final values of others irrelevant weights are initialization-dependent. If the irrelevant nodes connect to lucky nodes at the upper-layer, then similar to Thm. 5, the corresponding fan-out weights converge to zero. On the other hand, if they connect to nodes that are also irrelevant, then these fan-out weights are not-determined and their final values depends on initialization. However, it doesn’t matter since these upper-layer irrelevant nodes eventually meet with zero weights if going up recursively, since the top-most output layer has no over-parameterization. We leave a formal analysis to future work.

Figure 6: Distribution of BatchNorm bias in pre-trained VGG16 on ImageNet. Orange/blue are positive/negative biases. Conv0 corresponds to the lowest layer (closest to the input). VGG11/13/19 in Fig. 18.

4 Simulations

4.1 Checking Assumption 3

To make Theorem 4 and Theorem 5 work, we make Assumption 3 that the activation field of different teacher nodes should be well-separated. To justify this, we analyze the bias of BatchNorm layers after the convolutional layers in pre-trained VGG11/13/16/19. We check the BatchNorm bias as these models use Linear-BatchNorm-ReLU architecture. After BatchNorm first normalizes the input data into zero mean distribution, the BatchNorm bias determines how much data pass the ReLU threshold. If the bias is negative, then a small portion of data pass ReLU gating and Assumption 3 is likely to hold. From Fig. 6, it is quite clear that the majority of BatchNorm bias parameters are negative, in particular for the top layers.

4.2 Numerical Experiments of Thm. 5

We verify Thm. 5 by checking whether moves close to under different initialization. We use a network with one hidden layer. The teacher network is 10-20-30, while the student network has more nodes in the hidden layers. Input data are Gaussian noise. We initialize the student networks so that the first nodes are close to the teacher. Specifically, we first create matrices and by first filling with i.i.d Gaussian noise, and then normalizing their columns to . Then the initial value of student is , where is a factor controlling how close is to . For we initialize with noise. Similarly we initialize with a factor . The larger and , the close the initialization and to the ground truth values.

Fig. 7 shows the behavior over different iterations. All experiments are repeated 32 times with different random seeds, and (mean

std) are reported. We can see that a close initialization leads to faster (and low variance) convergence of

to small values. In particular, it is important to have close to (large ), which leads to a clear separation between row norms of and , even if they are close to each other at the beginning of training. Having close to makes the initial gap larger and also helps convergence. On the other hand, if is small, then even if is large, the gap between row norms of and only shifts but doesn’t expand over iterations.

Figure 7: Numerical verification of Thm. 5. Rows are , and over-parameterization. Columns are different initializations of student networks. All figures show , i.e., norm of each row of .

5 Experiments

5.1 Experiment Setup

We evaluate both the fully connected (FC) and ConvNet setting. For FC, we use a ReLU teacher network of size 50-75-100-125. For ConvNet, we use a teacher with channel size 64-64-64-64. The student networks have the same depth but with nodes/channels at each layer, such that they are substnatially over-parameterized. When BatchNorm is added, it is added after ReLU.

We use random i.i.d Gaussian inputs with mean 0 and std (abbreviated as GAUS) and CIFAR-10 as our dataset in the experiments. GAUS generates infinite number of samples while CIFAR-10 is a finite dataset. For GAUS, we use a random teacher network as the label provider (with classes). To make sure the weights of the teacher are weakly overlapped, we sample each entry of from , making sure they are non-zero and mutually different within the same layer, and sample biases from . In the FC case, the data dimension is 20 while in the ConvNet case it is . For CIFAR-10 we use a pre-trained teacher network with BatchNorm. In the FC case, it has an accuracy of ; for ConvNet, the accuracy is . We repeat 5 times for all experiments, with different random seed and report min/max values.

Two metrics are used to check our prediction that some lucky student nodes converge to the teacher:

Normalized correlation

. We compute normalized correlation (or cosine similarity)

between teacher and student activations111For and , , where . evaluated on a validation set. At each layer, we average the best correlation over teacher nodes: , where is computed for each teacher and student pairs . means that most teacher nodes are covered by at least one student.

Mean Rank . After training, each teacher node has the most correlated student node . We check the correlation rank of , normalized to (

=rank first), back at initialization and at different epochs, and average them over teacher nodes to yield mean rank

. Small means that student nodes that initially correlate well with the teacher keeps the lead toward the end of training.

Figure 8: Correlation and mean rank over training on GAUS. steadily grows and quickly improves over time. Layer-0 (the lowest layer that is closest to the input) shows best match with teacher nodes and best mean rank. BatchNorm helps achieve both better correlation and lower , in particular for the CNN case.
Figure 9: Same experiment setting as in Fig. 8 on CIFAR-10. BatchNorm helps achieve lower .
Figure 10: Ablation studies on GAUS. (a) converges much faster in small models (10-15-20-25) than in large model (50-75-100-125). BatchNorm hurts in small models. (b) stalls using finite samples.
Figure 11: Visualization of (transpose of) and matrix before and after optimization (using GAUS). Student node indices are reordered according to teacher-student node correlations. After optimization, student node who has high correlation with the teacher node also has high entries. Such a behavior is more prominent in matrix that combines and the activation patterns of student nodes (Sec. 2).

5.2 Results

Experiments are summarized in Fig. 8 and Fig. 9. indeed grows during training, in particular for low layers that are closer to the input, where moves towards . Furthermore, the final winning student nodes also have a good rank at the early stage of training, in particular after the first epoch, which is consistent with late-resetting used in lottery-scale . BatchNorm helps a lot, in particular for the CNN case with GAUS dataset. For CIFAR-10, the final evaluation accuracy (see Appendix) learned by the student is often higher than the teacher. Using BatchNorm accelerates the growth of accuracy, improves , but seems not to accelerate the growth of .

The theory also predicts that the top-down modulation helps the convergence. For this, we plot at different layers during optimization on GAUS. For better visualization, we align each student node index with a teacher node according to highest

. Despite the fact that correlations are computed from the low-layer weights, it matches well with the top-layer modulation (identity matrix structure in Fig. 

11). Besides, we also perform ablation studies on GAUS.

Size of teacher network. As shown in Fig. 10(a), for small teacher networks (FC 10-15-20-25), the convergence is much faster and training without BatchNorm is faster than training with BatchNorm. For large teacher networks, BatchNorm definitely increases convergence speed and growth of .

Figure 12: Effect of different degrees of over-parameterization. Left panel: Teacher FC (10-15-20-25) without batchnorm, right panel: teacher CNN (10-15-20-25) with batchnorm.

Degree of over-parameterization. Fig. 12 shows the effects of different degree of over-parameterization (, , , , and ). We initialize 32 different teacher network (10-15-20-25) with different random seed, and plot standard derivation with shaded area. We can clearly see that grows more stably and converges to higher values with over-parameterization. On the other hand, and are slower in convergence due to excessive parameters.

Finite versus Infinite Dataset. We also repeat the experiments with a pre-generated finite dataset of GAUS in the CNN case (Fig. 10(b)), and find that the convergence of node similarity stalls after a few iterations. This is because some nodes receive very few data points in their activated regions, which is not a problem for infinite dataset. We suspect that this is probably the reason why CIFAR-10, as a finite dataset, does not show similar behavior as GAUS.

6 Conclusion and Future Work

In this paper we propose a new theoretical framework that uses teacher-student setting to understand the training dynamics of multi-layered ReLU network. With this framework, we are able to conceptually explain many puzzling phenomena in deep networks, such as why over-parameterization helps generalization, why the same network can fit to both random and structured data, why lottery tickets lottery ; lottery-scale exist. We backup these intuitive explanations by Theorem 4 and Theorem 5, which collectively show that student nodes that are initialized to be close to the teacher nodes converge to them with a faster rate, and the fan-out weights of irrelevant nodes converge to zero. As the next steps, we aim to extend Theorem 5 to general multi-layer setting (when both and are present), relax Assumption 3 and study more BatchNorm effects than what Theorem 3 suggests.

7 Acknowledgement

The first author thanks Simon Du, Jason Lee, Chiyuan Zhang, Rong Ge, Greg Yang, Jonathan Frankle and many others for the informal discussions.


8 Appendix: Proofs

Figure 13:

Teacher-Student Setting, loss function and notations.

8.1 Theorem 1


The first part of gradient backpropagated to node is:


Therefore, for the gradient to node , we have:


And similar for . Therefore, by mathematical induction, we know that all gradient at nodes in different layer follows the same form. ∎

Figure 14: BatchNorm explanation
Figure 15: Different BatchNorm Configuration.

8.2 Theorem 2


Using Thm. 1, we can write down weight update for weight that connects node and node :


Note that , , and run over all parents and children nodes on the teacher side. This formulation works for over-parameterization (e.g., and can run over different nodes). Applying Assumption 1 and rearrange terms in matrix form yields Eqn. 6. ∎

8.3 Theorem 3


Given a batch with size , denote pre-batchnorm activations as and gradients as (See Fig. 14(a)). is its whitened version, and is the final output of BN. Here and and ,

are learnable parameters. With vector notation, the gradient update in BN has a compact form with clear geometric meaning:

Lemma 1 (Backpropagation of Batch Norm [35]).

For a top-down gradient , BN layer gives the following gradient update ( is the orthogonal complementary projection of subspace ):


Intuitively, the back-propagated gradient is zero-mean and perpendicular to the input activation of BN layer, as illustrated in Fig. 14. Unlike [16, 38] that analyzes BN in an approximate manner, in Thm. 1 we do not impose any assumptions.

Given Lemma 1, we can prove Thm. 3. For Fig. 15(a), using the property that (the expectation is taken over batch) and the weight update rule (over the same batch), we have:


For Fig. 15(b), note that and conclusion follows. ∎

8.4 Lemmas

For simplicity, in the following, we use .

Lemma 2 (Bottom Bounds).

Assume all . Denote


If Assumption 2 holds, we have:


If Assumption 3 also holds, then:


We have for :


If Assumption 3 also holds, we have:


Lemma 3 (Top Bounds).



If Assumption 2 holds, we have:


If Assumption 3 also holds, then:


The proof is similar to Lemma 2. ∎

Figure 16: Explanation of Lemma. 4.
Lemma 4 (Quadratic fall-off for diagonal elements of ).

For node , we have:


The intuition here is that both the volume of the affected area and the weight difference are proportional to . is their product and thus proportional to . See Fig. 16. ∎

8.5 Theorem 4


First of all, note that . So given , we also have a bound for .

When , the matrix form can be written as the following:


by using (and thus doesn’t matter). Since is conserved, it suffices to check whether the projected weight vector of onto the complementary space of the ground truth node , goes to zero: