Transformers learn in-context by gradient descent

12/15/2022
by   Johannes von Oswald, et al.
0

Transformers have become the state-of-the-art neural network architecture across numerous domains of machine learning. This is partly due to their celebrated ability to transfer and to learn in-context based on few examples. Nevertheless, the mechanisms by which Transformers become in-context learners are not well understood and remain mostly an intuition. Here, we argue that training Transformers on auto-regressive tasks can be closely related to well-known gradient-based meta-learning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradient-descent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers implement gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of optimized Transformers that learn in-context. Furthermore, we identify how Transformers surpass plain gradient descent by an iterative curvature correction and learn linear models on deep data representations to solve non-linear regression tasks. Finally, we discuss intriguing parallels to a mechanism identified to be crucial for in-context learning termed induction-head (Olsson et al., 2022) and show how it could be understood as a specific case of in-context learning by gradient descent learning within Transformers.

READ FULL TEXT

page 1

page 2

page 3

page 4

research
04/26/2023

The Closeness of In-Context Learning and Weight Shifting for Softmax Regression

Large language models (LLMs) are known for their exceptional performance...
research
06/01/2023

Transformers learn to implement preconditioned gradient descent for in-context learning

Motivated by the striking ability of transformers for in-context learnin...
research
07/07/2023

One Step of Gradient Descent is Provably the Optimal In-Context Learner with One Layer of Linear Self-Attention

Recent works have empirically analyzed in-context learning and shown tha...
research
09/11/2023

Uncovering mesa-optimization algorithms in Transformers

Transformers have become the dominant model in deep learning, but the re...
research
09/04/2023

Gated recurrent neural networks discover attention

Recent architectural developments have enabled recurrent neural networks...
research
11/28/2022

What learning algorithm is in-context learning? Investigations with linear models

Neural sequence models, especially transformers, exhibit a remarkable ca...
research
08/14/2023

CausalLM is not optimal for in-context learning

Recent empirical evidence indicates that transformer based in-context le...

Please sign up or login with your details

Forgot password? Click here to reset