Log In Sign Up

Learning Deep Neural Networks by Iterative Linearisation

The excellent real-world performance of deep neural networks has received increasing attention. Despite the capacity to overfit significantly, such large models work better than smaller ones. This phenomenon is often referred to as the scaling law by practitioners. It is of fundamental interest to study why the scaling law exists and how it avoids/controls overfitting. One approach has been looking at infinite width limits of neural networks (e.g., Neural Tangent Kernels, Gaussian Processes); however, in practise, these do not fully explain finite networks as their infinite counterparts do not learn features. Furthermore, the empirical kernel for finite networks (i.e., the inner product of feature vectors), changes significantly during training in contrast to infinite width networks. In this work we derive an iterative linearised training method. We justify iterative lineralisation as an interpolation between finite analogs of the infinite width regime, which do not learn features, and standard gradient descent training which does. We show some preliminary results where iterative linearised training works well, noting in particular how much feature learning is required to achieve comparable performance. We also provide novel insights into the training behaviour of neural networks.


page 1

page 2

page 3

page 4


Feature Learning in Infinite-Width Neural Networks

As its width tends to infinity, a deep neural network's behavior under g...

Steps Toward Deep Kernel Methods from Infinite Neural Networks

Contemporary deep neural networks exhibit impressive results on practica...

Limitations of the NTK for Understanding Generalization in Deep Learning

The “Neural Tangent Kernel” (NTK) (Jacot et al 2018), and its empirical ...

Meta-Principled Family of Hyperparameter Scaling Strategies

In this note, we first derive a one-parameter family of hyperparameter s...

A self consistent theory of Gaussian Processes captures feature learning effects in finite CNNs

Deep neural networks (DNNs) in the infinite width/channel limit have rec...

Asymptotics of Wide Convolutional Neural Networks

Wide neural networks have proven to be a rich class of architectures for...

Finite Versus Infinite Neural Networks: an Empirical Study

We perform a careful, thorough, and large scale empirical study of the c...

1 Introduction

Deep neural networks perform well on a wide variety of tasks despite their overparameterisation and capacity to memorise random labels [Zhang et al., 2017], often with improved generalisation behaviour as the number of parameters increases [Nakkiran et al., 2020]. This goes contrary to classical beliefs around learning theory and overfitting, meaning there is likely some implicit regularisation inducing an inductive bias which encourages the networks to converge to well-generalising solutions. One approach to investigate this has been to examine infinite width limits of neural networks using the Neural Tangent Kernel (NTK) [Jacot et al., 2018, Lee et al., 2019], interestingly these often do worse than standard neural networks, though with extra tricks they can perform equivalently well or better under certain scenarios [Lee et al., 2020]

. Similarly, despite their use for analysis due to having closed-form expressions for many terms, they don’t predict finite network behaviour very closely in many regards. For example, due to not learning features, they cannot be used for transfer learning and the empirical NTK (outer product of Jacobians) changes significantly throughout training whereas NTK theory states in the infinite limit that this is constant. This raises important questions about in what ways they are different, most of which can be summarised as how the use of feature learning impacts the networks that can be learnt.

To work towards answering this question, we look at an interpolation between standard training and a finite analog of infinite training where we fix the empirical NTK by performing weight space linearisation. One interpretation of this is that we are varying the amount of feature learning allowed. We find that essentially any amount of feature learning is enough to eventually converge to a similar performing network, assuming learning rates are small enough.

1.1 Related Work

Li et al. [2019] create an enhanced NTK for CIFAR10 with significantly better empirical performance than the standard one, however it still performs less well than the best neural networks. Yang and Hu [2021] use a different limit to allow feature learning, however neither of these give much insight as to why the standard parameterisation doesn’t work well.

Lee et al. [2020] run an empirical study comparing finite and infinite networks under many scenarios and Fort et al. [2020] look at how far SGD training is from fixed-NTK training, and at what point they tend to converge. Lewkowycz et al. [2020] investigate at what points in training the kernel regime applies as a good model of finite network behaviour. Both find better agreement later in training.

Chizat et al. [2019] consider a different way to make finite networks closer to their infinite width analogs by scaling in a particular way, finding that as they get closer to their infinite width analogs, they perform less well empirically.

2 Problem Formulation

Consider a neural network parameterised by weights

and a mean squared error loss function

111We use MSE for simplicity and compatability with NTK results here. While this is needed for some NTK results, it does not effect the algorithms we propose where any differentiable loss function can be used — see Appendix A , where we minimise for data and labels . We can write the change in the function over time under gradient flow with learning rate as:


