DeepAI
Log In Sign Up

Deep Learning without Shortcuts: Shaping the Kernel with Tailored Rectifiers

Training very deep neural networks is still an extremely challenging task. The common solution is to use shortcut connections and normalization layers, which are both crucial ingredients in the popular ResNet architecture. However, there is strong evidence to suggest that ResNets behave more like ensembles of shallower networks than truly deep ones. Recently, it was shown that deep vanilla networks (i.e. networks without normalization layers or shortcut connections) can be trained as fast as ResNets by applying certain transformations to their activation functions. However, this method (called Deep Kernel Shaping) isn't fully compatible with ReLUs, and produces networks that overfit significantly more than ResNets on ImageNet. In this work, we rectify this situation by developing a new type of transformation that is fully compatible with a variant of ReLUs – Leaky ReLUs. We show in experiments that our method, which introduces negligible extra computational cost, achieves validation accuracies with deep vanilla networks that are competitive with ResNets (of the same width/depth), and significantly higher than those obtained with the Edge of Chaos (EOC) method. And unlike with EOC, the validation accuracies we obtain do not get worse with depth.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

09/12/2017

Shifting Mean Activation Towards Zero with Bipolar Activation Functions

We propose a simple extension to the ReLU-family of activation functions...
05/03/2015

Highway Networks

There is plenty of theoretical and empirical evidence that depth of neur...
01/28/2020

Residual Tangent Kernels

A recent body of work has focused on the theoretical study of neural net...
01/21/2021

Characterizing signal propagation to close the performance gap in unnormalized ResNets

Batch Normalization is a key component in almost all state-of-the-art im...
04/06/2020

Evolving Normalization-Activation Layers

Normalization layers and activation functions are critical components in...
10/01/2021

ResNet strikes back: An improved training procedure in timm

The influential Residual Networks designed by He et al. remain the gold-...

1 Introduction

Figure 1:

Top-1 ImageNet validation accuracy of vanilla deep networks initialized using either EOC (with ReLU) or TAT (with LReLU) and trained with K-FAC.

Thanks to many architectural and algorithmic innovations, the recent decade has witnessed the unprecedented success of deep learning in various high-profile challenges, e.g., the ImageNet recognition task (Krizhevsky et al., 2012), the challenging board game of Go (Silver et al., 2017)

and human-like text generation 

(Brown et al., 2020). Among them, shortcut connections (He et al., 2016a; Srivastava et al., 2015) and normalization layers (Ioffe and Szegedy, 2015; Ba et al., 2016) are two architectural components of modern networks that are critically important for achieving fast training at very high depths, and feature prominently in the ubiquitous ResNet architecture of He et al. (2016b).

Despite the success of ResNets, there is significant evidence to suggest that the primary reason they work so well is that they resemble ensembles of shallower networks during training (Veit et al., 2016), which lets them avoid the common pathologies associated with very deep networks (e.g. Hochreiter et al., 2001; Duvenaud et al., 2014). Moreover, ResNets without normalization layers could lose expressivity as the depth goes to infinity (Hayou et al., 2021). In this sense, the question of whether truly deep networks can be efficient and effectively trained on challenging tasks remains an open one.

As argued by Oyedotun et al. (2020) and Ding et al. (2021), the multi-branch topology of ResNets also has certain drawbacks. For example, it is memory-inefficient at inference time, as the input to every residual block has to be kept in memory until the final addition. In particular, the shortcut branches in ResNet-50 account for about 40% of the memory usage by feature maps. Also, the classical interpretation of why deep networks perform well – because of the hierarchical feature representations they produce – does not strictly apply to ResNets, due to their aforementioned tendency to behave like ensembles of shallower networks. Beyond the drawbacks of ResNets, training vanilla deep neural networks (which we define as networks without shortcut connections or normalization layers) is an interesting research problem in its own right, and finding a solution could open the path to discovering new model architectures. However, recent progress in this direction has not fully succeeded in matching the generalization performance of ResNets.

Schoenholz et al. (2017)

used a mean-field analysis of deep MLPs to choose variances for the initial weights and bias parameters, and showed that the resulting method – called Edge of Chaos (EOC) – allowed vanilla networks to be trained at very high depths on small datasets. Building on EOC, and incorporating dynamical isometry theory,

Xiao et al. (2018) was able to train vanilla networks with Tanh units111Dynamical isometry is unavailable for ReLU (Pennington et al., 2017), even with orthogonal weights. at depths of up to 10,000. While impressive, these EOC-initialized networks trained significantly slower than standard ResNets of the same depth, and also exhibited significantly worse generalization performance. Qi et al. (2020) proposed to enforce the convolution kernels to be near isometric, but the gaps with ResNets are still significant on ImageNet. While Oyedotun et al. (2020) was able to narrow the generalization gap between vanilla networks and ResNets, their experiments were limited to networks with only 30 layers, and their networks required many times more parameters. More recently, Martens et al. (2021)

introduced a method called Deep Kernel Shaping (DKS) for initializing and transforming networks based on an analysis of their initialization-time kernel properties. They showed that their approach enabled vanilla networks to train faster than previous methods, even matching the speed of similarly sized ResNets when combined with stronger optimizers like K-FAC

(Martens and Grosse, 2015) or Shampoo (Anil et al., 2020). However, their method isn’t fully compatible with ReLUs, and in their experiments (which focused on training speed) their networks exhibited significantly more overfitting than ResNets.

Inspired by both DKS and the line of work using mean-field theory, we propose a new method called Tailored Activation Transformation (TAT). TAT inherits the main advantages of DKS, while working particularly well with the “Leaky ReLU” activation function. TAT enables very deep vanilla neural networks to be trained on ImageNet without the use of any additional architectural elements, while only introducing negligible extra computational cost. Using TAT, we demonstrate for the first time that a 50-layer vanilla deep network can nearly match the validation accuracy of its ResNet counterpart when trained on ImageNet. And unlike with the EOC method, validation accuracy we achieve does not decrease with depth (see Figure 1

). Furthermore, TAT can also be applied to ResNets without normalization layers, allowing them to match or even exceed the validation accuracy of standard ResNets of the same width/depth. A multi-framework open source implementation of DKS and TAT is available at

https://github.com/deepmind/dks.

2 Background

Our main tool of analysis will be kernel functions for neural networks (Neal, 1996; Cho and Saul, 2009; Daniely et al., 2016) and the related Q/C maps (Saxe et al., 2013; Poole et al., 2016; Martens et al., 2021). In this section, we introduce our notation and some key concepts used throughout.

2.1 Kernel Function Approximation for Wide Networks

