The Loss Surface of Residual Networks: Ensembles and the Role of Batch Normalization

11/08/2016 ∙ by Etai Littwin, et al. ∙ 0

Deep Residual Networks present a premium in performance in comparison to conventional networks of the same depth and are trainable at extreme depths. It has recently been shown that Residual Networks behave like ensembles of relatively shallow networks. We show that these ensembles are dynamic: while initially the virtual ensemble is mostly at depths lower than half the network's depth, as training progresses, it becomes deeper and deeper. The main mechanism that controls the dynamic ensemble behavior is the scaling introduced, e.g., by the Batch Normalization technique. We explain this behavior and demonstrate the driving force behind it. As a main tool in our analysis, we employ generalized spin glass models, which we also use in order to study the number of critical points in the optimization of Residual Networks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Residual Networks (He et al., 2015)

(ResNets) are neural networks with skip connections. These networks, which are a specific case of Highway Networks 

(Srivastava et al., 2015)

, present state of the art results in the most competitive computer vision tasks including image classification and object detection.

The success of residual networks was attributed to the ability to train very deep networks when employing skip connections (He et al., 2016). A complementary view is presented by Veit et al. (2016), who attribute it to the power of ensembles and present an unraveled view of ResNets that depicts ResNets as an ensemble of networks that share weights, with a binomial depth distribution around half depth. They also present experimental evidence that short paths of lengths shorter than half-depth dominate the ResNet gradient during training.

The analysis presented here shows that ResNets are ensembles with a dynamic depth behavior. When starting the training process, the ensemble is dominated by shallow networks, with depths lower than half-depth. As training progresses, the effective depth of the ensemble increases. This increase in depth allows the ResNet to increase its effective capacity as the network becomes more and more accurate.

Our analysis reveals the mechanism for this dynamic behavior and explains the driving force behind it. This mechanism remarkably takes place within the parameters of Batch Normalization (Ioffe & Szegedy, 2015), which is mostly considered as a normalization and a fine-grained whitening mechanism that addresses the problem of internal covariate shift and allows for faster learning rates.

We show that the scaling introduced by batch normalization determines the depth distribution in the virtual ensemble of the ResNet. These scales dynamically grow as training progresses, shifting the effective ensemble distribution to bigger depths.

The main tool we employ in our analysis is spin glass models. Choromanska et al. (2015) have created a link between conventional networks and such models, which leads to a comprehensive study of the critical points of neural networks based on the spin glass analysis of Auffinger et al. (2013). In our work, we generalize these results and link ResNets to generalized spin glass models. These models allow us to analyze the dynamic behavior presented above. Finally, we apply the results of Auffinger & Arous (2013) in order to study the loss surface of ResNets.

2 A recap of Choromanska et al. (2015)

We briefly summarize Choromanska et al. (2015)

, which connects the loss function of multilayer networks with the hamiltonian of the p spherical spin glass model, and state their main contributions and results. The notations of our paper are summarized in Appendix 

A and slightly differ from those in Choromanska et al. (2015).

A simple feed forward fully connected network , with layers and a single output unit is considered. Let be the number of units in layer , such that is the dimension of the input, and

. It is further assumed that the ReLU activation functions denoted by

are used. The output

of the network given an input vector

can be expressed as

(1)

where the first summation is over the network inputs , and the second is over all paths from input to output. There are such paths and . The variable denotes whether the path is active, i.e., whether all of the ReLU units along this path are producing positive activations, and the product represents the specific weight configuration multiplying given path

. It is assumed throughout the paper that the input variables are sampled i.i.d from a normal Gaussian distribution.

Definition 1.

The mass of the network is defined as .

are modeled as independent Bernoulli random variables with a success probability

, i.e., each path is equally likely to be active. Therefore,

(2)

The task of binary classification using the network with parameters is considered, using either the hinge loss or the absolute loss :

(3)

where is a random variable corresponding to the true label of sample . In order to equate either loss with the hamiltonian of the p-spherical spin glass model, a few key approximations are made:

[leftmargin=!,labelwidth=A4]

A1

Variable independence - The inputs are modeled as independent normal Gaussian random variables.

A2

Redundancy in network parameterization - It is assumed the set of all the network weights contains only unique weights such that .

A3

Uniformity - It is assumed that all unique weights are close to being evenly distributed on the graph of connections defining the network . Practically, this means that we assume every node is adjacent to an edge with any one of the unique weights.

A4

Spherical constraint - The following is assumed:

(4)

for some constant .

These assumptions are made for the sake of analysis and do not hold. For example, A1 does not hold since each input is associated with many different paths and . See Choromanska et al. (2015) for further justification of these approximations.

Under A1A4, the loss takes the form of a centered Gaussian process on the sphere . Specifically, it is shown to resemble the hamiltonian of the a spherical p-spin glass model given by:

(5)