It has been shown [Jacot et al., 2018, Lee et al., 2019, Arora et al., 2019] that in the infinite width limit the empirical neural tangent kernel, , converges to a deterministic NTK, . This is a matrix dependent only on architecture and data and does not change during training. From this perspective, training the infinite width model under gradient flow (or gradient descent with a small step size) is equivalent to training the weight-space linearisation of the neural network [Lee et al., 2019]. This raises a number of interesting observations about why this doesn’t work well in finite networks and what is different in them. This is likely due to lack of enough random features whereas running gradient descent on the full network allows features to be learnt, reducing the reliance on having enough initial random features.

3 Iterative Linearisation

NTK theory says that if the width is large enough then training the weight-space linearisation is equivalent to training the full network [Lee et al., 2019]. However in practise training the fully linearised network performs very poorly for practically sized networks Lee et al. [2020]. In this section we propose iterative linearisation in order to interpolate between training of the standard network and the linearised network.

Consider standard (full batch) gradient descent on a neural network.

Here we can think of this as two separate variables we update each step, the weights and the features . However there is no requirement that we always update both, giving rise to the following generalised algorithm:


where . In addition, we write the linearised version of a neural network using its first order Taylor expansion at the weights as

Using this framework, when this is simply gradient descent and when it is fully linearised training. Other values of interpolate between these two extremes. See Algorithm 1 for more details. Note that we can also generalise this to not be periodic in terms of when we update so we call this fixed period iterative linearisation.

Input: learning rate , update periodicity , pre-initialised parameters

t = 1..epochs

       if t mod K = 0 then
       end if
end for
Algorithm 1 Iterative Linearisation (fixed period)

3.1 Interpreting Iterative Linearisation

It is insightful to look a bit more closely at what is happening in Algorithm 1. In the case of standard training of a linear model, (e.g. ), the ‘features’ are fixed and the learning decides how to use those features. This is what is happening in Equation (3) and with , we don’t learn any new features beyond the random ones we got through the initialisation of the network. Interestingly we can say the same for infinite width networks, this idea that infinite width networks don’t learn features is not new (see Yang and Hu [2021] for work trying to avoid this pitfall), but the finite analogy gives us a new perspective on what is happening in Algorithm 1. We then call Equation (4) feature learning, noting that the feature learning cannot be happening in Equation (3). From this interpretation, the Jacobian are the features we are using at time and tells us how frequently to update features, putting a limit on the frequency of feature learning updates.

4 Results

Figure 1:

Iterative linearisation results on MNIST and CIFAR10 with learning rates of 1e-4 and 1e-5

Figure 2: Iterative linearisation on CIFAR10 with large

To examine the effect of increasing , or equivalently reducing the feature learning frequency, we run a number of experiments on MNIST and CIFAR10 with a slightly larger variant of LeNet [LeCun et al., 1998] with 50 channels in each convolutional layer and softmax output. The purpose of these experiments is not to achieve amazing performance on the datasets (it only gets to for CIFAR10) but to examine how training changes as the learning rate and update frequency changes. For these experiments we note that unlike the general NTK theory, no process of our derivation relies on the use of MSE so we instead use cross-entropy loss as is standard for image classification. We additionally include the softmax in the loss function so it is not linearised by the algorithm, this both means that we continue to have an output for which taking the cross-entropy loss is meaningful as well as avoiding numerical issues caused by having the linearisation happening in the middle of the logsumexp trick.

Figure 1 shows the results of these experiments. For MNIST, all follow almost the exact same learning trajectory, converging very quickly and not saturating the amount learnt from each set of features before updating. For , it can be seen that the training levels off each time before is updated. This shows that the we can get close to 100% accuracy on MNIST by only updating the features from their initialisation twice (all except are within a 0.4% range at the end of training). From this we can conclude that the initialisation of a neural network with this architecture creates features that are not too far from what is needed to solve MNIST. Compare this with CIFAR10 (learning rate 1e-5) where it still only takes a few feature vector updates to reach the performance of however this is a much lower accuracy. It is still unclear from these experiments how the few updates necessary interacts with different architectures on the same dataset.

Comparing CIFAR10 with learning rates of and also gives interesting conclusions. For , training diverges as soon as goes above 5, and is unstable even for , whereas for training stays stable for all up to (see Figure 2). Note that this is not simply because it’s moving less far as this is a 10x reduction in the step size but a 4000x increase in . We include SGD training for comparison, noting that plotting per-step rather than per-epoch gives similar performance to GD training and so we should expect most runs to reach  60% accuracy if run for twice as long.

5 Conclusion

This paper has proposed iterative linearisation, a new training algorithm that interpolates between gradient descent on the standard and linearised neural network as a parallel infinite width vs finite networks. We show that, at least in the case of a LeNet-like architecture with small learning rates, any amount of features learning is enough to converge to a similar performing model. This provides an important step towards understanding feature learning and the distinction between how infinite and finite width networks learn. Better understanding how networks change with large amounts of parameters has important connections to empirical phenomena such as explaining deep double descent [Nakkiran et al., 2020].