For simplicity, we start with the kernel function approximation for feedforward fully-connected networks, and discuss its extensions to convolutional networks and non-feedforward networks later. In particular, we will assume a network that is defined by a sequence of combined layers (each of which is an affine transformation followed by the elementwise activation function ) as follows:

(1)

with weights initialized as (or scale-corrected uniform orthogonal matrices (Martens et al., 2021)), and biases initialized to zero. Due to the randomness of the initial parameters , the network can be viewed as random feature model at each layer (with ) at initialization time. This induces a random kernel defined as follows:

(2)

Given these assumptions, as the width of each layer goes to infinity,

converges in probability (see Theorem 

A.1) to a deterministic kernel that can be computed layer by layer:

(3)

where .

2.2 Local Q/C maps

By equation 3, any diagonal entry of only depends on the corresponding diagonal entry of . Hence, we obtain the following recursion for these diagonal entries, which we call q values:

(4)

where is the local Q map. We note that is an approximation of . Analogously, one can write the recursion for the normalized off-diagonal entries, which we call c values, as:

(5)

where is the local C map and . We note that

is an approximation of the cosine similarity between

and . Because is a three dimensional function, it is difficult to analyze, as the associated q values can vary wildly for distinct inputs. However, by scaling the inputs to have norm , and rescaling so that , it follows that for all . This allows us to treat as a one dimensional function from to satisfying . Additionally, it can be shown that possesses special structure as a positive definite function (see Appendix A.4 for details). Going forward, we will thus assume that , and that is scaled so that .

2.3 Extensions to convolutional networks and more complex topologies

As argued in Martens et al. (2021), Q/C maps can also be defined for convolutional networks if one adopts a Delta initialization (Balduzzi et al., 2017; Xiao et al., 2018), in which all weights except those in the center of the filter are initialized to zero. Intuitively, this makes convolutional networks behave like a collection of fully-connected networks operating independently over feature map locations. As such, the Q/C map computations for a feed-forward convolutional network are the same as above. Martens et al. (2021) also gives formulas to compute q and c values for weighted sum operations between the outputs of multiple layers (without nonlinearities), thus allowing more complex network topologies. In particular, the sum operation’s output q value is given by , and its output c value is given by . In order to maintain the property that all q values are 1 in the network, we will assume that sum operations are normalized in the sense that .

Following Martens et al. (2021), we will extend the definition of Q/C maps to include global Q/C maps, which describe the behavior of entire networks. Global maps, denoted by and for a given network , can be computed by applying the above rules for each layer in . For example, the global C map of a three-layer network is simply . Like the local C map, global C maps are positive definite functions (see Appendix A.4

). In this work, we restrict our attention to the family of networks comprising of combined layers, and normalized sums between the output of multiple affine layers, for which we can compute global Q/C maps. And all of our formal results will implicitly assume this family of networks.

2.4 Q/C maps for rescaled ResNets

ResNets consist of a sequence of residual blocks, each of which computes the sum of a residual branch (which consists of a small multi-layer convolutional network) and a shortcut branch (which copies the block’s input). In order to analyze ResNets we will consider the modified version used in Shao et al. (2020) and Martens et al. (2021) which removes the normalization layers found in the residual branches, and replaces the sum at the end of each block with a normalized sum. These networks, which we will call rescaled ResNets, are defined by the following recursion:

(6)

where is the residual branch, and is the shortcut weight (which must be in ). Applying the previously discussed rules for computing Q/C maps, we get for all and

(7)

3 Existing Solutions and Their Limitations

Global Q/C maps can be intuitively understood as a way of characterizing signal propagation through the network at initialization time. The q value approximates the squared magnitude of the activation vector, so that describe the contraction or expansion of this magnitude through the action of . On the other hand, the c value approximates the cosine similarity of the function values for different inputs, so that describes how well preserves this cosine similarity from its input to its output.

Standard initializations methods (LeCun et al., 1998; Glorot and Bengio, 2010; He et al., 2015) are motivated through an analysis of how the variance of the activations evolves throughout the network. This can be viewed as a primitive form of Q map analysis, and from that perspective, these methods are trying to ensure that q values remain stable throughout the network by controlling the local Q map. This is necessary for trainability, since very large or tiny q values can cause numerical issues, saturated activation functions (which have implications for C maps), and problems with scale-sensitive losses. However, as was first observed by Schoenholz et al. (2017), a well-behaved C map is also necessary for trainability. When the global C map is close to a constant function (i.e. degenerate) on , which easily happens in deep networks (as discussed in Appendix A.2), this means that the network’s output will appear either constant or random looking, and won’t convey any useful information about the input. Xiao et al. (2020) and Martens et al. (2021) give more formal arguments for why this leads to slow optimization and/or poor generalization under gradient descent.

Figure 2: Global C maps for ReLU networks (EOC) and TReLU networks (). The global C map of a TReLU network converges to a well-behavior function as depth increases (proved in Proposition C).

Several previous works (Schoenholz et al., 2017; Yang and Schoenholz, 2017; Hayou et al., 2019) attempt to achieve a well-behaved global C map by choosing the variance of the initial weights and biases in each layer such that – a procedure which is referred to as Edge of Chaos (EOC). However, this approach only slows down the convergence (with depth) of the c values from exponential to sublinear (Hayou et al., 2019), and does not solve the fundamental issue of degenerate global C maps for very deep networks. In particular, the global C map of a deep network with ReLU and EOC initialization rapidly concentrates around 1 as depth increases (see Figure 2). While EOC allows very deep vanilla networks to be trained, the training speed and generalization performance is typically much worse than for comparable ResNets. Klambauer et al. (2017) applied an affine transformation to the output of activation functions to achieve and , while Lu et al. (2020) applied them to achieve and , although the effect of both approaches is similar to EOC.

To address these problems, Martens et al. (2021) introduced DKS, which enforces the conditions and (for some modest constant ) directly on the network’s global C map . They show that these conditions, along with the positive definiteness of C maps, cause to be close to the identity and thus well-behaved. In addition to these C map conditions, DKS enforces that and , which lead to constant q values of 1 in the network, and lower kernel approximation error (respectively). DKS enforces these Q/C map conditions by applying a model class-preserving transformation . with non-trainable parameters , , and

. The hyperparameter

is chosen to be sufficiently greater than (e.g. 1.5) in order to prevent the transformed activation functions from looking “nearly linear” (as they would be exactly linear if ), which Martens et al. (2021) argue makes it hard for the network to achieve nonlinear behavior during training. Using DKS, they were able to match the training speed of ResNets on ImageNet with vanilla networks using K-FAC. However, DKS is not fully compatible with ReLUs, and the networks in their experiments fell substantially short of ResNets in terms of generalization performance.

