Understanding Training Dynamics of Deep ReLU Networks
We analyze the dynamics of training deep ReLU networks and their implications on generalization capability. Using a teacher-student setting, we discovered a novel relationship between the gradient received by hidden student nodes and the activations of teacher nodes for deep ReLU networks. With this relationship and the assumption of small overlapping teacher node activations, we prove that (1) student nodes whose weights are initialized to be close to teacher nodes converge to them at a faster rate, and (2) in over-parameterized regimes and 2-layer case, while a small set of lucky nodes do converge to the teacher nodes, the fan-out weights of other nodes converge to zero. This framework provides insight into multiple puzzling phenomena in deep learning like over-parameterization, implicit regularization, lottery tickets, etc. We verify our assumption by showing that the majority of BatchNorm biases of pre-trained VGG11/16 models are negative. Experiments on (1) random deep teacher networks with Gaussian inputs, (2) teacher network pre-trained on CIFAR-10 and (3) extensive ablation studies validate our multiple theoretical predictions.READ FULL TEXT VIEW PDF
Understanding Training Dynamics of Deep ReLU Networks
), a number of fundamental questions still remain unsolved. How can Stochastic Gradient Descent (SGD) find good solutions to a complicated non-convex optimization problem? Why do neural networks generalize? How can networks trained with SGD fit both random noise and structured datarethinking ; krueger2017deep ; neyshabur2017exploring , but prioritize structured models, even in the presence of massive noise rolnick2017deep ? Why are flat minima related to good generalization? Why does over-parameterization lead to better generalization neyshabur2018towards ; zhang2019identity ; spigler2018jamming ; neyshabur2014search ; li2018measuring ? Why do lottery tickets exist lottery ; lottery-scale ?
In this paper, we propose a theoretical framework for multilayered ReLU networks. Based on this framework, we try to explain these puzzling empirical phenomena with a unified view. We adopt a teacher-student setting where the label provided to an over-parameterized deep student ReLU network is the output of a fixed teacher ReLU network of the same depth and unknown weights (Fig. 1(a)). Here over-parameterization means that at each layer, the number of nodes in student network is more than the number of nodes in the teacher network. In this perspective, hidden student nodes are randomly initialized with different activation regions (Fig. 2(a)). During optimization, student nodes compete with each other to explain teacher nodes. From this setting, Theorem 4 shows that lucky student nodes which have greater overlap with teacher nodes converge to those teacher nodes at a fast rate, resulting in winner-take-all behavior. Furthermore, Theorem 5 shows that in the 2-layer case, if a subset of student nodes are close to the teachers’, they converge to them and the fan-out weights of other irrelevant nodes of the same layer vanishes.
With this framework, we try to intuitively explain various neural network behaviors as follows:
Fitting both structured and random data. Under gradient descent dynamics, some student nodes, which happen to overlap substantially with teacher nodes, will move into the teacher node and cover them. This is true for both structured data that corresponds to small teacher networks with few intermediate nodes, or noisy/random data that correspond to large teachers with many intermediate nodes. This explains why the same network can fit both structured and random data (Fig. 2(a-b)).
Over-parameterization. In over-parameterization, lots of student nodes are initialized randomly at each layer. Any teacher node is more likely to have a substantial overlap with some student nodes, which leads to fast convergence (Fig. 2(a) and (c), Thm. 4), consistent with lottery ; lottery-scale . This also explains that training models whose capacity just fit the data (or teacher) yields worse performance li2018measuring .
. Deep networks often converge to “flat minima” whose Hessian has a lot of small eigenvaluessagun2016eigenvalues ; sagun2017empirical ; lipton2016stuck ; baity2018comparing . Furthermore, while controversial dinh2017sharp , flat minima seem to be associated with good generalization, while sharp minima often lead to poor generalization hochreiter1997flat ; keskar2016large ; wu2017towards ; li2018visualizing . In our theory, when fitting with structured data, only a few lucky student nodes converge to the teacher, while for other nodes, their fan-out weights shrink towards zero, making them (and their fan-in weights) irrelevant to the final outcome (Thm. 5), yielding flat minima in which movement along most dimensions (“unlucky nodes”) results in minimal change in output. On the other hand, sharp minima is related to noisy data (Fig. 2(d)), in which more student nodes match with the teacher.
Implicit regularization. On the other hand, the snapping behavior enforces winner-take-all: after optimization, a teacher node is fully covered (explained) by a few student nodes, rather than splitting amongst student nodes due to over-parameterization. This explains why the same network, once trained with structured data, can generalize to the test set.
Lottery Tickets. Lottery Tickets lottery ; lottery-scale ; zhou2019deconstructing is an interesting phenomenon: if we reset “salient weights” (trained weights with large magnitude) back to the values before optimization but after initialization, prune other weights (often of total weights) and retrain the model, the test performance is the same or better; if we reinitialize salient weights, the test performance is much worse. In our theory, the salient weights are those lucky regions ( and in Fig. 3) that happen to overlap with some teacher nodes after initialization and converge to them in optimization. Therefore, if we reset their weights and prune others away, they can still converge to the same set of teacher nodes, and potentially achieve better performance due to less interference with other irrelevant nodes. However, if we reinitialize them, they are likely to fall into unfavorable regions which cannot cover teacher nodes, and therefore lead to poor performance (Fig. 3(c)), just like in the case of under-parameterization. Recently, Supermask zhou2019deconstructing shows that a supermask can be found from winning tickets. If it is applied to initialized weights, the network without training gives much better test performance than chance. This is also consistent with the intuitive picture in Fig. 3(b).
Notation. Consider a student network and its associated teacher network (Fig. 1(a)). Denote the input as . For each node , denote as the activation, as the ReLU gating (for the top-layer, and are always ), and
as the backpropagated gradient, all as functions of. We use the superscript to represent a teacher node (e.g., ). Therefore, never appears as teacher nodes are not updated. We use to represent weight between node and in the student network. Similarly, represents the weight between node and in the teacher network.
We focus on multi-layered ReLU networks. We use the following equality extensively: . For ReLU node , we use as the activation region of node .
Teacher network versus Dataset. The reason why we formulate the problem using teacher network rather than a dataset is the following: (1) It leads to a nice and symmetric formulation for multi-layered ReLU networks (Thm. 1). (2) A teacher network corresponds to an infinite size dataset, which separates the finite sample issues from induction bias in the dataset, which corresponds to the structure of teacher network. (3) If student weights can be shown to converge to teacher ones, generalization bound can naturally follow for the student. (4) The label complexity of data generated from a teacher is automatically reduced, which could lead to better generalization bound. On the other hand, a bound for arbitrary function class can be hard.
. We assume that both the teacher and the student output probabilities overclasses. We use the output of teacher as the input of the student. At the top layer, each node in the student corresponds to each node in the teacher. Therefore, the objective is:
By the backpropagation rule, we know that for each sample , the (negative) gradient . The gradient gets backpropagated until the first layer is reached.
Note that here, the gradient sent to node is correlated with the activation of the corresponding teacher node and other student nodes at the same layer. Intuitively, this means that the gradient “pushes” the student node to align with class of the teacher. If so, then the student learns the corresponding class well. A natural question arises:
Are student nodes at intermediate layers correlated with teacher nodes at the same layers?
One might wonder this is hard since the student’s intermediate layer receives no direct supervision from the corresponding teacher layer, but relies only on backpropagated gradient. Surprisingly, the following theorem shows that it is possible for every intermediate layer:
Note that Theorem 1 applies to arbitrarily deep ReLU networks and allows different number of nodes for the teacher and student. The role played by ReLU activation is to make the expression of concise, otherwise and can take a very complicated (and asymmetric) form.
In particular, we consider the over-parameterization setting: the number of nodes on the student side is much larger (e.g., 5-10x) than the number of nodes on the teacher side. Using Theorem 1, we discover a novel and concise form of gradient update rule:
Here we explain the notations. is teacher weights, , and , , and . We can define similar notations for (which has columns/filters), , , and (Fig. 4(c)). At the lowest layer , , at the highest layer where there is no ReLU, we have due to Eqn. 1. According to network structure, and only depends on weights , while and only depend on .
In the following, we will use Eqn. 6 to analyze the dynamics of the multi-layer ReLU networks. For convenience, we first define the two functions and ( is the ReLU function):
We assume these two functions have the following property .
There exists and so that:
Using this, we know that , , and so on. For brevity, denote (when notation is heavy) and so on. We impose the following assumption:
There exists and so that:
Intuitively, this means that the probability of the simultaneous activation of two teacher nodes and is small. If we have sufficient training data to cover the input space, then a sufficient condition for Assumption 3 to happen is that the teacher has negative bias, which means that they cut corners in the space spanned by the node activations of the lower layer (Fig. 4
a). We have empirically verified that the majority of biases in BatchNorm layers (after the data are whitened) are negative in VGG11/16 trained on ImageNet (Sec.4.1).
Batch Normalization batchnorm has been extensively used to speed up the training, reduce the tuning efforts and improve the test performance of neural networks. Here we use an interesting property of BatchNorm: the total “energy” of the incoming weights of each node is conserved over training iterations:
For Linear ReLU BN or Linear BN ReLU configuration, of a filter before BN remains constant in training. (Fig. 15).
See Appendix for the proof. The similar lemma is also in arora2018theoretical . This may partially explain why BN has stabilization effect: energy will not leak from one layer to nearby ones. Due to this property, in the following, for convenience we assume , and the gradient is always orthogonal to the current weight . Note that on the teacher side we can always push the magnitude component to the upper layer; on the student side, random initialization naturally leads to constant magnitude of weights.
We start with a simple case first. Consider that we only analyze layer without over-parameterization, i.e., . We also assume that , i.e., the input of that layer is whitened, and the top-layer signal is uniform, i.e., (all entries are 1). Then the following theorem shows that weight recovery could follow (we use as ).
For dynamics , where is a projection matrix into the orthogonal complement of . , are corresponding -th column in and . Denote and assume . If , then with the rate ( is learning rate). Here and .
See Appendix for the proof. Here we list a few remarks:
Faster convergence near . we can see that due to the fact that in general becomes larger when (since can be close to ), we expect a super-linear convergence near . This brings about an interesting winner-take-all mechanism: if the initial overlap between a student node and a particular teacher node is large, then the student node will snap to it (Fig. 1(c)).
Importance of projection operator . Intuitively, the projection is needed to remove any ambiguity related to weight scaling, in which the output remains constant if the top-layer weights are multiplied by a constant , while the low-layer weights are divided by . Previous works du2017gradient also uses similar techniques while we justify it with BN. Without , convergence can be harder.
Top-down modulation. Note that here we assume the top-layer signal is uniform, which means that according to , there is no preference on which student node corresponds to which teacher node . If there is a preference (e.g., ), then from the proof, the cross-term will be suppressed due to , making convergence easier. As we will see next, such a top-down modulation plays an important role for 2-layer and over-parameterization case. We believe that it also plays a similar role for deep networks.
In the over-parameterization case (, e.g., 5-10x), we arrange the variables into two parts: , where contains columns (same size as ), while contains columns. We use (or -set) to specify nodes , and (or -set) for the remaining part.
In this case, if we want to show “the main component” converges to , we will meet with one core question: to where will converge, or whether will even converge at all? We need to consider not only the dynamics of the current layer, but also the dynamics of the upper layer. Using a 1-hidden layer over-parameterized ReLU network as an example, Theorem 5 shows that the upper-layer dynamics automatically apply top-down modulation to suppress the influence of , regardless of their convergence. Here , where are the weight components of -set. See Fig. 5.
See Appendix for the proof (and definition of in Eqn. 45). The intuition is: if is close to and are far away from them due to Assumption 3, the off-diagonal elements of and are smaller than diagonal ones. This causes to move towards and to move towards zero. When becomes small, so does for or . This in turn suppresses the effect of and accelerates the convergence of . exponentially so that stays close to its initial locations, and Assumption 3 holds for all iterations. A few remarks:
Flat minima. Since , can be changed arbitrarily without affecting the outputs of the neural network. This could explain why there are many flat directions in trained networks, and why many eigenvalues of the Hessian are close to zero sagun2016eigenvalues .
Understanding of pruning methods. Theorem 5 naturally relates two different unstructured network pruning approaches: pruning small weights in magnitude han2015learning ; lottery and pruning weights suggested by Hessian lecun1990optimal ; hassibi1993optimal . It also suggests a principled structured pruning method: instead of pruning a filter by checking its weight norm, pruning accordingly to its top-down modulation.
Accelerated convergence and learning rate schedule. For simplicity, the theorem uses a uniform (and conservative) throughout the iterations. In practice, is initially small (due to noise introduced by -set) but will be large after a few iterations when vanishes. Given the same learning rate, this leads to accelerated convergence. At some point, the learning rate becomes too large, leading to fluctuation. In this case, needs to be reduced.
Many-to-one mapping. Theorem 5 shows that under strict conditions, there is one-to-one correspondence between teacher and student nodes. In general this is not the case. Two students nodes can be both in the vicinity of a teacher node and converge towards it, until that node is fully explained. We leave it to the future work for rigid mathematical analysis of many-to-one mappings.
Random initialization. One nice thing about Theorem 5 is that it only requires the initial to be small. In contrast, there is no requirement for small . Therefore, we could expect that with more over-parameterization and random initialization, in each layer , it is more likely to find the -set (of fixed size ), or the lucky weights, so that is quite close to . At the same time, we don’t need to worry about
which grows with more over-parameterization. Moreover, random initialization often gives orthogonal weight vectors, which naturally leads to Assumption3.
Using a similar approach, we could extend this analysis to multi-layer cases. We conjecture that similar behaviors happen: for each layer, due to over-parameterization, the weights of some lucky student nodes are close to the teacher ones. While these converge to the teacher, the final values of others irrelevant weights are initialization-dependent. If the irrelevant nodes connect to lucky nodes at the upper-layer, then similar to Thm. 5, the corresponding fan-out weights converge to zero. On the other hand, if they connect to nodes that are also irrelevant, then these fan-out weights are not-determined and their final values depends on initialization. However, it doesn’t matter since these upper-layer irrelevant nodes eventually meet with zero weights if going up recursively, since the top-most output layer has no over-parameterization. We leave a formal analysis to future work.
To make Theorem 4 and Theorem 5 work, we make Assumption 3 that the activation field of different teacher nodes should be well-separated. To justify this, we analyze the bias of BatchNorm layers after the convolutional layers in pre-trained VGG11/13/16/19. We check the BatchNorm bias as these models use Linear-BatchNorm-ReLU architecture. After BatchNorm first normalizes the input data into zero mean distribution, the BatchNorm bias determines how much data pass the ReLU threshold. If the bias is negative, then a small portion of data pass ReLU gating and Assumption 3 is likely to hold. From Fig. 6, it is quite clear that the majority of BatchNorm bias parameters are negative, in particular for the top layers.
We verify Thm. 5 by checking whether moves close to under different initialization. We use a network with one hidden layer. The teacher network is 10-20-30, while the student network has more nodes in the hidden layers. Input data are Gaussian noise. We initialize the student networks so that the first nodes are close to the teacher. Specifically, we first create matrices and by first filling with i.i.d Gaussian noise, and then normalizing their columns to . Then the initial value of student is , where is a factor controlling how close is to . For we initialize with noise. Similarly we initialize with a factor . The larger and , the close the initialization and to the ground truth values.
Fig. 7 shows the behavior over different iterations. All experiments are repeated 32 times with different random seeds, and (mean
std) are reported. We can see that a close initialization leads to faster (and low variance) convergence ofto small values. In particular, it is important to have close to (large ), which leads to a clear separation between row norms of and , even if they are close to each other at the beginning of training. Having close to makes the initial gap larger and also helps convergence. On the other hand, if is small, then even if is large, the gap between row norms of and only shifts but doesn’t expand over iterations.
We evaluate both the fully connected (FC) and ConvNet setting. For FC, we use a ReLU teacher network of size 50-75-100-125. For ConvNet, we use a teacher with channel size 64-64-64-64. The student networks have the same depth but with nodes/channels at each layer, such that they are substnatially over-parameterized. When BatchNorm is added, it is added after ReLU.
We use random i.i.d Gaussian inputs with mean 0 and std (abbreviated as GAUS) and CIFAR-10 as our dataset in the experiments. GAUS generates infinite number of samples while CIFAR-10 is a finite dataset. For GAUS, we use a random teacher network as the label provider (with classes). To make sure the weights of the teacher are weakly overlapped, we sample each entry of from , making sure they are non-zero and mutually different within the same layer, and sample biases from . In the FC case, the data dimension is 20 while in the ConvNet case it is . For CIFAR-10 we use a pre-trained teacher network with BatchNorm. In the FC case, it has an accuracy of ; for ConvNet, the accuracy is . We repeat 5 times for all experiments, with different random seed and report min/max values.
Two metrics are used to check our prediction that some lucky student nodes converge to the teacher:
. We compute normalized correlation (or cosine similarity)between teacher and student activations111For and , , where . evaluated on a validation set. At each layer, we average the best correlation over teacher nodes: , where is computed for each teacher and student pairs . means that most teacher nodes are covered by at least one student.
Mean Rank . After training, each teacher node has the most correlated student node . We check the correlation rank of , normalized to (
=rank first), back at initialization and at different epochs, and average them over teacher nodes to yield mean rank. Small means that student nodes that initially correlate well with the teacher keeps the lead toward the end of training.
Experiments are summarized in Fig. 8 and Fig. 9. indeed grows during training, in particular for low layers that are closer to the input, where moves towards . Furthermore, the final winning student nodes also have a good rank at the early stage of training, in particular after the first epoch, which is consistent with late-resetting used in lottery-scale . BatchNorm helps a lot, in particular for the CNN case with GAUS dataset. For CIFAR-10, the final evaluation accuracy (see Appendix) learned by the student is often higher than the teacher. Using BatchNorm accelerates the growth of accuracy, improves , but seems not to accelerate the growth of .
The theory also predicts that the top-down modulation helps the convergence. For this, we plot at different layers during optimization on GAUS. For better visualization, we align each student node index with a teacher node according to highest
. Despite the fact that correlations are computed from the low-layer weights, it matches well with the top-layer modulation (identity matrix structure in Fig.11). Besides, we also perform ablation studies on GAUS.
Size of teacher network. As shown in Fig. 10(a), for small teacher networks (FC 10-15-20-25), the convergence is much faster and training without BatchNorm is faster than training with BatchNorm. For large teacher networks, BatchNorm definitely increases convergence speed and growth of .
Degree of over-parameterization. Fig. 12 shows the effects of different degree of over-parameterization (, , , , and ). We initialize 32 different teacher network (10-15-20-25) with different random seed, and plot standard derivation with shaded area. We can clearly see that grows more stably and converges to higher values with over-parameterization. On the other hand, and are slower in convergence due to excessive parameters.
Finite versus Infinite Dataset. We also repeat the experiments with a pre-generated finite dataset of GAUS in the CNN case (Fig. 10(b)), and find that the convergence of node similarity stalls after a few iterations. This is because some nodes receive very few data points in their activated regions, which is not a problem for infinite dataset. We suspect that this is probably the reason why CIFAR-10, as a finite dataset, does not show similar behavior as GAUS.
In this paper we propose a new theoretical framework that uses teacher-student setting to understand the training dynamics of multi-layered ReLU network. With this framework, we are able to conceptually explain many puzzling phenomena in deep networks, such as why over-parameterization helps generalization, why the same network can fit to both random and structured data, why lottery tickets lottery ; lottery-scale exist. We backup these intuitive explanations by Theorem 4 and Theorem 5, which collectively show that student nodes that are initialized to be close to the teacher nodes converge to them with a faster rate, and the fan-out weights of irrelevant nodes converge to zero. As the next steps, we aim to extend Theorem 5 to general multi-layer setting (when both and are present), relax Assumption 3 and study more BatchNorm effects than what Theorem 3 suggests.
The first author thanks Simon Du, Jason Lee, Chiyuan Zhang, Rong Ge, Greg Yang, Jonathan Frankle and many others for the informal discussions.
International conference on machine learning, pages 173–182, 2016.
Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
Imagenet classification with deep convolutional neural networks.In Advances in neural information processing systems, pages 1097–1105, 2012.
The first part of gradient backpropagated to node is:
Therefore, for the gradient to node , we have:
And similar for . Therefore, by mathematical induction, we know that all gradient at nodes in different layer follows the same form. ∎
Using Thm. 1, we can write down weight update for weight that connects node and node :
Given a batch with size , denote pre-batchnorm activations as and gradients as (See Fig. 14(a)). is its whitened version, and is the final output of BN. Here and and ,
are learnable parameters. With vector notation, the gradient update in BN has a compact form with clear geometric meaning:
For a top-down gradient , BN layer gives the following gradient update ( is the orthogonal complementary projection of subspace ):
Intuitively, the back-propagated gradient is zero-mean and perpendicular to the input activation of BN layer, as illustrated in Fig. 14. Unlike [16, 38] that analyzes BN in an approximate manner, in Thm. 1 we do not impose any assumptions.
For simplicity, in the following, we use .
We have for :
If Assumption 3 also holds, we have:
The proof is similar to Lemma 2. ∎
For node , we have:
The intuition here is that both the volume of the affected area and the weight difference are proportional to . is their product and thus proportional to . See Fig. 16. ∎
First of all, note that . So given , we also have a bound for .
When , the matrix form can be written as the following:
by using (and thus doesn’t matter). Since is conserved, it suffices to check whether the projected weight vector of onto the complementary space of the ground truth node , goes to zero: