Deep neural networks perform well on a wide variety of tasks despite their overparameterisation and capacity to memorise random labels [Zhang et al., 2017], often with improved generalisation behaviour as the number of parameters increases [Nakkiran et al., 2020]. This goes contrary to classical beliefs around learning theory and overfitting, meaning there is likely some implicit regularisation inducing an inductive bias which encourages the networks to converge to well-generalising solutions. One approach to investigate this has been to examine infinite width limits of neural networks using the Neural Tangent Kernel (NTK) [Jacot et al., 2018, Lee et al., 2019], interestingly these often do worse than standard neural networks, though with extra tricks they can perform equivalently well or better under certain scenarios [Lee et al., 2020]
. Similarly, despite their use for analysis due to having closed-form expressions for many terms, they don’t predict finite network behaviour very closely in many regards. For example, due to not learning features, they cannot be used for transfer learning and the empirical NTK (outer product of Jacobians) changes significantly throughout training whereas NTK theory states in the infinite limit that this is constant. This raises important questions about in what ways they are different, most of which can be summarised as how the use of feature learning impacts the networks that can be learnt.
To work towards answering this question, we look at an interpolation between standard training and a finite analog of infinite training where we fix the empirical NTK by performing weight space linearisation. One interpretation of this is that we are varying the amount of feature learning allowed. We find that essentially any amount of feature learning is enough to eventually converge to a similar performing network, assuming learning rates are small enough.
1.1 Related Work
Li et al.  create an enhanced NTK for CIFAR10 with significantly better empirical performance than the standard one, however it still performs less well than the best neural networks. Yang and Hu  use a different limit to allow feature learning, however neither of these give much insight as to why the standard parameterisation doesn’t work well.
Lee et al.  run an empirical study comparing finite and infinite networks under many scenarios and Fort et al.  look at how far SGD training is from fixed-NTK training, and at what point they tend to converge. Lewkowycz et al.  investigate at what points in training the kernel regime applies as a good model of finite network behaviour. Both find better agreement later in training.
Chizat et al.  consider a different way to make finite networks closer to their infinite width analogs by scaling in a particular way, finding that as they get closer to their infinite width analogs, they perform less well empirically.
2 Problem Formulation
Consider a neural network parameterised by weights
and a mean squared error loss function111We use MSE for simplicity and compatability with NTK results here. While this is needed for some NTK results, it does not effect the algorithms we propose where any differentiable loss function can be used — see Appendix A , where we minimise for data and labels . We can write the change in the function over time under gradient flow with learning rate as:
It has been shown [Jacot et al., 2018, Lee et al., 2019, Arora et al., 2019] that in the infinite width limit the empirical neural tangent kernel, , converges to a deterministic NTK, . This is a matrix dependent only on architecture and data and does not change during training. From this perspective, training the infinite width model under gradient flow (or gradient descent with a small step size) is equivalent to training the weight-space linearisation of the neural network [Lee et al., 2019]. This raises a number of interesting observations about why this doesn’t work well in finite networks and what is different in them. This is likely due to lack of enough random features whereas running gradient descent on the full network allows features to be learnt, reducing the reliance on having enough initial random features.
3 Iterative Linearisation
NTK theory says that if the width is large enough then training the weight-space linearisation is equivalent to training the full network [Lee et al., 2019]. However in practise training the fully linearised network performs very poorly for practically sized networks Lee et al. . In this section we propose iterative linearisation in order to interpolate between training of the standard network and the linearised network.
Consider standard (full batch) gradient descent on a neural network.
Here we can think of this as two separate variables we update each step, the weights and the features . However there is no requirement that we always update both, giving rise to the following generalised algorithm:
where . In addition, we write the linearised version of a neural network using its first order Taylor expansion at the weights as
Using this framework, when this is simply gradient descent and when it is fully linearised training. Other values of interpolate between these two extremes. See Algorithm 1 for more details. Note that we can also generalise this to not be periodic in terms of when we update so we call this fixed period iterative linearisation.
3.1 Interpreting Iterative Linearisation
It is insightful to look a bit more closely at what is happening in Algorithm 1. In the case of standard training of a linear model, (e.g. ), the ‘features’ are fixed and the learning decides how to use those features. This is what is happening in Equation (3) and with , we don’t learn any new features beyond the random ones we got through the initialisation of the network. Interestingly we can say the same for infinite width networks, this idea that infinite width networks don’t learn features is not new (see Yang and Hu  for work trying to avoid this pitfall), but the finite analogy gives us a new perspective on what is happening in Algorithm 1. We then call Equation (4) feature learning, noting that the feature learning cannot be happening in Equation (3). From this interpretation, the Jacobian are the features we are using at time and tells us how frequently to update features, putting a limit on the frequency of feature learning updates.
To examine the effect of increasing , or equivalently reducing the feature learning frequency, we run a number of experiments on MNIST and CIFAR10 with a slightly larger variant of LeNet [LeCun et al., 1998] with 50 channels in each convolutional layer and softmax output. The purpose of these experiments is not to achieve amazing performance on the datasets (it only gets to for CIFAR10) but to examine how training changes as the learning rate and update frequency changes. For these experiments we note that unlike the general NTK theory, no process of our derivation relies on the use of MSE so we instead use cross-entropy loss as is standard for image classification. We additionally include the softmax in the loss function so it is not linearised by the algorithm, this both means that we continue to have an output for which taking the cross-entropy loss is meaningful as well as avoiding numerical issues caused by having the linearisation happening in the middle of the logsumexp trick.
Figure 1 shows the results of these experiments. For MNIST, all follow almost the exact same learning trajectory, converging very quickly and not saturating the amount learnt from each set of features before updating. For , it can be seen that the training levels off each time before is updated. This shows that the we can get close to 100% accuracy on MNIST by only updating the features from their initialisation twice (all except are within a 0.4% range at the end of training). From this we can conclude that the initialisation of a neural network with this architecture creates features that are not too far from what is needed to solve MNIST. Compare this with CIFAR10 (learning rate 1e-5) where it still only takes a few feature vector updates to reach the performance of however this is a much lower accuracy. It is still unclear from these experiments how the few updates necessary interacts with different architectures on the same dataset.
Comparing CIFAR10 with learning rates of and also gives interesting conclusions. For , training diverges as soon as goes above 5, and is unstable even for , whereas for training stays stable for all up to (see Figure 2). Note that this is not simply because it’s moving less far as this is a 10x reduction in the step size but a 4000x increase in . We include SGD training for comparison, noting that plotting per-step rather than per-epoch gives similar performance to GD training and so we should expect most runs to reach 60% accuracy if run for twice as long.
This paper has proposed iterative linearisation, a new training algorithm that interpolates between gradient descent on the standard and linearised neural network as a parallel infinite width vs finite networks. We show that, at least in the case of a LeNet-like architecture with small learning rates, any amount of features learning is enough to converge to a similar performing model. This provides an important step towards understanding feature learning and the distinction between how infinite and finite width networks learn. Better understanding how networks change with large amounts of parameters has important connections to empirical phenomena such as explaining deep double descent [Nakkiran et al., 2020].
5.1 Future Work
It is important to do more rigourous empirical investigations to confirm these results, in particular to scale up to larger models in order to disentangle the impact of iterative linearisation training from the fact that this architecture will never do particularly well on CIFAR10. This is also important to better understand under what architectures/learning rates/frequencies iterative linearisation training is stable.
Another direction is to better understand the types of solutions that iterative linearisation finds for various values of . This will shed light onto how the inductive bias is changing, in particular understanding if all find similar solutions and the infinite width limit is a step change similar to the test performance in the experiments here, or if this is a gradual change towards the solutions which don’t learn features.
Finally, we only consider fixed period iterative linearisation here where we update the feature vector at regular intervals. However Fort et al.  showed that the empirical NTK changes faster earlier in training so it makes sense for to be more adaptive if this was to be used directly for training.
- On exact computation with an infinitely wide neural net. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 8139–8148. External Links: Cited by: §2.
- On lazy training in differentiable programming. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 2933–2943. External Links: Cited by: §1.1.
- Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the neural tangent kernel. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin (Eds.), External Links: Cited by: §1.1, §5.1.
- Neural tangent kernel: convergence and generalization in neural networks. In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada, S. Bengio, H. M. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett (Eds.), pp. 8580–8589. External Links: Cited by: §1, §2.
- Gradient-based learning applied to document recognition. Proc. IEEE 86 (11), pp. 2278–2324. External Links: Cited by: §4.
- Finite versus infinite neural networks: an empirical study. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin (Eds.), External Links: Cited by: §1.1, §1, §3.
- Wide neural networks of any depth evolve as linear models under gradient descent. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, H. M. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. B. Fox, and R. Garnett (Eds.), pp. 8570–8581. External Links: Cited by: §1, §2, §3.
- The large learning rate phase of deep learning: the catapult mechanism. CoRR abs/2003.02218. External Links: Cited by: §1.1.
- Enhanced convolutional neural tangent kernels. CoRR abs/1911.00809. External Links: Cited by: §1.1.
- Deep double descent: where bigger models and more data hurt. In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020, External Links: Cited by: §1, §5.
Tensor programs IV: feature learning in infinite-width neural networks.
Proceedings of the 38th International Conference on Machine Learning, ICML 2021, 18-24 July 2021, Virtual Event, M. Meila and T. Zhang (Eds.), Proceedings of Machine Learning Research, Vol. 139, pp. 11727–11737. External Links: Cited by: §1.1, §3.1.
- Understanding deep learning requires rethinking generalization. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, External Links: Cited by: §1.
Appendix A Iterative linearisation with a general loss function
In Section 3 we show how to get to iterative linearisation from standard gradient under mean squared error loss. The use of mean squared error is more instructive due to its similarities with NTK results, however it is not strictly necessary. For completeness we include here the same idea but for a general loss function .
Standard gradient descent on a function parameterised by , with step size and data can be written as
We can apply the chain rule, resulting in
Where is the derivative of (in the case of mean squared error, this is the residual: ). Now again using , we can write this as
With a similar argument to Section 3, we note that we don’t need to update the features every step, resulting in the following formulation.
This now lets us use softmax followed by cross-entropy in the loss while maintaining the same interpretation, as we do for the MNIST and CIFAR10 results.