4 Tailored Activation Transformation (TAT)

The reason why DKS is not fully compatible with ReLUs is that they are positive homogeneous, i.e.  for . This makes the

parameter of the transformed activation function redundant, thus reducing the degrees of freedom with which to enforce DKS’s four Q/C map conditions.

Martens et al. (2021) attempt to circumvent this issue by dropping the condition , which leads to vanilla deep networks that are trainable, but slower to optimize compared to using DKS with other activation functions. This is a significant drawback for DKS, as the best generalizing deep models often use ReLU-family activations. We therefore set out to investigate other possible remedies – either in the form of different activation functions, new Q/C map conditions, or both. To this end, we adopt a ReLU-family activation function with an extra degree of freedom (known as “Leaky ReLU”), and modify the Q/C map conditions in order to preserve certain desirable properties of this choice. The resulting method, which we name Tailored Activation Transformation (TAT) achieves competitive generalization performance with ResNets in our experiments.

EOC (smooth) EOC (LReLU) DKS TAT (smooth) TAT (LReLU)
Table 1: Comparison of different methods applied to a network .

4.1 Tailored Activation Transformation for Leaky ReLUs

One way of addressing the issue of DKS’s partial incompatibility with ReLUs is to consider a slightly different activation function – namely the Leaky ReLU (LReLU) (Maas et al., 2013):

(8)

where is the negative slope parameter. While using LReLUs with in place of ReLUs changes the model class, it doesn’t limit the model’s expressive capabilities compared to ReLU, as assuming

, one can simulate a ReLU network with a LReLU network of the same depth by doubling the number of neurons (see Proposition 

C). Rather than using a fixed value for , we will use it as an extra parameter to satisfy our desired Q/C map conditions. Define . By Lemma C, the local Q and C maps for this choice of activation function are:

(9)

Note that the condition is actually stronger than DKS’s Q map conditions ( and ), and has the potential to reduce kernel approximation errors in finite width networks compared to DKS, as it provides a better guarantee on the stability of w.r.t. random perturbations of the q values at each layer. Additionally, because the form of does not depend on either of the layer’s input q values, it won’t be affected by such perturbations at all. (Notably, if one uses the negative slope parameter to transform LReLUs with DKS, these properties will not be achieved.) In support of these intuitions is the fact that better bounds on the kernel approximation error exist for ReLU networks than for general smooth ones (as discussed in Appendix A.1).

Another consequence of using for our activation function is that we have as in EOC. If combined with the condition (which is used to achieve in DKS) this would imply by Theorem 4.1 that is the identity function, which by equation 9 is only true when , thus resulting in a linear network. In order to avoid this situation, and the closely related one where appears “nearly linear”, we instead choose the value of so that , for a hyperparameter . As shown in the following theorem, controls how close is to the identity, thus allowing us to achieve a well-behaved global C map without making nearly linear: [] For a network with as its activation function (with ), we have

(10)

Another motivation for using as an activation function is given by the following proposition: [] The global C map of a feedforward network with as its activation function is equal to that of a rescaled ResNet of the same depth (see Section 2.4) with normalized ReLU activation , shortcut weight , and residual branch consisting of a combined layer (or just a normalized ReLU activation) followed by an affine layer.

This result implies that at initialization, a vanilla network using behaves similarly to a ResNet, a property that is quite desirable given the success that ResNets have already demonstrated.

In summary, we have the following three conditions:

(11)

which we achieve by picking the negative slope parameter so that . We define the Tailored Rectifier (TReLU) to be with chosen in this way. Note that the first two conditions are also true when applying the EOC method to LReLUs, and its only the third which sets TReLU apart. While this might seem like a minor difference, it actually matters a lot to the behavior of the global C map. This can be seen in Figure 2 where the c value quickly converges towards with depth under EOC, resulting in a degenerate global C map. By contrast, the global C map of TReLU for a fixed converges rapidly to a nice function, suggesting a very deep vanilla network with TReLU has the same well-behaved global C map as a shallow network. We prove this in Proposition C by showing the local C map in equation 9 converges to an ODE as we increase the depth. For direct comparison of all Q/C map conditions, we refer the readers to Table 1.

For the hyperparameter , we note that a value very close to will produce a network that is “nearly linear”, while a value very close to 1 will give rise to a degenerate C map. In practice we use or , which seems to work well in most settings. Once we decide on , we can solve the value using binary search by exploiting the closed-form form of in equation 9 to efficiently compute . For instance, if is a layer vanilla network, one can compute as follows:

(12)

which is a function of . This approach can be generalized to more advanced architectures, such as rescaled ResNets, as discussed in Appendix B.

4.2 Tailored Activation Transformation for Smooth Activation Functions

Unlike LReLU, most activation functions don’t have closed-form formulas for their local C maps. As a result, the computation of involves the numerical approximation of many two-dimensional integrals to high precision (as in equation 5), which can be quite expensive. One alternative way to control how close is to the identity, while maintaining the condition , is to modulate its second derivative . The validity of this approach is established by the following theorem: [] Suppose is a network with a smooth activation function. If , then we have

(13)

Given and , a straightforward computation shows that if is an -layer vanilla network. (See Appendix B for a discussion of how to do this computation for more general architectures.) From this we obtain the following four local Q/C map conditions:

(14)

To achieve these we adopt the same activation transformation as DKS: for non-trainable scalars , , , and . We emphasize that these conditions cannot be used with LReLU, as LReLU networks have . By equation 4 and basic properties of expectations, we have

(15)

so that . To obtain the values for , and , we can treat the remaining conditions as a three-dimensional nonlinear system, which can be written as follows:

(16)

We do not have a closed-form solution of this system. However, each expectation is a one dimensional integral, and so can be quickly evaluated to high precision using Gaussian quadrature. One can then use black-box nonlinear equation solvers, such as modified Powell’s method (Powell, 1964), to obtain a solution. See https://github.com/deepmind/dks for a complete implementation.

5 Experiments

Our main experimental evaluation of TAT and competing approaches is on training deep convolutional networks for ImageNet classification (Deng et al., 2009). The goal of these experiments is not to achieve state-of-the-art, but rather to compare TAT as fairly as possible with existing methods, and standard ResNets in particular. To this end, we use ResNet V2 (He et al., 2016b) as the main reference architecture, from which we obtain rescaled ResNets (by removing normalization layers and weighing the branches as per equation 6

), and vanilla networks (by further removing shortcuts). For networks without batch normalization, we add dropout to the penultimate layer for regularization, as was done in