with spherical constraint

(6)

where are independent normal Gaussian variables.

In Auffinger et al. (2013)

, the asymptotic complexity of spherical p spin glass model is analyzed based on random matrix theory. In 

Choromanska et al. (2015) these results are used in order to shed light on the optimization process of neural networks. For example, the asymptotic complexity of spherical spin glasses reveals a layered structure of low-index critical points near the global optimum. These findings are then given as a possible explanation to several central phenomena found in neural networks optimization, such as similar performance of large nets, and the improbability of getting stuck in a “bad” local minima.

As part of our work, we follow a similar path. First, a link is formed between residual networks and the general multi interaction spherical spin glass model. Then, using Auffinger & Arous (2013), we obtain insights on residual networks. The other part of our work studies the dynamic behavior of neural networks using the same spin glass models.

3 Residual nets and general spin glass models

We begin by establishing a connection between the loss function of deep residual networks and the hamiltonian of the general spherical spin glass model. We consider a simple feed forward fully connected network

, with ReLU activation functions and residual connections. For simplicity of notations without the loss of generality, we assume

. as before. In our ResNet model, there exist identity connections skipping a single layer each, starting from the first hidden layer. The output of layer is given by:

(7)

where denotes the weight matrix connecting layer with layer . Notice that the first hidden layer has no parallel skip connection, and so . Without loss of generality, the scalar output of the network is the sum of the outputs of the output layer and is expressed as

(8)

where denotes whether path of length is open, and . The residual connections in imply that the output is now the sum of products of different lengths, indexed by . Each path of length includes non-skip connections (those involving the first term in Eq. 7 and not the second, identity term) out of layers . Therefore, . We define the following measure on the network:

Definition 2.

The mass of a depth subnetwork in is defined as .

The properties of redundancy in network parameters and their uniform distribution, as described in Sec. 

2, allow us to re-index Eq. 8.

Lemma 1.

Assuming assumptions hold, and , then the output can be expressed after reindexing as:

(9)

All proofs can be found in Appendix B.

Making the modeling assumption that the ReLU gates are independent Bernoulli random variables with probability , we obtain that for every path of length , and

(10)

In order to connect ResNets to generalized spherical spin glass models, we denote the variables:

(11)

Note that since the input variables are sampled from a centered Gaussian distribution (dependent or not), then the set of variables are dependent normal Gaussian variables.

Lemma 2.

Assuming hold, and then the following holds:

(12)

We approximate the expected output with by assuming the minimal value in 12 holds such that . The following expression for is thus obtained:

(13)

The independence assumption A1 was not assumed yet, and 13 holds regardless. Assuming A4 and denoting the scaled weights , we can link the distribution of to the distribution on :

(14)

where and is a normalization factor such that .

The following lemma gives a generalized expression for the binary and hinge losses of the network.

Lemma 3Choromanska et al. (2015)).

Assuming assumptions hold, then both the losses and can be generalized to a distribution of the form:

(15)

where are positive constants that do not affect the optimization process, and will be omitted in the following sections.

The model in Eq. 15 has the form of a spin glass model, except for the dependency between the variables . We later use an assumption similar to of independence between these variables in order to link the two binary classification losses and the general spherical spin glass model. However, for the results in this section, this is not necessary.

We denote the important quantities:

(16)

The series determines the weight of interactions of a specific length in the loss surface. Notice that for constant depth and large enough , . Therefore, for wide networks, where and, therefore, are large, interactions of order dominate the loss surface, and the effect of the residual connections diminishes. Conversely, for constant and a large enough (deep networks), we have that , and can expect interactions of order to dominate the loss. The asymptotic behavior of is captured by the following lemma:

Theorem 1.

Assuming , we have that:

(17)

As the next theorem shows, the epsilons are concentrated in a narrow band near the maximal value.

Theorem 2.

For any , and assuming , it holds that:

(18)

Thm. 2 implies that for deep residual networks, the contribution of weight products of order far away from the maximum is negligible. The loss is, therefor, similar in complexity to that of an ensemble of potentially shallow conventional nets. In a common weight initialization scheme for neural networks,  (Orr & Müller, 2003; Glorot & Bengio, 2010). With this initialization and , and the maximal weight is obtained at less than half the network’s depth . Therefore, at the initialization, the loss function is primarily influenced by interactions of considerably lower order than the depth , which facilitates easier optimization.

4 Dynamic behavior of residual nets

