756

AI ALIGNMENT FORUM
AF

755
Interpretability (ML & AI)Sparse Autoencoders (SAEs)AI
Frontpage

33

JumpReLU SAEs + Early Access to Gemma 2 SAEs

by Senthooran Rajamanoharan, Tom Lieberum, nps29, Arthur Conmy, Vikrant Varma, János Kramár, Neel Nanda
19th Jul 2024
2 min read
10

33

This is a linkpost for https://storage.googleapis.com/jumprelu-saes-paper/JumpReLU_SAEs.pdf
JumpReLU SAEs + Early Access to Gemma 2 SAEs
7leogao
1Lennart Buerger
1Logan Riggs
New Comment
3 comments, sorted by
top scoring
Click to highlight new comments since: Today at 3:50 PM
[-]leogao1y72

Very exciting that JumpReLU works well with STE gradient estimation! I think this fixes one of the biggest flaws with TopK, which is that having a fixed number of latents k on each token is kind of wonky. I also like the argument in section 4 a lot - in particular the point about how this works because we're optimizing the expectation of the loss. Because of how sparse the features are, I wonder if it would reduce gradient noise substantially to use a KDE with state persisting across a few recent steps.

Reply
[-]Lennart Buerger1y10

Nice work! I was wondering what context length you were using when you extracted the LLM activations to train the SAE. I could not find it in the paper but I might also have missed it. I know that OpenAI used a context length of 64 tokens in all their experiments which is probably not sufficient to elicit many interesting features. Do you use a variable context length or also a fixed value?

Reply
[-]Logan Riggs1y10

Did y'all do any ablations on your loss terms. For example:
1. JumpReLU() -> ReLU
2. L0 (w/ STE) -> L1

I'd be curious to see if the pareto improvements and high frequency features are due to one, the other, or both

Reply
Moderation Log
More from Senthooran Rajamanoharan
View more
Curated and popular this week
3Comments
Interpretability (ML & AI)Sparse Autoencoders (SAEs)AI
Frontpage

New paper from the Google DeepMind mechanistic interpretability team, led by Sen Rajamanoharan!

We introduce JumpReLU SAEs, a new SAE architecture that replaces the standard ReLUs with discontinuous JumpReLU activations, and seems to be (narrowly) state of the art over existing methods like TopK and Gated SAEs for achieving high reconstruction at a given sparsity level, without a hit to interpretability. We train through discontinuity with straight-through estimators, which also let us directly optimise the L0.

To accompany this, we will release the weights of hundreds of JumpReLU SAEs on every layer and sublayer of Gemma 2 2B and 9B in a few weeks. Apply now for early access to the 9B ones! We're keen to get feedback from the community, and to get these into the hands of researchers as fast as possible. There's a lot of great projects that we hope will be much easier with open SAEs on capable models!

Jump ReLU Activations
Scenarios where JumpReLU is superior
JumpReLUs are state of the art

Gated SAEs already reduced to JumpReLU activations after weight tying, so this can be thought of as Gated SAEs++, but less computationally intensive to train, and better performing. They should be runnable in existing Gated implementations.

Abstract:

Sparse autoencoders (SAEs) are a promising unsupervised approach for identifying causally relevant and interpretable linear features in a language model’s (LM) activations. To be useful for downstream tasks, SAEs need to decompose LM activations faithfully; yet to be interpretable the decomposition must be sparse – two objectives that are in tension. In this paper, we introduce JumpReLU SAEs, which achieve state-of the-art reconstruction fidelity at a given sparsity level on Gemma 2 9B activations, compared to other recent advances such as Gated and TopK SAEs. We also show that this improvement does not come at the cost of interpretability through manual and automated interpretability studies. JumpReLU SAEs are a simple modification of vanilla (ReLU) SAEs – where we replace the ReLU with a discontinuous JumpReLU activation function – and are similarly efficient to train and run. By utilising straight-through-estimators (STEs) in a principled manner, we show how it is possible to train JumpReLU SAEs effectively despite the discontinuous JumpReLU function introduced in the SAE’s forward pass. Similarly, we use STEs to directly train L0 to be sparse, instead of training on proxies such as L1, avoiding problems like shrinkage.