Brock et al. (2021a). We train the models with epochs and a batch size of , unless stated otherwise. For TReLU, we obtain by grid search in . The weight initialization used for all methods is the Orthogonal Delta initialization, with an extra multiplier given by . We initialize biases iid from . We use in all experiments (unless explicitly stated otherwise), with the single exception that we use in standard ResNets, as per standard practice (He et al., 2015). For all other details see Appendix D.

5.1 Towards removing batch normalization

Two crucial components for the successful training of very deep neural networks are shortcut connections and batch normalization (BN) layers. As argued in De and Smith (2020) and Shao et al. (2020), BN implicitly biases the residual blocks toward the identity function, which makes the network better behaved at initialization time, and thus easier to train. This suggests that one can compensate for the removal of BN layers, at least in terms of their effect on the behaviour of the network at initialization time, by down-scaling the residual branch of each residual block. Arguably, almost all recent work on training deep networks without normalization layers (Zhang et al., 2018; Shao et al., 2020; Bachlechner et al., 2020; Brock et al., 2021b, a) has adopted this idea by introducing multipliers on the residual branches (which may or may not be optimized during training).

Optimizer Standard ResNet Activation Rescaled ResNet ()
K-FAC 76.4 ReLU 72.6 74.5 75.6 75.9
TReLU 74.6 75.5 76.4 75.9
SGD 76.3 ReLU 63.7 72.4 73.9 75.0
TReLU 71.0 72.6 76.0 74.8
Table 2: Top-1 validation accuracy of rescaled ResNet50 with varying shortcut weights. We set for TReLU.

In Table 2, we show that one can close most of the gap with standard ResNets by simply adopting the modification in equation 6 without using BN layers. By further replacing ReLU with TReLU, we can exactly match the performance of standard ResNets. With K-FAC as the optimizer, the rescaled ResNet with shortcut weight is only shy of the validation accuracy () of the standard ResNet. Further replacing ReLU with TReLU, we match the performance of standard ResNet with shortcut weight .

Depth Optimizers vanilla BN LN
50 K-FAC 72.6 72.8 72.7
SGD 63.7 72.6 58.1
101 K-FAC 71.8 67.6 72.0
SGD 41.6 43.4 28.6
Table 3: ImageNet top-1 validation accuracies of shortcut-free networks on ImageNet.

5.2 The difficulty of removing shortcut connections

While the aforementioned works have shown that it is possible to achieve competitive results without normalization layers, they all rely on the use of shortcut connections to make the network look more linear at initialization. A natural question to ask is whether normalization layers could compensate for the removal of shortcut connections. We address this question by training shortcut-free networks with either BN or Layer Normalization (LN) layers. As shown in Table 3, these changes do not seem to make a significant difference, especially with strong optimizers like K-FAC. These findings are in agreement with the analyses of Yang et al. (2019) and Martens et al. (2021), who respectively showed that deep shortcut-free networks with BN layers still suffer from exploding gradients, and deep shortcut-free networks with LN layers still have degenerate C maps.

5.3 Training Deep Neural Networks without Shortcuts

The main motivation for developing TAT is to help deep vanilla networks achieve generalization performance similar to standard ResNets. In our investigations we include rescaled ResNets with a shortcut weight of either 0 (i.e. vanilla networks) or 0.8. In Table 4 we can see that with a strong optimizer like K-FAC, we can reduce the gap on the 50 layer network to only 1.8% accuracy when training for 90 epochs, and further down to 0.6% when training for 180 epochs. For 101 layers, the gaps are 3.6% and 1.7% respectively, which we show can be further reduced with wider networks (see Table 9). To our knowledge, this is the first time that a deep vanilla network has been trained to such a high validation accuracy on ImageNet. In addition, our networks have fewer parameters and run faster than standard ResNets, and use less memory at inference time due to the removal of shortcut connections and BN layers. The gaps when using SGD as the optimizer are noticeably larger, which we further explore in Section 5.5. Lastly, using rescaled ResNets with a shortcut weight of and TReLU, we can exactly match or even surpass the performance of standard ResNets.

Depth Optimizer 90 epochs 180 epochs
ResNet ResNet
50 K-FAC 76.4 74.6 76.4 76.6 76.0 77.0
SGD 76.3 71.0 76.0 76.6 72.3 76.8
101 K-FAC 77.8 74.2 77.8 77.6 75.9 78.4
SGD 77.9 70.0 77.3 77.6 73.8 77.4
Table 4: ImageNet top-1 validation accuracy. For rescaled ResNets ( or ), we do not include any normalization layer. For standard ResNets, batch normalization is included. By default, ReLU activation is used for standard ResNet while we use TReLU for rescaled networks.

5.4 Comparisons with existing approaches

Depth Optimizer Method (L)ReLU Tanh
50 K-FAC EOC 72.6 70.6
TAT 74.6 73.1
SGD EOC 63.7 55.7
TAT 71.0 69.5
101 K-FAC EOC 71.8 69.2
TAT 74.2 72.8
SGD EOC 41.6 54.0
TAT 70.0 69.0
Table 5: ImageNet top-1 validation accuracy comparison between EOC and TAT on deep vanilla networks.

Comparison with EOC. Our first comparison is between TAT and EOC on vanilla deep networks. For EOC with ReLUs we set to achieve as in He et al. (2015), since ReLU networks always satisfy whenever . For Tanh activations, a comprehensive comparison with EOC is more difficult, as there are infinitely many choices of that achieve . Here we use 222We also ran experiments with , and the scheme described in Pennington et al. (2017) and Xiao et al. (2018) for dynamical isometry. The results were worse than those reported in the table., as suggested in Hayou et al. (2019). In Table 5, we can see that in all the settings, networks constructed with TAT outperform EOC-initialized networks by a significant margin, especially when using SGD. Another observation is that the accuracy of EOC-initialized networks drops as depth increases.

Comparison with DKS. The closest approach to TAT in the existing literature is DKS, whose similarity and drawbacks are discussed in Section 4. We compare TAT to DKS on both LReLUs333For DKS, we set the negative slope as a parameter and adopt the transformation ., and smooth functions like the SoftPlus and Tanh. For smooth activations, we perform a grid search over for in TAT, and for in DKS, and report only the best performing one. From the results shown in Table 7, we observe that TAT, together with LReLU (i.e. TReLU), performs the best in nearly all settings we tested, and that its advantage becomes larger when we remove dropout. One possible reason for the superior performance of TReLU networks is the stronger Q/C map conditions that they satisfy compared to other activations (i.e.  for all vs and , and invariance of to the input q value), and the extra resilience to kernel approximation error that these stronger conditions imply. In practice, we found that TReLU indeed has smaller kernel approximation error (compared to DKS with smooth activation functions, see Appendix E.1) and works equally well with Gaussian initialization (see Appendix E.7).

