I'm having a bit of difficulty understanding the exact task/set up of this post, and so I have a few questions.
Here's a summary of your post as I understand it:
Questions/comments:
Some nitpicks:
Scare quotes are here because their example is really disanalogous to MLP superposition. IE as they point out in their second post, their task is well thought of as naturally being decomposed into two attention heads; and a model that has n >= 2 heads isn't really "placing circuits in superposition" so much as doing a natural task decomposition that they didn't think of.
In fact, it feels like that result is a cautionary tale that just because a model implements an algorithm in a non-basis aligned manner, does not mean the model is implementing an approximate algorithm that requires exploiting near-orthogonality in high-dimensionality space (the traditional kind of residual stream/MLP activation superposition), nor does it mean that the algorithm is "implementing more circuits than is feasible" (i.e. the sense that they try to construct in the May 2023 update). You might just not understand the algorithm the model is implementing!
If I were to speculate more, it seems like they were screwed over by continuing to think about one-layer attention model as a set of skip trigrams, which they are not. More poetically, if your "natural" basis isn't natural, then of course your model won't use your "natural" basis.
Note that this construction isn't optimal, in part because of the fact that output tokens corresponding to the same token occuring twice occur half as often as those with two different tokens, while this construction gets lower log loss in the one-token case as in the two distinct token case. But the qualitative analysis carries through regardless.
Background
This project was inspired by Anthropic’s post on attention head superposition, which constructed a toy model trained to learn a circuit to identify skip-trigrams that are OV-incoherent (attending from multiple destination tokens to a single source token) as a way to ensure that superposition would occur. Since the OV circuit only sees half of the information – the source tokens – the OV circuit of a single head cannot distinguish between multiple possible skip-trigrams. As long as there are more skip-trigrams with the same source-token to represent than heads, the model cannot represent them in the naive way, and may resort to superposition.
In a more recent update post, they found that the underlying algorithm for OV-incoherent skip-trigrams in a simpler 2-head model implemented a conditional on the source token. One head predicts the output for the skip trigram [current token] … [current token] -> [ground truth([0]...[current token])], one of which will yield the right answer. The second head destructively interferes with this result by writing out the negative logit contribution of the first head if the source token is not the one common to all skip-trigrams (in this case, [0]). Because their example cleanly separated tasks between the two attention heads, the authors argued that it was more like the building of high-level features out of low-level ones than a feature superimposed across multiple attention heads.
OV-coherent Superposition
Instead, we claim there is an analogous force pushing the model toward adopting a distributed representation/head superposition whenever the model must learn patterns that require implementing nonlinear functions of multiple source tokens given a fixed destination token. We call this “OV-coherent” superposition: despite the information at the destination position being fixed, the information copied from an attended-to token depends on the information at source tokens to which it is not attending. This pushes the model to form interference patterns between heads attending to different tokens.
To test this, we implemented a 1-layer, attention-only toy model with one-hot (un)embeddings trained to solve a problem requiring attention to multiple source tokens, described below. Here, we focus on a 2-head model which solves the task with perfect accuracy, and lay out some interesting motifs for further investigation.
Key Takeaways:
Heads in our model seem to implement nested conditional statements that exploit the if-else nature of the QK circuits. This allows the model to learn to write more specific information conditional on attending to certain tokens, given that it can implicitly rule out the existence of other tokens elsewhere in the context. The heads furthermore implement these nested conditionals in such a way that they distribute important source tokens between them, and constructively interfere to produce the correct answer.
Most of the time, we found that this “conditional dependence” relies on heads implementing an “all or nothing” approach to attention. Heads do not generally* spread their attention across multiple interesting tokens, but instead move through the hierarchy of features in their QK circuits and attend to the most “interesting” (still a poorly defined term!) one present. This seems to be a common property of attention patterns in real-model heads as well.
When there are multiple important source tokens to attend to in the context, heads implementing interference schema will tend to learn QK circuits such that they distribute tokens amongst themselves and don’t leave crucial information unattended to. In 2-head models, this manifests as reversed “preference orderings” over source tokens, but potentially more complicated arrangements work in larger models.
Problem Details:
The inputs are sequences of integers in the range [0, 11]. They can be broken down into two categories:
Example Model Solution: (d_head = 5, n_heads = 2, d_model = 11, no LayerNorm)
QK-circuit Behavior:
Both heads learn a strict “preference ordering” over signal tokens, with the ordering generally reversed between the two heads. This guarantees that, for any given context, both signal tokens are fully attended to. In the example below, H0 attended solely to “2” if it’s in the context, else “3”, “4”, and finally “1”. H1 instead has the preference ordering “1” else “4”, “3”, and then “2”.
Attention scores from “0” to each signal token for each head
While this “flipped hierarchy” scheme is overwhelmingly more common, the model sometimes learned a “split attention” scheme during training, in which the heads would attend to the same tokens to varying degree. Notably, we only saw split attention for d_head < 3, indicating that this may be a bottleneck issue. However, we should acknowledge that the model may have other ways of solving this simple task than the one outlined above, and indicate that this was by no means an exhaustive analysis.*
OV-circuit Behaviour:
In the OV circuit, heads use the attention hierarchy described above to write to the logits of completions consistent with the tokens it attends to. For example, if H0 attends to “1”, it will positively contribute to the output logits of tokens mapped to by unordered pairs (1,1), (1,2), (1,3), and (1,4).
However, the heads also exploit a “conditional dependence” between source tokens; if H0 attends to token X, it “knows” that the other source position does not contain a token Y higher in its attention hierarchy than X, or else it would be attending there instead. It can then safely not contribute to the logits of the output mapped to by (X, Y).
We can see this clearly in the graph below, which shows the direct logit effect of each head conditional on attending to each signal token:
Logit contribution of heads to completions corresponding to pairs of signal tokens (x axis), conditional on attention to source token (y axis). Cross-referencing with the attention-hierarchy above shows that heads get more specific with their outputs on less-interesting tokens, and for any given pair of inputs they will constructively interfere on the correct answer
When the head attends to its “favourite” token (“2” and “1” respectively), it implicitly has no information about the other position, and so writes roughly uniformly to the logits of all possible completions. But as they run through their preferences, the heads successively write strongly to the logits of one fewer completion. For contexts that contain their “least favourite” token repeated, the heads can confidently guess the correct answer by a process of eliminating options from its own attention hierarchy.
In contexts where both heads write to multiple outputs, they will only constructively interfere on the correct token. For example, when the sequence is “3…4 0”, each head will write to multiple completions corresponding to (3,...) for one head and (4, …) for the other. However, only the correct answer – the completion corresponding to (3, 4)) – will receive positive logit contributions from both heads.
Observations
We think this shows an interesting example of heads performing computation in superposition. Moreover, the incentive pushing the model towards interference is qualitatively very different from the OV-incoherence explored by Anthropic. Instead of needing to copy different information from a fixed source token to a variable destination token, our problem imposes the constraint that although the destination token is fixed, the information that a head would need to copy from a source token depends on information elsewhere in the context.
For n_heads = 2, the “mutual reverse ordering” between the heads is an elegant way for the model to ensure that, no matter which signal tokens show up in the context, each will be attended to.* It is plausible that real world models implement this mechanism (albeit much less crisply) to distribute important features across heads.
This conditional dependence seems to us an interesting way to view the relationship between the QK and OV circuits. OV circuits can extract more information from individual tokens by learning to exploit the distribution of features of their QK circuits. One way of interpreting this might be that heads can implicitly copy information from multiple tokens even when attending to a single token. In other words: the heads can learn to write information as a function of the broader context by exploiting the fact that conditioning on their attention to a token gives information about the distribution of tokens in the entire sequence.
Future Work - Toy Models
Sparsity:
First, we did not vary sparsity or importance of signal tokens in these experiments. We have a pretty poor idea of how these variables affect the behaviours we observed here, so this seems something worth looking into.
*The Split attention model, d_head bottlenecks, and the surprising resourcefulness of networks:
We presented our findings for d_head =5 because the learned algorithm is more human-parsible, but we were surprised by how limited we could make this model while achieving perfect accuracy on a signal-only test set. We were particularly surprised by the model’s capability for d_head = 1, since this essentially limits each head to a scalar degree of freedom for each token. The split attention mechanism we stumbled upon for d_head < 3 are much harder to parse, and may rely on more complicated context-specific solutions. However, the fact that we haven’t seen these for larger d_head supports the idea that, for d_head >=3, the signal tokens can be stored orthogonally in that dimension.
We were also surprised that this problem can be solved with one head, as long as d_head >= 4. Intuitively, once a head has enough dimensions to store every "interesting" token orthogonally, its OV circuit can simply learn to map each of these basis vectors to the corresponding completions. Possibly there is a trade-off here between degrees of freedom in d_head and n_heads, though this is not super crisp. There is also the possibility that superposition is happening both in the n_head dimension and per head in the d_head dimension. In contrast with the “hierarchical” superposition presented above, we can call the latter type “bottleneck superposition”. While this complicates the picture substantially, it also opens up some pretty interesting possibilities and is something we’d like to investigate more thoroughly!
Future Work - LLMs
We initially set out to investigate attention head superposition “in the wild” as part of Neel Nanda’s SERI MATS stream, by studying possible linear combinations of name mover heads in Redwood Research’s IOI task. At first, this seemed like a good way to study a realistic (but not too realistic) problem that already had a circuit of attention heads in place. However, we quickly realized that this task represents a strange middle ground in terms of complexity: simple enough to want to represent a single feature (name moving), but complicated enough to have more features hidden within it. While even the toy model seems to have opened up some puzzling new directions, we also think it would be worthwhile to look at head superposition in real-world LLMs.
For example, it could be interesting to look for cases of inverted preference orderings in the wild by hunting for pairs of QK circuits that exhibit similar behavior. This will likely be messy, and will likely require a better understanding of the relationship between the two types of superposition mentioned above.