[Lucius] Identify better SAE sparsity penalties by reasoning about the distribution of feature activations
- In sparse coding, one can derive what prior over encoded variables a particular sparsity penalty corresponds to. E.g. an L1 penalty assumes a Laplacian prior over feature activations, while a log(1+a^2) would assume a Cauchy prior. Can we figure out what distribution of feature activations over the data we’d expect, and use this to derive a better sparsity penalty that improves SAE quality?
This is very interesting! What prior does log(1+|a|) correspond to? And what about using instead of ? Does this only hold if we expect feature activations to be independent (rather than, say, mutually exclusive)?
[Nix] Toy model of feature splitting
- There are at least two explanations for feature splitting I find plausible:
- Activations exist in higher dimensional manifolds in feature space, feature splitting is a symptom of one higher dimensional mostly-continuous feature being chunked into discrete features at different resolutions.
- There is a finite number of highly-related discrete features that activate on similar (but not identical) inputs and cause similar (but not identical) output actions. These can be summarized as a single feature with reasonable explained variance, but is better summarized as a collection of “split” features.
These do not sound like different explanations to me. In particular, the distinction between "mostly-continuous but approximated as discrete" and "discrete but very similar" seems ill-formed. All features are in fact discrete (because floating point numbers are discrete) and approximately continuous (because we posit that replacing floats with reals won't change the behavior of the network meaningfully).
As far as toy models go, I'm pretty confident that the max-of-K setup from Compact Proofs of Model Performance via Mechanistic Interpretability will be a decent toy model. If you train SAEs post-unembed (probably also pre-unembed) with width d_vocab, you should find one feature for each sequence maximum (roughly). If you train with SAE width , I expect each feature to split into roughly features corresponding to the choice of query token, largest non-max token, and the number of copies of the maximum token. (How the SAE training data is distributed will change what exact features (principal directions of variation) are important to learn.). I'm quite interested in chatting with anyone working on / interested in this, and I expect my MATS scholar will get to testing this within the next month or two.
Edit: I expect this toy model will also permit exploring:
[Lee] Is there structure in feature splitting?
- Suppose we have a trained SAE with N features. If we apply e.g. NMF or SAEs to these directions are there directions that explain the structure of the splitting? As in, suppose we have a feature for math and a feature for physics. And suppose these split into (among other things)
- 'topology in a math context'
- 'topology in a physics context'
- 'high dimensions in a math context'
- 'high dimensions in a physics context'
- Is the topology-ifying direction the same for both features? Is the high-dimensionifying direction the same for both features? And if so, why did/didn't the original SAEs find these directions?
I predict that whether or not the SAE finds the splitting directions depends on details about how much non-sparsity is penalized and how wide the SAE is. Given enough capacity, the SAE benefits (sparsity-wise) from replacing the (topology, math, physics) features with (topology-in-math, topology-in-physics), because split features activate more sparsely. Conversely, if the sparsity penalty is strong enough and there is not enough capacity to split, the loss recovered from having a topology feature at all (on top of the math/physics feature) may not outweigh the cost in sparsity.
Progress Measures for Grokking via Mechanistic Interpretability (Neel Nanda et al) - nothing important in mech interp has properly built on this IMO, but there's just a ton of gorgeous results in there. I think it's the most (only?) truly rigorous reverse-engineering work out there
Totally agree that this has gorgeous results, and this is what got me into mech interp in the first place! Re "most (only?) truly rigorous reverse-engineering work out there": I think the clock and pizza paper seems comparably rigorous, and there's also my recent Compact Proofs of Model Performance via Mechanistic Interpretability (and Gabe's heuristic analysis of the same Max-of-K model), and the work one of my MARS scholars did showing that some pizza models use a ReLU to compute numerical integration, which is the first nontrivial mechanistic explanation of a nonlinearity found in a trained model (nontrivial in the sense that it asymptotically compresses the brute-force input-output behavior with a (provably) non-vacuous bound).
I believe what you describe is effectively Casual Scrubbing. Edit: Note that it is not exactly the same as causal scrubbing, which picks looks at the activations for another input sampled at random.
On our particular model, doing this replacement shows us that the noise bound in our particular model is actually about 4 standard deviations worse than random, probably because the training procedure (sequences chosen uniformly at random) means we care a lot more about large possible maxes than small ones. (See Appendix H.1.2 for some very sparse details.)
On other toy models we've looked at (modular addition in particular, writeup forthcoming), we have (very) preliminary evidence suggesting that randomizing the noise has a steep drop-off in bound-tightness (as a function of how compact a proof the noise term comes from) in a very similar fashion to what we see with proofs. There seems to be a pretty narrow band of hypotheses for which the noise is structureless but we can't prove it. This is supported by a handful of comments about how causal scrubbing indicates that many existing mech interp hypotheses in fact don't capture enough of the behavior.
We propose a simple fix: Use instead of , which seems to be a Pareto improvement over (at least in some real models, though results might be mixed) in terms of the number of features required to achieve a given reconstruction error.
When I was discussing better sparsity penalties with Lawrence, and the fact that I observed some instability in in toy models of super-position, he pointed out that the gradient of norm explodes near zero, meaning that features with "small errors" that cause them to have very small but non-zero overlap with some activations might be killed off entirely rather than merely having the overlap penalized.
See here for some brief write-up and animations.
"explanation of (network, dataset)": I'm afraid I don't have a great formalish definition beyond just pointing at the intuitive notion.
What's wrong with "proof" as a formal definition of explanation (of behavior of a network on a dataset)? I claim that description length works pretty well on "formal proof", I'm in the process of producing a write-up on results exploring this.
I believe the closest research to this topic is under the heading "Performative Power" (cf, e.g., this arXiv paper). I think "The Age of Surveillance Capitalism: The Fight for a Human Future at the New Frontier of Power" by Shoshana Zuboff is also a pretty good book that seems related.