Depth Optimizer TReLU
50 K-FAC 74.6 72.5 73.6
SGD 71.0 66.7 67.9
101 K-FAC 74.2 71.9 72.8
SGD 70.0 54.3 66.3
Table 6: Comparison with PReLU.

Comparison with PReLU. The Parametric ReLU (PReLU) introduced in He et al. (2015) differs from LReLU by making the negative slope a trainable parameter. Note that this is distinct from what we are doing with TReLU, since there we compute the negative slope parameter ahead of time and fix it during training. In our comparisons with PReLU we consider two different initializations: (which recovers the standard ReLU), and , which was used in He et al. (2015). We report the results on deep vanilla networks in Table 6 (see Appendix E.6 for results on rescaled ResNets). For all settings, our method outperforms PReLU by a large margin, emphasizing the importance of the initial negative slope value. In principle, these two methods can be combined together (i.e. we could first initialize the negative slope parameter with TAT, and then optimize it during training), however we did not see any benefit from doing this in our experiments.

Depth Optimizer Shortcut Weight TAT DKS
LReLU SoftPlus Tanh LReLU SoftPlus Tanh
50 K-FAC 74.6/74.2 74.4/74.2 73.1/72.9 74.3/74.3 74.3/73.7 72.9/72.9
76.4/75.9 76.4/75.0 74.8/74.4 76.2/76.2 76.3/75.1 74.7/74.5
SGD 71.1/71.1 70.2/70.0 69.5/69.5 70.4/70.4 71.8/71.4 69.2/69.2
76.0/75.8 74.3/73.8 72.4/72.2 73.4/73.0 75.2/74.1 72.8/72.8
101 K-FAC 74.2/74.2 74.1/73.4 72.8/72.5 73.5/73.5 73.9/73.1 72.5/72.4
77.8/77.0 76.6/75.7 75.8/75.1 76.8/76.7 76.8/75.6 75.9/75.7
SGD 70.0/70.0 70.3/68.8 69.0/67.8 68.3/68.3 68.3/68.3 69.8/69.8
77.3/76.0 75.3/75.3 73.8/73.5 74.9/74.6 76.3/75.1 74.6/74.6
Table 7: Comparisons between TAT and DKS. The numbers on the right hand of are results without dropout. The methods with are introduced in this paper.

5.5 The role of the optimizer

Optimizer Batch size
128 256 512 1024 2048 4096
K-FAC 74.5 74.4 74.5 74.6 74.2 72.0
SGD 72.7 72.6 72.7 71.0 69.3 62.0
LARS 72.4 72.3 72.6 71.8 71.3 70.2
Table 8: Batch size scaling.

One interesting phenomenon we observed in our experiments, which echoes the findings of Martens et al. (2021), is that a strong optimizer such as K-FAC significantly outperforms SGD on vanilla deep networks in terms of training speed. One plausible explanation is that K-FAC works better than SGD in the large-batch setting, and our default batch size of 1024 is already beyond SGD’s “critical batch size”, at which scaling efficiency begins to drop. Indeed, it was shown by Zhang et al. (2019) that optimization algorithms that employ preconditioning, such as Adam and K-FAC, result in much larger critical batch sizes.

Figure 3: Training speed comparison between K-FAC and SGD on 50 layer vanilla TReLU network.

To investigate this further, we tried batch sizes between 128 and 4096 for training 50-layer vanilla TReLU networks. As shown in Table 8, K-FAC performs equally well for all different batch sizes except 4096 (where we see increased overfitting), while the performance of SGD starts to drop when we increase the batch size past 512. Surprisingly, we observe a similar trend for the LARS optimizer (You et al., 2019), which was designed for large-batch training. Even at the smallest batch size we tested (128), K-FAC still outperforms SGD by a gap of 1.8% within our standard epoch budget. We conjecture the reason behind this to be that vanilla networks without normalization and shortcuts give rise to loss landscapes with worse curvature properties compared to ResNets, and that this slows down simpler optimizers like SGD. To investigate further, we also ran SGD (with a batch size of 512) and K-FAC for up to 360 epochs with a “one-cycle” cosine learning rate schedule (Loshchilov and Hutter, 2016) that decreases the learning rate to to by the final epoch. As shown in Figure 3, SGD does indeed eventually catch up with K-FAC (using cosine scheme), requiring just over double the number of epochs to achieve the same validation accuracy. While one may argue that K-FAC introduces additional computational overhead at each step, thus making a head-to-head comparison versus SGD unfair, we note that this overhead can amortized by not updating K-FAC’s preconditioner matrix at every step. In our experiments we found that this strategy allowed K-FAC to achieve a similar per-step runtime to SGD, while retaining its optimization advantage on vanilla networks. (See Appendix E.3.)

6 Conclusions

In this work we considered the problem of training and generalization in vanilla deep neural networks (i.e. those without shortcut connections and normalization layers). To address this we developed a novel method that modifies the activation functions in a way tailored to the specific architecture, and which enables us to achieve generalization performance on par with standard ResNets of the same width/depth. Unlike the most closely related approach (DKS), our method is fully compatible with ReLU-family activation functions, and in fact achieves its best performance with them. By obviating the need for shortcut connections, we believe our method could enable further research into deep models and their representations. In addition, our method may enable new architectures to be trained for which existing techniques, such as shortcuts and normalization layers, are insufficient.

Reproducibility Statement

Here we discuss our efforts to facilitate the reproducibility of this paper. Firstly, we have made an open Python implementation of DKS and TAT, supporting multiple tensor programming frameworks, available at

https://github.com/deepmind/dks. Secondly, we have given all important details of our experiments in Appendix D.

References

  • R. Anil, V. Gupta, T. Koren, K. Regan, and Y. Singer (2020) Scalable second order optimization for deep learning. arXiv preprint arXiv:2002.09018. Cited by: §1.
  • J. Ba, R. Grosse, and J. Martens (2017) Distributed second-order optimization using kronecker-factored approximations. In International Conference on Learning Representations, Cited by: Appendix D.
  • J. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §1.
  • T. Bachlechner, B. P. Majumder, H. H. Mao, G. W. Cottrell, and J. McAuley (2020) Rezero is all you need: fast convergence at large depth. arXiv preprint arXiv:2003.04887. Cited by: §5.1.
  • D. Balduzzi, M. Frean, L. Leary, J. Lewis, K. W. Ma, and B. McWilliams (2017) The shattered gradients problem: if resnets are the answer, then what is the question?. In

    International Conference on Machine Learning

    ,
    pp. 342–350. Cited by: Appendix D, §E.7, §2.3.
  • J. Bradbury, R. Frostig, P. Hawkins, M. J. Johnson, C. Leary, D. Maclaurin, G. Necula, A. Paszke, J. VanderPlas, S. Wanderman-Milne, and Q. Zhang (2018) JAX: composable transformations of Python+NumPy programs External Links: Link Cited by: Appendix D.
  • A. Brock, S. De, S. L. Smith, and K. Simonyan (2021a) High-performance large-scale image recognition without normalization. arXiv preprint arXiv:2102.06171. Cited by: §5.1, §5.
  • A. Brock, S. De, and S. L. Smith (2021b) Characterizing signal propagation to close the performance gap in unnormalized resnets. In International Conference on Learning Representations, Cited by: §5.1.
  • T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. (2020) Language models are few-shot learners. arXiv preprint arXiv:2005.14165. Cited by: §1.
  • S. Buchanan, D. Gilboa, and J. Wright (2020) Deep networks and the multiple manifold problem. In International Conference on Learning Representations, Cited by: §A.1.
  • Y. Cho and L. Saul (2009) Kernel methods for deep learning. Advances in Neural Information Processing Systems 22, pp. 342–350. Cited by: §2.
  • A. Daniely, R. Frostig, and Y. Singer (2016) Toward deeper understanding of neural networks: the power of initialization and a dual view on expressivity. Advances In Neural Information Processing Systems 29, pp. 2253–2261. Cited by: §A.1, §A.1, §A.2, §A.4, Appendix C, §2.
  • S. De and S. Smith (2020) Batch normalization biases residual blocks towards the identity function in deep networks. Advances in Neural Information Processing Systems 33. Cited by: §5.1.
  • J. Deng, W. Dong, R. Socher, L. Li, K. Li, and L. Fei-Fei (2009) Imagenet: a large-scale hierarchical image database. In

    2009 IEEE conference on computer vision and pattern recognition

    ,
    pp. 248–255. Cited by: §5.
  • X. Ding, X. Zhang, N. Ma, J. Han, G. Ding, and J. Sun (2021) Repvgg: making vgg-style convnets great again. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13733–13742. Cited by: §1.
  • D. Duvenaud, O. Rippel, R. Adams, and Z. Ghahramani (2014) Avoiding pathologies in very deep networks. In Artificial Intelligence and Statistics, pp. 202–210. Cited by: §1.
  • X. Glorot and Y. Bengio (2010) Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249–256. Cited by: §3.
  • S. Hayou, E. Clerico, B. He, G. Deligiannidis, A. Doucet, and J. Rousseau (2021) Stable resnet. In International Conference on Artificial Intelligence and Statistics, pp. 1324–1332. Cited by: §1.
  • S. Hayou, A. Doucet, and J. Rousseau (2019) On the impact of the activation function on deep neural networks training. In International conference on machine learning, pp. 2672–2680. Cited by: §3, §5.4.
  • K. He, X. Zhang, S. Ren, and J. Sun (2015) Delving deep into rectifiers: surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision, pp. 1026–1034. Cited by: §A.1, §E.6, §3, §5.4, §5.4, §5.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016a) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016b) Identity mappings in deep residual networks. In European conference on computer vision, pp. 630–645. Cited by: §1, §5.
  • T. Hennigan, T. Cai, T. Norman, and I. Babuschkin (2020) Haiku: Sonnet for JAX External Links: Link Cited by: Appendix D.
  • M. Hessel, D. Budden, F. Viola, M. Rosca, E. Sezener, and T. Hennigan (2020) Optax: composable gradient transformation and optimisation, in jax! External Links: Link Cited by: Appendix D.
  • S. Hochreiter, Y. Bengio, P. Frasconi, J. Schmidhuber, et al. (2001) Gradient flow in recurrent nets: the difficulty of learning long-term dependencies.

    A field guide to dynamical recurrent neural networks. IEEE Press

    .
    Cited by: §1.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pp. 448–456. Cited by: §1.
  • H. K. Khalil (2008) Nonlinear systems third edition. Cited by: Appendix C.
  • G. Klambauer, T. Unterthiner, A. Mayr, and S. Hochreiter (2017) Self-normalizing neural networks. Advances in neural information processing systems 30. Cited by: §3.
  • A. Krizhevsky et al. (2009) Learning multiple layers of features from tiny images. Cited by: §E.2.
  • A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012)

    Imagenet classification with deep convolutional neural networks

    .
    Advances in neural information processing systems. Cited by: §1.
  • Y. A. LeCun, L. Bottou, G. B. Orr, and K. Müller (1998) Efficient backprop. In Neural networks: Tricks of the trade, Cited by: §3.
  • I. Loshchilov and F. Hutter (2016)

    SGDR: stochastic gradient descent with warm restarts

    .
    arXiv preprint arXiv:1608.03983. Cited by: §5.5, footnote 5.
  • Y. Lu, S. Gould, and T. Ajanthan (2020) Bidirectional self-normalizing neural networks. arXiv preprint arXiv:2006.12169. Cited by: §3.
  • A. L. Maas, A. Y. Hannun, and A. Y. Ng (2013) Rectifier nonlinearities improve neural network acoustic models. In International Conference on Machine Learning, Cited by: §4.1.
  • J. Martens, A. Ballard, G. Desjardins, G. Swirszcz, V. Dalibard, J. Sohl-Dickstein, and S. S. Schoenholz (2021) Rapid training of deep neural networks without skip connections or normalization layers using deep kernel shaping. arXiv preprint arXiv:2110.01765. Cited by: §A.2, §A.2, §A.3, §A.4, §B.1, §B.1, §B.2, Appendix D, §E.1, §E.7, §1, §2.1, §2.3, §2.3, §2.4, §2, §3, §3, §4, §5.2, §5.5.
  • J. Martens and R. Grosse (2015) Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning, pp. 2408–2417. Cited by: §1.
  • J. Martens (2021) On the validity of kernel approximations for orthogonally-initialized neural networks. arXiv preprint arXiv:2104.05878. Cited by: §A.1.
  • R. M. Neal (1996) Bayesian learning for neural networks. Lecture notes in statistics 118. Cited by: §2.
  • O. K. Oyedotun, D. Aouada, B. Ottersten, et al. (2020) Going deeper with neural networks without skip connections. In 2020 IEEE International Conference on Image Processing (ICIP), pp. 1756–1760. Cited by: §1, §1.
  • J. Pennington, S. S. Schoenholz, and S. Ganguli (2017) Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice. In Proceedings of the 31st International Conference on Neural Information Processing Systems, pp. 4788–4798. Cited by: footnote 1, footnote 2.
  • B. Poole, S. Lahiri, M. Raghu, J. Sohl-Dickstein, and S. Ganguli (2016) Exponential expressivity in deep neural networks through transient chaos. Advances in neural information processing systems 29, pp. 3360–3368. Cited by: §A.2, §A.2, §A.3, §2.
  • M. J. Powell (1964) An efficient method for finding the minimum of a function of several variables without calculating derivatives. The Computer Journal 7 (2), pp. 155–162. Cited by: §4.2.
  • H. Qi, C. You, X. Wang, Y. Ma, and J. Malik (2020) Deep isometric learning for visual recognition. In International Conference on Machine Learning, pp. 7824–7835. Cited by: §1.
  • A. M. Saxe, J. L. McClelland, and S. Ganguli (2013) Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120. Cited by: §2.
  • S. S. Schoenholz, J. Gilmer, S. Ganguli, and J. Sohl-Dickstein (2017) Deep information propagation. In International Conference on Learning Representations, Cited by: §1, §3, §3.
  • J. Shao, K. Hu, C. Wang, X. Xue, and B. Raj (2020) Is normalization indispensable for training deep neural network?. Advances in Neural Information Processing Systems 33. Cited by: §2.4, §5.1.
  • D. Silver, J. Schrittwieser, K. Simonyan, I. Antonoglou, A. Huang, A. Guez, T. Hubert, L. Baker, M. Lai, A. Bolton, et al. (2017) Mastering the game of go without human knowledge. nature 550 (7676), pp. 354–359. Cited by: §1.
  • R. K. Srivastava, K. Greff, and J. Schmidhuber (2015) Highway networks. arXiv preprint arXiv:1505.00387. Cited by: §1.
  • C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna (2016) Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2818–2826. Cited by: Appendix D.
  • A. Veit, M. J. Wilber, and S. Belongie (2016) Residual networks behave like ensembles of relatively shallow networks. Advances in neural information processing systems 29, pp. 550–558. Cited by: §1.
  • L. Xiao, Y. Bahri, J. Sohl-Dickstein, S. Schoenholz, and J. Pennington (2018) Dynamical isometry and a mean field theory of cnns: how to train 10,000-layer vanilla convolutional neural networks. In International Conference on Machine Learning, Cited by: Appendix D, §E.7, §1, §2.3, footnote 2.
  • L. Xiao, J. Pennington, and S. Schoenholz (2020) Disentangling trainability and generalization in deep neural networks. In International Conference on Machine Learning, pp. 10462–10472. Cited by: §3.
  • G. Yang, J. Pennington, V. Rao, J. Sohl-Dickstein, and S. S. Schoenholz (2019) A mean field theory of batch normalization. ArXiv abs/1902.08129. Cited by: §5.2.
  • G. Yang and S. Schoenholz (2017) Mean field residual networks: on the edge of chaos. In Advances in Neural Information Processing Systems, Vol. 30. Cited by: §3.
  • Y. You, J. Li, S. Reddi, J. Hseu, S. Kumar, S. Bhojanapalli, X. Song, J. Demmel, K. Keutzer, and C. Hsieh (2019) Large batch optimization for deep learning: training bert in 76 minutes. arXiv preprint arXiv:1904.00962. Cited by: §5.5.
  • S. Zagoruyko and N. Komodakis (2016) Wide residual networks. In British Machine Vision Conference 2016, Cited by: §E.2.
  • G. Zhang, L. Li, Z. Nado, J. Martens, S. Sachdeva, G. Dahl, C. Shallue, and R. B. Grosse (2019) Which algorithmic choices matter at which batch sizes? insights from a noisy quadratic model. Advances in neural information processing systems. Cited by: §5.5.
  • H. Zhang, Y. N. Dauphin, and T. Ma (2018) Fixup initialization: residual learning without normalization. In International Conference on Learning Representations, Cited by: §5.1.

Appendix A Background

a.1 Kernel Function Approximation Error Bounds

In Section 2.1, we claimed that the kernel defined in equation 2 would converge to a deterministic kernel, as the width of each layer goes to infinity. To be specific, one has the following result bounding the kernel approximation error. [Adapted from Theorem 2 of Daniely et al. (2016)] Consider a fully-connected network of depth with weights initialized independently using a standard Gaussian fan-in initialization. Further suppose that the activation function is -bounded (i.e. , and for some constant ) and satisfies , and that the width of each layer is greater than or equal to . Then at initialization time, for inputs and satisfying , we have that

with probability at least .

The bound in Theorem A.1 predicts an exponential dependence on the depth of the minimum required width of each layer. However, for a network with ReLU activations, this dependence is only quadratic in , as is established in the following theorem: [Adapted from Theorem 3 of Daniely et al. (2016)] Consider a fully-connected network of depth with ReLU activations and weights initialized independently using a He initialization (He et al., 2015), and suppose that the width of each layer is greater than or equal to . Then at initialization time, for inputs and satisfying , and , we have that

with probability at least . According to Lemma D.1 of Buchanan et al. (2020), the requirement of the width for ReLU networks could further be reduced to linear in the depth , but with a worse dependency on .

Although Theorems A.1 and A.1 are only applicable to Gaussian initializations, a similar bound has been given by Martens (2021) for scaled uniform orthogonal initializations in the case that . Moreover, Martens (2021) conjectures that their result could be extended to general values of .

a.2 Degenerate C maps for very deep networks

Daniely et al. (2016), Poole et al. (2016), and Martens et al. (2021) have shown that without very careful interventions, C maps inevitably become “degenerate” in deep networks, tending rapidly towards constant functions on as depth increases. The following proposition is a restatement of Claim 1 from Daniely et al. (2016): Suppose is a deep network consisting of a composition of combined layers. Then for all we have

for some .

While the above result doesn’t characterize the rate of convergence to a constant function, Poole et al. (2016) show that if , it happens exponentially fast as a function of in the asymptotic limit of large . Martens et al. (2021) gives a similar result which holds uniformly for all , and for networks with more general repeated structures.

a.3 C map derivative

Poole et al. (2016) gave the following nice formula for the derivative of C map of a combined layer with activation function :

(17)

For a rigorous proof of this result we refer the reader to Martens et al. (2021).

One can iterate this formula to obtain a similar equation for higher-order derivatives:

(18)

a.4 Some useful properties of C maps

In this section we will assume that .

