Second pass through this post which solidly nerd-sniped me!
A quick summary of my understand of the post: (intentionally being very reductive though I understand the post may make more subtle points).
My thoughts:
Thanks for writing this up! Looking forward to subsequent post/details :)
PS: Is there are non-trivial relationship between this post and tuned lens/logit lens? https://arxiv.org/pdf/2303.08112.pdf Seems possible.
Thank for for the extensive comment! Your summary is really helpful to see how this came across, here's my take on a couple of these points:
2.b: The network would be sneaking information about the size of the residual stream past LayerNorm. So the network wants to implement an sort of "grow by a factor X every layer" and wants to prevent LayerNorm from resetting its progress.
How and why disconnected
Yep we give some evidence for How, but for Why we have only a guess.
still don't feel like I know why though
earn generic "amplification" functions
Yes, all we have is some intuition here. It seems plausible that the model needs to communicate stuff between some layers, but doesn't want this to take up space in the residual stream. So this exponential growth is a neat way to make old information decay away (relatively). And it seems plausible to implement a few amplification circuits for information that has to be preserved for much later in the network.
We would love to see more ideas & hypotheses on why the model might be doing this, as well as attempts to test this! We mainly wrote-up this post because both Alex and I independently noticed this and weren't aware of this previously, so we wanted to make a reference post.
And it seems plausible to implement a few amplification circuits for information that has to be preserved for much later in the network.
Although -- naive speculation -- the deletion-by-magnitude theory could enforce locality in what layers read what information, which seems like it would cut away exponentially many virtual heads? That would be awfully convenient for interpretability. (More trying to gesture at some soft "locality" constraint, rather than make a confident / crisp claim in this comment.)
We would love to see more ideas & hypotheses on why the model might be doing this, as well as attempts to test this! We mainly wrote-up this post because both Alex and I independently noticed this and weren't aware of this previously, so we wanted to make a reference post.
Happy to provide! I think I'm pretty interested in testing this/working on this in the future. Currently a bit tied up but I think (as Alex hints at) there could be some big implications for interpretability here.
TLDR: Documenting existing circuits is good but explaining what relationship circuits have to each other within the model, such as by understanding how the model allocated limited resources such as residual stream and weights between different learnable circuit seems important.
The general topic I think we are getting at is something like "circuit economics". The thing I'm trying to gesture at is that while circuits might deliver value in distinct ways (such as reducing loss on different inputs, activating on distinct patterns), they share capacity in weights (see polysemantic and capacity in neural networks) and I guess "bandwidth" (getting penalized for interfering signals in activations). There are a few reasons why I think this feels like economics which include: scarce resources, value chains (features composed of other features) and competition (if a circuit is predicting something well with one heuristic, maybe there will be smaller gradient updates to encourage another circuit learning a different heuristic to emerge).
So to tie this back to your post and Alex's comment "which seems like it would cut away exponentially many virtual heads? That would be awfully convenient for interpretability.". I think that what interpretability has recently dealt with in elucidating specific circuits is something like "micro-interpretability" and is akin to microeconomics. However this post seems to show a larger trend ie "macro-interpretability" which would possibly affect which of such circuits are possible/likely to be in the final model.
I'll elaborate briefly on the off chance this seems like it might be a useful analogy/framing to motivate further work.
This is very speculative "theory" if you can call it that, but I guess I feel this would be "big if true". I also make no claims about this being super original or actually that useful in practice but it does feel intuition generating. I think this is totally the kind of thing people might have worked on sooner but it's likely been historically hard to measure the kinds of things that might be relevant. What your post shows is that between the transformer circuits framework and TransformerLens we are able to somewhat quickly take a bunch of interesting measurements relatively quickly which may provide more traction on this than previously possible.
More generally "circuit economics" as a framing seems to suggest that there are different types of "goods" in the transformer economy. those which directly lead to better predictions and those which are useful for making better predictions when integrated with other features. The success of Logit Lens seems to suggest that the latter category increases over the course of the layers. Maybe this is the only kind of good in which case transformers would be "fundamentally interpretable" in some sense. All intermediate signals could be interpreted as final products.
Can you say more on this point? The latter kind of good (useful when integrated with other features) doesn't necessarily imply that direct unembed (logit lens) or learned linear unembed (tuned lens iirc) would be able to extract use from such goods. I suspect that I probably just missed your point, though.
I wonder if this is related to vector-packing and unpacking via cosine similarity: the activation norm is increased so layers can select a large & variable number of semi-orthogonal bases. (This is very much related to your information packing idea.)
Easy experimental manipulation to test this would be to increase the number of heads, thereby decreasing the dimensionality of the cos_sim for attention, which should increase the per-layer norm growth. (Alas, this will change the loss too - so not a perfect manipulation)
Summary: For a range of language models and a range of input prompts, the norm of each residual stream grows exponentially over the forward pass, with average per-layer growth rate of about 1.045 in GPT2-XL. We show a bunch of evidence for this. We discuss to what extent different weights and parts of the network are responsible.
We find that some model weights increase exponentially as a function of layer number. We finally note our current favored explanation: Due to LayerNorm, it's hard to cancel out existing residual stream features, but easy to overshadow existing features by just making new features 4.5% larger.
Thanks to Aryan Bhatt, Marius Hobbhahn, Neel Nanda, and Nicky Pochinkov for discussion.
Plots showing exponential norm and variance growth
Our results are reproducible in this Colab.
Alex noticed exponential growth in the contents of GPT-2-XL's residual streams. He ran dozens of prompts through the model, plotted for each layer the distribution of residual stream norms in a histogram, and found exponential growth in the L2 norm of the residual streams:
Here's the norm of each residual stream for a specific prompt:
Stefan had previously noticed this phenomenon in GPT2-small, back in MATS 3.0:
Basic Facts about Language Model Internals also finds a growth in the norms of the attention-out matrices WO and the norms of MLP out matrices Wout ("writing weights"), while they find stable norms for WQ, WK, and Win ("reading weights"):
Comparison of various transformer models
We started our investigation by computing these residual stream norms for a variety of models, recovering Stefan's results (rescaled by √dmodel=√768) and Alex's earlier numbers. We see a number of straight lines in these logarithmic plots, which shows phases of exponential growth.
We are surprised by the decrease in Residual Stream norm in some of the EleutherAI models.[2] We would have expected that, because the transformer blocks can only access the normalized activations, it's hard for the model to "cancel out" a direction in the residual stream. Therefore, the norm always grows. However, this isn't what we see above. One explanation is that the model is able to memorize or predict the LayerNorm scale. If the model does this well enough it can (partially) delete activations and reduce the norm by writing vectors that cancel out previous activations.
The very small models (distillgpt2, gpt2-small) have superexponential norm growth, but most models show exponential growth throughout extended periods. For example, from layer 5 to 41 in GPT2-XL, we see an exponential increase in residual stream norm at a rate of ~1.045 per layer. We showed this trend as an orange line in the above plot, and below we demonstrate the growth for a specific example:
BOS and padding tokens
In our initial tests, we noticed some residual streams showed a irregular and surprising growth curve:
As for the reason behind this shape, we expect that the residual stream (norm) is very predictable at BOS and padding positions. This is because these positions cannot attend to other positions and thus always have the same values (up to positional embedding). Thus it would be no problem for the model to cancel out activations, and our arguments about this being hard do not hold for BOS and padding positions. We don't know whether there is a particular meaning behind this shape.
We suspect that is the source of the U-shape shown in Basic facts about language models during training:
Theories for the source of the growth
From now on we focus on the GPT2-XL case. Here is the residual stream growth curve again (orange dots), but also including the
resid_mid
hook between the two Attention and MLP sub-layers (blue dots).Our first idea upon hearing exponential growth was:
However, we think that this does not work due to LayerNorm (LN). With LayerNorm, the input into the Attention and MLP sub-layers is normalized to have standard deviation 1 and norm √dmodel (neglecting the learned LN parameters, which we discuss in footnote [3]). Despite this, the Attention and MLP sub-layers have output contributions which increase proportionally to the overall residual stream norm that is exponentially increasing.
We can think of two ways to get exponential growth of the residual stream despite LN:
To illustrate the second theory, consider the following toy example where x and y have the same norm, but x contains only one feature (the "alternating" feature) while y contains two features (the "alternating" feature, and the "1st != 3rd number" feature).
x=12⋅⎛⎜ ⎜ ⎜⎝−11−11⎞⎟ ⎟ ⎟⎠andy=12⋅⎛⎜ ⎜ ⎜⎝−1111⎞⎟ ⎟ ⎟⎠Win=100⋅(−11−11−1010), and sigmoid activation function.Then Winx=(2000) and Winy=(100100)which gives approximately (10) and (11) after the sigmoid function.This way, a property (number of features) can be hidden in the inputs (hidden as in, the inputs have identical norms), and affect the norms of the outputs. This works less nicely with ReLU or GELU but the output norms still differ.
To distinguish these two theories, we can test whether we see an exponential increase in the norm of Attention/MLP weights, or alternatively, an exponential increase in the norm of Attention/MLP outputs on random layer-independent inputs. Either of these would mean we don't need theory 2's sneaking features-shenanigans and can explain the exponential growth as being "hard-coded" into the model weights.
Note: It's possible for just one of the sub-layer types (Attention or MLP) to grow exponentially and still cause the overall exponential growth (see appendix 1 for a related proof). But this seems unlikely as the non-exponential sub-layer would lose impact on the residual stream, and we expect the model to make use of both of them. Indeed, plotting the outputs
attn_out
andmlp_out
shows both increasing at the exponential rate (butattn_out
seems to fall off at layer ~30).Analyzing the model weights to understand the behaviour
We want to know why the residual stream norm is growing. Is it some process that naturally creates an exponential increase (maybe features accumulating in the residual stream)—and how would that work? Or are the weights[3] of later layers inherently larger and thus cause larger outputs?
We know that both
attn_out
andmlp_out
grow exponentially. In the next two section we look at the Attention and MLP weights, respectively.TL;DR: We do find evidence for exponentially increasing weights in both sub-layers, although in both cases we are somewhat confused what is happening.
Analyzing the Attention weights
What do want to get evidence on? We want to know why
attn_out
grows exponentially with layer number: Is the growth a property inherent to the Attention weights in each of the layers (theory 1), or is the growth relying on properties of the residual stream (theory 2).What test do we run and why does that give us evidence? We test whether the Attention OV-circuit weights grow exponentially with layer number, at the same rate as the actual Attention outputs
attn_out
. If true, this is evidence for theory 1.The Attention layer output
attn_out
is determined by the QK-circuits (select which inputs to attend to), and the OV-circuits (determine how the inputs are transformed). For the purposes of understanding the overall residual stream growth—why the outputs have larger norm than the inputs—we want to focus on the OV-circuits, which determine how the norm changes from input to output.The OV-circuits consist of the WOV matrices (product of the value WV and output WO matrices) and the bias bO.[4] There are 25 attention heads in GPT2-XL, i.e. 25 WOV matrices. In the figure below we plot the Frobenius norm[5] of the WOV matrices (grey solid lines) and L2 norm of the bO vector (pink line), and compare it to the L2 norm of
attn_out
(blue solid line).The Frobenius norms of the attention heads (grey lines) match the actual
attn_out
norms (blue line) somewhat accurately, and grow exponentially. The bias term (pink line) seems mostly negligible except for in the final layers.What did we find? We find that the Attention weights, specifically the WOV norms, grow approximately exponentially at the rate of
attn_out
. This is evidence for theory 1 because it means that the model bothered to learn weights that increase exponentially with layer number. [6]Caveats: We do not understand the full picture of how
attn_out
is generated, all we notice is that they grow at the same rate.What we would have liked: We show that any normalized random input into Attention layer N leads to an Attention output of the observed norm.
What we got: For some unit-normalized, Gaussian-sampled vector x, consider the sum of the sum of WOV⋅x for all 25 WOV matrices (one for each head). This sum's norm is 5 times larger than the
attn_out
norm, as shown in the figure. [7]Analyzing the MLP weights
What do want to get evidence on? We want to know why
mlp_out
grows exponentially with layer number: Is the growth a property inherent to the MLP weights in each of the layers (theory 1), or is the growth relying on properties of the residual stream (theory 2).What test do we run and why does that give us evidence? We test whether feeding layer-independent inputs to the MLPs produces outputs that do scale exponentially with layer, in a way which follows the exponential growth of
mlp_out
.If this is true, this is evidence for theory 1 and against theory 2. If this is false, we cannot draw strong evidence from this.
We do not attempt to find the right way to combine model weights into a "norm" of the MLP layer. Instead, we draw input vectors from a Normal distribution, and normalize them to mean 0 and variance 1. We feed these vectors into the MLP. [8]
What did we find? We find that the MLP outputs of normalized random Gaussian inputs do scale exponentially with layer numbers, for layers 30 - 43, at the same rate as
mlp_out
. This is evidence for theory 1.Caveats: We do not reproduce the
mlp_out
norms but find a much larger output norm with the random inputs. We discuss this further in an appendix, but the bottom line is that random vectors are indeed qualitatively different from residual stream vectors, and notably random vectors cause 4x more of the GELU activation to be active (>0) than normal residual stream vectors. (On the second theory—do random vectors have "more features", and thus higher norm?)Why an exponential residual stream norm increase might be useful
Transformers might sometimes want to delete information from the residual stream, maybe to make space for new information. However, since all blocks only receive the normalized (LayerNorm) residual stream, it may be impossible to do deletions the intuitive way of "just write −v to the residual stream" to delete a vector v. It might approximately work if the model can predict the LayerNorm scale, but it seems hard to do accurately.
Alternatively, the model could write all new information with an increased norm. An exponential growth would make the most recent layers have an exponentially larger effect on the residual stream at any given layer.
However, this is complicated by weight decay, which is a term in the loss that penalizes large weight magnitudes. While we analyzed GPT2-XL's weights in this post, we also earlier displayed similar residual stream norm trends for a range of models. The OPT and GPT-Neo models were trained with weight decay of 0.1, while the Pythia models were trained with 0.01. We don't know about distilgpt2 or the normal GPT2-series. If models trained with weight decay still exhibit weight norms which increase exponentially with layer number, then that means something is happening which somehow merits an exponential hit to loss.[9]
ETA 5/7/23: Apparently, LN parameters are often excluded from weight decay. (For example, see the minGPT implementation.) This means that the gain parameters can freely magnify the LN output, without incurring extra regularization loss. (However, this also suggests that
W_in
andW_OV
should in general become extremely tiny, up to precision limits. This is because their norm can be folded into the LN parameters in order to avoid regularization penalties.)Conclusion
We documented a basic tendency of transformers: residual stream variance grows exponentially. We think that a big chunk of the exponential increase does come from the model weights, but have not fully understood the underlying mechanics (e.g. GELU activation rates).
Contributions:
Stefan (StefanHex) wrote a lot of the post, noticed this in GPT2-small, compared the phenomenon between models, and did the analysis of activations and weights.
Alex (TurnTrout) wrote some of the post and edited it, noticed the phenomenon in GPT2-XL, made about half of the assets and some of the hooking code for computing residual stream norms. He also wrote appendix 1.
Appendix 1: Attention+MLP contribution norms must exceed block-over-block norm growth rate
Proposition: Attention + MLP norm contributions must exceed the growth rate.
Consider residual streams xi∈Rdmodel for the activation vector just before transformer layer i, in a transformer where the MLP comes after the Attention sublayer. Suppose that, for layer n, |xn||xn−1|=g for growth rate g≥0. Then
g−1≤|Attnn−1(xn−1)|+|MLP(xn−1+Attnn−1(xn−1))||xn−1|.
Proof.
g=|xn||xn−1|=|xn−1+Attnn−1(xn−1)+MLP(xn−1+Attnn−1(xn−1))||xn−1|Transformer block≤|xn−1|+|Attnn−1(xn−1)|+|MLP(xn−1+Attnn−1(xn−1))||xn−1|Triangle inequality=1+|Attnn−1(xn−1)|+|MLP(xn−1+Attnn−1(xn−1))||xn−1|.
Then
g−1≤|Attnn−1(xn−1)|+|MLP(xn−1+Attnn−1(xn−1))||xn−1|.
QED.
For example, if g=1.05, then the norms of the attention and MLP contributions must together be at least 5% of the norm of the
resid_pre
xn−1 for layer n−1.Appendix 2: Explaining the difference between
attn_out
andmlp_out
Remembering the two plots from Theories for the source of the growth, we notice a surprisingly large y-axis difference between the norms. We repeat those norm curves here:
Now we show the same lines again, but switch to using the standard deviation. This is equivalent[1] (norm divided by standard deviation = √D=40) but more intuitive to reason about. We also divide all lines by 1.045N to make the lines fit better into the plot. The difference from
resid_pre
toresid_post
at each layer has to be approximately a factor for 1.045 for the exponential growth to hold.Intuitively, we expected these standard deviations to add up like those of independent (Gaussian) random vectors, σ2a+b=σ2a+σ2b ("error propagation" formula), but this doesn't work. We realized that correlated random vectors can have a higher summed variance, up to a maximum of σ2a+b=(σa+σb)2. It would be interesting to see where in that range
attn_out
andmlp_out
lie, i.e. how correlated the Attention and MLP outputs are with the residual stream input.In both plots we see that the uncorrelated addition of residual stream and sub-layer output (lower end of the range) is much lower that required, providing nowhere near the observed growth for the residual stream. Our (somewhat extreme) upper end of the range is much larger, so if
attn_out
ormlp_out
were perfectly proportional to their input residual stream we would see a much larger growth.This does not directly affect our argument, which relies on just realizing the exponential growth at various points. We shared this since we initially did not take into account the correlation, and found this interesting.
Appendix 3: Which of the MLP weights are the source of the exponential growth?
We showed that the MLP output for random layer-independent inputs grows exponentially with layer number. This proves that there is something inherent to the MLP weights in each layer that causes the output to grow, but it does not show us what that is. The behaviour should be predictable from the MLP weights Win, bin, Wout, and bout. In this section we want to show our investigation into this question, even though we have not completely solved it. This will also explain the large difference in norm between the random-input MLP output, and the actual-model MLP output we showed (figure from above inserted again)
Our first step is to plot the norms of the individual MLP weight components. We are very surprised to not see any exponential increase at the expected rate in any of these norms!
The other important part of an MLP layer is the non-linear activation function, in our case GELU. It turns out that the average neuron activation rate, i.e. the fraction of hidden neurons with pre-activation values >0 rises exponentially throughout the network! This is an essential component to the exponential growth of
resid_out
, and we did not notice this trend in any of the weight matrix norms. Note however that we only observe this exponential growth here from layer 5 til ~20.In the plot below we see that even the neuron activation rate for random inputs (blue line) rises exponentially, so the exponential increase is still inherent to the layer weights, it was just not visible in the norms.
The plot below also explains the difference in L2 norm between actual
mlp_out
and the random outputs (the first plot in this appendix): The neuron activation rate is simply much higher for random inputs (blue line) than in the actual model run (red line), or for randomly-resampled[10] residual stream inputs (orange and green lines). The random vectors clearly differ from actual residual stream vectors in some significant way, but we have not investigated this further.Note on norm, variance, and mean of the residual stream: All our models' residual streams have mean zero. One can always rewrite the model weights to make the residual stream mean zero, by subtracting the mean of weights from the weights themselves. We use the TransformerLens library which does this by default (
Var=E[(x−μ)2]−E[x]2=E[x2]=||x||22/D ,Std=||x||2/√Dcenter_writing_weights
). Then the L2 norm ||x||2 and variance or standard deviation are relatedwith the residual stream size D.
According to the model card, the Pythia models have "exactly the same" architectures as their OPT counterparts.
Note that we fold together the LayerNorm weights with the following Win or WOV weights. So when we show an exponential increase in, say, WOV weights this might actually be fully or partially coming from the LayerNorm weights. It does not make a conceptual difference (the model still stores exponentially increasing weights), but may affect regularization.
That is, after each residual stream is set to mean 0 and std 1, LN applies learned gain parameters. If the residual stream norm can be recovered using these gain parameters, then there are only dmodel such parameters to scale (and thus penalize). But if Win has to amplify the post-LN residual stream, then there are 4⋅d2model parameters which would have to be scaled up by the same amount. This roughly seems like a quadratic increase in the regularization term in the loss function, but this is just a heuristic guess.
ETA 5/7/23: Apparently LN parameters are not, in general, weight-decayed.
Note that the value-bias bV is set to zero in TransformerLens, using another weight-rewrite trick.
According to Stefan's experimental data, the Frobenius norm of a matrix W is equivalent to the expectation value of the L2 vector norm of W⋅x for a random vector x (sampled from normal distribution and normalized to mean 0 and variance 1). So calculating the Frobenius norm seems equivalent to testing the behaviour on random inputs. Maybe this is a theorem?
If GPT-2 is using weight decay, then the model learning exponentially large weights is a strong sign that this exponential scaling is really necessary for something else loss-relevant. Apparently the model is taking an exponential loss hit in order to implement these increasing weight norms.
ETA 5/7/23: Apparently LN parameters are not, in general, weight-decayed.
Possible reasons for this discrepancy: (i) We do not take the attention pattern into account. The attention could give above-average weight to the BOS token whose OV-circuit output may be smaller. (ii) We measured the output norm for a random Gaussian input which may be a bad model for the residual stream.
Seeing the exponential growth (∝
mlp_out
) here would not be necessary but would be sufficient as evidence for theory 1 and against theory 2. This is because random vectors might qualitatively differ from typical residual stream activations and not reproduce the typical behaviour. If they do however reproduce themlp_out
scaling, this is unlikely to be coincidence.Note that we used TransformerLens to test all models, which (by default) does a couple of weight-rewriting tricks (such as
fold_ln
,center_writing_weights
) that do not change the model output, but might affect the regularization.Randomly resampled
resid_mid
activations, taken from positions 1 to 6 to avoid BOS and padding tokens.