Due to their excellent practical performance, deep learning based methods are today the de facto standard for many visual and linguistic learning tasks. Nevertheless, despite this practical success, our theoretical understanding of how, and what, can neural networks learn is still in its infancy zhangUnderstandingDeepLearning2017. Recently, a growing body of work has started to explore the use of linear approximations to analyze deep networks, giving rise to the neural tangent kernel (NTK) framework jacotNeuralTangentKernel.
The NTK framework is based on the observation that for certain initialization schemes, the infinite-width limit of many neural architectures can be exactly characterized using kernel tools jacotNeuralTangentKernel. This reduces key questions in deep learning theory to the study of linear methods and convex functional analysis, for which a rich set of theories exist LearningKernelsSupport2002. This intuitive approach has been proved very fertile, leading to important results in generalization and optimization of very wide networks leeWideNeuralNetworks2020; duGradientDescentFinds; biettiInductiveBiasNeural2019; zouImprovedAnalysisTraining; belkinHessian2020; aroraCNTK2019.
The NTK theory, however, can only fully describe certain infinitely wide neural networks. For the narrow architectures used in practice, it only provides a first-order approximation of their training dynamics (see Fig. 1). Despite these limitations, the intuitiveness of the NTK, which allows to use a powerful set of theoretical tools to exploit it, has lead to a rapid increase in the amount of work that successfully leverages the NTK in applications, such as predicting generalization deshpandeLinearizedFrameworkNew2021 and training speed zancatoPredictingTrainingTime2020b, explaining certain inductive biases mobahiSelfDistillationAmplifiesRegularizationa; tancikFourierFeaturesLet2020; gebhartUnifiedPathsPerspective2020; seyedICML2021
or designing new classifiersaroraHARNESSINGPOWERINFINITELY2020; maddoxFastAdaptation2021.
Recent reports, however, have started questioning the effectiveness of this approach, as one can find multiple examples in which kernel methods are provably outperformed by neural networks allen-zhuWhatCanResNet2019; ghorbaniWhenNeuralNetworks2020a; malachQuantifyingBenefitUsing2021. Most importantly, it has been observed empirically that linearized models – computed using a first-order Taylor expansion around the initialization of a neural network – perform much worse than the networks they approximate on standard image recognition benchmarks fortDeepLearningKernel2020; a phenomenon which has been coined the non-linear advantage. However, it has also been observed that, if one linearizes the network at a later stage of training, the non-linear advantage is greatly reduced fortDeepLearningKernel2020. The reasons behind this phenomenon are poorly understood, but are key to explain the success of deep learning.
Building from these observations, in this work, we delve deeper into the source of the non-linear advantage, trying to understand why previous work could successfully leverage the NTK in some applications. In particular, we shed new light on the question: When can the NTK approximation be used to predict generalization, and what does it actually say about it? We propose, for the first time, to empirically study this problem from the perspective of the characteristics of the training labels. To that end, we conduct a systematic analysis comparing the performance of different neural network architectures with their kernelized versions, on different problems with the same data support, but different labels. Doing so, we identify the alignment of the target function with the NTK as a key quantity governing important aspects of generalization in deep learning. Namely, one can rank the learning complexity to solve certain tasks using deep networks based solely on their kernel alignment.
Moreover, we study the evolution of this alignment during training. Interestingly, we discover that optimizing a deep network significantly increases the alignment of its empirical NTK with the target function. This phenomenon allows networks to explore functions beyond the limitations imposed to their linearization, and improves their training speed. Nonetheless, we also show that these kernel dynamics do not always have a positive effect in generalization, and we provide concrete examples where deep networks exhibit a non-linear disadvantage compared to their kernel approximations.
The main contributions of our work are:
We show that the alignment with the empirical NTK at initialization can provide a good measure of relative learning complexity in deep learning for a diverse set of tasks.
We identify, for the first time, a set of non-trivial tasks in which neural networks perform worse than their linearized approximations. This highlights the strong dependency of the non-linear advantage on the properties of the target function.
We, hence, provide a fine-grained analysis of the evolution of the empirical NTK during training, showing that training increases its alignment with the training labels.
Finally, we show that this adaptation towards the target function during training is responsible for the rapid convergence of neural networks, but that it also abides by its own inductive bias which can either benefit, or hurt generalization, depending on the target task.
We see our empirical findings as an important step forward in our understanding of deep learning. Our work paves the way for new research avenues based on our newly observed phenomena, but it also provides a fresh perspective to understand how to use the NTK approximation in several applications.
Let denote a neural network parameterized by a set of weights . Without loss of generality, but to reduce the computational load, in this work, we only consider the binary classification setting, where the objective is to learn an underlying target function by training the weights to minimize the empirical risk over a finite set of i.i.d. training data from an underlying distribution . We broadly say that a model generalizes whenever it achieves a low expected risk .
In a small neighborhood around the weight initialization , a neural network can be approximated using a first-order Taylor expansion (see Fig. 1)
where denotes the Jacobian of the network with respect to the parameters evaluated at . Here, the model represents a linearized network
which maps weight vectors to functions living in a reproducible kernel Hilbert space (RKHS), determined by the empirical neural tangent kernel (NTK) at , jacotNeuralTangentKernel; leeWideNeuralNetworks2020. Unless stated otherwise, we will generally drop the dependency on and use to refer to the NTK at initialization.
In most contexts, the NTK evolves during training by following the trajectory of the network Jacobian computed at a checkpoint (see Fig. 1). Remarkably, however, it was recently discovered that for certain types of initialization, and in the limit of infinite network-width, the approximation in (1) is provably exact jacotNeuralTangentKernel, making the NTK constant throughout training. In this regime, one can, therefore, provide generalization guarantees for neural networks using generalization bounds from kernel methods LearningKernelsSupport2002, and show bartlettRademacherGaussianComplexities2001
, for instance, that with high probability
where , with denoting a regularization constant, and being the RKHS norm of the target function, which for positive definite kernels can be computed a
Here, the couples. This means that, in kernel regimes, the difference between the empirical and the expected risk is smaller when training on target functions with a lower RKHS norm. That is, whose projection on the eigenfunctions of the kernel is mostly concentrated along its largest eigenvalues. One can then see the RKHS norm as a metric that ranks the complexity of learning different targets using the kernel .
Estimating (3) in practice is challenging as it requires access to the smallest eigenvalues of the kernel. However, one can use the following lemma to compute a more tractable bound of the RKHS norm, which shows that a high target-kernel alignment is a good proxy for a small RKHS norm.
Let denote the alignment of the target with the kernel . Then . Moreover, for the NTK, .
See Appendix. ∎
At this point, it is important to highlight that in most practical applications we do not deal with infinitely-wide networks, and hence (2) can only be regarded as a learning guarantee for linearized models, i.e. . Furthermore, from now on, we will interchangeably use the terms NTK and empirical NTK jacotNeuralTangentKernel; leeWideNeuralNetworks2020 to simply refer to the finite-width kernels derived from (1). In our experiments, we compute those using the neural_tangents library novakNeuralTangentsFast built on top of the JAX framework GoogleJax2021
, which we also use to generate the linearized models. Similarly, as it is commonly done in the kernel literature, we will use the eigenvectors of the kernel Gram matrix to approximate the values of the eigenfunctionsover a finite dataset. We will, thus, use the terms eigenvector and eigenfunction interchangeably. In particular, we will use to denote the matrix containing the th Gram eigenvector in its th row, where the rows are ordered according to the vector of decreasing eigenvalues . A more detailed description of this setup can be found in the Appendix.
3 Linearized models can predict relative task complexity for deep networks
As mentioned in Sec. 1 a growing body of work is using the linear approximation of neural networks as kernel methods to analyze and build novel algorithms zancatoPredictingTrainingTime2020b; tancikFourierFeaturesLet2020; gebhartUnifiedPathsPerspective2020; mobahiSelfDistillationAmplifiesRegularizationa; maddoxFastAdaptation2021; aroraHARNESSINGPOWERINFINITELY2020; deshpandeLinearizedFrameworkNew2021. Meanwhile, recent reports, both theoretical and empirical, have started to question if the NTK approximation can really tell anything useful about generalization for finite-width networks allen-zhuWhatCanResNet2019; fortDeepLearningKernel2020; ghorbaniWhenNeuralNetworks2020a; malachQuantifyingBenefitUsing2021. In this section, we try to demystify some of these confusions and aim to shed light on the question: What can the empirical NTK actually predict about generalization?
To that end, we conduct a systematic study with different neural networks and their linearized approximations given by (1), which we train to solve a structured array of predictive tasks with different complexity. Our results indicate that for many problems the linear models and the deep networks do agree in the way they order the complexity of learning certain tasks, even if their performance on the same problems can greatly differ. This highlights why the NTK approximation can be successfully applied in many contexts in which the main goal is to predict just relative differences in performance across tasks for neural networks.
3.1 Learning NTK eigenfunctions
In kernel theory, the sample and optimization complexity required to learn a given function is normally bounded by its kernel norm LearningKernelsSupport2002, which intuitively measures the alignment of the target function with the eigenfunctions of the kernel. The eigenfunctions themselves, thus, represent a natural set of target functions with increasingly high learning complexity – according to the increasing value of their associated eigenvalues – for kernel methods. Since our goal is to find if the kernel approximation can indeed predict generalization for neural networks, we evaluate the performance of these networks when learning the eigenfunctions of their NTKs.
In particular, we generate a sequence of datasets constructed using the standard CIFAR10 krizhevskyLearningMultipleLayers2009
samples, which we label using different binarized versions of the NTK eigenfunctions. That is, to every samplein CIFAR10 we assign it the label , where represents the th eigenfunction of the NTK at initialization (see Sec. 2). In this construction, the choice of CIFAR10 as supporting distribution makes our experiments close to real settings which might be conditioned by low dimensional structures in the data manifold ghorbaniWhenNeuralNetworks2020a; paccolatGeometricCompressionInvariant2021a; while the choice of eigenfunctions as targets guarantees a progressive increase in complexity, at least, for the linearized networks. Specifically, for the alignment is given by (see Appendix for the detailed proof).
We train different neural network architectures – selected to cover the spectrum of small to large models rosenblattPerceptronProbabilisticModel1958; lecunGradientbasedLearningApplied1998; resnet – and their linearized models given by (1
). Unless stated otherwise, we use the same standard training procedure consisting of the use of stochastic gradient descent (SGD) to optimize a binary cross-entropy loss, with a decaying learning rate starting atand momentum set to , in all our experiments. Our results report values of different metrics at the end of epochs of training.
Fig. 2 summarizes the main results of our experiments111Results with equivalent findings for other training schemes and datasets can be found in the Appendix.. Here, we can clearly see how the validation accuracy of networks trained to predict targets aligned with progressively drops with decreasing eigenvalues for both linearized models – as predicted by the theory – as well as for neural networks. Similarly, Fig. 3 shows how the training dynamics of these networks also correlate with eigenfunction index. Specifically, we see that networks take more time to fit eigenfunctions associated to smaller eigenvalues, and need to travel a larger distance in the weight space to do so.
Overall, our observations reveal that sorting tasks based on their alignment with the NTK is a good predictor of learning complexity both for linear models and neural networks. Nevertheless, we also see that there are large performance gaps between neural networks and their linearized approximations. Indeed, even if the networks and their approximations agree on which eigenfunctions are harder to learn, the kernel methods seem to be much better at it. This is clearly different to what was previously observed for other targets in which neural networks performed better than kernels allen-zhuWhatCanResNet2019; fortDeepLearningKernel2020; ghorbaniWhenNeuralNetworks2020a; malachQuantifyingBenefitUsing2021, and it highlights that the existence of a non-linear advantage is not always true.
3.2 Learning linear predictors
The NTK eigenfunctions are one example of a canonical set of tasks with increasing hardness for kernel methods, whose learning complexity for neural networks follows the same order. However, could there be more examples? And, are the previously observed correlations useful to predict other generalization phenomena? In order, to answer these questions, we propose to analyze another set of problems, but this time using a sequence of targets of increasing complexity for neural networks.
In this sense, it has recently been observed that, for convolutional neural networks (CNNs), the set of linear predictors – i.e., hyperplanes separating two distributions – represents a function class with a wide range of learning complexities among its elementsortizjimenez2020neural. In particular, it has been confirmed empirically that it is possible to rank the complexity for a neural network to learn different linearly separable tasks based on a sequence of vectors known as neural anisotropy directions (NADs) ortizjimenez2020neural.
Definition (Neural anisotropy directions).
The NADs of a neural network are the ordered sequence of orthonormal vectors
The NADs of a neural network are the ordered sequence of orthonormal vectorswhich form a full basis of the input space and whose order is determined by the sample complexity required to learn the linear predictors .
, the authors provided several heuristics to compute the NADs of a neural network. However, we now provide a new, more principled, interpretation of the NADs, showing one can also obtain this sequence using a kernel approximation. To that end, we will make use of the following theorem.
Let be a unitary vector that parameterizes a linear predictor , and let . The alignment of with is given by
where denotes the derivative of with respect to the weights and the input.
See Appendix. ∎
Theorem 1 gives an alternative method to compute NADs. Indeed, in the kernel regime, the NADs are simply the right singular vectors of the matrix of mixed-derivatives of the network222All predictors have the same norm. Hence, their alignment is inversely proportional to their kernel norm. Note however, that this interpretation is just based on an approximation, and hence there is no explicit guarantee that these NADs will capture the direcional inductive bias of deep networks. Our experiments show otherwise, as they reveal that CNNs actually rank the learning complexity of different linear predictors in a way compatible with Theorem 1.
Indeed, as shown in Fig. 4, when trained to classify a set of linearly separable datasets333Full details of the experiment can be found in the Appendix., aligned with the NTK-predicted NADs, CNNs perform better on those predictors with a higher kernel alignment (i.e., corresponding to the first NADs) than on those with a lower one (i.e., later NADs). The fact that NADs can be explained using kernel theory constitutes another clear example that theory derived from a naïve linear expansion of a neural network can sometimes capture important trends in the inductive bias of deep networks; even when we observe a clear performance gap between linear and non-linear models. Surprisingly, neural networks exhibit a strong non-linear advantage on these tasks, even though these NADs were explicitly constructed to be well-aligned with the linear models.
Overall, our results explain why previous heuristics that used the NTK to rank the complexity of learning certain tasks tancikFourierFeaturesLet2020; zancatoPredictingTrainingTime2020b; deshpandeLinearizedFrameworkNew2021 were successful in doing so. Specifically, by systematically evaluating the performance of neural networks on tasks of increasing complexity for their linearized approximations, we have observed that the non-linear dynamics on these networks do not change the way in which they sort the complexity of these problems. However, we should not forget that the differences in performance between the neural networks and their approximation are very significant and whether they favour or not neural networks depends on the task. In what follows, we thus delve deeper into the sources of these differences.
4 Sources of the non-linear advantage
In this section, we study in more detail the mechanisms that separate neural networks from their linear approximations, which can sometimes lead to significant performance differences. Specifically, we show that there are important nuances involved in the comparison of linear and non-linear models, which depend on number of samples, architecture or target task. To shed more light on this complex relations, we conduct a fine-grained analysis of the optimization dynamics of neural networks, studying the evolution of their empirical NTK. We show that due to the non-linear optimization dynamics of deep networks, the alignment of their NTK with the training task grows with training. This is an important phenomenon which can explain why, as it was previously observed but not explained in fortDeepLearningKernel2020, there is a better agreement between the predictions of neural networks and those obtained by linear approximations after some epochs of training. Finally, we show that the kernel adaptation is a major driver of the faster speed of convergence of neural networks with respect to kernel methods. Nevertheless, we also show that this adaptation can sometimes be imperfect, and lead neural networks to overfit the training labels.
4.1 The non-linear advantage depends on the sample size
As we have seen, there exist multiple problems in which neural networks perform significantly better than their linearized approximations (see Sec. 3.2), but also others where they do not (see Sec. 3.1). We now show, however, that the magnitude of these differences is influenced by the training set size.
We can illustrate this phenomenon by training several neural networks to predict some semantically-meaningful labels of an image dataset. In particular, and for consistency with the previous two-class examples, we deal with a binary version of the popular CIFAR10 dataset, in which we assign label to all samples from the first five classes in the dataset, and label to the rest. We will refer to this version of the dataset as CIFAR2. Indeed, as seen in Fig. 5, some neural networks exhibit a large non-linear advantage on this task, but this advantage mostly appears when training on larger datasets. This phenomenon suggests that the inductive bias that boosts neural network’s performance with respect to kernel methods involves an important element of scale.
One can intuitively understand this behaviour by analyzing the distance travelled by the parameters during optimization (see bottom row of Fig. 5). Indeed, for smaller training set sizes, the networks can find solutions that fit the training data closer to their initialization more easily. And, as a result, the error incurred by the linear approximation in these cases is smaller. This suggests why previous studies could not identify large performance gaps between NTK-based models and neural networks in small-data tasks aroraHARNESSINGPOWERINFINITELY2020, while it also highlights the strength of the linear approximation in this regime.
4.2 The kernel alignment with the target increases during training
So far we have mostly analyzed results dealing with linear expansions around the weight initialization . However, recent empirical studies have argued that linearizing at later stages of training induces smaller approximation errors fortDeepLearningKernel2020, suggesting that the NTK dynamics can better explain the final training behaviour of a neural network. To the best of our knowledge, this phenomenon is still poorly understood, mostly because it hinges on understanding the non-linear dynamics of deep networks. We now show, however, that understanding the way the spectrum of the NTK evolves during training can provide important insights into these dynamics.
To that end, we first analyze the evolution of the principal components of the empirical NTK in relation to the target function. Specifically, let denote the matrix of first eigenvectors of the Gram matrix of obtained by linearizing the network after epochs of training, and let be the vector of training labels. In Sec. 3.1, we have seen that both neural networks and their linear approximations perform better on targets aligned with the first eigenfunctions of . We propose, therefore, to track the energy concentration
of the labels onto the first eigenfunctions with the aim to identify a significant transformation of the principal eigenspace.
Fig. 7 shows the result of this procedure applied to a CNN trained to classify CIFAR2. Strikingly, the amount of energy of the target function that is concentrated on the first eigenfunctions of the NTK significantly grows during training. This is a heavily non-linear phenomenon – by definition the linearized models have a fixed kernel – and it hinges on a dynamical realignment of during training. That is, training a neural network rotates in a way that increases .
This phenomenon is prevalent across training setups, and it can also be observed when training on other tasks. In fact, a more fine-grained inspection reveals that the kernel rotation mostly happens in a single functional axis. It maximizes the alignment of with the target function , i.e., , but does not greatly affect the rest of the spectrum. Indeed, our results show that during training grows significantly more for the target than for any other function.
as indicated by Lemma 1, target functions with small also have high . This means that in those tasks with a high are indeed represented by weights closer to the origin of the linearization, which thus makes the approximation error of the linearization smaller for these targets as observed in fortDeepLearningKernel2020.
4.3 Kernel rotation improves speed of convergence, but can hurt generalization
The rotation of during training is an important mechanism that explains why the NTK dynamics can better capture the behaviour of neural networks at the end of training. However, it is also fundamental in explaining the ability of neural networks to quickly fit the training data. Specifically, it is important to highlight the stark contrast, at a dynamical level, between linear models and neural networks. Indeed, we have consistently observed across our experiments that neural networks converge much faster to solutions with a near-zero training loss than their linear approximations.
We can explain the influence of the rotation of the NTK on this phenomenon through a simple experiment. In particular, we train three different models to predict another arbitrary eigenfunction : i) a ResNet18, ii) its linearization around initialization, and iii) an unbiased linearization around the solution of the ResNet18 , i.e., .
Fig. 8 compares the dynamics of these models, revealing that the neural network indeed converges much faster than its linear approximation. We see, however, that the kernelized model constructed using the pretrained NTK of the network has also a faster convergence. We conclude, therefore, that, since the difference between the two linear models only lies on the kernel they use, it is indeed the transformation of the kernel through the non-linear dynamics of the neural network training that makes these models converge so quickly.
Note, however, that the rapid adaptation of to the training labels can have heavy toll in generalization, i.e., the model based on the pretrained kernel converges much faster than the randomly initialized one, but has a much lower test accuracy (comparable to the one of the neural network). On other hand, on CIFAR2 (see Sec. 4.1) the kernel rotation does greatly improve test accuracy for some models. The fact that it can both boost and hurt generalization, highlights that the kernel rotation is subject to its own form of inductive bias to govern its dynamics. In this sense, we believe that explaining the non-trivial coupling between the kernel rotation, alignment, and training dynamics is an important avenue for future research, which will allow us to better understand deep networks.
5 Final remarks
Explaining generalization in deep learning zhangUnderstandingDeepLearning2017 has been the subject of extensive research in recent years neyshaburUnderstandingRoleOverParametrization2018a; soudryImplicitBiasGradient2018; gunasekarCharacterizingImplicitBias2018a; vcPiecewiseNN2019; universalAbbe2020. The NTK theory is part of this trend, and it has been used to prove convergence and generalization of very wide networks jacotNeuralTangentKernel; leeWideNeuralNetworks2020; duGradientDescentFinds; biettiInductiveBiasNeural2019; zouImprovedAnalysisTraining; belkinHessian2020; aroraCNTK2019. Most of these efforts, however, still cannot explain the behavior of the models used in practice. Therefore, multiple authors have proposed to study neural networks empirically, with the aim to identify novel deep learning phenomena which can be later explained theoretically nagarajanUniformConvergenceMay2019; nakkiran2021the; randomLabels2020; fortDeepLearningKernel2020.
In this work, we have followed a similar approach, and presented a systematic study comparing the behaviour of neural networks and their linear approximations on different tasks. Previous studies had shown there exist tasks that neural networks can solve but kernels cannot allen-zhuWhatCanResNet2019; ghorbaniWhenNeuralNetworks2020a; malachQuantifyingBenefitUsing2021; fortDeepLearningKernel2020. Our work complements those results, and provides examples of tasks where kernels perform better than neural networks (see Sec.3.1
). This is important because in machine learning there isno-free lunch. Hence, any sufficiently powerful model for a set of problems must also be limited on other tasks NFLWolpert96. Moving forward, knowing which tasks a neural network can and cannot solve efficiently will be fundamental to explain their inductive bias.
Similarly, our work complements the work in fortDeepLearningKernel2020 where the authors also empirically compared neural networks and their linear approximations, but focused on a single training task. In this work, we built from these observations and studied models on a more diverse set of problems. Doing so, we have shown that the alignment with the empirical NTK can rank the learning complexity of certain tasks for neural networks, despite being agnostic to the non-linear dis/advantages. In this sense, we have revealed that important factors such as sample size, architecture, or target task can greatly influence the gap between kernelized models and neural networks. Finally, we have seen that the evolution of the NTK during training increases its alignment with the target task, thus explaining important phenomena including some of the previous observations from fortDeepLearningKernel2020.
Finally, a similar phenomenon to the kernel rotation described in Sec. 4.2 was recently observed for small neural networks, suggesting networks adapt to the main directions of variability of the data manifold ghorbaniWhenNeuralNetworks2020a; paccolatGeometricCompressionInvariant2021a. Our work goes one step further and reveals that the kernel moves in the direction that maximizes the alignment with the training task. We confirmed this effect is prevalent on all sorts of models and training tasks, it improves training speed, but it is also subject to its own inductive bias.
Future work should focus on providing a better theoretical understanding to our empirical findings, and try to explain the intricate connections between the evolution of the NTK and the dynamics of training. From a practical perspective, our work paves the way for new applications that use linearized networks to predict generalization in deep learning, or exploit their more predictable inductive bias.
Appendix A General training setup
As mentioned in the main text, all our models are trained using the same scheme which was selected without any hyperparameter tuning, besides ensuring a good performance on CIFAR2 for the neural networks. Namely, we train using stochastic gradient descent (SGD) to optimize a binary cross-entropy loss, with a decaying learning rate starting atand momentum set to . Furthermore, we use a batch size of and train for a epochs. This is enough to obtain close-to-zero training losses for the neural networks, and converge to a stable test accuracy in the case of the linearized models444As mentioned in the main text, the linearized models converge significantly slower than the neural networks given the same training setup.. In fact, in the experiments involving CIFAR2, we train all models for epochs to allow further optimization of the linearized models. Nevertheless, even then, the neural networks perform significantly better than their linear approximations on this dataset.
In terms of models, all our experiments use the same three models: A multilayer perceptron (MLP) with two hidden layers ofneurons each, the standard LeNet5 from lecunGradientbasedLearningApplied1998, and the standard ResNet18 resnet. We used a single V100 GPU to train all models, resulting in training times which oscillated between minutes for the MLP, to around minutes for the ResNet18.
Appendix B NTK computation details
We now provide a few details regarding the computation of different quantities involving neural tangent kernels and linearized neural networks. In particular, in our experiments we make extensive use of the neural_tangents novakNeuralTangentsFast library written in JAX GoogleJax2021, which provides utilities to compute the empirical NTK or construct linearized neural networks efficiently.
As it is common in the kernel literature, in this work, we use the eigenvectors of the kernel Gram matrix to approximate the eigenfunctions of the NTK. Specifically, unless stated otherwise, in all our experiments we compute the Gram matrix of the NTK at initialization using the samples of the CIFAR10 dataset, which include both training and test samples. To that end we use the empirical_kernel_fn from neural_tangents which allows to compute this matrix using a batch implementation. Note that this operation is computationally intense, scaling quadratically with the number of samples, but also quadratically with the number of classes. For instance, in the single-output setting of our experiments555 Note that the binary cross-entropy loss can be computed using a single output logit.
Note that the binary cross-entropy loss can be computed using a single output logit., it takes up to 32 hours to compute the Gram matrix of the full CIFAR10 dataset for a ResNet18 using 4 V100 GPUs with 32Gb of RAM each. The computation for the LeNet and the MLP take only 20 and 3 minutes, respectively, due to their much smaller sizes.
Using neural_tangents, training and evaluating a linear approximation of any neural network is trivial. In fact, the library already comes with a function linearize
which allows to obtain a fully trainable model from any differentiable JAX function. Thus, in our experiments, we treat all linearized models as standard neural networks and use the same optimization code to train them. In the case of the ResNet18 network, which includes batch normalization layersioffeBatchNormalizationAccelerating2015, we fix the batch normalization parameters to their initialization values when performing the linearization. This effectively deactivates batch normalization for the linear approximations. Note that in fortDeepLearningKernel2020 they also compared neural networks with batch normalization to their linear approximations without it.
Obtaining the full eigendecomposition of the NTK Gram matrix is computationally intense and only provides an approximation of the true eigenfunctions. However, in some cases it is possible to compute some of the spectral properties of the NTK in a more direct way, thus circumventing the need to compute the Gram matrix. This is for example the case for the target alignment , which using the formula in Lemma 1 can be computed directly using a weighted average of the Jacobian. This is precisely the way in which we computed for different eigenfunctions in Sec. 4.2.
Appendix C Neural anisotropy experiments
c.1 Dataset construction
We used the same linearly separable dataset construction proposed in ortizjimenez2020neural to test the NADs computed using Theorem 1. Specifically, for every NAD we tested, we sampled training samples from where
As in ortizjimenez2020neural, we used a value of and , and tested on another test samples from the same dataset.
c.2 NAD computation
We provide a few visual examples of the NADs computed for the LeNet and ResNet18 networks using the singular value decomposition (SVD) of the mixed second derivative. Specifically, in our preliminary experiments we did not find any significant difference in the NADs computed using the average over the data or at the origin, and hence opted to perform the SVD overinstead of . This was also done in ortizjimenez2020neural
, and can be seen as a first order approximation of the expectation. Furthermore, to avoid the non-differentiability problems of the ReLU activations described inortizjimenez2020neural we used GeLU activations gelu in these experiments. A few examples of the NADs computed using this procedure can be found in Figure 9.
Appendix D Differed proofs
d.1 Proof of Lemma 1
We give here the proof of Lemma 1 which gives a tractable bound on the kernel norm of a function based on its alignment with the kernel. We restate the theorem to ease readability.
Let denote the alignment of the target with the kernel . Then . Moreover, for the NTK, .
Given the Mercer’s decomposition of , i.e., , the alignment of with can also be written
Now, recall that for a positive definite kernel, the kernel norm of a function admits the expression
The two quantities can be related using Cauchy-Schwarz to obtain
On the other hand, in the case of the NTK
d.2 Proof of Theorem 1
We give here the proof of Theorem 1 which gives a closed form expression of the alignment of a linear predictor with the NTK. We restate the theorem to ease readability.
Let be a unitary vector that parameterizes a linear predictor , and let . The alignment of with is given by
where denotes the derivative of with respect to the weights and the input.
Plugging the definition of a linear predictor on the expression of the alignment for the NTK (see Lemma 1) we get
which using Stein’s lemma becomes
Appendix E Additional results
We now provide the results of some experiments which complement the findings in the main text.
e.1 Learning NTK eigenfunctions
In Section 3.2, we saw that neural networks and their linear approximations share the way in which they rank the complexity of learning different NTK eigenfunctions. However, this experiments were performed using a single training setup, and a single data distribution. For this reason, we now provide two complementary sets of experiments which highlight the generality of the previous results.
On the one hand, we repeated the experiment in Section 3.2. using a different training strategy, replacing SGD with the popular Adam adam optimization algorithm. As we can see in Figure 10 the main findings of Section 3.2. also transfer to this different training setup. In particular, we see that the performance of all models progressively decays with increasing eigenfunction index, and that the linearized models have a clear linear advantage over the non-linear neural networks.
On the other hand, we also repeated the same experiments changing the underlying data distribution, and instead of using the CIFAR10 samples, we used the MNISTlecunGradientbasedLearningApplied1998 digits. The results in Figure 11 show again the same tendency.666Note that the MNIST dataset has more samples than CIFAR10 (i.e., samples), and hence, due to the quadratic complexity of the Gram matrix computation and budgetary reasons, we decided to not perform this experiment on ResNet18. However, we now see, that for the LeNet, the accuracy curves of the neural network and the linearized model cross around , highlighting that the existence of a linear or non-linear advantage greatly depends on the target task.
e.1.1 Convergence speed of other networks
We also provide the collection of training metrics of the MLP (see Figure 12) and the LeNet (see Figure 13) trained on the different eigenfunctions of . Again, as was the case for the ResNet18, we see that training is “harder” for the eigenfunctions corresponding to the smaller eigenvalues, as the time to reach a low training loss, and the distance to the weight initialization grow with eigenfunction index.
e.2 Energy concentration of CIFAR2 labels
We have seen that training a network on a task greatly increases the energy concentration of the training labels with the final kernel. In the main text, however, we only provided the results for the evolution of the energy concentration, i.e., , for the LeNet trained on CIFAR2. Table 1 shows the complete results for the other networks, clearly showing an increase in the energy concentration at the end of training.
|Energy conc. (init)||26.0%||25.8%||22.7%|
|Energy conc. (end)||63.6%||83.0%||96.7%|
e.3 Increase of alignment when learning NTK eigenfunction
Finally, we provide the final alignment plots of the networks trained to predict different eigenfunctions than the one showed in the main text. As we can see in Figure 14 and Figure 15 the relative alignment of the training eigenfunction spikes at the end of training, demonstrating the single-axis rotation of the empirical NTK during training.