Neural Networks can Learn Representations with Gradient Descent

by   Alex Damian, et al.

Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a two layer neural network outside the kernel regime by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form f^⋆(x) = g(Ux) where U: ^d →^r with d ≫ r. When the degree of f^⋆ is p, it is known that n ≍ d^p samples are necessary to learn f^⋆ in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to f^⋆. This results in an improved sample complexity of n≍ d^2 r + dr^p. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation U but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of d.


page 1

page 2

page 3

page 4


Identifying good directions to escape the NTK regime and efficiently learn low-degree plus sparse polynomials

A recent goal in the theory of deep learning is to identify how neural n...

Transfer Learning with Kernel Methods

Transfer learning refers to the process of adapting a model trained on a...

Beyond NTK with Vanilla Gradient Descent: A Mean-Field Analysis of Neural Networks with Polynomial Width, Samples, and Time

Despite recent theoretical progress on the non-convex optimization of tw...

What Can ResNet Learn Efficiently, Going Beyond Kernels?

How can neural networks such as ResNet efficiently learn CIFAR-10 with t...

Learning Two-Layer Neural Networks, One (Giant) Step at a Time

We study the training dynamics of shallow neural networks, investigating...

Every Model Learned by Gradient Descent Is Approximately a Kernel Machine

Deep learning's successes are often attributed to its ability to automat...

Simulated Annealing in Early Layers Leads to Better Generalization

Recently, a number of iterative learning methods have been introduced to...

Please sign up or login with your details

Forgot password? Click here to reset