All of Joseph Bloom's Comments + Replies

Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE's and it only worked for some / not others and certainly didn't seem like a good explanation after I'd run it on more than one. 

I've been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstru... (read more)

6Sam Marks
Yep, as you say, @Logan Riggs figured out what's going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I'm able to replicate your results. Here's Logan's plot for one of your dictionaries (not sure which) and here's my replication of Logan's plot for your layer 1 dictionary Interestingly, this does not happen for my dictionaries! Here's the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped (Note that all three plots have a different y-axis scale.) Why the difference? I'm not really sure. Two guesses: 1. The model: GPT2-small uses learned positional embeddings whereas Pythia models use rotary embeddings 2. The training: I train my autoencoders on variable-length sequences up to length 128; left padding is used to pad shorter sequences up to length 128. Maybe this makes a difference somehow. ---------------------------------------- In terms of standardization of which metrics to report, I'm torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they're performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
  • MSE Losses were in the WandB report (screenshot below).
  • I've loaded in your weights for one SAE and I get very bad performance (high L0, high L1, and bad MSE Loss) at first. 
  • It turns out that this is because my forward pass uses a tied decoder bias which is subtracted from the initial activations and added as part of the decoder forward pass. AFAICT, you don't do this. 
  • To verify this, I added the decoder bias to the activations of your SAE prior to running a forward pass with my code (to effectively remove the decoder bias subtraction from my meth
... (read more)
1Sam Marks
My SAEs also have a tied decoder bias which is subtracted from the original activations. Here's the relevant code in dictionary.py def encode(self, x): return nn.ReLU()(self.encoder(x - self.bias)) def decode(self, f): return self.decoder(f) + self.bias def forward(self, x, output_features=False, ghost_mask=None): [...] f = self.encode(x) x_hat = self.decode(f) [...] return x_hat Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I'm a bit confused why subtracting off the decoder bias had to be done explicitly in your code -- maybe you used dictionary.encoder and dictionary.decoder instead of dictionary.encode and dictionary.decode? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis "one of us needs to shift our inputs by +/- the decoder bias" only made things worse, so I'm pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it. I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.

I've run some of the SAE's through more thorough eval code this morning (getting variance explained with the centring and calculating mean CE losses with more batches). As far as I can tell the CE loss is not that high at all and the MSE loss is quite low. I'm wondering whether you might be using the wrong hooks? These are resid_pre so layer 0 is just the embeddings and layer 1 is after the first transformer block and so on. One other possibility is that you are using a different dataset? I trained these SAEs on OpenWebText. I don't much padding at all, th... (read more)

1Sam Marks
In the notebook I link in my original comment, I check that the activations I get out of nnsight are the same as the activations that come from transformer_lens. Together with the fact that our sparsity statistics broadly align, I'm guessing that the issue isn't that I'm extracting different activations than you are. Repeating my replication attempt with data from OpenWebText, I get this: Layer MSE Loss % Variance Explained L1 L0 % Alive CE Reconstructed 1 0.069 95 40 15 46 6.45 7 0.81 86 125 59.2 96 4.38 Broadly speaking, same story as above, except that the MSE losses look better (still not great), and that the CE reconstructed looks very bad for layer 1. Seems like there was a typo here -- what do you mean? Logan Riggs reports that he tried to replicate your results and got something more similar to you. I think Logan is making decisions about padding and tokenization more like the decisions you make, so it's possible that the difference is down to something around padding and tokenization. Possible next steps: * Can you report your MSE Losses (instead of just variance explained)? * Can you try to evaluate the residual stream dictionaries in the 5_32768 set released here? If you get CE reconstructed much better than mine, then it means that we're computing CE reconstructed in different ways, where your way consistently reports better numbers. If you get CE reconstructed much worse than mine, then it might mean that there's a translation error between our codebases (e.g. using different activations).

Oh no. I'll look into this and get back to you shortly. One obvious candidate is that I was reporting CE for some batch at the end of training that was very small and so the statistics likely had high variance and the last datapoint may have been fairly low. In retrospect I should have explicitly recalculated this again post training. However, I'll take a deeper dive now to see what's up. 

1Joseph Isaac Bloom
I've run some of the SAE's through more thorough eval code this morning (getting variance explained with the centring and calculating mean CE losses with more batches). As far as I can tell the CE loss is not that high at all and the MSE loss is quite low. I'm wondering whether you might be using the wrong hooks? These are resid_pre so layer 0 is just the embeddings and layer 1 is after the first transformer block and so on. One other possibility is that you are using a different dataset? I trained these SAEs on OpenWebText. I don't much padding at all, that might be a big difference too. I'm curious to get to the bottom of this.  One sanity check I've done is just sampling from the model when using the SAE to reconstruct activations and it seems to be about as good, which I think rules out CE loss in the ranges you quote above.  For percent alive neurons a batch size of 8192 would be far too few to estimate dead neurons (since many neurons have a feature sparsity < 10**-3.  You're absolutely right about missing the centreing in percent variance explained. I've estimated variance explained again for the same layers and get very similar results to what I had originally. I'll make some updates to my code to produce CE score metrics that have less variance in the future at the cost of slightly more train time.  If we don't find a simple answer I'm happy to run some more experiments but I'd guess an 80% probability that there's a simple bug which would explain the difference in what you get. Rank order of most likely: Using the wrong activations, using datapoints with lots of padding, using a different dataset (I tried the pile and it wasn't that bad either). 

My vibe from this post is something like "we're making on stuff that could be helpful so there's stuff to work on!" and this is a vibe I like. However, I suspect that for people who might not be as excited about these approaches, you're likely not touching on important cruxes (eg: do these approaches really scale? Are some agendas capabilities enhancing? Will these solve deceptive alignment or just corrigible alignment?)

I also think that if the goal is to actually make progress and not to maximize the number of people making progress or who feel like they'... (read more)

1Chris_Leong
Hmm... I suppose this is pretty good evidence that CCS may not be as promising as it first appeared, esp. the banana/shed results. https://www.lesswrong.com/posts/wtfvbsYjNHYYBmT3k/discussion-challenges-with-unsupervised-llm-knowledge-1 Update: Seems like the banana results are being challenged.

Really exciting! I added a version of AVEC to my interpretability tool for gridworld agents and am keen to explore it more. I really like that the injection coefficient has a scalar and this had enabled me to do what I can "an injection coefficient scan". 

The procedure I'm using looks like this:

  1. Repeat your input tokens say, 128 times. 
  2. Apply the activation vector at 128 different steps between a coefficient of -10 and 10 to each of your input tokens when doing your AVEC forward pass. 
  3. Decompose the resulting residual stream to whatever granula
... (read more)
2Alex Turner
I don't think I follow your procedure. Would you be willing to walk me through an example situation?

We would love to see more ideas & hypotheses on why the model might be doing this, as well as attempts to test this! We mainly wrote-up this post because both Alex and I independently noticed this and weren't aware of this previously, so we wanted to make a reference post.

Happy to provide! I think I'm pretty interested in testing this/working on this in the future. Currently a bit tied up but I think (as Alex hints at) there could be some big implications for interpretability here.

TLDR: Documenting existing circuits is good but explaining what relation... (read more)

2Alex Turner
Can you say more on this point? The latter kind of good (useful when integrated with other features) doesn't necessarily imply that direct unembed (logit lens) or learned linear unembed (tuned lens iirc) would be able to extract use from such goods. I suspect that I probably just missed your point, though.

Second pass through this post which solidly nerd-sniped me! 

A quick summary of my understand of the post: (intentionally being very reductive though I understand the post may make more subtle points). 

  1. There appears to be exponential growth in the norm of the residual stream in a range of models. Why is this the case?
  2. You consider two hypotheses: 
    1. 1. That the parameters in the Attention and/or MLP weights increase later in the network. 
    2. 2. That there is some monkey business with the layer norm sneaking in a single extra feature. 
  3. In ter
... (read more)
3Stefan Heimersheim
Thank for for the extensive comment! Your summary is really helpful to see how this came across, here's my take on a couple of these points: 2.b: The network would be sneaking information about the size of the residual stream past LayerNorm. So the network wants to implement an sort of "grow by a factor X every layer" and wants to prevent LayerNorm from resetting its progress. 1. There's the difference between (i) How does the model make the residual stream grow exponentially -- the answer is probably theory 1, that something in the weights grow exponentially. And there is (ii) our best guess on Why the model would ever want this, which is the information deletion thing. Yep we give some evidence for How, but for Why we have only a guess. Yes, all we have is some intuition here. It seems plausible that the model needs to communicate stuff between some layers, but doesn't want this to take up space in the residual stream. So this exponential growth is a neat way to make old information decay away (relatively). And it seems plausible to implement a few amplification circuits for information that has to be preserved for much later in the network. We would love to see more ideas & hypotheses on why the model might be doing this, as well as attempts to test this! We mainly wrote-up this post because both Alex and I independently noticed this and weren't aware of this previously, so we wanted to make a reference post.

Thank you for letting me know about your work on procgen with MI. It sounds like you're making progress, particularly I'd be interested in your visualisation techniques (how do they compare to what was done in Understanding RL Vision?) and the reproduction of the cheese-maze policies (is this tricky? Do you think a DT could be well-calibrated on this problem?). 

Some questions that might be useful to discuss more:

  • What are the pros/cons of doing DT vs actor-critic MI? (You're using Actor-Critic of some form?). It could also be interesting to study analo
... (read more)
4Alex Turner
We're studying a net with the structure I commented below, trained via PPO. I'd be happy to discuss more at EAG.  Not posting much publicly right now so that we can a: work on the research sprint and b: let people preregister credences in various mechint / generalization propositions, so that they can calibrate / see how their opinions evolve over time. 

Hey Adam, thanks for running Refine and writing this up. 

Out of curiosity, do you (or anyone else) know if there are statistics for previous SERI-MATS cohorts/other programs designed to generate conceptual alignment researchers? 

3Adam Shimi
Thanks for the kind words! I'm not aware of any such statistics, but I'm guessing that MATS organizers might have some.