Cool work! Do you/any others have plans to train and open source some sparse autoencoders on some open source LLMs? (Eg smaller Pythia or GPT-2 models). Seems like a cool resource to exist, that might help enable some of this work.
Yep! We are planning to do exactly that for (at least) the models we focus on in the paper (Pythia-70m + Pythia-410m), and probably also GPT2 small. We are also working on cleaning up our codebase (https://github.com/HoagyC/sparse_coding) and implementing some easy dictionary training solutions.
Awesome! On the residual stream, or also on the MLP/attention outputs? I think both would be great if you have the resources, I expect there's a lot of interest in both and how they interact. (IMO the Anthropic paper training on MLP activations is equivalent to training it on the MLP layer output, just with 4x the parameters). Ideally if you're doing it on attn_out you could instead do it on the mixed value (z in TransformerLens) which has the same dimensionality, but makes it super clear which head the dictionary is looking at, and is robust to head superposition.
Having a clean codebase also seems like a really useful resource, esp if you implement some of the tricks from the Anthropic paper, like neuron resampling. Looking forwards to it!
I actually do have some publicly hosted, only on residual stream and some simple training code.
I'm wanting to integrate some basic visualizations (and include Antrhopic's tricks) before making a public post on it, but currently:
Which can be downloaded & interpreted with this notebook
With easy training code for bespoke models here.
Some more suggestions of things to look for:
I'd love to get involved, I'll hit you up on the Discord channel you mention.
Mostly my own writing, except for the 'Better Training Methods' section which was written by @Aidan Ewart.
We made a lot of progress in 4 months working on Sparse Autoencoders, an unsupervised method to scalably find monosemantic features in LLMs, but there's still plenty of work to do. Below I (Logan) give both research ideas, as well as my current, half-baked thoughts on how to pursue them.
Find All the Circuits!
Feature Search
There are three ways to find features AFAIK:
1. Which input tokens activate it?
2. What output logits are causally downstream from it?
3. Which intermediate features cause it/are caused by it?
1) Input Tokens
When finding the input tokens, you may run into outlier dimensions that activate highly for most tokens (predominately the first token), so you need to account for that.
2) Output Logits
For output logits, if you have a dataset task (e.g. predicting stereotypical gender), you can remove each feature one at a time, and sort by greatest effect. This also extends to substituting features between two distributions and finding the smallest substitution to go from one to the other. For example,
Suppose at token Jane, it activates 2 Features A & B [1,1,0] and Dave activates 2 features B & C [0,1,1]. Then we can see what is the smallest substitution between the two that makes Jane complete as " male". If A is the "female" feature, then ablating it (setting it to zero) will make the model set male/female to equal probability. Adding the female feature to Dave and subtracting the male direction should make Dave complete as "female".[1]
3) Intermediate Features
Say we're looking at layer 5, feature 783, which activates ~10 for 20 datapoints on average. We can ablate each feature in layer 4, one at a time, and see which feature made those 20 datapoint's activation go down the most. This generally resulted in features that make a lot of sense e.g. Feature "acronyms after (", is effected when you ablate the previous layer's feature for acronyms & the one for "(". Other times, it's generally the same feature, since this is the residual stream[2]
This can be extended to dictionaries trained on the output of MLP & Attention layers. Additionally, one could do a weight-based approach going from the residual stream to the MLP layer, which may allow predicting beforehand what a feature is by just the weights e.g. "This feature is just 0.5*(acronyms features) + 2.3*(open parentheses).
Prompt Feature Diff
If I want to understand the effect of few-shot prompts, I can take the 0-shot prompt:
"The initials of Logan Riggs are", and see which features activate for those ~6 tokens. Then add in few-shot prompts before, and see the different features that activate for those ~6 tokens. In general, this can be applied to:
Feature diff between Features in [prompt] & Features in [prompt] given [Pre-prompt]
With examples being:
[few-shot prompts/Chain-of-thought/adversarial prompts/soft prompts][prompt]
(though I don't know how to extend this to appending "Let's think step-by-step")
Useful related work is Causal Scrubbing.
ACDC
Automatic Circuit DisCovery (ACDC) is a simple technique: to find what's relevant for X, just remove everything upstream of it one at a time and see what breaks. Then recursively apply it. We do a similar technique in our paper, but only on the residual stream. Dictionaries (the decoder part of autoencoders) can also be trained on the output of MLP & attention units. We've in fact done it before and it appears quite interpretable!
We can apply this technique to connect features found in the residual stream to the MLP & attn units. Ideally, we could do a more weight-based method, such as connecting the features learned in the residual stream to the MLP. This may straightforwardly work going from the residual stream to the MLP_out dictionary. If not, it may work with dictionaries trained on the neurons of MLP (ie the activations post non-linearity).
For attention units, I have a half-baked thought of connecting residual stream directions at one layer w/ another layer (or Attn_out) using the QK & OV circuits for a given attention head, but haven't thought very much about this.
Better Sparse Autoencoders
I think we are quite close to finding all the features for one layer in GPT2 small. Perfecting this will help find more accurate and predictive circuits. This includes driving reconstruction & perplexity-difference down, better training methods, and better, less-Goodhart-able interpretability methods.
Reconstruction, Sparsity, & Perplexity-Diff
Reconstruction loss - How well the autoencoder reconstructs e.g. Layer 6 of the model.
Sparsity- How many features/datapoint on average? (ie L0 norm on latent activation)
Perplexity-diff - When you run the LLM on a dataset, you get some prediction loss (which can be converted to perplexity). You can then run the LLM on the same dataset, but replace e.g. Layer 6 w/ the autoencoder, and get a different prediction loss. Subtract. If these are 0, then this is strong evidence for the autoencoder being functionally equivalent to the original model.
Typically, we plot unexplained variance (ie reconstruction loss that takes into account variance) vs sparsity.
, where we would want solutions in the bottom-left corner: perfectly explaining the data w/ minimal sparsity (features/datapoints). We have seen evidence (by hand and GPT-autointerp) that sparser solutions are more monosemantic. Until we have better interp methods, driving down these 3 metrics are a useful proxy.
One effective method, not written in our paper, is directly optimizing for minimal KL-divergence in addition to reconstruction & sparsity. This has driven perplexity-difference down, for similar sparsity, at the cost of some reconstruction loss.
Better Training Methods
In their work, Yun et al. use an iterative method using FISTA to find sparse codes for activations, and optimising the dictionary to lower MSE with respect to those codes. We used autoencoders as we think it better reflects what the model might be computing, but it is possible that methods like the one Yun et al use will result in a better dictionary.
Possible options here include using Yun et al's method, pre-training a dictionary as an autoencoder and further optimising using FISTA, or simply using FISTA with a pre-trained dictionary to reduce MSE.
We could also find different methods of decomposing activations, using nonlinear autoencoders or VAEs with sparse priors. This is a very interesting line of work which might result in a better understanding of how transformers can represent information nonlinearly. We've faced convergence issues trying to train more powerful decompositional tools (both linear & not), but these can be helped by using softplus activations during training. Also, it seems that the link between sparsity and monosemanticity might break down very quickly as you apply more and more complex methods, and perhaps there is an alternative form of regularisation (instead of sparsity) which would work better for stronger autoencoders.
Better Interp Methods
How do we know we found good features? We can't just say 0-reconstruction loss & 0 perplexity-diff, because the original model itself achieves that! (plus the identity function) That's why we have sparsity, but is 20 features/datapoint better than 60 features/datapoint? How does this scale as you scale model size or layers?
It'd be good to have a clean, objective measure of interpretability. You could do a subjective measure of 10 randomly selected features, but that's noisy!
I have some preliminary work on making a monsemanticity measure I can share shortly, but no good results yet!
Our previous proxies for "right hyperparams for feature goodness" have been from toy models, specifically MMCS (mean max cosine similarity) ie how similar features between two dictionaries are (if two dictionaries learned similar features, then these are "realer" features...maybe), and dead features. Check the toy model results for more details, both Lee's original work & update and our open sourced replication.
Come Work With Us
We are currently discussing research in the #unsupervised-interp channel (under Interpretabilty) in the EleutherAI Discord server. If you're a researcher and have directions you'd like to apply sparse auteoncoders to, feel free to message me on Discord (loganriggs) or LW & we can chat!
Now that I write it though, I think you could just find the features that make both distributions "neutral", and just add those directions.
One can verify this by checking the cosine similarity between two features at different layers. If they have high cosine sim, then they're pointing in very similar directions and will be decoded by future layers/unembedded in the same way.