Tl;dr: We generalize the mathematical framework for computation in superposition from compressing many boolean logic gates into a neural network, to compressing many small neural networks into a larger neural network. The number of small networks we can fit into the large network depends on the small networks' total parameter count, not their neuron count.

Work done at Apollo Research. The bottom half of this post is just maths that you do not need to read to get the gist.

Introduction

Background

Anthropic's toy model of superposition shows how to compress many sparsely activating variables into a low dimensional vector space and then read them out again. But it doesn't show how to carry out computations on the compressed variables in their native format. The mathematical framework for computation in superposition makes a first stab at closing that gap. It shows how to compute boolean circuits in superposition.

What we do

We show how a network can perform any computations whatsoever in superposition. Specifically, we show how  small residual neural networks, each with  parameters that perform arbitrary tasks can be compressed into a single larger residual network that performs all  tasks, provided that the large network is only evaluated on sparse combinations of tasks — any particular forward pass only asks for  tasks to be carried out. In the limit of  going to infinity, this larger network will require  parameters[1]

Crucially, this means that the total number of small networks the larger network can implement scales approximately linearly with the number of weights in the network, not the number of neurons, as would be the case without computation in superposition. For example, if each small network uses  neurons per MLP layer and  dimensions in the residual stream, a large network with  neurons per MLP connected to a -dimensional residual stream could implement about  small networks, not just 
. Qualitatively speaking, our construction works using same basic trick as the one for boolean circuits in superposition. We just generalize it from boolean AND gates to any operations the neural network could implement.

Generalising to circuits

While our derivation here assumes  networks carrying out unrelated tasks in parallel, nothing in the construction stops us from instead chaining the small networks in series, with later small networks taking the outputs of earlier small networks as their inputs. Therefore, the construction in this post can be thought of as a framework for representing arbitrary circuits in superposition.

Some very tentative implications, maybe?

Real neural networks probably don’t work exactly the way this construction does. It's made to be easy for us to prove things about it, not to be efficient in real life. The finite width of real networks might make other constructions better. We're also not dealing with potential correlations between the activations of different circuits, which might change the optimal setup even more. And ultimately, we don't actually know whether the structure of real-world datasets is sparse in the right way to incentivise learning sparsely activating circuits.

Neverthless, there may be some useful takeaways about  real networks, so long as we don't forget that they come with a heavy pinch of salt:

  • There is no superposition in parameter space: In this construction, we cannot compress more small networks into the large network than the large network has parameters. So, while a network can have more features than the dimension of its activation spaces, it can't implement more distinct operations[2] than the dimension of its parameter space[3].
  • Circuits don't have to follow the layer structure: This construction lines up the layers of the small networks with the layers of the large network, but that's just for our convenience. So long as the large network has more layers than the small networks, we can implement things all over the place. A single neuron in a small network could correspond to neurons across a range of layers in the big network. Thus, if somebody is looking at the residual stream activations in a layer of the big network, they might see a lot of half-computed nonsense that's hard to make sense of. You could call this cross-layer superposition.
  • Computation in superposition doesn't need one-dimensional 'features': Our construction doesn't assume that the  small networks internally work using one-dimensional variables represented as directions in activation space. Circuits may be embedded in the larger network as sparsely activating subspaces in the neurons and the residual stream, but within those spaces, their own representations don't have to be sparse or linear.
  • The total parameter vector could be decomposable into a sum of the parameter vectors dedicated to each small network: At least in this construction, the parameter vector of the large network  is a sum of  vectors  parametrizing the individual small networks: . If real networks share this property, then with the right optimization procedure, it might be possible to recover the individual small networks  from  by looking at the network's loss landscape. Apollo Research is trying out a way to do this at the moment.

Future work

  • Other architectures We think this construction can be straightforwardly extended to transformers and CNNs, without significantly changing any takeaways. We are investigating the error bounds for attention blocks at the moment.
  • Tracr extension Theoretically, this framework could allow people to create superposed circuits by hand. We'd be excited about someone writing a nore sophisticated version of Tracr based on these constructions, which could be used for building a more realistic interpretability benchmark akin to InterpBench. Note that the error bounds in this post are all formulated for the large network width limit — there is still some work to do to make this practical.
  • Training dynamics This post makes claims about the expressivity of neural networks, but in real life, the structures learned by neural networks depend greatly on the inductive biases of their training. We would like to build on this framework to explore if training actually incentivises the learning of sparse circuits. We have some ideas on this front, based on attempting to unify SLT ideas with the idea of the low-hanging fruit prior.

The Construction

Suppose we have  small neural networks. For simplicity we will assume that each small network consists of  layers, with  neurons in each layer with a fixed elementwise nonlinearity, and a fixed residual stream width . We require that these small networks are at least somewhat robust to noise: there is some magnitude of random noise  that we can apply to all the preactivations of any of the small networks' neurons without changing downstream layer activations by more than some small .[4]

Then we can create a large network that is also  layers deep, with a residual stream width  neurons in each layer and the same activation functions, which can leverage superposition to compute the outputs of all $T$ neural networks in parallel. 
This works even for  and , provided that only  small neural networks are being passed a non-zero input vector on most forward passes. This large network will require on the order of  parameters in total[5].

The core idea behind this construction is similar to that for computing many ANDs of binary inputs in superposition. There may be many other constructions that would also work, but we think that in the limit of very wide neural networks, all constructions would perform more or less the same, and yield the same fundamental limits for how many small networks can be superposed into a network with  parameters[6]. As with all constructions involving superposition, the key to the construction working out is in managing the size of the interference between separate small networks, and making sure that it does not become larger than the size of the signal — the correct output of each small network. In this construction, there are two sources of interference:

Read-in interference

Our  small networks have a combined   residual stream dimensions. So, activation vectors of different small networks in the large residual stream cannot be completely orthogonal. This means that when a particular small network is passed an input of  but other small networks are passed nonzero inputs, the value of the inputs that are read in by the weights that implement the first small network won't be exactly zero. In our construction, this read-in interference is what ends up dominating the constraints on how many small networks we can compute in a single large network.

At a high level, we manage read-in interference by making the residual stream width  larger so the overlap between small networks is smaller, and making the MLP width  larger so the read-in interference can be spread across more neurons.


Read-out interference

Our  small networks have a combined  neurons per layer. Naively, we could randomly assign every neuron in every small network to one neuron in the big network. But then, if two small networks that happened to share a neuron activated at the same time, that neuron would get conflicting inputs and misfire. So we could only carry out one of the  tasks at a time.

To make the small networks robust to these misfires, we introduce redundancy into the big network, representing each neuron in the small network with many neurons in the big network. This means that each neuron in the big network is assigned to even more small networks than if there was no redundancy, but this cost is worth it: we can now recover the value of any activation of any small network by averaging over the values of every neuron in the large neuron that represents it. If few enough small networks are active at once, then almost all neurons in the large network assigned to any particular small network's neuron will take on the correct value for that neuron, almost all of the time, and in the limit of , the difference between the value of a small network's neuron and the average of all the neurons in the large network that compute that small network will go to zero. 

Maths

If you don't care about technical details, you can safely skip this section.

Let the input to the -th small network be denoted by  and the activation vector of small network  in layer  for input  by  or simply 
Similarly, denote the activation vector for the large network in layer  by 
We also define a set of random matrices with orthonormal rows :

with  satisfying . Since the matrices are projection matrices to random -dimensional subspaces of , their columns satisfy . These matrices define projections from the residual streams of each small network into a random subspace of the larger residual stream. What we want to prove is that if the number of  that are nonzero is , then for all , there exists terms  satisfying , such that:

.

We'll (sort-of) prove this using induction. 

Embedding Matrix

The base case for the induction is just the embedding in layer . The input to the large network is the concatenated vector . The embedding matrix[7]  is constructed by directly projecting each  into the residual stream using , which we can do by stacking the projection matrices next to each other:

.

Then, the residual stream activation vector at layer zero 

 is equal to  as required.

Other layers

We'd now like to assume that  is satified in layer , and demonstrate that it is satisfied in layer . To do so, we need to work out what the matrices  should be.

Reading from the residual stream

To start, we need a way to compute the outputs of  all at once with the larger matrix . If we had  we could do this by making  block diagonal, but we are looking for a construction with . To make progress, we start by noting that 

   ,

where we have used that . We want the read-in interference 

introduced to network  in layer  to be sufficiently small, staying below the  noise level we assume the subnetworks to be robust to. The justification for  being small will be based based on the fact that for  is approximately a matrix of gaussians with variance . Details are in Section Read-in interference.

Writing to the neurons

We can't just connect the outputs of this multiplication to neurons in layer  of the large network even if the interference is small. This is because  so we'd have to share neurons between many circuits and we wouldn't be able to tell if a neuron  fires due to circuit  activating, or some other circuit that connects to that neuron activating instead. Instead, we need to introduce some redundancy to the representations of the activations of each small network[8]. We do this by multiplying by a distributing matrix . This matrix is defined as follows:

  1. Start with the first  rows (each row is a vector in ), which connect to small network . These are the rows of  which determine which neurons are involved in computing the th layer of the first small network.
  2. Then, pick a random partition of the neurons of the th layer of the big network into `neuron sets' of size . There are M/m many sets.
  3. Let . For each neuron set, consider the set of submatrices of  which consist of only the first  rows, and only the columns in that set, so each submatrix has shape . For each submatrix, with probability  set it equal to a random permutation of the identity matrix, and with probability , set it equal to the zero matrix.
  4. Repeat for each set of  rows of , corresponding to each small network. Each time, pick a different random partition of the neurons into neuron sets.

For the -th small network, the neurons that are in sets which are assigned a permutation matrix are called connected to that small network, and the neurons that are in sets assigned the zero matrix are called unconnected. We denote the set of all sets of neurons in the large network that are connected to the th small network in layer  by  (a subset of the powerset of ), and the set of all neurons in the large network that are connected to the th neuron of the th small network in layer  by . Every small network will on average connect its weights  to  sets of  neurons in the big network. So, we set

 .

Writing back to the residual stream

To write back to the residual stream from the neurons, first we can recover the value of the activations of each small network by averaging all the neurons in the large network that are connected to that small network neuron. We do this by multiplying the activations of the big network with :

  .

Then we can apply each  to recover , and then we can embed these activations back into the residual stream using :

 

If  is small enough (which requires  to be small as well, then we are done, and  will have the correct form.

Error analysis

Let  be upper bounds on the L2 norm of the small networks' activations in the residual stream, and operator norm of their MLP input matrices, respectively:

In the analysis below, we find that the L2 size of the total interference added to a subnet in an MLP layer will be 

.

For this noise to stay below the  we assumed the small networks to be robust to at every layer, our large network needs at least

 

parameters in total. Any less than that, and the inteference will begin to overwhelm the signal. Assuming the noise  isn't larger than the maximum size of the small network's neuron activations,  we'll have . So we need  parameters in total.

Read-in interference

In this construction, we find that our total error term in dominated by read-in interference.

The noise from an activation vector  of a circuit  being multiplied by weight matrix  of a different circuit  will be 

The entries of the matrix   will have approximate size . Since the  entries of a row of  are randomly distributed, the entries of  will then have average size . So, the noise  from activation  of small network  being partially projected into preactivations of neurons in small network  will be on the order of 

.

On average, each neuron has  weight rows of small networks connecting to it. Using , if there are  circuits active at a given time, the total read-in interference  on the preactivation on any one neuron in any small network  will be bounded by 

because the noise sources are independent. This noise dominates the total error term.


Read-out interference

In our construction, we find that read-out interference  from multiple circuits using the same neuron is subdominant and vanishes in the limit of large networks. For the read-out of a small network from the MLP of the large network to become inaccurate, some fraction of the  neurons playing the role of one neuron in the original small network have to all `misfire', activating when they shouldn't, or with incorrect magnitude even when they do fire. Since we assumed that our activation functions are Lipschitz continuous, we can bound any `misfire' to be smaller than some bound .

We'll assume that there is some critical fraction  which is the maximum number of misfires we can tolerate, which is dependent on the error tolerance of our small networks:  misfires would give us an error  on the read-out of neuron  in small network , which we require to be smaller than the maximum error tolerance of the small networks .

One neuron: Consider a specific neuron  in small network . This neuron is assigned a set  of size approximately  of neurons to compute it in the large network.

k=1: Suppose that only small network  is active on the current forward pass. The chance of any circuit  connecting to a given neuron is . So, if , the probability that there are  misfirings in the set  will follow a binomial distribution: 

.

The last factor is approximately equal to  and can be ignored. 
k>1: Suppose there are  small networks active at once. Each neuron in  can be used in multiple active networks. We can imagine a matrix with  rows and  columns, with a  in the  position if the th neuron in  is connected to the th active small network, and a zero otherwise. The entries of this matrix are i.i.d Bernoulli random variables with probability , and the number of nonzero entries in this matrix is the total number of misfirings in . Again assuming , the probability  has  misfirings will be: 

Using Stirling's formula[9], we can write this as: 

.

We can approximate  as a decaying geometric series in , with initial value  and ratio 

Therefore, we have

 .

 

One forward pass: We have  sets of neurons . We want the chance of more than  misfirings for any of them on a forward pass to be vanishingly small for all  in the large width limit. That is, we want to scale  with the number of small networks , the size of small networks , and the number of active small networks  such that:

This condition is satisfied for any  so long as:

  1. The neuron count of the large network grows as some fractional power of the neuron counts of the small networks combined: .
  2. The combined number of active neurons in all the small networks on any one forward pass is small compared to the neuron count of the large network: .

The read-in error already imposes , so the former condition is not an additional constraint, except in that it precludes making the residual stream  exponentially wider than the MLP . The latter condition is fulfilled if the small networks activate sparsely.

So, in the large width limit  will vanish. Thus, the total error is dominated by .

Acknowledgements

Thanks to Dan Braun, Stefan Heimersheim, Lee Sharkey, and Bilal Chughtai for lots of discussions that shaped our thinking about this idea. Thanks also to Kaarel Hanni, Dmitry Vaintrob and Lawrence Chan for previous work that this idea builds on heavily, and for helping shape our thinking about this kind of thing.

  1. ^

     basically means ' up to log factors'.

  2. ^

    Put differently, we can't have an overcomplete basis of task vectors.

  3. ^

    This limit is already suggested by information theory: Every operation we want the network to implement takes some minimum number of bits in its parameters to specify. So, in general, the minimum description length of the large network in bits can't be smaller than the minimum description lengths of the small networks summed together.

  4. ^

    The more imprecision we're willing to tolerate in the final result, the larger  will be. If small networks vary in how noise robust they are, we pick the  of the least robust one to be conservative.

  5. ^

    These simplifications primarily serve to avoid obfuscating the ideas in the construction. We are pretty confident that the derivations go through if you allow the number of neurons, residual stream width, and number of layers per small network to vary. That is, suppose we are given a set of neural networks indexed by . For the -th network, denote the number of neurons per layer as , residual stream width , and number of layers . Then, there exists a large residual neural network with depth , number of neurons per layer , and residual stream width  which satisfies, and , which can compute the outputs of all  circuits in parallel by leveraging superposition.

  6. ^

    We think some additional tinkering might remove the log term, and constant prefactors could likely be improved, but we doubt anything will break the limit . We can't specify more operations than we have bits to specify them in.

  7. ^

    Using the convention of left multiplication by matrices.

  8. ^

    This is essentially the same idea that is referred to as superpositional codes in this essay.

  9. ^

    Which applies because , and the expected number of misfirings is .

New Comment