Awesome work with this! Definitely looks like a big improvement over standard SAEs for absorption. Some questions/thoughts:
In the decoder cos sim plot, it looks like there's still some slight mixing of features in co-occurring latent groups including some slight negative cos sim, although definitely a lot better than in the standard SAE. Given the underlying features are orthogonal, I'm curious why the Matryoshka SAE doesn't fully drive this to 0 and perfectly recover the underlying true features? Is it due to the sampling, so there's still some chance for the SAE to learn some absorption-esque feature mixes when the SAE latent isn't sampled? If there was no sampling and each latent group had its loss calculated each step (I know this is really inefficient in practice), would the SAE perfectly recover the true features?
It looks like Matryoshka SAEs will solve absorption as long as the parent feature in the hierarchy is learned before the child feature, but this doesn't seem like it's guaranteed to be the case. If the child feature happens to fire with much higher magnitude than the parent, then I would suspect the SAE would learn the child latent first to minimize expected MSE loss, and end up with absorption still. E.g. if a parent feature fires with probability 0.3 and magnitude 2.0 (expected MSE = 0.3 * 2.0^2 = 1.2), and a child feature fires with probability 0.15 but magnitude 10.0 (expected MSE = 0.15 * 10^2 = 15.0), I would expect the SAE would learn the child feature before the parent, and merge the parent representation into the child, resulting in absorption. In real LLMs, this might potentially never happen though so possibly not an issue, but could be something to look out for when training Matryoshka SAEs on real LLMs.
Awesome work with this! Definitely looks like a big improvement over standard SAEs for absorption. Some questions/thoughts:
In the decoder cos sim plot, it looks like there's still some slight mixing of features in co-occurring latent groups including some slight negative cos sim, although definitely a lot better than in the standard SAE. Given the underlying features are orthogonal, I'm curious why the Matryoshka SAE doesn't fully drive this to 0 and perfectly recover the underlying true features? Is it due to the sampling, so there's still some chance for the SAE to learn some absorption-esque feature mixes when the SAE latent isn't sampled? If there was no sampling and each latent group had its loss calculated each step (I know this is really inefficient in practice), would the SAE perfectly recover the true features?
It looks like Matryoshka SAEs will solve absorption as long as the parent feature in the hierarchy is learned before the child feature, but this doesn't seem like it's guaranteed to be the case. If the child feature happens to fire with much higher magnitude than the parent, then I would suspect the SAE would learn the child latent first to minimize expected MSE loss, and end up with absorption still. E.g. if a parent feature fires with probability 0.3 and magnitude 2.0 (expected MSE = 0.3 * 2.0^2 = 1.2), and a child feature fires with probability 0.15 but magnitude 10.0 (expected MSE = 0.15 * 10^2 = 15.0), I would expect the SAE would learn the child feature before the parent, and merge the parent representation into the child, resulting in absorption. In real LLMs, this might potentially never happen though so possibly not an issue, but could be something to look out for when training Matryoshka SAEs on real LLMs.