Although the residuals for each of the four component matrices (after removing the first two principal components) are both small and seem to be noise, proving that there's no structure that causes the noise to interact constructively when we multiply the matrices and “blow up” is hard.
Have you tried replacing what you believe is noise with actual random noise, with similar statistical properties, and then testing the performance of the resulting model? You may not be able to prove the original model is safe, but you can produce a model that has had all potential structure that you hypothesize is just noise replaced, where you know the noise hypothesis is true.
I believe what you describe is effectively Casual Scrubbing. Edit: Note that it is not exactly the same as causal scrubbing, which picks looks at the activations for another input sampled at random.
On our particular model, doing this replacement shows us that the noise bound in our particular model is actually about 4 standard deviations worse than random, probably because the training procedure (sequences chosen uniformly at random) means we care a lot more about large possible maxes than small ones. (See Appendix H.1.2 for some very sparse details.)
On other toy models we've looked at (modular addition in particular, writeup forthcoming), we have (very) preliminary evidence suggesting that randomizing the noise has a steep drop-off in bound-tightness (as a function of how compact a proof the noise term comes from) in a very similar fashion to what we see with proofs. There seems to be a pretty narrow band of hypotheses for which the noise is structureless but we can't prove it. This is supported by a handful of comments about how causal scrubbing indicates that many existing mech interp hypotheses in fact don't capture enough of the behavior.
We recently released a paper on using mechanistic interpretability to generate compact formal guarantees on model performance. In this companion blog post to our paper, we'll summarize the paper and flesh out some of the motivation and inspiration behind our work.
Paper abstract
Introduction
One hope for interpretability is that as we get AGI, we’ll be able to use increasingly capable automation to accelerate the pace at which we can interpret ever more powerful models. These automatically generated interpretations need to satisfy two criteria:
Progress happens best when there are clear and unambiguous targets and quantitative metrics. For correspondence, the field has developed increasingly targeted metrics for measuring performance: ablations, patching, and causal scrubbing. In our paper, we use mathematical proof to ensure correspondence, and present proof length as the first quantitative measure of explanation compression that is theoretically grounded, objective, and avoids trivial Goodharting.
We see our core contributions in the paper as:
While we believe that the proofs themselves (and in particular our proof which achieves a length that is linear in the number of model parameters for the parts of the model we understand adequately) may be of particular interest to those interested in guarantees, we believe that the insights about explanation compression from this methodology and our results are applicable more broadly to the field of mechanistic interpretability.
Correspondence vs compression
Consider two extremal proof strategies of minimal compression with maximal correspondence and maximal compression with minimal correspondence.
The central obstacle to good explanations is the trade off between making arguments compact, and obtaining sufficiently tight bounds. Broadly, being able to compress behavior shows that we have understanding. The thesis of our paper is that constructing compact proofs with good bounds requires and implies mechanistic understanding.
How to compact a proof
We can think of the brute force proof as treating every input to the model as a distinct case, where exhaustive case analysis gives us the final bound. Constructing a shorter argument requires reasoning over fewer distinct cases.
The naïve way to get shorter arguments is to throw away large parts of the input distribution (for example, only focusing on a distribution made of contrast pairs). In order to avoid throwing away parts of the input distribution, we abstract away details of the specific model implementation from our argument, and then group the inputs and construct a cheap-to-compute proxy for the model’s behavior on all data points in each group.
In order to ensure that the argument is a valid formal proof and that our proxies are valid, we bound the model’s worst-case behavior on each group when using the proxies. If our relaxations lose critical details then our bounds become much less tight, in contrast, if our relaxations abstract away only irrelevant details, then we might be able to still have tight bounds. Mechanistic understanding of model internals allows us to choose better groups and proxies, which in turn leads to shorter proofs with tighter bounds.
Proofs on a toy model
In our paper, we prototype the proofs approach on small transformers trained on a Max-of-K task. We construct three classes of proof strategies, which each use an increasing amount of model understanding and correspondingly are cheaper to check.
The first is the brute force proof, which uses no mechanistic understanding and treats each possible input as its own case.
The second class of proof strategies breaks models into paths (i.e. circuits), and groups the inputs by their max token, the largest non-max token, and the final query token. Besides the path decomposition, these strategies use both the fact that our models generate the correct output by paying attention to the max token in a sequence. The amount of attention paid to a given max token is primarily determined by the identity of the max token together with the largest non-max token and the final query token.
The third class of proof strategies examines each path independently, and groups the inputs for each path by two of the three types of tokens mentioned previously. To get cheaper proxies for model behavior, strategies in this class use knowledge that we might associate with traditional “mech interp”, such as the fact that the QK circuit is approximately rank one, with the principal component measuring the size of the key token.
We then qualitatively and quantitatively examine the connection between mechanistic understanding and proof length/bound tightness:
Reasoning about error in compressing the weights
As we impose tighter demands on proof length, there is a steep drop-off in tightness of bound. This is perhaps even more apparent visually:
The fundamental issue seems to be that our compression is lossy and worst-case error bounds in compression add up quickly. Going from O(d2vocabdmodel) to O(dvocabdmodel) in the skip connection EU costs us about 15%. Going from O(d2vocabdmodel) to O(dvocabdmodel) in the QK attention circuit EQKE costs us about 30%–40% in our best worst-case accuracy bound.
Focusing on the QK circuit, we claim that the EQKE is approximately rank one:
One way to cheaply check this without multiplying out EQKE (which would take O(d2vocabdmodel) time) is to factor out the principal components of each of the component matrices. Although the residuals for each of the four component matrices (after removing the first two principal components) are both small and seem to be noise, proving that there's no structure that causes the noise to interact constructively when we multiply the matrices and “blow up” is hard.
In general, it's often hard to prove an absence of structure, and we believe that this will be a key limitation on scaling proofs, even with a high degree of mechanistic understanding.
See our paper for more details.
Our takeaways
We think that there are two main takeaways from our work:
The work on this project was done by @Jason Gross, @rajashree, @Thomas Kwa, @Euan Ong, Chun Hei Yip, Alex Gibson, Soufiane Noubir, and @LawrenceC. The paper write-up was done by @LawrenceC, @rajashree, and @Jason Gross; @Adrià Garriga-alonso assisted in writing up this blog post.
Citation Info
To reference this work, please cite our paper:
We think that our work serves as an example of how to leverage a downstream task to pick a metric for evaluating mechanistic interpretations. Specifically, formally proving that an explanation captures why the model has a particular behavior can be thought of as a pessimal ablation of the parts the explanation claims are unimportant.[2] That is, if we can replace the unimportant parts of the model with their worst possible values (relative to our performance metric) while maintaining performance, this provides a proof that our model implements the same behavior as in our explanation.
Compare to the zero, mean, or resample ablation, where we replace the unimportant parts of the model with zeros, their mean values, or randomly sampled values from other data points.