How is it that GPT performs better at question-answering tasks when you first prompt it with a series of positive examples? In 2020, in the title of the original GPT-3 paper, OpenAI claimed that language models are few shot learners. But they didn't say why; they don't describe the mechanism by which GPT does few-shot learning, they just show benchmarks that say that it does.
Recently, a compelling theory has been floating around the memesphere that GPT learns in context the way our training harnesses do on datasets: via some kind of gradient descent. Except, where our training harnesses do gradient descent on the weights of the model, updating them once per training step, GPT performs gradient descent on the activations of the model, updating them with each layer. This would be big if true! Finally, an accidental mesa-optimizer in the wild.
Recently, I read two papers about gradient descent in activation space. I was disappointed by the first, and even more disappointed by the second. In this post, I'll explain why.
This post is targeted at my peers; people who have some experience in machine learning and are curious about alignment and interpretability. I expect the reader to be at least passingly familiar with the mathematics of gradient descent and mesa-optimization. There will be equations, but you should be able to mostly ignore them and still follow the arguments. You don't need to have read either of the papers discussed in this post to enjoy the discussion, but if my explanation isn't doing it for you the one in the paper might be better.
Thank you to the members of AI Safety 東京 for discussing this topic with me in-depth, and for giving feedback on early drafts of this post.
What is activation space gradient descent?
We normally think of gradient descent as a loop, like this:
But we can unroll the loop, revealing an iterative structure: you start with some initial weights, then via successive applications of gradient descent obtain a series of new weights:
You know what else has an iterative structure? A neural network!
Maybe GPT does in-context learning by treating its activations as weights of some model, using its layers to perform a series of iterative updates to those weights. Then, perhaps in the final layer, it would run the trained model on some data to make predictions. More concretely, when you feed GPT an in-context learning problem like this (prompt in plain text, completion in bold):
What is the capital of France? Paris
What is the capital of England? London
What is the capital of Spain? Madrid
What is the capital of Germany? Berlin
GPT does the following steps:
construct some representation of a model and loss function in activation space, based on the training examples in the prompt
train the model on the loss function by applying an iterative update to the weights with each layer
execute the model on the test query in the prompt
decode the model's response into text
This would be a really cool thing for GPT to be doing! Not only would it explain how GPT does in-context learning (which is currently mostly mysterious), but it would be a very clear example of a mesa-optimizer—a model discovered during training, that itself optimizes an objective other than the training objective. And an important example, too - looking at GPT's architecture you wouldn't expect it to be doing optimization at all!
My questions about this theory are:
How, mathematically / mechanically, does GPT do gradient descent? What models does it train on which loss functions? How does it represent them as activations?
How does GPT determine what the loss is from the natural language prompt?
Does GPT do gradient descent for every prompt? What prompts cause GPT to do the gradient descent thing?
Why does GPT do mesa-optimization, and not something else? What is it about the next-word-prediction training objective that causes mesa-optimization? Or is mesa-optimization a byproduct of RLHF?
Let's take a look at two papers about this and see how many of these questions we can answer.
Transformers Learn in Context by Gradient Descent (van Oswald et al. 2022)
This was my reaction after skimming the intro / results:
Blaine: this is a very exciting paper indeed Anon: "Exciting" in a "oh my god I am panicking"-kind of way 🥲 Blaine: nah, exciting in a "finally the mesa-optimizer people have something to poke at" kind of way Blaine: they show a weight construction of transformers that does gradient descent in activation space, then show that the transformer training procedure actually does find this construction in practice Blaine: kinda the flip of Zhang et al. 2022, which demonstrates a weight construction for transformers that does correct logical inference, then shows that gradient descent does not in practice find such a construction Blaine: I would have thought that gradient descent and logical inference were equally difficult problems, so I'm surprised that one is in practice learned from data and the other isn't
In retrospect, my surprise was justified - this paper isn't claiming what I thought it was claiming, and it's not nearly as conclusive as one would think from a skim read. That being said, I still applaud von Oswald et al.; this is good interpretability, and I'll follow future work with great interest.
Why am I disappointed?
I thought this paper was going to tell me how GPT does few-shot learning. In my defence, you can see how I would think that from a skim read of the abstract:
Transformers have become the state-of-the-art neural network architecture across numerous domains of machine learning. This is partly due to their celebrated ability to transfer and to learn in-context based on few examples. Nevertheless, the mechanisms by which Transformers become in-context learners are not well understood and remain mostly an intuition. Here, we argue that training Transformers on auto-regressive [here, my eyes glaze over; I trust that by reading the paper I'll learn what this jargon salad means] [...] Thus we show how trained Transformers implement gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of optimized Transformers that learn in-context. [...]
GPT is an optimized transformer that learns in context! It's the optimized transformer that learns in context!
But it turns out that the jargon salad was very important. This paper is not interested in explaining large language models like GPT. Instead, von Oswald at al. focus on small (usually one-layer) models trained on toy regression problems:
We now introduce [...] a training dataset D=(xi,yi)Ni=1 comprising of input samples xi∈RNy and respective labels yi∈RNy. [...] we consider an in-context learning problem where we are given N context tokens together with an extra query token, indexed by N+1. In terms of our linear regression problem, the N context tokens ej=(xj,yj)∈RNx+Ny correspond to the N training points in D, and the N+1-th tokeneN+1=(xN+1,yN+1)=(xtest,^ytest)=etest to the test input xtest and the corresponding prediction ^ytest.
i.e. where an in-context learning problem for GPTmight look like this:
What is the capital of France? Paris
What is the capital of England? London
What is the capital of Spain? Madrid
What is the capital of Germany? Berlin
the in-context learning problems van Oswald et al. consider look like this:
Each pair of numbers is treated as a single token[1]; this representation is therefore very natural for autoregressive transformers, whose whole game is next token prediction. Perhaps in response to reviewer comments, van Oswald et al. note that this doesn't quite match the traditional in-context learning framing; notice that in GPT's problem, the query is presented as part of the text stream, and the model is only asked to predict the answer, whereas in van Oswald et al.'s formulation the query and answer are part of the same token. Towards the end of the paper, they reframe the prediction task to look like this:
where each (x,y) pair is presented as a sequence of two tokens, and the model has to learn that they are associated pairs. They demonstrate that this doesn't really impact their argument.
I take a different issue with the framing. The GPT in-context learning task is mostly one of problem identification and recall; the model already "knows" that Berlin is the capital of Germany; the job of the prompt is to get the model to realize that it is being asked to perform a truthful question-answering task. The key issue in zero-shot / few-shot learning is that questions are ambiguous! Without context, all of these are good continuations:
What is the capital of Germany? Berlin
What is the capital of Germany? What is the capital of Sweden? What is the capital of Italy?
What is the capital of Germany? Who cares! Geography is for nerds.
The job of a few-shot / zero-shot learning system is to learn the human prior over problem-space, such that you can answer the "right" question among a selection of equally plausible candidates.
But the tokenized-regression-dataset framing lacks this important quality! The whole training dataset is contained in the prompt, and the question the model answers is totally unambiguous. Further, the model is specifically trained to perform whole-dataset regression tasks. This doesn't at all match how GPT is trained! If GPT does in-context learning, it does so by accident. Nobody at OpenAI was trying to build a few-shot learner—they were trying to build a next-word predictor, and the interesting thing is that they got a few-shot learner for free. In contrast, van Oswald et al.'s model is very specifically and intentionally trained to do many-shot in-context learning.
But even more than that, the most surprising part of the "language models do in-context learning by gradient descent" theory is that "What is the capital of Germany? Berlin" does not look like a problem that can be solved by gradient descent. In order to solve it by gradient descent, one first has to project it into some mathematical framing, and the details of this projection would be super interesting! That's what I came here to find, and I'm sad that I didn't.
On Linearity
But you know what? This is all a fuss about nothing. This paper doesn't teach me anything about whether or not GPT does few-shot learning by gradient descent, and that's fine. They didn't set out to prove that GPT does few-shot learning by gradient descent; they want to show that transformers do in-context learning by gradient descent. Let's meet the paper where it's at and see if it excels on its own terms.
Blaine: as always, on a closer read this paper is much less exciting Blaine: their results only hold for linear self-attention—self attention that's equivalent to a single matrix multiplication, and only for linear regression problems. I cannot understate how much mileage they get out of assuming that everything is linear; if they use softmax attention (the kind that everyone uses) they get much less convincing results Anon: Is that underwhelming result because they were unable to try the same thing with 'softmax attention' (but it might work if it could be tried)? Blaine: they tried it for softmax attention and it didn't work Blaine: this is the relevant figure - if the single-layer softmax attention was doing gradient descent, the green triangles and blue crosses in the top row would be on top of each other (as they are in the bottom row)Blaine: in contrast, this is the figure for a single linear layer
Linearity is a tricky concept to grasp unless you're the kind of person who reads mathematics papers for fun. If you are that kind of person, alarm bells should already be ringing. If you're not, then settle in while I tell you a story.
Are you sitting comfortably?
Good, then I'll begin.
Story Time
Once upon a time, I was working for a self-driving car company. A self-driving car needs to be able to perceive the road around it, and we did this using a bunch of machine learning systems. When transformers became a Big Deal, we tried to replace some of perception systems (which were mostly CNNs, the king whose throne transformers usurped) with attention-based systems. But it was really hard! An average paragraph of text contains maybe 200 tokens. A 128-channel lidar scan has perhaps 128×3600≈400 thousand points. A full HD image has 1920×1080≈2 million pixels. Non-linear self-attention is O(n2) in the number of input tokens[2], and we needed to run our perception systems at least 10 times a second. Clearly, we couldn't just fling all our bytes into a Perceiver IO and call it a day.
Fortunately for us, other people had noticed the problem, and there was a huge wealth of literature on efficient transformer alternatives. This is a good survey paper; it's the one that we used. The most appealing approaches involved linearization. Efficient Attention (Shen et al. 2020) is a central example; they show that if you replace the non-linear softmax with a linear similarity function, then swap a few matrix multiplications around, you can avoid computing a huge matrix of intermediate values, bringing the complexity down from O(n2) to O(n). And all without hurting performance!
Efficient attention achieves substantially better performance-cost trade-off. As rows res3 to fpn5 show, inserting an efficient attention module or a non-local module at the same location in a network has nearly identical effects on the performance, while efficient attention uses orders of magnitude less resources.
In retrospect, we should have expected this; it's well known that non-linearity is a necessary part of neural networks' success. We ended up taking a different approach that retained non-linearity, exploiting the problem's geometric properties to reduce the size of the context window. But the experience of banging my head repeatedly against the linearity wall has left me with a deep suspicion. If you want to claim that your results generalize from linear to non-linear models, I'm going to make you work for it.
Did van Oswald et al. do the work?
There's a lot to like in this paper. The mathematical presentation is clear and novel, and their experimental results mostly support their claims.
The core of the paper is a delicious mathematical trick. By rearranging the equation for gradient descent, you can think of a step of gradient descent as being an update to the data, rather than an update to the weights. We usually think of the gradient descent algorithm like this:
randomly initialize your weights W0∼N(0,1)
calculate ΔWi=−η∇WiL(Wi,X,Y), a weight update corresponding to a tiny step in the direction of the gradient of the loss with respect to the weights
update your weights Wi+1=Wi+ΔWi
repeat 2 and 3 until convergence or you get bored
They show that, for linear models (the proof does not hold for non-linear models),this is precisely equivalent to the following algorithm:
randomly initialize your weights as before W∼N(0,1), and set Y0=Y
calculate ΔYi=−(ΔW)X, where ΔW=−η∇WL(W,X,Yi). This is a "data update" corresponding to moving the training labels a tiny step in the direction of the outputs of the model given by the random weights.
update your data Yi+1=Yi+ΔYi
repeat 2 and 3 until convergence or you get bored
If we take this dual approach, we can get predictions on held-out data by adding a test point (x⋆,−Wx⋆) and keeping track of the data updates Δy⋆ at each training iteration. At the end of training we'll have a point (x⋆,−Wx⋆+∑Δy⋆), and by linearity of matrix multiplication we have
y⋆=(W+∑ΔW)x⋆=Wx⋆+∑(ΔW)x⋆=Wx⋆−∑Δy⋆=−(−Wx⋆+∑Δy⋆)
i.e. we can recover y⋆ by taking the negative of our test point.
Importantly, at no point do we have to keep track of the weights of the model. We can do training and inference, simultaneously, just by making iterative updates to the training data. This is very convenient for us, because transformers work by doing iterative updates on their input data. A transformer maintains a residual stream for each token in its context window, and each layer of the transformer updates the latent-space representation of each token. Updating the data instead of the weights is the natural way for transformers to behave!
Van Oswald et al. show that, if you remove the pesky non-linearity and rearrange the matrix multiplications a little bit, you can parameterize a self-attention layer so that it does one step of gradient descent as in the procedure above. Importantly, not all parameterizations work; the value-projection and key-query products have to be of a specific form. So even though in theory it's possible for transformers to do this kind of gradient descent (just as in theory it's possible for any two-layer network to arbitrarily closely approximate any function R→R), it remains to be seen whether the training procedure finds such a parameterization in practice.
Van Oswald et al. then show that, in fact, the training procedure finds such a parameterization in practice. This is by far the best bit of the paper. Here's that figure again:
From left to right:
van Oswald et al. compare the training loss of a single-layer transformer with the loss of a linear model trained by one-step gradient descent with L2 loss (henceforth the "reference model"). Pay close attention to the scale on the left - the losses converge to the same value, and that value is around 0.20, not 0.
they show that over the course of training, the models' predictions converge to each other, as do their internals
they show that the two procedure generalize identically, i.e. for regression problems with 5 / 10 / 20 / 35 / 50 datapoints, the transformer and the reference model get the same loss. Note that the transformer has only ever seen datasets with exactly 10 points during training.
a different generalization test. The transformer is only trained on training data with −1<x<1 - this is the area to the left of the dotted vertical line. Again, the two models perform the same out-of-distribution.
This is a lot of effort to go to to convince me that two models are the same. Unfortunately, most of the evidence is merely suggestive—two models can have the same loss, and make the same predictions, without implementing the same algorithm. Of these plots, by far the most important is the centre left. This plot has three lines, and the most important one is the green one labelled "Model cos":
"Model cos" is the cosine similarity between the sensitivities of the two models:
Model cos=simcos(∂^yθGD(xτ,xtest)∂xtest,∂^yθ(xτ,xtest)∂xtest),A⋅B=∥A∥∥B∥cosθ,simcos(A,B)=cosθ=A⋅B∥A∥∥B∥.
Here "sensitivity" means the (partial) derivative of the model's output w.r.t. its input, ie. "if we change the input, how does the output change?". The cosine similarity is the cosine of the angle between two vectors. Van Oswald et al. state (and I agree) that if two linear models' sensitivities have cosine similarity equal to 1, they are the same model (up to a scalar coefficient):
And these two models have cosine similarity 1! They even show that if you repeatedly apply the transformer layer, you get the same loss curve as gradient descent:
Result! A one layer, linearized transformer trained on regression problems will end up implementing one-step gradient descent for a linear model with L2 loss. Even with all the bold caveats, this is a cool finding. 👏👏👏
But remember how linearity makes me suspicious? The class of functions that can be represented by a linear model is really small. Sure, a linear transformer is equivalent to one-step of gradient descent for a linear model on an L2 loss. It's also equivalent to one matrix multiplication. Any linear transform can be expressed as a linear transformer. The cute result here is that one-step gradient descent for linear models is itself a linear transform; once you have that, its representation as a one-layer linear transformer is almost a given[4].
The question now is whether these results apply for the kinds of transformers that people actually use: multi-layer, non-linear transformers.
Scaling up
Unfortunately, the closer we get to architectures people actually use, the fuzzier the picture becomes. Van Oswald et al. try two ways of scaling up to larger models.
standard models, where they stack multiple layers on top of each other but train them independently.
"recurrent" models, where they apply one layer multiple times during training. Since they claim that a one-layer transformer represents one step of gradient descent, applying the layer twice should be equivalent to two steps of gradient descent.
Look first at the green lines in the second column of plots, the ones labelled "Model cos". Notice that these lines do not trend to 1. Since we established that two models are the same iff they have sensitivity cosine similarity 1, that means that neither of these models are doing gradient descent.
The authors could have dropped these figures from the paper, only publishing the convincing single-layer results. To their great credit, they didn't. Instead, they dug deeper to find a different algorithm the network could be implementing:
When optimizing such Transformers with K layers, we observe that these models generally outperform K steps of plain gradient descent, see Figure 3. Their behavior is however well described by a variant of gradient descent, for which we tune a single parameter γ defined through the transformation function H(X) which transforms the input data according to xj←H(X)xj, with H(X)=(I−γXX⊤). We explain how a LSA layer can implement this input transformation in Appendix A.7. We term this gradient descent variant GD++.
When they plot GD++ on their figures, they find it has sensitivity cosine similarity 1.
I'm torn by this. On the one hand, this is exactly the kind of thing that I think people should be doing; I praise work like Circuits, Transformer Circuits and DEER that peer into the weights of deep networks and speculate as to what functions they might be implementing. I don't want to make an isolated demand for rigour. However, in this particular case, I notice that switching from GD to GD++ gives the authors a bunch of free variables they can adjust until they get results that fit. Despite my worldliness, handsomeness and great wealth of experience in ML, I've never encountered GD++ before; is it a standard technique, or did they do a big search across algorithm space to find one that fit? That they invent a name for it suggests to me the latter. It's also curious to me that one-layer transformers correspond to vanilla gradient descent, not GD++; if GD++ outperforms vanilla gradient descent, and the training procedure can produce models parameterized to perform GD++, then why does it show up when you're training recurrent two-layer models, but not one-layer models?
They also attempt to show that non-linear transformers (both softmax transformers and linear transformers preceded by non-linear MLPs) do the gradient descent thing, to mixed success. First, linear transformers with MLPs:
Observe that the green line on the right labelled "Partial cosine" does not trend to 1. These models are not the same.
Now the softmax transformers:
Again, notice that the green lines labelled Model cos do not trend to 1. The green line for figure 9b trends to almost 1, but almost-1 and 1 are different numbers. These models are not the same.
The Takeaway
Anon: Hard to understand the takeaway. Some forms of LLM probably generate internal models and some definitely don't? Blaine: takeaway: if you try really really hard, you can get a particular kind of small transformer to do something that looks like gradient descent if you squint Blaine: this is suggestive that larger more powerful models might be doing some kind of gradient-descent-in-activation-space; maybe when you say to chat-GPT "1 -> 7, 2 -> 4, 3 -> 2, 4 -> 1, 5 ->" it does an optimization to fit a model and then spits out the answer Blaine: but it's not quite the "LLMs have mesa-optimizers" paper that I thought it was from the abstract and introduction
Remember way back at the start of this article we were looking to learn how GPT does few-shot learning? Maybe we can find another paper that will tell us.
Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers (Dai et al. 2022)
Oooooooh boy, now we're talking. I was told that GPT does few-shot learning by activation-space gradient descent, and you probably couldn't find me a more explicit claim if you tried. Let's dig in and see if the paper lives up to its title.
In which the paper does not live up to its title
Blaine: this paper is decidedly less impressive. makes very similar claims, but this time cashes them out definitionally Blaine: they show that you can write "applying a linear model W with one gradient update step" in a form that resembles linear attention. Blaine: then, they argue, since you can frame any (linear model) gradient update as linear attention, the converse holds — every linear attention layer is a "meta-gradient update" Blaine: to show that this is true, they compare ICL (in-context-learning) with finetuning by gradient descent, except they only do one step of gradient descent, and they only finetune some of the weights Blaine: they then "find that ICL has many properties in common with finetuning" but all the properties they list are the properties they explicitly set to be the same
Oh dear.
Recall how in the previous paper, van Oswald et al. do some rearranging of the standard framing of gradient descent to show that updating the weights Wi+1=Wi+ΔWi is (for certain models and losses) equivalent to updating the labels Yi+1=Yi+ΔYi? Here, Dai et al. do a similar rearrangement to show that updating a linear model by one step of gradient descent is equivalent to one layer of linear attention:
where E=(ei) is a matrix of gradients, X′=(xi) is training data and x is the query point. Following the citation chain, this formulation comes from The Dual Form of Neural Networks Revisited (Irie et al. 2022). That paper uses the formulation to reframe the linear layers in MLPs as attention layers, layers that attend to the gradients produced during training; this lets them inspect which training examples the model is making use of when it makes its prediction, which is a neat trick (if very computationally expensive).
This paper notices that, since linear-model gradient descent can be framed as linear attention, we can run the process backwards. Any attention layer can then be rearranged to look like gradient descent:
Attn(V,K,q)≈(WVX(WKX)⊤+∑iWVx′i⊕(WKx′i)⊤)q.
If we label WVX(WKX)⊤=WZSL as "initial parameters" and ∑iWVx′i⊗(WKx′i)⊤=ΔWICL as a "meta-gradient update", now every transformer network is doing meta-gradient descent!
Attn(V,K,q)≈(WZSL+ΔWICL)q.
I am unimpressed.
Did you notice that they snuck linearity in when you weren't looking? Van Oswald et al. make a big deal out of how the models they're testing are different from the models people actually use, and they include a bunch of extra experiments exploring both linear and non-linear models, presenting even the unfavourable results. While Dai et al. don't exactly hide the non-linearity, they don't call attention to it either; most mentions in this summary are mine. That "≈" is doing a lot of work.
The important part of gradient descent is not the descent, but the gradients. Of what function is the meta-gradient a gradient? What is the thing being optimized? Dai et al. don't even try to tell us.
Continuing, they also note that you can frame fine-tuning by one step of gradient descent as a one-step gradient update. They then compare fine-tuning (FT) against In-Context Learning (ICL), but not before making some adjustments "for a more fair comparison":
In order to compare the meta-optimization of ICL with explicit optimization, we design a specific finetuning setting as a baseline for comparison. Considering that ICL directly takes effect on only the attention keys and values, our finetuning setting also updates only the parameters for the key and value projection. [...]
we specify the training examples as the demonstration examples for ICL;
we train each example for only one step in the same order as demonstrated for ICL;
we format each training example with the same template used for ICL T(x′i,y′i) and use the causal language modeling objective for finetuning.
They then "find that ICL has many properties in common with finetuning":
Both Perform Gradient Descent Comparing Equation (12) and Equation (13), we find that both ICL and finetuning introduce updates (∆WICL v.s. ∆WFT) to WZSL, which can both be regarded as gradient descent. The only difference is that ICL produces meta-gradients by forward computation while finetuning acquires real gradients by backpropagation.
Same Training Information [...]
Same Causal Order of Training Examples [...]
Both Aim at Attention Compared with zeroshot learning, the direct effect of ICL and our finetuning are both restricted to the computation of attention keys and values.
This is just a list of things they have defined to be the same. The first point is just a restatement of their thesis that all attention models do "meta-gradient descent", and we should treat that like real gradient descent. The second point is vacuous. The third point is specifically addressed by the adjustments to the ordinary fine-tuning setting. The fourth point is addressed by their restriction of fine-tuning to only update the attention key and value matrices. These are tautologies, not novel results.
Am I being unfair here? Maybe they're not trying to present results, just putting a weird amount of emphasis on the steps they took to make their experiments fair. Let's instead look at section 4.4 Results to see what they think their novel contributions are.
Their most compelling result is that the "weight update" terms in the meta-gradient rearrangement of the attention formula tend, in practice, to be more similar to the weight updates produced by their finetuning procedure than they are to random updates. Bonus points for doing these experiments with a GPT, rather than with a weird toy network you expect to generalize to GPT:
The columns to look at are "SimAOU" (similarity between the meta-gradients and the true gradients) and "Random SimAOU" (similarity between the meta-gradients and a random vector). But the problem is that on average a random weight update will make your model worse, and we know that both fine-tuning and in-context learning improve performance. It might just be that weight updates that improve performance are more similar to each other than they are to noise, regardless of the underlying mechanism. See also the SimAM column, where they compute the cosine similarity between the attention maps given by FT and ICL. If two linear models implement the same algorithm, they should have cosine similarity 1! The highest we see here is 0.687.
The rest of the paper's results indicate that whatever the "meta-gradients" are, they're definitely not the same gradients produced by one-step fine-tuning. Here we see that the similarity between the two varies substantially across layers of the network, with some pretty wild error bars:
Dai et al. conclude that
The results prove that ICL behaves similarly to explicit finetuning at the prediction level, the representation level, and the attention behavior level.
They do not prove anything of the sort. Most importantly, they do not show that language models are mesa-optimizers. Calling attention layers "meta-gradient updates" is like calling a rock in a pipe a utility optimizer and suggesting we should be scared lest it maniacally pursue reducing the flow of Earth's water.
Concluding thoughts
I came into this exercise hoping to find a wealth of evidence that the transformer's secret special sauce is that it's doing gradient descent in activation space. This would be a really pleasing result:
It would demystify transformers' magical powers of in-context learning. We don't understand much about gradient descent, but we understand it a hell of a lot better than we currently understand GPT.
It would unlock new avenues of research; understanding which functions GPT is optimising and why could tell us a lot about its generalization behaviour.
Finally we would have an example of a powerful (but not existentially risky) system with an accidental mesa-optimizer. What a great object for alignment researchers to study!
I can see why so many people want this theory to be true, but as far as I can tell the evidence, while suggestive, just doesn't bear out. Of these two papers, I think only van Oswald et al. 2022 is worth your time, but their most impressive results make liberal use of linearity in a way that makes me suspicious that they will generalize to larger, non-linear models. I look forward to reading further research on the topic. Were I to work on this, here are some questions I'd pursue:
can we formulate softmax attention as a gradient update, perhaps by working out of which function Dai et al.'s meta-gradient is a gradient?
do single-layer transformers still look like gradient updates if we don't train them specifically to do regression tasks?
can we, by staring deeply into the matrices of GPT à la circuits, recognize any that satisfy van Oswald et al.'s constraints on PWV and WKWQ?
People unfamiliar with neural networks might think that the string "(-1, -2.31)" is quite complicated as a token compared to "What" or " is"; surely you would need an infinite number of tokens to represent all pairs of real numbers! Wouldn't most of the network would be devoted to learning the mapping from abstract tokens <token 2352> to pairs of numbers?
If we used the same tokenizer for these models as we do for GPT these would be great intuitions! But most of the work of GPT's tokenization is done in the embedding step, where we map symbolic tokens such as "What" or " is" to high-dimensional real-valued vectors in a "semantically meaningful" space. Only once we have real-valued vectors can we actually run the matrix multiplications that make up the bulk of a neural network. But here we start with a pair of real numbers! So it doesn't make sense to map them onto abstract symbols and then reproject them into a high-dimensional semantic space. We can just pass them straight in as a two-dimensional vector, skipping the embedding step entirely.
Read: attention scales horribly—if it takes one second to predict the next word of a 200 word paragraph, it takes one and a half minutes to predict the next word of a 2000 word essay and almost three hours to predict the next word of a 20000 word novella. This is why LLMs have such small context windows.
In contrast, in the kind of toy problems where we could run full-fat non-linear transformers without running out of memory, they solved the problems easily with great performance.
The word "almost" is doing a lot of work, and van Oswald et al. deserve a lot of credit for actually doing the legwork to demonstrate that the equivalence holds in practice as well as in theory.
How is it that GPT performs better at question-answering tasks when you first prompt it with a series of positive examples? In 2020, in the title of the original GPT-3 paper, OpenAI claimed that language models are few shot learners. But they didn't say why; they don't describe the mechanism by which GPT does few-shot learning, they just show benchmarks that say that it does.
Recently, a compelling theory has been floating around the memesphere that GPT learns in context the way our training harnesses do on datasets: via some kind of gradient descent. Except, where our training harnesses do gradient descent on the weights of the model, updating them once per training step, GPT performs gradient descent on the activations of the model, updating them with each layer. This would be big if true! Finally, an accidental mesa-optimizer in the wild.
Recently, I read two papers about gradient descent in activation space. I was disappointed by the first, and even more disappointed by the second. In this post, I'll explain why.
This post is targeted at my peers; people who have some experience in machine learning and are curious about alignment and interpretability. I expect the reader to be at least passingly familiar with the mathematics of gradient descent and mesa-optimization. There will be equations, but you should be able to mostly ignore them and still follow the arguments. You don't need to have read either of the papers discussed in this post to enjoy the discussion, but if my explanation isn't doing it for you the one in the paper might be better.
Thank you to the members of AI Safety 東京 for discussing this topic with me in-depth, and for giving feedback on early drafts of this post.
What is activation space gradient descent?
We normally think of gradient descent as a loop, like this:
But we can unroll the loop, revealing an iterative structure: you start with some initial weights, then via successive applications of gradient descent obtain a series of new weights:
You know what else has an iterative structure? A neural network!
Maybe GPT does in-context learning by treating its activations as weights of some model, using its layers to perform a series of iterative updates to those weights. Then, perhaps in the final layer, it would run the trained model on some data to make predictions. More concretely, when you feed GPT an in-context learning problem like this (prompt in plain text, completion in bold):
GPT does the following steps:
This would be a really cool thing for GPT to be doing! Not only would it explain how GPT does in-context learning (which is currently mostly mysterious), but it would be a very clear example of a mesa-optimizer—a model discovered during training, that itself optimizes an objective other than the training objective. And an important example, too - looking at GPT's architecture you wouldn't expect it to be doing optimization at all!
My questions about this theory are:
Let's take a look at two papers about this and see how many of these questions we can answer.
Transformers Learn in Context by Gradient Descent (van Oswald et al. 2022)
Links: arXiv, LessWrong
This was my reaction after skimming the intro / results:
In retrospect, my surprise was justified - this paper isn't claiming what I thought it was claiming, and it's not nearly as conclusive as one would think from a skim read. That being said, I still applaud von Oswald et al.; this is good interpretability, and I'll follow future work with great interest.
Why am I disappointed?
I thought this paper was going to tell me how GPT does few-shot learning. In my defence, you can see how I would think that from a skim read of the abstract:
GPT is an optimized transformer that learns in context! It's the optimized transformer that learns in context!
But it turns out that the jargon salad was very important. This paper is not interested in explaining large language models like GPT. Instead, von Oswald at al. focus on small (usually one-layer) models trained on toy regression problems:
i.e. where an in-context learning problem for GPTmight look like this:
the in-context learning problems van Oswald et al. consider look like this:
Each pair of numbers is treated as a single token[1]; this representation is therefore very natural for autoregressive transformers, whose whole game is next token prediction. Perhaps in response to reviewer comments, van Oswald et al. note that this doesn't quite match the traditional in-context learning framing; notice that in GPT's problem, the query is presented as part of the text stream, and the model is only asked to predict the answer, whereas in van Oswald et al.'s formulation the query and answer are part of the same token. Towards the end of the paper, they reframe the prediction task to look like this:
where each (x,y) pair is presented as a sequence of two tokens, and the model has to learn that they are associated pairs. They demonstrate that this doesn't really impact their argument.
I take a different issue with the framing. The GPT in-context learning task is mostly one of problem identification and recall; the model already "knows" that Berlin is the capital of Germany; the job of the prompt is to get the model to realize that it is being asked to perform a truthful question-answering task. The key issue in zero-shot / few-shot learning is that questions are ambiguous! Without context, all of these are good continuations:
The job of a few-shot / zero-shot learning system is to learn the human prior over problem-space, such that you can answer the "right" question among a selection of equally plausible candidates.
But the tokenized-regression-dataset framing lacks this important quality! The whole training dataset is contained in the prompt, and the question the model answers is totally unambiguous. Further, the model is specifically trained to perform whole-dataset regression tasks. This doesn't at all match how GPT is trained! If GPT does in-context learning, it does so by accident. Nobody at OpenAI was trying to build a few-shot learner—they were trying to build a next-word predictor, and the interesting thing is that they got a few-shot learner for free. In contrast, van Oswald et al.'s model is very specifically and intentionally trained to do many-shot in-context learning.
But even more than that, the most surprising part of the "language models do in-context learning by gradient descent" theory is that "What is the capital of Germany? Berlin" does not look like a problem that can be solved by gradient descent. In order to solve it by gradient descent, one first has to project it into some mathematical framing, and the details of this projection would be super interesting! That's what I came here to find, and I'm sad that I didn't.
On Linearity
But you know what? This is all a fuss about nothing. This paper doesn't teach me anything about whether or not GPT does few-shot learning by gradient descent, and that's fine. They didn't set out to prove that GPT does few-shot learning by gradient descent; they want to show that transformers do in-context learning by gradient descent. Let's meet the paper where it's at and see if it excels on its own terms.
Linearity is a tricky concept to grasp unless you're the kind of person who reads mathematics papers for fun. If you are that kind of person, alarm bells should already be ringing. If you're not, then settle in while I tell you a story.
Are you sitting comfortably?
Good, then I'll begin.
Story Time
Once upon a time, I was working for a self-driving car company. A self-driving car needs to be able to perceive the road around it, and we did this using a bunch of machine learning systems. When transformers became a Big Deal, we tried to replace some of perception systems (which were mostly CNNs, the king whose throne transformers usurped) with attention-based systems. But it was really hard! An average paragraph of text contains maybe 200 tokens. A 128-channel lidar scan has perhaps 128×3600≈400 thousand points. A full HD image has 1920×1080≈2 million pixels. Non-linear self-attention is O(n2) in the number of input tokens[2], and we needed to run our perception systems at least 10 times a second. Clearly, we couldn't just fling all our bytes into a Perceiver IO and call it a day.
Fortunately for us, other people had noticed the problem, and there was a huge wealth of literature on efficient transformer alternatives. This is a good survey paper; it's the one that we used. The most appealing approaches involved linearization. Efficient Attention (Shen et al. 2020) is a central example; they show that if you replace the non-linear softmax with a linear similarity function, then swap a few matrix multiplications around, you can avoid computing a huge matrix of intermediate values, bringing the complexity down from O(n2) to O(n). And all without hurting performance!
But try as we might, we could not replicate these results. It wasn't even that the linearized attention models performed worse than their non-linear counterparts; they didn't perform at all, even in toy problems[3]. And it wasn't just Shen et al. 2020; we had the same trouble with Rethinking Attention with Performers (Choromanski, Likhosherstov, Dohan, Song, Gane, Sarlos, Hawkins and Davis et al 2022), same with Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (Katharopoulos et al. 2020).
In retrospect, we should have expected this; it's well known that non-linearity is a necessary part of neural networks' success. We ended up taking a different approach that retained non-linearity, exploiting the problem's geometric properties to reduce the size of the context window. But the experience of banging my head repeatedly against the linearity wall has left me with a deep suspicion. If you want to claim that your results generalize from linear to non-linear models, I'm going to make you work for it.
Did van Oswald et al. do the work?
There's a lot to like in this paper. The mathematical presentation is clear and novel, and their experimental results mostly support their claims.
The core of the paper is a delicious mathematical trick. By rearranging the equation for gradient descent, you can think of a step of gradient descent as being an update to the data, rather than an update to the weights. We usually think of the gradient descent algorithm like this:
They show that, for linear models (the proof does not hold for non-linear models), this is precisely equivalent to the following algorithm:
If we take this dual approach, we can get predictions on held-out data by adding a test point (x⋆,−Wx⋆) and keeping track of the data updates Δy⋆ at each training iteration. At the end of training we'll have a point (x⋆,−Wx⋆+∑Δy⋆), and by linearity of matrix multiplication we have
y⋆=(W+∑ΔW)x⋆=Wx⋆+∑(ΔW)x⋆=Wx⋆−∑Δy⋆=−(−Wx⋆+∑Δy⋆)i.e. we can recover y⋆ by taking the negative of our test point.
Importantly, at no point do we have to keep track of the weights of the model. We can do training and inference, simultaneously, just by making iterative updates to the training data. This is very convenient for us, because transformers work by doing iterative updates on their input data. A transformer maintains a residual stream for each token in its context window, and each layer of the transformer updates the latent-space representation of each token. Updating the data instead of the weights is the natural way for transformers to behave!
Van Oswald et al. show that, if you remove the pesky non-linearity and rearrange the matrix multiplications a little bit, you can parameterize a self-attention layer so that it does one step of gradient descent as in the procedure above. Importantly, not all parameterizations work; the value-projection and key-query products have to be of a specific form. So even though in theory it's possible for transformers to do this kind of gradient descent (just as in theory it's possible for any two-layer network to arbitrarily closely approximate any function R→R), it remains to be seen whether the training procedure finds such a parameterization in practice.
Van Oswald et al. then show that, in fact, the training procedure finds such a parameterization in practice. This is by far the best bit of the paper. Here's that figure again:
From left to right:
This is a lot of effort to go to to convince me that two models are the same. Unfortunately, most of the evidence is merely suggestive—two models can have the same loss, and make the same predictions, without implementing the same algorithm. Of these plots, by far the most important is the centre left. This plot has three lines, and the most important one is the green one labelled "Model cos":
"Model cos" is the cosine similarity between the sensitivities of the two models:
Model cos=simcos(∂^yθGD(xτ,xtest)∂xtest,∂^yθ(xτ,xtest)∂xtest),A⋅B=∥A∥∥B∥cosθ,simcos(A,B)=cosθ=A⋅B∥A∥∥B∥.Here "sensitivity" means the (partial) derivative of the model's output w.r.t. its input, ie. "if we change the input, how does the output change?". The cosine similarity is the cosine of the angle between two vectors. Van Oswald et al. state (and I agree) that if two linear models' sensitivities have cosine similarity equal to 1, they are the same model (up to a scalar coefficient):
simcos(A,B)=1⟹cos(θ)=1⟹θ=0⟹A∝B.Sosimcos(∂∂xWx,∂∂xQx)=1⟹∂∂xWx∝∂∂xQx⟹W∝Q.And these two models have cosine similarity 1! They even show that if you repeatedly apply the transformer layer, you get the same loss curve as gradient descent:
Result! A one layer, linearized transformer trained on regression problems will end up implementing one-step gradient descent for a linear model with L2 loss. Even with all the bold caveats, this is a cool finding. 👏👏👏
But remember how linearity makes me suspicious? The class of functions that can be represented by a linear model is really small. Sure, a linear transformer is equivalent to one-step of gradient descent for a linear model on an L2 loss. It's also equivalent to one matrix multiplication. Any linear transform can be expressed as a linear transformer. The cute result here is that one-step gradient descent for linear models is itself a linear transform; once you have that, its representation as a one-layer linear transformer is almost a given[4].
The question now is whether these results apply for the kinds of transformers that people actually use: multi-layer, non-linear transformers.
Scaling up
Unfortunately, the closer we get to architectures people actually use, the fuzzier the picture becomes. Van Oswald et al. try two ways of scaling up to larger models.
Look first at the green lines in the second column of plots, the ones labelled "Model cos". Notice that these lines do not trend to 1. Since we established that two models are the same iff they have sensitivity cosine similarity 1, that means that neither of these models are doing gradient descent.
The authors could have dropped these figures from the paper, only publishing the convincing single-layer results. To their great credit, they didn't. Instead, they dug deeper to find a different algorithm the network could be implementing:
When they plot GD++ on their figures, they find it has sensitivity cosine similarity 1.
I'm torn by this. On the one hand, this is exactly the kind of thing that I think people should be doing; I praise work like Circuits, Transformer Circuits and DEER that peer into the weights of deep networks and speculate as to what functions they might be implementing. I don't want to make an isolated demand for rigour. However, in this particular case, I notice that switching from GD to GD++ gives the authors a bunch of free variables they can adjust until they get results that fit. Despite my worldliness, handsomeness and great wealth of experience in ML, I've never encountered GD++ before; is it a standard technique, or did they do a big search across algorithm space to find one that fit? That they invent a name for it suggests to me the latter. It's also curious to me that one-layer transformers correspond to vanilla gradient descent, not GD++; if GD++ outperforms vanilla gradient descent, and the training procedure can produce models parameterized to perform GD++, then why does it show up when you're training recurrent two-layer models, but not one-layer models?
They also attempt to show that non-linear transformers (both softmax transformers and linear transformers preceded by non-linear MLPs) do the gradient descent thing, to mixed success. First, linear transformers with MLPs:
Observe that the green line on the right labelled "Partial cosine" does not trend to 1. These models are not the same.
Now the softmax transformers:
Again, notice that the green lines labelled Model cos do not trend to 1. The green line for figure 9b trends to almost 1, but almost-1 and 1 are different numbers. These models are not the same.
The Takeaway
Remember way back at the start of this article we were looking to learn how GPT does few-shot learning? Maybe we can find another paper that will tell us.
Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers (Dai et al. 2022)
arXiv, some previous discussion in this LessWrong post
Oooooooh boy, now we're talking. I was told that GPT does few-shot learning by activation-space gradient descent, and you probably couldn't find me a more explicit claim if you tried. Let's dig in and see if the paper lives up to its title.
In which the paper does not live up to its title
Oh dear.
Recall how in the previous paper, van Oswald et al. do some rearranging of the standard framing of gradient descent to show that updating the weights Wi+1=Wi+ΔWi is (for certain models and losses) equivalent to updating the labels Yi+1=Yi+ΔYi? Here, Dai et al. do a similar rearrangement to show that updating a linear model by one step of gradient descent is equivalent to one layer of linear attention:
F(x)=(W0+ΔW)x=W0x+ΔWx=W0x+∑i(ei⊗x′⊤i)x=W0x+LinearAttn(E,X′,x)where E=(ei) is a matrix of gradients, X′=(xi) is training data and x is the query point. Following the citation chain, this formulation comes from The Dual Form of Neural Networks Revisited (Irie et al. 2022). That paper uses the formulation to reframe the linear layers in MLPs as attention layers, layers that attend to the gradients produced during training; this lets them inspect which training examples the model is making use of when it makes its prediction, which is a neat trick (if very computationally expensive).
This paper notices that, since linear-model gradient descent can be framed as linear attention, we can run the process backwards. Any attention layer can then be rearranged to look like gradient descent:
Attn(V,K,q)≈(WVX(WKX)⊤+∑iWVx′i⊕(WKx′i)⊤)q.If we label WVX(WKX)⊤=WZSL as "initial parameters" and ∑iWVx′i⊗(WKx′i)⊤=ΔWICL as a "meta-gradient update", now every transformer network is doing meta-gradient descent!
Attn(V,K,q)≈(WZSL+ΔWICL)q.I am unimpressed.
Continuing, they also note that you can frame fine-tuning by one step of gradient descent as a one-step gradient update. They then compare fine-tuning (FT) against In-Context Learning (ICL), but not before making some adjustments "for a more fair comparison":
They then "find that ICL has many properties in common with finetuning":
This is just a list of things they have defined to be the same. The first point is just a restatement of their thesis that all attention models do "meta-gradient descent", and we should treat that like real gradient descent. The second point is vacuous. The third point is specifically addressed by the adjustments to the ordinary fine-tuning setting. The fourth point is addressed by their restriction of fine-tuning to only update the attention key and value matrices. These are tautologies, not novel results.
Am I being unfair here? Maybe they're not trying to present results, just putting a weird amount of emphasis on the steps they took to make their experiments fair. Let's instead look at section 4.4 Results to see what they think their novel contributions are.
Their most compelling result is that the "weight update" terms in the meta-gradient rearrangement of the attention formula tend, in practice, to be more similar to the weight updates produced by their finetuning procedure than they are to random updates. Bonus points for doing these experiments with a GPT, rather than with a weird toy network you expect to generalize to GPT:
The columns to look at are "SimAOU" (similarity between the meta-gradients and the true gradients) and "Random SimAOU" (similarity between the meta-gradients and a random vector). But the problem is that on average a random weight update will make your model worse, and we know that both fine-tuning and in-context learning improve performance. It might just be that weight updates that improve performance are more similar to each other than they are to noise, regardless of the underlying mechanism. See also the SimAM column, where they compute the cosine similarity between the attention maps given by FT and ICL. If two linear models implement the same algorithm, they should have cosine similarity 1! The highest we see here is 0.687.
The rest of the paper's results indicate that whatever the "meta-gradients" are, they're definitely not the same gradients produced by one-step fine-tuning. Here we see that the similarity between the two varies substantially across layers of the network, with some pretty wild error bars:
Dai et al. conclude that
They do not prove anything of the sort. Most importantly, they do not show that language models are mesa-optimizers. Calling attention layers "meta-gradient updates" is like calling a rock in a pipe a utility optimizer and suggesting we should be scared lest it maniacally pursue reducing the flow of Earth's water.
Concluding thoughts
I came into this exercise hoping to find a wealth of evidence that the transformer's secret special sauce is that it's doing gradient descent in activation space. This would be a really pleasing result:
I can see why so many people want this theory to be true, but as far as I can tell the evidence, while suggestive, just doesn't bear out. Of these two papers, I think only van Oswald et al. 2022 is worth your time, but their most impressive results make liberal use of linearity in a way that makes me suspicious that they will generalize to larger, non-linear models. I look forward to reading further research on the topic. Were I to work on this, here are some questions I'd pursue:
People unfamiliar with neural networks might think that the string "(-1, -2.31)" is quite complicated as a token compared to "What" or " is"; surely you would need an infinite number of tokens to represent all pairs of real numbers! Wouldn't most of the network would be devoted to learning the mapping from abstract tokens <token 2352> to pairs of numbers?
If we used the same tokenizer for these models as we do for GPT these would be great intuitions! But most of the work of GPT's tokenization is done in the embedding step, where we map symbolic tokens such as "What" or " is" to high-dimensional real-valued vectors in a "semantically meaningful" space. Only once we have real-valued vectors can we actually run the matrix multiplications that make up the bulk of a neural network. But here we start with a pair of real numbers! So it doesn't make sense to map them onto abstract symbols and then reproject them into a high-dimensional semantic space. We can just pass them straight in as a two-dimensional vector, skipping the embedding step entirely.
Read: attention scales horribly—if it takes one second to predict the next word of a 200 word paragraph, it takes one and a half minutes to predict the next word of a 2000 word essay and almost three hours to predict the next word of a 20000 word novella. This is why LLMs have such small context windows.
In contrast, in the kind of toy problems where we could run full-fat non-linear transformers without running out of memory, they solved the problems easily with great performance.
The word "almost" is doing a lot of work, and van Oswald et al. deserve a lot of credit for actually doing the legwork to demonstrate that the equivalence holds in practice as well as in theory.