My guess is that this result is very sensitive to the design of the training dataset:
the input/output data pairs are for , where is the basis vector.
In particular, I think it is likely very sensitive to the implicit assumption that feature i and feature j never co-occur on a single input. I'd be interested to see experiments where each feature is turned on with some (not too small) probability, independently of all other features, similarly to the original toy models setting. This would result in some inputs where feature i and j are on simultaneously. My prediction would be that polysemanticity goes down very significantly (probably to zero if the probabilities are high enough and the training is done for long enough).
I also don't understand why L1 regularization on activations is necessary to show incidental polysemanticity given your setup. Even if you remove the L1 regularization on activations, it is still the case that "benign collisions" impose no cost on the model, since feature i and feature j are never simultaneously present in a given input. So if you do get a benign collision, what causes it to go away? Overall my expectation would be that without the L1 regularization on activations (and with the training dataset as described in this post), you'd get a complicated mess where every neuron is highly polysemantic, i.e. even more polysemanticity than described in this post. Why is that wrong?
Thanks for the feedback!
In particular, I think it is likely very sensitive to the implicit assumption that feature i and feature j never co-occur on a single input.
Definitely! I still think that this assumption is fairly realistic because in practice, most pairs of unrelated features would co-occur only very rarely, and I expect the winner-take-all dynamic to dominate most of the time. But I agree that it would be nice to quantify this and test it out.
Overall my expectation would be that without the L1 regularization on activations (and with the training dataset as described in this post), you'd get a complicated mess where every neuron is highly polysemantic, i.e. even more polysemanticity than described in this post. Why is that wrong?
If there is no L1 regularization on activations, then every hidden neuron would indeed be highly "polysemantic" in the sense that it has nonzero weights for each input feature. But on the other hand, the whole encoding space would become rotationally symmetric, and when that's the case it feels like polysemanticity shouldn't be about individual neurons (since the canonical basis is not special anymore) and instead about the angles that different encodings form. In particular, as long as mgen, the space of optimal solutions for this setup requires the encodings to form angles of at least 90° with each other, and it's unclear whether we should call this polysemantic.
So one of the reasons why we need L1 regularization is to break the rotational symmetry and create a privileged basis: that way, it's actually meaningful to ask whether a particular hidden neuron is representing more than one feature.
Good point on the rotational symmetry, that makes sense now.
I still think that this assumption is fairly realistic because in practice, most pairs of unrelated features would co-occur only very rarely, and I expect the winner-take-all dynamic to dominate most of the time. But I agree that it would be nice to quantify this and test it out.
Agreed that's a plausible hypothesis. I mostly wish that in this toy model you had a hyperparameter for the frequency of co-occurrence of features, and identified how it affects the rate of incidental polysemanticity.
Great work! Love the push for intuitions especially in the working notes.
My understanding of superposition hypothesis from TMS paper has been(feel free to correct me!):
Is it possible that the features here are not enough basis aligned and is closer to case 1? As you already commented demonstrating polysemanticity when the hidden layer has a non linearity and m>n would be principled imo.
This is a preliminary research report; we are still building on initial work and would appreciate any feedback.
Summary
Polysemantic neurons (neurons that activate for a set of unrelated features) have been seen as a significant obstacle towards interpretability of task-optimized deep networks,[1] with implications for AI safety.
The classic origin story of polysemanticity is that the data contains more "features" than there are neurons, such that learning to solve a task forces the network to allocate multiple unrelated features to the same neuron, threatening our ability to understand the network's internal processing.
In this work, we present a second and non-mutually exclusive origin story of polysemanticity. We show that polysemanticity can arise incidentally, even when there are ample neurons to represent all features in the data, using a combination of theory and experiments. This second type of polysemanticity occurs because random initialization can, by chance alone, initially assign multiple features to the same neuron, and the training dynamics then strengthen such overlap. Due to its origin, we term this incidental polysemanticity.
Intuition
The reason why neural networks can learn anything despite starting out with completely random weights is that, just by random chance, some neurons will happen to be very slightly correlated[2] with some useful feature, and this correlation gets amplified by gradient descent until the feature is accurately represented. If in addition to this there is some incentive for activations to be sparse, then the feature will tend to be represented by a single neuron as opposed to a linear combination of neurons: this is a winner-take-all dynamic.[3] When a winner-take-all dynamic is present, then by default, the neuron that is initially most correlated with the feature will be the neuron that wins out and represents the feature when training completes.
Therefore, if at the start of training, one neuron happens to be the most correlated neuron with two unrelated features (say dogs and airplanes), then this might[4] continue being the case throughout the learning process, and that neuron will ultimately end up taking full responsibility for representing both features. We call this phenomenon incidental polysemanticity. Here "incidental" refers to the fact that this phenomenon is contingent on the random initializations of the weights and the dynamics of training, rather than being necessary in order to achieve low loss (and in fact, in some circumstances, incidental polysemanticity might cause the neural network to get stuck in a local optimum).
How often should we expect this to happen? Suppose that we have n useful features to represent and m≥n neurons to represent them with (so that it is technically possible for each feature to be represented by a different neuron). By symmetry, the probability that the ith and jth feature "collide", in the sense of being initially most correlated with the same neuron, is exactly 1/m. And there are (n2)=n(n−1)/2 pairs of features, so on average we should expect (n2)number of pairs (i,j)×1mprobability of (i,j) colliding=n(n−1)2m=Θ(n2m) collisions[5] overall. In particular, this means that
Our experiments in a toy model show that this is precisely what happens, and a constant fraction of these collisions result in polysemantic neurons, despite the fact that there would be enough neurons to avoid polysemanticity entirely.
Outline
In the rest of this post, we
Setup
Model
We consider a model similar to the ReLU-output model in Toy Models of Superposition. It is an autoencoder with n features (inputs/outputs) which
The output is computed as ReLU(WWTx):
The main difference compared to the model from Toy Models of Superposition is the l1 regularization. The role of the l1 regularization is to push for sparsity in the activations and therefore induce a winner-take-all dynamic. We picked this model because it makes incidental polysemanticity particularly easy to demonstrate and study, but we do think the story it tells is representative (see the "Discussion and future work" section for more on this).
We make the following assumptions on parameter values:
Possible solutions
Let Wi∈Rm be the ith row of W. It tells us how the ith feature is encoded in the hidden layer. When the input is ei, the output of the model can then be written as (ReLU(W1⋅Wi),…,ReLU(Wn⋅Wi)), so for this to be equal to ei we need ∥Wi∥2=1[6] and Wi⋅Wj≤0 for j≠i.
Letting fk∈Rm denote the kth basis vector in Rm. There are both monosemantic and polysemantic solutions that satisfy these conditions:
Loss and dynamics
Let us consider total squared error loss, which can be decomposed as L=∑i⎛⎝(1−∥Wi∥2)2+∑j≠iReLU(Wi⋅Wj)2+λ∥Wi∥1⎞⎠. The training dynamics are dWidt:=−∂L∂Wi=4 (1−∥Wi∥2)Wifeature benefit−4 ∑j≠iReLU(Wi⋅Wj)Wjinterference−λ sign(Wi)regularization, where t is the training time, which you can roughly think of as the number of training steps. For simplicity, we'll ignore the constants 4 going forward.[7]
It can be decomposed into three intuitive "forces" acting on the encodings Wi:
The winning neuron takes it all
See our working notes (in particular, Feature benefit vs regularization) for a more formal treatment.
Sparsity force
For a moment, let's ignore the interference force, and figure out how (and how fast) regularization will push towards sparsity in some encoding Wi. Since we're only looking at feature benefit and regularization, the other encodings Wj have no influence at all on what happens in Wi.
Assuming ∥Wi∥<1, each weight Wik is
Crucially, the upwards push is relative to how large Wik is, while the downwards push is absolute. This means that weights whose absolute value is above some threshold θ will grow, while those below the threshold will shrink, creating a "rich get richer and poor get poorer" dynamic that will push for sparsity. This threshold is given by (1−∥Wi∥2)Wik=λ sign(Wi)⟺|Wik|=λ1−∥Wi∥2=:θ, so we have d|Wik|dt=(1−∥Wi∥2)|Wik|feature benefit−λf1[Wik≠0]regularization=⎧⎪⎨⎪⎩(1−∥Wi∥2)constant in k(|Wik|−θ)distance from thresholdif Wik≠00otherwise.(1)
We call this combination of feature benefit and regularization force the sparsity force. It uniformly stretches the gaps between (the absolute values of) different nonzero weights.
Note that the threshold θ is not fixed: we will see that as Wi gets sparser, ∥Wi∥2 will get closer to 1, which increases the threshold and allows it to get rid of larger and larger entries, until only one is left. But how fast will this go?
How fast does it sparsify?
The next two subsections are not critical for understanding the overall message; feel free to skip directly to the section titled "Interference arbiters collisions between features" if you're happy with just accepting the fact that Wi will progressively sparsify over some predictable length of training time.
In order to track how fast Wi sparsifies, we will look at its l1 norm ∥Wi∥1=∑k|Wik| as a proxy for how many nonzero coordinates are left. Indeed, we will have ∥Wi∥≈1 throughout, so if Wi has m′ nonzero values at any point in time, their typical value will be ±1/√m′, which means ∥Wi∥1≈m′1√m′=√m′.
Since the sparsity force is proportional to 1−∥Wi∥2, we need to get a sense of what values ∥Wi∥ will take over time. As it turns out, ∥Wi∥ changes relatively slowly, so we can get useful information by assuming the derivative d∥Wi∥2dt is 0: 0≈d∥Wi∥2dt=2dWidt⋅Wi=2⎛⎜ ⎜⎝(1−∥Wi∥2)∥Wi∥2from feature benefit−λ∥Wi∥1from regularization⎞⎟ ⎟⎠, which means 1−∥Wi∥2≈λ∥Wi∥1∥Wi∥2. Plugging this back into d∥Wi∥1dt=∑kd|Wik|dt and using reasonable assumptions about the initial distribution of Wi (see our working notes for details), we can prove that ∥Wi∥1 will decrease as 1/λt with training time t: ∥Wi(t)∥1=⎧⎪ ⎪ ⎪ ⎪⎨⎪ ⎪ ⎪ ⎪⎩Θ(√m)t≤1λ√mΘ(1/λt)1λ√m≤t≤1λΘ(1)t≥1λ. Correspondingly, if we approximate the number m′ of nonzero cooordinates as ∥Wi∥21, it will start out at m, decrease as 1/(λt)2, then reach 1 at training time t=Θ(1/λ).
Numerical simulations
We compared our theoretical predictions for ∥Wi∥1 and m′ (if the constants hidden in Θ(⋅) are assumed to be 1) to their actual values over training time when the interference force is turned off. The specific values of parameters are m:=105 and λ:=10−5, and the standard deviation of the Wik's was 0.9√m.
Code is available here.
Interference arbiters collisions between features
What happens when you bring the interference force into this picture? In this section, we argue informally that the interference is initially weak if m≥n, and only becomes significant later on in training, in cases where two of the encodings Wi and Wj have a coordinate k such that Wik and Wjk are both large and have the same sign—when that's the case, the larger of the two wins out.
How strong is the interference?
First, observe that in the expression for the interference force on Wi −∑j≠iReLU(Wi⋅Wj)Wj, each Wj contributes only if the angle it forms with Wi is less than 90∘. So the force will mostly be in the same direction as Wi, but opposite. That means that we can get a good grasp on its strength by measuring its component in the direction of Wi, which we can do by taking an inner product with Wi.
We have ⎛⎝∑j≠iReLU(Wi⋅Wj)Wj⎞⎠⋅Wi=∑j≠iReLU(Wi⋅Wj)(Wi⋅Wj)=∑j≠iReLU(Wi⋅Wj)2. Initially, each encoding is a vector of m i.i.d. normals of mean 0 and standard deviation Θ(1/√m), so the distribution of the inner products Wi⋅Wj is symmetric around 0 and also has standard deviation Θ(1/√m). This means that ReLU(Wi⋅Wj)2 has mean Θ(1/m), and thus the sum has mean Θ(n/m). As long as m≥n, this is dominated by the feature benefit force: indeed, the same computation for the feature benefit gives ((1−∥Wi∥2)Wi)⋅Wi=(1−∥Wi∥2)∥Wi∥2=Θ(1) as long as Ω(1)≤∥Wi∥2≤1−Ω(1).
Moreover, over time, the positive inner products Wi⋅Wj>0 will tend to decrease exponentially. This is because the interference force on Wi includes the term −ReLU(Wi⋅Wj)Wj and the interference force on Wj includes the term −ReLU(Wi⋅Wj)Wi. Together, they affect Wi⋅Wj as (−ReLU(Wi⋅Wj)Wj)⋅Wj+(−ReLU(Wi⋅Wj)Wi)⋅Wi=−(Wi⋅Wj)(∥Wi∥2+∥Wj∥2)=−Θ(Wi⋅Wj) as long as ∥Wi∥2,∥Wj∥2=Θ(1), which is definitely the case at the start and will continue to hold true throughout training.
Benign and malign collisions
On the other hand, the interference between two encodings Wi and Wj starts to matter significantly when it affects one coordinate much more strongly than the others (rather than affecting all coordinates proportionally, like the feature benefit force does). This is the case when Wi and Wj share only one nonzero coordinate: a single k such that Wik,Wjk≠0. Indeed, when that's the case, the interference force −ReLU(Wi⋅Wj)Wj
so only Wik can be affected by this force.
When this happens, there are two cases:
Polysemanticity will happen when the largest[8] coordinates in encodings Wi and Wj get into a benign collision. This happens with probability 1mlargest weight in Wi is also largest in Wj×12they have opposite signs=12m, so we should expect roughly (n2)×12m∼n24m polysemantic neurons by the end.
Experiments
Training the model we described on n:=256 and m ranging from 256 to 4096 shows that this trend of Θ(n2m) does hold, and the constant 14 seems to be fairly accurate as well.
Discussion and future work
Implications for mechanistic interpretability
The fact that there are two completely different ways for polysemanticity to occur could have important consequences on how to deal with it.
To our knowledge, polysemanticity has mostly been studied in settings where the encoding space has no privileged basis: the space can be arbitrarily rotated without changing the dynamics, and in particular the corresponding layer doesn't have non-linearities or any regularization other than l2. In this setting, the features can be represented arbitrarily in the encoding space, and we usually observe interference (non-orthogonal encodings) only when there are more features than dimensions.
On the other hand, the incidental polysemanticity we have demonstrated here is inherently tied to the canonical basis, contingent on the random initialization and dynamics, and happens even when there are significantly more dimensions available than features.
This means that some tools that work against one type of polysemanticity might not work against the other. For example:
In addition, it would be interesting to find ways to distinguish incidental polysemanticity from necessary polysemanticity.
A more realistic toy model
The setup we studied is simplistic in several ways. Some of these ways are without loss of (much) generality, such as the fact that encoding and decoding matrices are tied together,[9] or the fact that the input features are basis vectors.[10]
But there are also some choices that we made for simplicity which might be more significant, and which it would be nice to investigate. In particular:
Gaps in the theory
We were able to give strong theoretical guarantees for the sparsification process by considering how the feature benefit force and regularization interact when interference is ignored, but we haven't yet been able to make confident theoretical claims about how the three forces interact together.
In particular:
Author contributions
see e.g. the "Polysemantic Neurons" section in Zoom In: An Introduction to Circuits ↩︎
When we say a neuron is correlated with a feature, what we more formally mean is that the neuron's activation is correlated with whether the feature is present in the input (where the correlation is taken over the data points). But the former is easier to say. ↩︎
Analogous phenomena are known under other names, such as "privileged basis". ↩︎
depending mostly on the specifics of the neural architecture and the data (but also on the random initializations of the weights) ↩︎
Here, we define a "collision" as the event that two features i and j collide. So for example there is a three-way collision between i, j and k, that would count as three collisions between i and j, i and k, and j and k. ↩︎
We use ∥⋅∥ to denote Euclidean length (l2 norm), and ∥⋅∥1 to denote Manhattan length (l1 norm). ↩︎
It's equivalent to making λ four times larger and making training time four times slower. ↩︎
This would not necessarily be the largest weight at initialization, since there might be significant collisions with other encodings, but the largest weight at initialization is still the most likely to win the race all things considered. ↩︎
We're referring to the fact that the encoding matrix WT is forced to be the transpose of the decoding matrix W. This assumption makes sense because even if they were kept independent and initialized to different values, they would naturally acquire similar values over time because of the learning dynamics. Indeed, the ith column of the encoding matrix and the ith row of the decoding matrix "reinforce each other" through the feature benefit force until they have an inner product of 1, and ao as long as they start out small or if there is some weight decay, they would end up almost identical by the end of training. ↩︎
If the input features are not the canonical basis vectors but are still orthogonal (and the outputs are still basis vectors), then we could apply a fixed linear transformation to the encoding matrix and recover the same training dynamics. And in general it makes sense to consider orthogonal input features, because when the features themselves are not orthogonal (or at least approximately orthogonal), the question of what polysemanticity even is becomes more murky. ↩︎