5.1 Future Work

It is important to do more rigourous empirical investigations to confirm these results, in particular to scale up to larger models in order to disentangle the impact of iterative linearisation training from the fact that this architecture will never do particularly well on CIFAR10. This is also important to better understand under what architectures/learning rates/frequencies iterative linearisation training is stable.

Another direction is to better understand the types of solutions that iterative linearisation finds for various values of . This will shed light onto how the inductive bias is changing, in particular understanding if all find similar solutions and the infinite width limit is a step change similar to the test performance in the experiments here, or if this is a gradual change towards the solutions which don’t learn features.

Finally, we only consider fixed period iterative linearisation here where we update the feature vector at regular intervals. However Fort et al. [2020] showed that the empirical NTK changes faster earlier in training so it makes sense for to be more adaptive if this was to be used directly for training.


  • S. Arora, S. S. Du, W. Hu, Z. Li, R. Salakhutdinov, and R. Wang (2019) On exact computation with an infinitely wide neural net. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 8139–8148. External Links: Link Cited by: §2.
  • L. Chizat, E. Oyallon, and F. R. Bach (2019) On lazy training in differentiable programming. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 2933–2943. External Links: Link Cited by: §1.1.
  • S. Fort, G. K. Dziugaite, M. Paul, S. Kharaghani, D. M. Roy, and S. Ganguli (2020) Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the neural tangent kernel. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin (Eds.), External Links: Link Cited by: §1.1, §5.1.
  • A. Jacot, C. Hongler, and F. Gabriel (2018) Neural tangent kernel: convergence and generalization in neural networks. In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada, S. Bengio, H. M. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett (Eds.), pp. 8580–8589. External Links: Link Cited by: §1, §2.
  • Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998) Gradient-based learning applied to document recognition. Proc. IEEE 86 (11), pp. 2278–2324. External Links: Link, Document Cited by: §4.
  • J. Lee, S. S. Schoenholz, J. Pennington, B. Adlam, L. Xiao, R. Novak, and J. Sohl-Dickstein (2020) Finite versus infinite neural networks: an empirical study. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin (Eds.), External Links: Link Cited by: §1.1, §1, §3.
  • J. Lee, L. Xiao, S. S. Schoenholz, Y. Bahri, R. Novak, J. Sohl-Dickstein, and J. Pennington (2019) Wide neural networks of any depth evolve as linear models under gradient descent. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 8570–8581. External Links: Link Cited by: §1, §2, §3.
  • A. Lewkowycz, Y. Bahri, E. Dyer, J. Sohl-Dickstein, and G. Gur-Ari (2020) The large learning rate phase of deep learning: the catapult mechanism. CoRR abs/2003.02218. External Links: Link, 2003.02218 Cited by: §1.1.
  • Z. Li, R. Wang, D. Yu, S. S. Du, W. Hu, R. Salakhutdinov, and S. Arora (2019) Enhanced convolutional neural tangent kernels. CoRR abs/1911.00809. External Links: Link, 1911.00809 Cited by: §1.1.
  • P. Nakkiran, G. Kaplun, Y. Bansal, T. Yang, B. Barak, and I. Sutskever (2020) Deep double descent: where bigger models and more data hurt. In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020, External Links: Link Cited by: §1, §5.
  • G. Yang and E. J. Hu (2021) Tensor programs IV: feature learning in infinite-width neural networks. In

    Proceedings of the 38th International Conference on Machine Learning, ICML 2021, 18-24 July 2021, Virtual Event

    , M. Meila and T. Zhang (Eds.),
    Proceedings of Machine Learning Research, Vol. 139, pp. 11727–11737. External Links: Link Cited by: §1.1, §3.1.
  • C. Zhang, S. Bengio, M. Hardt, B. Recht, and O. Vinyals (2017) Understanding deep learning requires rethinking generalization. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, External Links: Link Cited by: §1.

Appendix A Iterative linearisation with a general loss function

In Section 3 we show how to get to iterative linearisation from standard gradient under mean squared error loss. The use of mean squared error is more instructive due to its similarities with NTK results, however it is not strictly necessary. For completeness we include here the same idea but for a general loss function .

Standard gradient descent on a function parameterised by , with step size and data can be written as

We can apply the chain rule, resulting in

Where is the derivative of (in the case of mean squared error, this is the residual: ). Now again using , we can write this as

With a similar argument to Section 3, we note that we don’t need to update the features every step, resulting in the following formulation.



This now lets us use softmax followed by cross-entropy in the loss while maintaining the same interpretation, as we do for the MNIST and CIFAR10 results.