Hey Joseph (and coauthors),
Your directions are really fantastic. I hope you don't mind, but I generated the activation data for the first 3000+ directions for each of the 12 layers and uploaded your directions to Neuronpedia:
https://www.neuronpedia.org/gpt2-small/res-jb
Your directions are also linked on the home page and the model page.
They're also accessible by layer (sorted by top activation), eg layer 6: https://neuronpedia.org/gpt2-small/6-res-jb
I added the "Anthropic dashboard" to Neuronpedia for your dataset.
Explanations, comments, and autointerp scoring are also working - anyone can do this:
I plan to do some autointerp explaining on a batch of these directions too.
Btw - your directions are so good that it's easy to find super interesting stuff. 5-RES-JB:5 is about astronomy:
I'm aware that you're going to do some library updates to get even better directions, and I'm excited for that - will re-generate/upload all layers after the new changes come in.
Things that I'm still working on and hope to get working in the next few days:
Again, your directions look fantastic - congrats. I hope this is useful/interesting for you and anyone trying to browse/explain them. Also, I didn't know how to provide a citation/reference to you (and your team?) so I just used RES-JB = Residuals by Joseph Bloom and included links to all relevant sources on your directions page.
If there's anything you'd like me to modify about this, or any feature you'd like me to add to make it better, please do not hesitate to let me know.
I tried replicating your statistics using my own evaluation code (in evaluation.py here). I pseudo-randomly chose layer 1 and layer 7. Sadly, my results look rather different from yours:
Layer | MSE Loss | % Variance Explained | L1 | L0 | % Alive | CE Reconstructed |
---|---|---|---|---|---|---|
1 | 0.11 | 92 | 44 | 17.5 | 54 | 5.95 |
7 | 1.1 | 82 | 137 | 65.4 | 95 | 4.29 |
Places where our metrics agree: L1 and L0.
Places where our metrics disagree, but probably for a relatively benign reason:
Our metrics disagree strongly on CE reconstructed, and this is a bit alarming. It means that either you have a bug which significantly underestimates reconstructed CE loss, or I have a bug which significantly overestimates it. I think I'm 50/50 on which it is. Note that according to my stats, your MSE loss is kinda bad, which would suggest that you should also have high CE reconstructed (especially when working with residual stream dictionaries! (in contrast to e.g. MLP dictionaries which are much more forgiving)).
Spitballing a possible cause: when computing CE loss, did you exclude padding tokens? If not, then it's possible that many of the tokens on which you're computing CE are padding tokens, which is artificially making your CE look extremely good.
Here is my code. You'll need to pip install nnsight
before running it. Many thanks to Caden Juang for implementing the UnifiedTransformer functionality in nnsight, which is a crazy Frankenstein marriage of nnsight and transformer_lens; it would have been very hard for me to attempt this replication without this feature.
Oh no. I'll look into this and get back to you shortly. One obvious candidate is that I was reporting CE for some batch at the end of training that was very small and so the statistics likely had high variance and the last datapoint may have been fairly low. In retrospect I should have explicitly recalculated this again post training. However, I'll take a deeper dive now to see what's up.
I've run some of the SAE's through more thorough eval code this morning (getting variance explained with the centring and calculating mean CE losses with more batches). As far as I can tell the CE loss is not that high at all and the MSE loss is quite low. I'm wondering whether you might be using the wrong hooks? These are resid_pre so layer 0 is just the embeddings and layer 1 is after the first transformer block and so on. One other possibility is that you are using a different dataset? I trained these SAEs on OpenWebText. I don't much padding at all, that might be a big difference too. I'm curious to get to the bottom of this.
One sanity check I've done is just sampling from the model when using the SAE to reconstruct activations and it seems to be about as good, which I think rules out CE loss in the ranges you quote above.
For percent alive neurons a batch size of 8192 would be far too few to estimate dead neurons (since many neurons have a feature sparsity < 10**-3.
You're absolutely right about missing the centreing in percent variance explained. I've estimated variance explained again for the same layers and get very similar results to what I had originally. I'll make some updates to my code to produce CE score metrics that have less variance in the future at the cost of slightly more train time.
If we don't find a simple answer I'm happy to run some more experiments but I'd guess an 80% probability that there's a simple bug which would explain the difference in what you get. Rank order of most likely: Using the wrong activations, using datapoints with lots of padding, using a different dataset (I tried the pile and it wasn't that bad either).
In the notebook I link in my original comment, I check that the activations I get out of nnsight are the same as the activations that come from transformer_lens. Together with the fact that our sparsity statistics broadly align, I'm guessing that the issue isn't that I'm extracting different activations than you are.
Repeating my replication attempt with data from OpenWebText, I get this:
Layer | MSE Loss | % Variance Explained | L1 | L0 | % Alive | CE Reconstructed |
---|---|---|---|---|---|---|
1 | 0.069 | 95 | 40 | 15 | 46 | 6.45 |
7 | 0.81 | 86 | 125 | 59.2 | 96 | 4.38 |
Broadly speaking, same story as above, except that the MSE losses look better (still not great), and that the CE reconstructed looks very bad for layer 1.
I don't much padding at all, that might be a big difference too.
Seems like there was a typo here -- what do you mean?
Logan Riggs reports that he tried to replicate your results and got something more similar to you. I think Logan is making decisions about padding and tokenization more like the decisions you make, so it's possible that the difference is down to something around padding and tokenization.
Possible next steps:
Another sanity check: when you compute CE loss using the same code that you use when computing CE loss when activations are reconstructed by the autoencoders, but instead of actually using the autoencoder you just plug the correct activations back in, do you get the same answer (~3.3) as when you evaluate CE loss normally?
I'd be pretty interested in knowing if my SAEs seem good now based on your evals :) Hopefully this was the only issue.
My SAEs also have a tied decoder bias which is subtracted from the original activations. Here's the relevant code in dictionary.py
def encode(self, x):
return nn.ReLU()(self.encoder(x - self.bias))
def decode(self, f):
return self.decoder(f) + self.bias
def forward(self, x, output_features=False, ghost_mask=None):
[...]
f = self.encode(x)
x_hat = self.decode(f)
[...]
return x_hat
Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I'm a bit confused why subtracting off the decoder bias had to be done explicitly in your code -- maybe you used dictionary.encoder
and dictionary.decoder
instead of dictionary.encode
and dictionary.decode
? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis "one of us needs to shift our inputs by +/- the decoder bias" only made things worse, so I'm pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it.
I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.
Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE's and it only worked for some / not others and certainly didn't seem like a good explanation after I'd run it on more than one.
I've been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstruction performance to to me but that outside of the context length of 128 tokens, performance markedly decreases. He also mentioned your eval code uses very long prompts whereas mine limits to 128 tokens so this may be the main cause of the difference. Logan mentioned you had discussed this with him so I'm guessing you've got more details on this than I have? I'll build some evals specifically to look at this in the future I think.
Scientifically, I am fairly surprised about the token length effect and want to try training on activations from much longer context sizes now. I have noticed (anecdotally) that the number of features I get sometimes increases over the prompt so an SAE trained on activations from shorter prompts are plausibly going to have a much easier time balancing reconstruction and sparsity, which might explain the generally lower MSE / higher reconstruction. Though we shouldn't really compare between models and with different levels of sparsity as we're likely to be at different locations on the pareto frontier.
One final note is that I'm excited to see whether performance on the first 128 tokens actually improves in SAEs trained on activations from > 128 token forward passes (since maybe the SAE becomes better in general).
Yep, as you say, @Logan Riggs figured out what's going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I'm able to replicate your results.
Here's Logan's plot for one of your dictionaries (not sure which)
and here's my replication of Logan's plot for your layer 1 dictionary
Interestingly, this does not happen for my dictionaries! Here's the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped
(Note that all three plots have a different y-axis scale.)
Why the difference? I'm not really sure. Two guesses:
In terms of standardization of which metrics to report, I'm torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they're performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
Why do you scale your MSE by 1/(x_centred**2).sum(dim=-1, keepdim=True).sqrt()
? In particular, I'm confused about why you have the square root. Shouldn't it just be 1/(x_centred**2).sum(dim=-1, keepdim=True)
?
Browse these SAE Features on Neuronpedia!
Update 1: Since we posted this last night, someone pointed out that our implementation of ghost grads has a non-trivial error (which makes the results a-priori quite surprising). We computed the ghost grad forward pass using Exp(Relu(W_enc(x)[dead_neuron_mask])) rather than Exp((W_enc(x)[dead_neuron_mask])). I'm running some ablation experiments now to get to the bottom of this.
Update 2: I've since investigated this further and run some ablation studies with the following results. Ghost grads weren't working as intended due to the Exp(Relu(x)) bug but the resulting SAE's were still quite good (later layers had few dead neurons simply because when we dropped the number of features, you get less dead neurons. I've found that with a correct ghost grads implementation, you can get less dead neurons and will update the library shortly. (I will make edits to the rest of this post to reflect my current views). Sorry for the confusion.
This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, under mentorship from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants.
This is intended to be a fairly informal post sharing a set of Sparse Autoencoders trained on the residual stream of GPT2-small which achieve fairly good reconstruction performance and contain fairly sparse / interpretable features. More importantly, advice from Anthropic and community members has enabled us to train these fairly more efficiently / faster than before. The specific methods that were most useful were: ghost gradients, learning rate warmup, and initializing the decoder bias with the geometric median. We discuss each of these in more detail below.
5 Minute Summary
We’re publishing a set of 12 Sparse AutoEncoders for the GPT2 Small residual stream.
Readers can access the Sparse Autoencoder weights in this HuggingFace Repo. Training code and code for loading the weights / model and data loaders can be found in this Github Repository. Training curves and feature dashboards can also be found in this wandb report. Users can download all 25k feature dashboards generated for layer 2 and 10 SAEs and the first 5000 of the layer 5 SAE features here (note the left hand of column of the dashboards should currently be ignored).
Layer
Variance Explained
L1 loss
L0*
% Alive Features
Reconstruction
CE Log Loss
0
99.15%
4.58
12.24
80.0%
3.32
1
98.37%
41.04
14.68
83.4%
3.33
2
98.07%
51.88
18.80
80.0%
3.37
3
96.97%
74.96
25.75
86.3%
3.48
4
95.77%
90.23
33.14
97.7%
3.44
5
94.90%
108.59
43.61
99.7%
3.45
6
93.90%
136.07
49.68
100%
3.44
7
93.08%
138.05
57.29
100%
3.45
8
92.57%
167.35
65.47
100%
3.45
9
92.05%
198.42
71.10
100%
3.45
10
91.12%
215.11
53.79
100%
3.52
11
93.30%
270.13
59.16
100%
3.57
Original Model
3.3
Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.
Training SAEs that we were happy with used to take much longer than it is taking us now. Last week, it took me 20 hours to train a 50k feature SAE on 1 billion tokens and over the weekend it took 3 hours for us to train 25k SAE on 300M tokens with similar variance explained, L0 and CE loss recovered.
We attribute the improvement to having implemented various pieces of advice that have made our lives a lot easier:
While we haven’t tested our code extensively since implementing these improvements, we suspect that hyperparameter tuning may be easier in the future since these method improvements make the process generally less sensitive.
To demonstrate the interpretability of these SAEs, we share screenshots of feature dashboards we produced using a reproduction of Anthropic’s dashboard developed by Callum McDougall.
Finally, we end by discussing how readers can access these SAE’s, some experiments which they could perform to upskill with SAEs and possible research directions to pursue.
Introduction
What are Sparse AutoEncoders and why should we care about them?
Sparse autoencoders (SAEs) are an unsupervised technique to take a model's activations and decompose it into interpretable feature vectors. We highly recommend this tutorial on SAEs for those interested. Recent papers on the topic can be found here and here.
I’m particularly excited about Sparse Autoencoders for two reasons:
General Advice for Training SAEs
Why can training Sparse AutoEncoders be difficult?
Sparse autoencoders are an unsupervised method which attempts to trade off reconstruction accuracy against interpretability, which we achieve by inducing activation sparsity. Since we don’t have good metrics for interpretability / reconstruction quality, it’s hard to know when we are actually optimizing what we care about. On top of this, we’re trying to pick a good point on the pareto frontier between interpretability and reconstruction quality which is a hard thing to assess well.
The main objective is to have your Sparse Autoencoder learn a population of sparse features (which are likely to be interpretable) without having some dense features (features which activate all the time and are likely uninterpretable) or too many dead features (features which never fire). As discussed in the 5 minute summary, we went from training GPT2 small residual streams in 12+ hours to ~ 3 hours (so 4x faster / cheaper). Though the L0 and CE loss were somewhat similar, the feature density histograms also suggested we avoided dead / dense features way more effectively after using ghost gradients.
Let’s dig into the challenges associated with dead / dense features a bit more a bit more:
Which tricks help the most?
In terms of solutions, Anthropic published useful advice (especially ghost gradients) and the research community is building consensus on how to train SAEs well (what your loss curves and final statistics should look like for example). I found Arthur Conmy’s post very useful.
The top 3 changes that I made which led to my ability to train these SAEs cheaply/quickly were:
Sparse AutoEncoders for the GPT2 Residual Stream
Why GPT2 small? Why the residual stream?
GPT2 small has been extensively studied by the mechanistic interpretability community and whilst not an incredibly performant model, it certainly has some kind of “prototypical object of study” property. We chose the residual stream because this enables us to analyze “the total sum of previous output” in a manner not dissimilar to the logit lens approach. This may be useful for understanding how features are constructed from earlier features as well as studying how the distribution of features over time changes in a model.
Architecture and Hyperparameters
We trained 12 Sparse Autoencoders on the Residual Stream of GPT2-small.
Were I to be training SAEs on another model or part of the same model, I wouldn’t change any of the architectural choices (except maybe expansion factor). Other parameters like learning rate, l1 coefficient, number of tokens to train on all likely need to be tuned in practice. It also seems plausible we’ll continue to see methodological advances in the future which I’m excited about!
What do we think about when choosing hyper-parameters / evaluating SAEs?
We train against:
However, what we actually care about is whether we reconstructed information required for model performance and how useful the features are for interpretability. Better proxies for these desiderata are:
Though L0 is a pretty good proxy for interpretability, in practice the feature density histogram (the distribution of how frequent features are) turns out to be one of the most important things we need to get right when tuning hyperparameters.
How good are these Sparse AutoEncoders?
At a glance, the summary metrics seem fairly good. I’ll make a number of comments:
Georg Lang pointed out to me that the L2 loss grows quadratically with the norm which increases with layer whilst the L1 coefficient grows linearly. This means that since I didn’t vary the L1 coefficient when training these, we’re effectively pushing less hard for sparsity in later layers (which would explain the trend in L0 / L1 and the feature density histograms). Interestingly, the variance explained still gets worse with layers.
Layer
Variance Explained
L1 loss
L0*
% Alive Features
Reconstruction
CE Log Loss
0
99.15%
4.58
12.24
80.0%
3.32
1
98.37%
41.04
14.68
83.4%
3.33
2
98.07%
51.88
18.80
80.0%
3.37
3
96.97%
74.96
25.75
86.3%
3.48
4
95.77%
90.23
33.14
97.7%
3.44
5
94.90%
108.59
43.61
99.7%
3.45
6
93.90%
136.07
49.68
100%
3.44
7
93.08%
138.05
57.29
100%
3.45
8
92.57%
167.35
65.47
100%
3.45
9
92.05%
198.42
71.10
100%
3.45
10
91.12%
215.11
53.79
100%
3.52
11
93.30%
270.13
59.16
100%
3.57
Original Model
3.3
Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.
Log Feature Sparsity Histograms for each the residual stream SAEs of GPT2-small
How interpretable are the features in each layer?
Feature interpretability is far from a settled science, but feature dashboards sure do automate a huge chunk of the work. We use a reproduction of Anthropic’s dashboard developed by Callum McDougall.
We’re still working on making our dashboard generating code more efficient (to keep up with the improvements in our ability to train sparse autoencoders!). In the meantime, we’ve collected some anecdotal examples of features at layers 2, 5 and 10 to give examples of the kinds of features we can detect in GPT2 small.
Though we share some features below, you can look through more features at the bottom of the dashboard here.
Layer 2: The President Feature
For example, below we show a “President” feature which promotes the first names of presidents.
Layer 2: The “c” subword token feature
Another example of a fairly typical layer 2 feature is this feature which fires on “c” due to tokenization which splits a word that starts with c.
Layer 5: The what you are saying thanks OR sorry for feature
For example, this feature appears to fire for short stretches of text involving thanks or apologies.
Layer 5: The Force is strong with this Feature.
Though there are plenty of features that seem interesting about layer 5 SAE, some are just way stronger in the force than others.
Layer 10: Violence / Conflict Feature.
How to get involved
I want to look at more dashboards!
You can download all 25k for layer 10 and 2 and the first 5k for layer 5 here.
How can I download and analyze these SAEs?
For those who would like to play around with these sparse autoencoders, my codebase is pretty crazy right now but you can mostly ignore it once you have the SAE. The codebase has:
In order to speed up my own analysis, I creates a “SessionLoader” class which takes a path to the saved SAE and then instantiates the model it was trained on, the sparse autoencoder and the activations_loader (which gets your tokens/activations). Between these three artifacts, you start analyzing an SAE very quickly post training.
After cloning the repo and installing the requirements.txt, users can simply run the following commands:
What kinds of analysis can we do with Residual Stream Sparse Autoencoders?
Without getting into entire research directions, it’s worth discussing briefly the kinds of experiments that can be run with Sparse Autoencoders. These are projects that might enable you to get a taste of working with SAEs and decide if you’re excited to work with them more seriously.
Some example upskilling projects could be:
For more ideas, posts published by researchers currently working on SAEs can be a great source of inspiration:
What research directions could you pursue with SAEs?
It’s likely that there are a bunch of low hanging fruit with SAEs right now. Logan Riggs posted a bunch of ideas here and Anthropic list some directions for future work which are worth reading here.
Two direction, I’m excited by are:
Appendix
Thanks
I’d like to thank Neel Nanda and Arthur Conmy for their support and feedback while I’ve been working on this and other SAE related work. I also appreciate feedback and support from members of the Mechanistic Interpretability Stream in the MATS 5.0 cohort, especially Ben Wu and Andy Arditi.
I’d also like to thank the interpretability team at Anthropic for continually sharing their advice on how to train sparse autoencoders, and to Callum McDougall for his awesome SAE visualizer (replication of Anthropic’s dashboard).
Funding Note
This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, with support from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants.
Related Work
How to Cite