Observe that and that maps to (which follows from its interpretation as computing cosine similarities for infinitely wide networks). Moreover, is a positive definite function, which means that it can be written as for (Daniely et al., 2016; Martens et al., 2021). Note that for smooth activation functions, positive definiteness can be easily verified by Taylor-expanding about and using

(19)

As discussed in Section 2.3, global C maps are computed by recursively taking compositions and weighted averages (with non-negative weights), starting from . Because all of the above properties are preserved under these operations, it follows that global C maps inherit them from .

Appendix B Additional details and pseudocode for activation function transformations

b.1 Taking all subnetworks into account

In the main text of this paper we have used the condition in DKS, in TAT for Leaky ReLUs, and in TAT for smooth activation functions. However, the condition used by Martens et al. (2021) in DKS was actually , where is the so-called “maximal slope function”:

where “” denotes that is a subnetwork444A subnetwork of is defined as a (non-strict) connected subset of the layers in that constitute a neural network with a singular input and output layer. So for example, layers 3, 4 and 5 of a 10 layer MLP form a subnetwork, while layers 3, 4, and 6 do not. of . (That is fully determined by follows from the fact that can be written in terms of compositions, weighted average operations, and applications of

, and that C maps always preserve the value 1. Using the chain rule, and the linearity of derivatives, these facts allow one to write

as a polynomial function of .)

The motivation given by Martens et al. (2021) for looking at over all subnetworks (instead of just ) is that we want all layers of , in all of its subnetworks, to be readily trainable. For example, a very deep and untrainable MLP could be made to have a reasonable global C map simply by adding a skip connection from its input to its output, but this won’t do anything to address the untrainability of the layers being “skipped around” (which form a subnetwork).

In the main text we ignored this complication in the interest of a shorter presentation, and because we happened to have for the simple network architectures focused on in this work. To remedy this, in the current section we will discuss how to modify the conditions and used in TAT so that they take into account all subnetworks. This will be done using a natural generalization of the maximal slope function from DKS. We will then address the computational challenges that result from doing this.

To begin, we will replace the condition (used in TAT for Leaky ReLUs) by the condition , where we define the maximal c value function of by

where is the negative slope parameter (which determines in LReLU networks [via ] and thus each ).

We will similarly replace the condition (used in TAT for smooth activations) by the condition , where we define the maximal curvature function of by

where each is determined by . That each is a well-defined function of follows from the fact that C maps always map the value 1 to 1, the aforementioned relationship between and , and the fact that we have under TAT (so that for all subnetworks ). These facts allow us to write as a constant multiple of using the linearity of 2nd derivatives and the 2nd-order chain rule (which is given by ).

b.2 Computing and in general

Given these new conditions for TAT, it remains to compute their left hand sides so that we may ultimately solve for the required quantities ( or ). In Section 2.3 we discussed how a (sub)network ’s C map can be computed in terms of the local C map by a series of composition and non-negative weighted sum operations. We can define a generalized version of this construction which replaces with an arbitrary non-decreasing function , so that . A recipe for computing is given in Appendix B.4.

Given , we define the subnetwork maximizing function by

With this definition, it is not hard to see that if , , and , then (where the dependence on is implicit through the dependence of on ), , and . Thus, it suffices to derive a scheme for computing (and inverting) for general networks and non-decreasing functions .

Naively, computing could involve a very large maximization and be quite computationally expensive. But analogously to the maximal slope function computation described in Martens et al. (2021), the computation of can simplified substantially, so that we rarely have to maximize over more than a few possible subnetworks. In particular, since is a non-decreasing function of for all (which follows from the fact that is non-decreasing), and , it thus follows that for all . This means that for the purposes of the maximization, we can ignore any subnetwork in which composes with another subnetwork (not necessarily in ) to form a strictly larger subnetwork isomorphic to one in . This will typically be the vast majority of them. Note that this does not therefore imply that , since not all subnetworks compose in this way. For example, a sufficiently deep residual branch of a residual block in a rescaled ResNet won’t compose with any subnetwork to form a larger one.

b.3 Solving for and

Having shown how to efficiently compute , and thus both of and , it remains to show how we can invert them to find solutions for and (respectively). Fortunately, this turns out to be easy, as both functions are strictly monotonic in their arguments ( and ), provided that contains at least one nonlinear layer. Thus, we may apply a simple 1-dimensional root-finding approach, such as binary search.

To see that is a strictly decreasing function of (or in other words, a strictly increasing function of ), we observe that it is a maximum over terms of the form , which are all either strictly decreasing non-negative functions of , or are identically zero. These properties of follow from the fact that it involves only applications of , along with compositions and non-negative weighted averages, and that is a strictly decreasing function of for all (in Leaky ReLU networks). A similar argument can be used to show that is a strictly increasing function of (and is in fact equal to a non-negative multiple of ).

b.4 Recipe for computing

As defined, is computed from by taking the computational graph for and replacing the local C map with wherever the former appears. So in particular, one can obtain a computational graph for from ’s computational graph by recursively applying the following rules:

  1. Composition of two subnetworks and maps to .

  2. Affine layers map to the identity function.

  3. Nonlinear layers map to .

  4. Normalized sums with weights over the outputs of subnetworks , map to the function

    where are the respective inputs to the ’s.

  5. ’s input layer maps to .

In the special case of computing , one gets the following simplified list of rules:

  1. Composition of two subnetworks and maps to

  2. Affine layers map to .

  3. Nonlinear layers map to .

  4. Normalized sums with weights over the outputs of subnetworks , map to the function

  5. ’s input layer maps to .

Note that this second procedure will always produce a non-negative multiple of , provided that contains at least one nonlinear layer.

b.5 Rescaled ResNet example

In this subsection we will demonstrate how to apply the above rules to compute the maximal curvature function for a rescaled ResNet with shortcut weight and residual branch (as defined in equation 6). We note that this computation also handles the case of a vanilla network by simply taking .

First, we observe that all subnetworks in compose to form larger ones in , except for itself, and for the residual branches of its residual blocks. We thus have that .

Because each residual branch has a simple feedforward structure with three nonlinear layers, it follows that . And because each shortcut branch has no nonlinear layers, it follows that . Applying the rule for weighted averages to the output of each block we thus have that . Given a network with nonlinear layers, we have blocks, and since the blocks compose in a feedforward manner it thus follows that . We therefore conclude that .

The rescaled ResNets used in our experiments have a slightly more complex structure (based on the ResNet-50 and ResNet-101 architectures), with a nonlinear layer appearing after the sequence of residual blocks, and with a four of their blocks being “transition blocks”, whose shortcut branches contain a nonlinear layer. In these networks, the total number of residual blocks is given by . Following a similar argument to the one above we have that