Very exciting that JumpReLU works well with STE gradient estimation! I think this fixes one of the biggest flaws with TopK, which is that having a fixed number of latents k on each token is kind of wonky. I also like the argument in section 4 a lot - in particular the point about how this works because we're optimizing the expectation of the loss. Because of how sparse the features are, I wonder if it would reduce gradient noise substantially to use a KDE with state persisting across a few recent steps.
Nice work! I was wondering what context length you were using when you extracted the LLM activations to train the SAE. I could not find it in the paper but I might also have missed it. I know that OpenAI used a context length of 64 tokens in all their experiments which is probably not sufficient to elicit many interesting features. Do you use a variable context length or also a fixed value?
Did y'all do any ablations on your loss terms. For example:
1. JumpReLU() -> ReLU
2. L0 (w/ STE) -> L1
I'd be curious to see if the pareto improvements and high frequency features are due to one, the other, or both
New paper from the Google DeepMind mechanistic interpretability team, led by Sen Rajamanoharan!
We introduce JumpReLU SAEs, a new SAE architecture that replaces the standard ReLUs with discontinuous JumpReLU activations, and seems to be (narrowly) state of the art over existing methods like TopK and Gated SAEs for achieving high reconstruction at a given sparsity level, without a hit to interpretability. We train through discontinuity with straight-through estimators, which also let us directly optimise the L0.
To accompany this, we will release the weights of hundreds of JumpReLU SAEs on every layer and sublayer of Gemma 2 2B and 9B in a few weeks. Apply now for early access to the 9B ones! We're keen to get feedback from the community, and to get these into the hands of researchers as fast as possible. There's a lot of great projects that we hope will be much easier with open SAEs on capable models!
Gated SAEs already reduced to JumpReLU activations after weight tying, so this can be thought of as Gated SAEs++, but less computationally intensive to train, and better performing. They should be runnable in existing Gated implementations.
Abstract: