TL;DR We experimentally test the mathematical framework for circuits in superposition by hand-coding the weights of an MLP to implement many conditional[1] rotations in superposition on two-dimensional input features. The code can be found here.
This work was supported by Coefficient Giving and Goodfire AI
In a previous post, we sketched out a construction for compressing many different circuits into a neural network such that they can be computed in superposition.[2]
By computation in superposition, we mean that a network represents features in superposition and can perform more possible computations with them than it has neurons (although not all at once), across multiple layers. Having better models of this is important for understanding whether and how networks use superposition, which in turn is important for mechanistic interpretability research more generally.
Performing computation in superposition over multiple layers introduces additional noise compared to just storing features in superposition. This restricts the amount and type of computation that can be implemented in a network of a given size, because the noise needs to be reduced or suppressed to stay smaller than the signal.
Here, we test the construction by using it to compress circuits that perform random rotations on different two-dimensional input vectors, conditional on a Boolean variable being active.
We ended up needing to adjust the construction a bit to do this, because one of the assumptions the original version made about the circuits to be compressed turned out to be less reasonable in practice than we thought it'd be, and did not actually apply to these rotation circuits.[3]
This sort of thing is why we do testing. Reality is complicated.
Note that this is certainly not the ideal way to implement rotations in superposition. We are testing a general method that should work for a wide class of circuits. If we just wanted to implement these specific circuits as compactly and noiselessly as possible, there'd be better ways to do that.[4]
It seems to work. But we can only fit so many independent circuits into the network before the errors compound and explode.
Independent circuits in superposition are costly
Many people we've talked to seem to think that neural networks can easily learn to implement an exponentially large number of independent computations by using superposition. They often seem to base this intuition on toy models that showcase how many independent features neural networks can store in superposition in their activations, but that's very different from implementing independent computations in superposition on these features. The latter is much more difficult and costly.
Our previous post concluded that in theory, we could fit close to circuits that each require ReLU neurons per layer into a neural network with ReLU neurons per layer, provided the circuits activate sparsely enough.[5] But that was an asymptotic expression for the case of network width going to infinity, based only on the leading error terms and mostly ignoring constant prefactors.
In our experiments, we had , and it was already pretty hard to fit in circuits and have the noise stay small enough to not blow up the computation in the first few layers. We think we could perhaps stabilize the computations for a much greater number of layers by implementing more active error correction,[6] but squeezing in many more independent circuits than this seems very difficult to us.
So how can language models memorize so many things?
If fitting many independent circuits into a neural network is so difficult, how can language models memorize so many different facts and even whole text passages? We don't know the answer to that yet. But we have started to train and analyze small toy models of memorization to find out, and we have some suspicions. Our work on this is still at a very early stage, so take the following speculation with an even bigger helping of salt than usual. But:
Many memorization circuits might be much shallower than the ones we investigate here. A shallower depth means fewer stages of computation for errors to accumulate.
Much of the information that models memorize is not actually independent. If a model wants to store that "The Eiffel Tower is in Paris" and that "The capital of France is Paris", it doesn't actually need two independent, non-overlapping lookup circuits to do this. Since both look-ups return the same answer, "Paris", the computation need not distinguish their outputs. Instead, all look-ups that return "Paris" could effectively share the same set of neurons, meaning they'd effectively not be independent look-ups at all, but rather parts of a single general "Sequences that end in Paris" lookup.
This picture of computation in neural networks suggests that there might not be that many more independent circuits than neurons in a large language model like GPT-2. Instead, the models might be exploiting structure in the data to implement what might superficially look like very many different capabilities using a comparatively much smaller set of fairly general circuits.
Lucius: This roughly tracks with some very early results I've been seeing when decomposing the parameters of models like GPT-2 Small into independent components. It currently looks to me like it's actually rare for these models to even have more rank circuit pieces than neurons in a layer, never mind neuron count squared. I'd guess that much wider models have more circuits relative to their neuron count, but at this point I kind of doubt that any real model today gets even remotely close to the theoretical limit of circuit per weight. Which might be great, because it'd mean there aren't that many circuits we need to interpret to understand the models.
The general setup for our circuits-in-superposition framework goes like this:
In this case, the small circuits we want to compress into one network each perform a particular two-dimensional rotation on different two-dimensional vectors , conditional on a Boolean "On-indicator" variable taking the value . The circuits have layers, with each layer performing another conditional rotation on the previous layer's output. We randomly selected the rotation angle for each circuit. The hidden dimension of each individual circuit is .
Conditional rotation
In each layer , the circuits are supposed to rotate a two-dimensional vector with norm by some angle , conditional on the Boolean On-indicator taking value rather than .
We can implement this with an expression using ReLUs, like this:
If , this expression simplifies to
which is a two-dimensional rotation by angle , as desired. If , we get .
On-indicator
At each layer , the circuits are also supposed to apply a step function to the On-indicator , and pass it on to the next layer:
The step function is there to make the On-indicators robust to small noise. Noise robustness is a requirement for circuits to be embedded in superposition. For more details on this and other properties that circuits need to have for computation in superposition to work, see Section 5.2
If the circuit is active, the On-indicator will be initialized as 1, and remain 1. If the circuit is inactive, the On-indicator will be initialized as 0, and stay 0.
For more detail on the small circuits, see Appendix section 5.1
Why conditional rotations in particular?
No particular reason. We just wanted some simple circuits. Originally, we tried to just do plain two-dimensional rotations. Then we remembered that plain rotations aren't noise-robust in the way our computation-in-superposition framework requires, because they're analytic at every point. We need circuits that return an output of when passed an input vector of sufficiently small size . Rotations don't do that. Making the rotations conditional on some Boolean logic fixes this, and for simplicity we just went with the easiest Boolean logic possible.
Why the step function?
Although we only need one float to encode , we need two ReLUs to make the circuit robust to noise over multiple layers. If we only used a single ReLU, e.g., , then a small error on would be doubled in layer , then doubled again in layer and so on, growing exponentially with circuit depth.
If dedicating two neurons just to copy over the embedding of a single Boolean seems a little silly to you, you're not wrong. If our main goal was just to implement lots of conditional rotation circuits in superposition using as few neurons as possible, there'd be better ways to do it. For example, we could use some fraction of the neurons in the large network to embed the On-indicators, and then just copy this information from layer to layer without ever recalculating these values. Such an implementation would do better at this particular task.
But our goal is not to solve this task in particular, but to test a general method for embedding many different kinds of circuits.
If you like, you can also see this part of the circuits as a stand-in for some more complicated logic, that calculates which circuits are on from layer to layer, i.e., some function that is more complicated than "same as last layer", and involves cross-circuit computations
Here, we'll present the results of embedding conditional rotation circuits into a ReLU MLP of width for various choices of and , using the embedding method described in Section 5.4
As we would expect, the read-offs diverge more with every subsequent layer as error accumulates, since we didn't implement any active error correction.[6] Most rotation circuits seem to work fine for the first - rotations, with the read-offs predicting the ideal values much better than chance. Estimates are much worse in the later layers. The errors also get larger when the number of active circuits is increased.
We quantify the error on the rotated vectors as the norm of the difference between the ideal result and the estimate :
On a population level, we can look at both the average error of circuits , and the worst-case error . The former tells us how noisy the computations are most of the time. The latter tells us whether any one circuit in the ensemble of circuits was implemented so noisily that for at least one input it basically doesn't work at all.
Figure 5: Errors on the rotating vectors , for an MLP of width with circuits each of width . Every circuit neuron is spread across MLP neurons. Upper row: mean squared errors, i.e., the average of over inactive or active circuits. Lower row: worst-case absolute errors, i.e., the maximum of over inactive or active circuits.[7]
The left plots show errors on inactive circuits, i.e., circuits that received an input of . Most circuits will be inactive on any given forward pass. The right plots show errors on active circuits, i.e. circuits that received an input of and some random vector with . As a baseline to put these numbers in perspective, once , the circuits are doing no better than guessing the mean. So, the average circuit beats this trivial baseline for the first three of the five rotations, or the first four if we restrict ourselves to inputs with only active circuits. As one might expect, the worst-case errors are worse. With inputs, all [8] circuits beat the baseline for the first two rotations. With inputs, only the first rotation is implemented with small enough noise by all circuits to beat the baseline, and at least one circuit doesn't work right for an input with for even a single rotation.
We can also see that the mean squared errors grow superlinearly the deeper we go into the network. This does not match our error calculations, which would predict that those errors should grow linearly. We think this is because as the errors grow they start to overwhelm the error suppression mechanism by activating additional ReLUs, and thus letting through additional noise. The calculations assume that the noise always stays small enough that this doesn't happen. More on that in the next subsection, where we compare our error calculations to the empirical results.
We can also look at the error on the On-indicator variables :
Here, an error means that the error on the On-indicator is large enough to pass through our step function and propagate to the next layer. At that point, errors will start to accumulate over layers at a faster-than-linear rate.
Figure 6: Errors on the Boolean On-indicators , for an MLP of width with circuits each of width . Every circuit neuron is spread across MLP neurons. Upper row: mean squared errors, i.e., the average of over inactive or active circuits. Lower row: worst-case absolute error, i.e., the maximum of over inactive or active circuits.[7]
The absolute value of the error for the On-indicator caps out at . This is because is the averaged output of a set of step functions, each of which is constrained to output a value in the range . Their average is thus also constrained to the range . Therefore, regardless of whether the true value for is or , has to be in the range .
Key takeaway: Most predictions are good when the errors are small in magnitude, but once the errors get too big, they start growing superlinearly, and our predictions begin to systematically undershoot more and more. We think this is because our error propagation calculations assume that the errors stay small enough to never switch on neurons that are not in use by active circuits. But for many settings, the errors get large enough that some of these neurons do switch on, opening more channels for the noise to propagate and compound even further.
Figure 7: Fraction of active neurons, i.e., what fraction of the large network neurons are non-zero, for a network of width with circuits of width . Every circuit neuron is spread across MLP neurons. The dotted lines are what we'd theoretically expect if the errors stayed small enough not to overwhelm the noise suppression threshold for the inactive circuits. If this assumption is broken, the network starts to develop a "seizure"[9]: the noise starts to grow much faster as errors switch on more neurons causing more errors which switch on more neurons.[7]
The key variables that influence the error here are
Predictions are good for layer 1, . For larger , the errors get larger and the predictions get worse, with empirical errors often systematically larger than our theoretical predictions. For layer and onward, our theoretical predictions likewise undershoot more and more as the overall error gets larger. At layer and onward, the superlinear growth from compounding noise becomes very visible for large .
As our error calculations would predict, rotating vectors for inactive circuits have smaller errors than rotating vectors for active circuits, so long as the overall noise level stays small enough for the calculations to be applicable.[10]
In the case of On-indicators for active circuits, our error calculations predict that they just stay zero. This is indeed empirically the case for the early layers. But as we move deeper into the network and or get larger, a critical point seems to be reached. Past this point, noise on other parts of the network seems to become large enough to cause cascading failures, likely by switching the regime of ReLUs. Then, the active On-indicators become noisy as well.
The noise on the inactive On-indicator circuits follows the same pattern: Predictions are mostly good until the noise seems to grow large enough to break the assumptions of the calculation as the network starts to develop a "seizure". Here, this point seems to be reached at particularly early layers though. We think this is because seizures start with the On-indicators for inactive circuits and spread from there to other quantities.
Each individual circuit can be written as
where the superscript indicates the layer, ranging from to , and the subscript indexes the different circuits, ranging from to .
The circuits are supposed to compute/propagate the On-indicator, which controls whether the circuit should be active, and perform rotations on the rotating vector, .
The weight matrices for each circuit are parametrized as
and the biases as[11]
So, the circuit parameters here don't actually vary with the layer index at all.[12]
The initial input vector for each circuit takes the form
The On-indicator for circuit at layer can be read out from the circuit's hidden activations as
The rotating vector can be read out from the circuit's hidden activations as
If the circuit is active, the rotating vector will be initialized as some vector of length , and the On-indicator will be set to . If the circuit is inactive, the rotating vector and On-indicator will be initialized to , and stay .
A circuit's rotating vector is robust to some noise if the circuit is inactive. If a circuit is active, errors on its rotating vector will neither be removed nor amplified.
In any given forward pass, most circuits will be inactive.
From the previous post:
Assumptions
For this construction to work as intended, we need to assume that:
- Only circuits can be active on any given forward pass.
- Small circuits are robust to noise when inactive. I.e. a small deviation to the activation value of an inactive circuit applied in layer will not change the activation value of that circuit in layer .
- If a circuit is inactive, all of its neurons have activation value zero. I.e. if circuit is inactive.
- The entries of the weight matrices for different circuits in the same layer are uncorrelated with each other.
Assumption 1 is just the standard sparsity condition for superposition.
Assumption 2 is necessary, but if it is not true for some of the circuits we want to implement, we can make it true by modifying them slightly, in a way that doesn't change their functionality. How this works will not be covered in this post though.[13]
Assumptions 3 and 4 are not actually necessary for something similar to this construction to work, but without them the construction becomes more complicated. The details of this are also beyond the the scope of this post.
Assumption 1 holds for our conditional rotation circuits. Assumption 2 also holds, because the rotation is only applied if the Boolean On-indicator is large enough, and the step function for also doesn't propagate it to the next layer if is too small. Assumption 3 also holds. However, assumption 4 does not hold. So we're going to have to adjust the way we embed the circuits a little.
Further, while it wasn't explicitly listed as an assumption, the previous post assumed that the circuits don't have any biases. But our circuits here do (see Equation 5.1). So we'll need to make some more adjustments to account for that difference as well.
Assumptions 1-3 have stayed the same. We think the construction could likely be modified to make assumption 3 unnecessary,[15] but it simplifies the implementation a lot. The old assumption 4 is not required for the new construction. The new embedding algorithm can work without it (see next section). We think the embedding algorithm could also be modified to make the new assumption 4 unnecessary, but this may come at some cost to the total number of circuits we can embed.[16] The new assumption 5 was actually necessary all along, even in the previous construction. Without it, even a tiny amount of noise can eventually grow to overwhelm the signal if the circuit is deep enough.[17]
Note that assumption 2 is about noise reduction on inactive circuits whereas assumption 5 is about lack of noise amplification on active circuits. Assumption 2 says we should have
for some sufficiently small . Assumption 5 says we should have
Bracket notation is a type of vector notation that's excellent for thinking about outer products. It's also the vector notation usually used in quantum mechanics, and the term "superposition" is a quantum loan-word, so deploying it here seemed kabbalistically accurate.
It works like this:
We will use bracket notation to represent -dimensional vectors, like the embedding vectors in section 5.4.1
So far, we've used to represent -dimensional vectors, i.e., and , and two-dimensional vectors, i.e., . And we just defined for -dimensional vectors in the previous section 5.3.1
We will also want to mix and notation to make -dimensional vectors by stacking number of -dimensional vectors.
where is a -dimensional vector and is the section of from index to index . Each of , , , is -dimensional.
Any -matrix will act on as if it were a -dimensional vector, ignoring the dimension. For example:
Any will act on as if it were a , ignoring the -dimension. For example:
In this section, we describe a general method for embedding any set of small circuits that satisfy the assumptions 1 - 5 listed in Section 5.2.2
The activations of the large network at layer are computed from the activations at layer as
To embed the small circuit weights in the large network, we first split their weight matrices into two parts: the averaged weights across all circuits , and the differing weights :
The differing weights are embedded into the weights of the large network using embedding vectors . The averaged weights are embedded using an embedding matrix :
Embedding vectors
The embedding vectors are constructed using exactly the same algorithm described in the previous post. The information about them that matters most for the rest of this post:
Embedding matrix
is constructed as
Where the scalar is chosen to balance the weights of , such that
This ensures that interference terms between different circuits do not systematically aggregate along any particular direction.
Layer 0
As in the previous post, the weights at layer are treated differently. Instead of using the weight embedding described in the previous subsection, we just embed the circuit input vectors into the input vector of the network as
This allows us to set the first weight matrix to an identity,
yielding
We embed the circuit biases into the network biases by simply spreading them out over the full dimensions:
where is just the -vector with value 1 everywhere, i.e.:
If the embedding works, we can read out the approximate small circuit activations from the activations of the large network, up to some error tolerance. Using the un-embedding vector , we define the estimator for the small circuit activations for each layer as
Following equations (5.5) and (5.6), we define the estimators for the On-indicator
and the rotating vector
The readout errors are defined as
In this section we'll attempt to adapt the error calculations from the previous post to the new framework, and use the resulting expressions to try and predict the errors on our conditional rotation circuits. The main change in the construction we need to account for is the presence of the embedding matrix , which didn't exist in the construction the previous post used. We will also include some particular terms in the error calculations that the previous post ignored. These terms are sub-leading in the limit of large network width and number of circuits , but turned out to be non-negligible for some of the empirical data we tested the theory on.
The results here will only hold as long as the size of the errors stays much smaller than the size of the signal, which won't always be the case in reality. Past that point, errors may compound much more quickly over layers than the results here would predict.
We find that the terms for the error on inactive circuits carry over unchanged, but the error on active circuits is a little different.
Note that these calculations are not as careful as the ones in the previous post. In some spots, we just guess at approximations for certain terms and then check those guesses empirically.
Provided the overall error stays small enough to not change the signs of any neuron preactivations, the only source of error on inactive circuits is the embedding overlap error, which was estimated in the previous post as:
This is the error that results from signal in active circuits bleeding over into the inactive circuits they share neurons with. Since this error depends purely on how the activations of the circuits are embedded in the neurons of the network and that embedding hasn't changed from the previous construction, the formula for it remains the same.
In the previous post, we used the approximation to evaluate this error as[19]
Here, we will instead use the slightly more precise approximation
This results in a minor but noticeable improvement in the fitting of the data. Inserting the new approximation into the previous derivation yields
We'll assume for now that , so that Equation (5.20) becomes
From definitions (5.28) and (5.29), it follows that
Using the definition of the readout for circuit in layer , from Equation (5.30), the error can be written as
If we split the sum over circuits into active and passive circuits, we can simplify this expression:
since it is simple to evaluate, and only neglect the rest.[20]
This leaves us with
As in the last post, we will assume that all ReLU neurons used by the active circuits are turned on, since doing so will simplify the calculations and not overestimate the errors in expectation.[21] This gives
Inserting approximation (6.1) for on inactive circuits into this expression yields
Which evaluates to
This derivation assumed that , i.e., the circuits are uncorrelated with each other. However, we're just going to go ahead and guess that (6.16) is still a reasonable approximation in the case as well.[22] We'll see how true that is in practice when we compare these predictions to the error we measure in practice.
We can break this expression up into three components.
These three components have mean zero and are uncorrelated with each other, so we can calculate their mean squared errors separately.
Here we'll calculate the mean square of the error contribution , from Equation (6.18). Inserting the definition of into the definition of gives:
in the third line, we used the fact that
which follows from the definition of in Equation (5.23).
To estimate the expected norm of the error, we insert approximation (6.3) again:
Here we'll calculate the mean square of the error contribution from Equation (6.19). First, we further subdivide into two parts, one for terms involving and one for terms involving .
We can straightforwardly approximate the expectation of the square of . The steps for this are similar to those we used for in the previous subsection, see Equations (6.26) and (6.27).
The expectation of squared is a bit tougher. We didn't see an easy way to formally derive an approximation for it. So instead, we're going to be lazy: We'll just guess that it can be approximately described by the same formula as the one for , just with and swapped for . Then we'll just check that guess empirically on some example data.
To verify this equation, we can use the fact that can be factored out from Equation (6.30). So, if the result holds for one set of circuits, it should hold for any set of circuits. We just need to make up some values for , , and to check the equation on.[23] To keep it simple, we choose the following:
We find a good fit for small values of and small values of . For larger and we see deviations of up to 50% from the hypothesis. I.e, the approximation is not perfect, but it seems ok. At larger errors, we expect all our error calculations to break down anyway because the errors start getting large enough to switch on ReLU neurons that should be off, leading to cascading failures as the network has a "seizure".
The two terms also add up as we would expect.
Assumption 5 about our circuits was that active circuits should not amplify the size of errors relative to the size of the signal from layer to layer:
Since for our circuits, this implies that
For simplicity, we use the most pessimistic estimate:
Adding the three terms together and applying the resulting formula recursively to evaluate the propagated error from previous layers, we get
The first term has a prefactor of while the second has a prefactor of because the first term is the propagation error (from the previous post) which only appears from the second layer onward, while the second term is the embedding overlap error, which appears from the first layer onward.
The On-indicators are read out from the circuits' hidden activations as . Provided the overall error stays small enough to not overwhelm any ReLU neuron enough to flip it from off to on, the errors on and will be identical for active circuits, and thus cancel out exactly
For inactive circuits, equations (5.4) and (6.4) yield:
For active circuits
We can evaluate the errors on active rotation circuits using Equation (6.41). We want the error on though, not on . However, provided that , we have
So, since for active circuits, we can just restrict (6.41) to the last two vector indices of to calculate the error:
where we define
Inserting the definitions of for our circuits into this formula yields
and
where
and hence
Since we didn't use , the above holds equally well for . Additionally, since directions of and are uncorrelated, we can just add them up:
Inserting everything finally yields
For inactive circuits
Inserting equation (6.4) into
yields
Meaning they are only applied if a Boolean variable is true.
It was assumption 4, that the weights of different circuits are uncorrelated with each other.
We did try that as well, and might present those results in a future post. They look overall pretty similar. The average error on the circuits just scales with a better prefactor.
basically means " up to log factors". The full requirement is that must be small. This result lines up well with expressions for the number of logic gates that neural networks can implement in superposition that were derived here [1, 2 and 3]. It also makes intuitive sense from an information theory perspective. The bits of information specifying the parameters of the circuits have to fit into the parameters of the neural network.}
By "active error correction" we mean using additional layers for error clean-up. I.e. some of the large network layers are dedicated exclusively to noise reduction and do not carry out any circuit computations. One reason we don't do this is the cost of the extra layers, but another is that this only works if the circuit activations belong to a discrete set of allowed values. In our circuits this is the case for the on-indicator but not the rotated vector. So we would need to discretize the allowed angles to apply error correction to them. See also Appendix section D.4 here, for how to do error correction on Boolean variables in superposition.
Data for is generated by running the network times, once for every possible active circuit. Data for and is generated by running the network times, with randomly generated pairs and triples of active circuits. The sample sizes for were chosen to be large enough that the worst-case absolute error did not seem to change very much any more when the sample size was increased further. It was still fluctuating a little though, so, maybe treat those points with some caution.
This number was chosen because it is the largest possible for and in our prime-factorization-based embedding algorithm.
Thanks to Gurkenglas for coming up with this evocative term.
If the noise level gets large enough to significantly overwhelm the inactive circuits' error suppression, then we are outside the applicability of the theory.
has no subscript because the embedding scheme (described in Section 5.4) requires all small circuits to have the same biases, in any given layer. If this is not the case the circuits need to be re-scaled so that the biases match.
In the general formalism, the small circuit weights and biases can be different for each layer . However, in the specific example we are implementing here, and do not depend on .
The "future post" mentioned in this quote from the last post is this post. The trick to modify a circuit such that assumption two holds, is to add an on-indicator. We haven't actualy described this in full generality yet, but you can hopfully see how a similar trick can be used to modify circutis other than rotations.
You might notice that for both assumption 2 and 3 to hold simultaniusly, each small network needs a negative bias.
We haven't actually tried that out in practice though.
As long as all the biases have the same sign, we can rescale the weights and biases of different circuits prior to embedding so that assumption 4 is satisfied. If some biases don't have the same sign, we can increase , and partition the circuits into different parts of such that the biases have the same sign in each part. This potentially comes at the cost of being able to fit in fewer circuits, because it makes larger. will never have to be larger than the initial though.
Assumption 5 doesn't need to strictly hold for circuits of finite depth. It also doesn't strictly need to hold at any layer individually. The circuit could also have active noise 'correction' every few layers such that assumption 5 holds on average.
or 'dot product'
This formula assumes that circuit activation vectors are all of similar magnitude.
We can't neglect this term for , because will get canceled by the , so it isn't sub-leading.
This assumption holds for the conditional rotation circuits.
Intuitively it doesn't seem like should ultimately change much in the derivation, but the calculations to show this were just getting kind of cumbersome. So we figured we might as well just guess and check.
What values we pick shouldn't matter so long as satisfies .