The expression for the output of a residual net in Eq. 14 provides valuable insights into the machinery at work when optimizing such models. Thm. 1 and 2 imply that the loss surface resembles that of an ensemble of shallow nets (although not a real ensemble due to obvious dependencies), with various depths concentrated in a narrow band. As noticed in Veit et al. (2016), viewing ResNets as ensembles of relatively shallow networks helps in explaining some of the apparent advantages of these models, particularly the apparent ease of optimization of extremely deep models, since deep paths barely affect the overall loss of the network. However, this alone does not explain the increase in accuracy of deep residual nets over actual ensembles of standard networks. In order to explain the improved performance of ResNets, we make the following claims:

  1. The mixture vector determines the distribution of the depths of the networks within the ensemble, and is controlled by the scaling parameter .

  2. During training, changes and causes a shift of focus from a shallow ensemble to deeper and deeper ensembles, which leads to an additional capacity.

  3. In networks that employ batch normalization, is directly embodied as the scale parameter . The starting condition of offers a good starting condition that involves extremely shallow nets.

The next lemma validates item 1 from this list of claims. It shows that we can shift the effective depth to any value by simply controlling .

Lemma 4.

For any integer there exists a global scaling parameter such that .

A simple global scaling of the weights is, therefore, enough to change the loss surface, from an ensemble of shallow conventional nets, to an ensemble of deep nets. This is illustrated in Fig. 1(a-c) for various values of .

In order to gain additional insight into this dynamic mechanism, we investigate the derivative of the loss with respect to the scale parameter . By noticing that , and using Eq. 15 we obtain:

(19)

Notice that the addition of a multiplier indicates that the derivative is increasingly influenced by deeper networks.

4.1 Batch normalization

Batch normalization has shown to be a crucial factor in the successful training of deep residual networks. As we will show, batch normalization layers offer an easy starting condition for the network, such that the gradients from early in the training process will originate from extremely shallow paths.

We consider a simple batch normalization procedure, which ignores the additive terms, has the output of each ReLU unit in layer normalized by a factor and then is multiplied by some parameter . The output of layer is therefore:

(20)

where

is the mean of the estimated standard deviations of various elements in the vector

. Furthermore, a typical initialization of batch normalization parameters is to set

. In this case, providing that units in the same layer have equal variance

, the recursive relation holds for any unit in layer . This, in turn, implies that the output of the ReLU units should have increasing variance as a function of depth. Multiplying the weight parameters in deep layers with an increasingly small scaling factor , effectively reduces the influence of deeper paths, so that extremely short paths will dominate the early stages of optimization. We next analyze how the weight scaling, as introduced by batch normalization, provides a driving force for the effective ensemble to become deeper as training progresses.

(a) (b) (c)
(d) (e) (f)
Figure 1: (a) A histogram of , , for and . (b) Same for (c) Same for . (d) Values (y-axis) of the batch normalization parameters (x-axis) for 10 layers ResNet trained to discriminate between 50 multivariate Gaussians. Higher plot lines indicate later stages of training. (e) The norm of the weights of a residual network, which does not employ batch normalization, as a function of the iteration. (f) The asymptotic of the mean number of critical points of a finite index as a function of .

4.2 The driving force behind the scale increase

In the following analysis, we examine the mechanics of a simple example, which can be extrapolated to more general architectures.

We consider a simple network of depth , with a single residual connection skipping layers. We further assume that batch normalization is applied at the output of each ReLU unit as described in Eq. 20. We denote by the indices of layers that are not skipped by the residual connection, and , . Since every path of length is multiplied by , and every path of length is multiplied by , the expression for the loss can be written:

(21)

We denote by the derivative operator with respect to the parameters , and the gradient evaluated at point .

Theorem 3.

Considering the loss in 21, and assuming , then for a small learning rate the following hold:

  1. For any , we have:

    (22)
  2. Assuming , for any we have:

    (23)

Thm. 3 suggests that will increase for layers that do not have skip-connections. Conversely, if layer has a parallel skip connection, then will increase when the gradient from deeper paths becomes dominant as shallow paths reach a local minima. Notice that an increase in results in an increase in , while remains unchanged, therefore shifting the balance into deeper ensembles.

This steady increase of , as predicted in our theoretical analysis, is also backed in experimental results, as depicted in Fig. 1(d). Note that the first layer, which cannot be skipped, behaves differently than the other layers.

It is worth noting that the mechanism for this dynamic property of residual networks can also be observed without the use of batch normalization, as a steady increase in the norm of the weights, as shown in Fig. 1(e). In order to model this, consider the residual network as discussed above, without batch normalization layers. Recalling, , the loss of this network is expressed as:

(24)
Theorem 4.

Considering the loss in 24, and assuming , then for a small learning rate the following hold:

(25)

Thm. 4 indicates that when deeper gradients become dominant (for example, near local minimas of the shallow network), the scaling of the weights will increase. This expansion will, in turn, emphasize the contribution of deeper paths, and increase the overall capacity of the residual network.

5 The loss surface of Ensembles

We now present the results of Auffinger & Arous (2013) regarding the asymptotic complexity in the case of of the multi-spherical spin glass model given by:

(26)

where are independent centered standard Gaussian variables, and are positive real numbers such that . A configuration of the spin spherical spin-glass model is a vector in satisfying the spherical constraint:

(27)

Note that the variance of the process is independent of :

(28)
Definition 3.

We define the following:

(29)

Note that for the single interaction spherical spin model . The index of a critical point of

is defined as the number of negative eigenvalues in the hessian

evaluated at the critical point .

Definition 4.

For any and , we denote the random number as the number of critical points of the hamiltonian in the set with index . That is:

(30)

Furthermore, define . Corollary 1.1 of Auffinger & Arous (2013) states that for any :

(31)

Eq. 31 provides the asymptotic mean total number of critical points with non-diverging index . It is presumed that the SGD algorithm will easily avoid critical points with a high index that have many descent directions, and maneuver towards low index critical points. We, therefore, investigate how the mean total number of low index critical points vary as the ensemble distribution embodied in changes its shape by a steady increase in .

Fig. 1(f) shows that as the ensemble progresses towards deeper networks, the mean amount of low index critical points increases, which might cause the SGD optimizer to get stuck in local minima. This is, however, resolved by the the fact that by the time the ensemble becomes deep enough, the loss function has already reached a point of low energy as shallower ensembles were more dominant earlier in the training. In the following theorem, we assume a finite ensemble such that .

Theorem 5.

For any , we denote the solution to the following constrained optimization problems:

(32)

It holds that:

(33)

Theorem 5 implies that any heterogeneous mixture of spin glasses contains fewer critical points of a finite index, than a mixture in which only interactions are considered. Therefore, for any distribution of that is attainable during the training of a ResNet of depth , the number of critical points is lower than the number of critical points for a conventional network of depth .

6 Conclusion

Ensembles are a powerful model for ResNets, which unravels some of the key questions that have surrounded ResNets since their introduction. Here, we show that ResNets display a dynamic ensemble behavior, which explains the ease of training such networks even at very large depths, while still maintaining the advantage of depth. As far as we know, the dynamic behavior of the effective capacity is unlike anything documented in the deep learning literature. Surprisingly, the dynamic mechanism typically takes place within the outer multiplicative factor of the batch normalization module.

References

Appendix A Summary of notations

Table 1 presents the various symbols used throughout this work and their meaning.

SYMBOL DESCRIPTION
x Input vector

, sampled from a normal distribution

d The dimensionality of the input
The output of layer of network given input
The final output of the network
True label of input
Loss function of network
Hinge loss
Absolute loss
The depth of network
Weights of the network
C A positive scale factor such that
Scaled weights such that
n The number of units in layers
The number of unique weights in the network
The total number of weights in the network
The weight matrix connecting layer to layer in .
The hamiltonian of the interaction spherical spin glass model.
The hamiltonian of the general spherical spin glass model.
Total number of paths from input to output in network
Total number of paths from input to output in network of length

ReLU activation function
Bernoulli random variable associated with the ReLU activation function, indexed by .

Parameter of the Bernoulli distribution associated with the ReLU unit

multiplier associated with paths of length in .
.
Normalization factor.
Batch normalization multiplicative factor in layer .
The mean of the estimated standard deviation various elements in .
Table 1: Notations

Appendix B Proofs

Proof of Lemma 1.

There are a total of paths of length from input to output, and a total of unique length configurations of weights. The uniformity assumption then implies that each configuration of weights is repeated times. By summing over the unique configurations, and re indexing the input we arrive at Eq. 9. ∎

Proof of Lemma 2.

From 11, we have that for each there exists a sequence such that , and . We, therefore, have that . Note that the minimum value of is a solution to the following:

(34)

which achieves its minimal value at . Similarly, the maximum value is achieved at for some index . ∎

Proof of Thm. 1.

We use the stirling approximation, which states , where . Ignoring the constants which do not depend on ,

(35)

which achieves its maximum value at . ∎

Proof of Thm. 2.

For brevity, we provide a sketch of the proof. It is enough to show that for . Ignoring the constants in the binomial terms, we have:

(36)

Where , which can be expressed using the Legendre polynomial of order :

(37)

In order to compute the limit of Eq. 36, we use the asymptotic of the Legendre polynomial of order for , . For the term in the nominator of Eq. 36 , we use the Stirling approximation for factorials . Substituting both approximations in Eq. 36 and taking the limit completes the proof. ∎

Proof of Lemma 4.

For simplicity, we ignore the constants in the binomial coefficient, and assume . Notice that for , we have that , and . From the monotonicity and continuity of , any value can be attained. The linear dependency completes the proof. ∎

Proof of Thm. 3.

1. Notice that by definition, layer is not skipped by the residual connection, and therefore multiplies every path in the network. Therefore, . Using taylor series expansion:

(38)

Substituting in 38 we have:

(39)

And hence:

(40)

Finally:

(41)

2. Since paths of length skip layer , we have that . Therefore: