Tl;dr: We generalize the mathematical framework forcomputation in superposition from compressing many boolean logic gates into a neural network, to compressing many small neural networks into a larger neural network. The number of small networks we can fit into the large network depends on the small networks' total parameter count, not their neuron count.
Work done at Apollo Research. The bottom half of this post is just maths that you do not need to read to get the gist.
Introduction
Background
Anthropic's toy model of superposition shows how to compress many sparsely activating variables into a low dimensional vector space and then read them out again. But it doesn't show how to carry out computations on the compressed variables in their native format. The mathematical framework for computation in superposition makes a first stab at closing that gap. It shows how to compute boolean circuits in superposition.
What we do
We show how a network can perform any computations whatsoever in superposition. Specifically, we show how T small residual neural networks, each with n parameters that perform arbitrary tasks can be compressed into a single larger residual network that performs all T tasks, provided that the large network is only evaluated on sparse combinations of tasks — any particular forward pass only asks for k≪T tasks to be carried out. In the limit of T,n going to infinity, this larger network will require N=˜O(kTn) parameters[1].
Crucially, this means that the total number of small networks the larger network can implement scales approximately linearly with the number of weights in the network, not the number of neurons, as would be the case without computation in superposition. For example, if each small network uses m neurons per MLP layer and d dimensions in the residual stream, a large network with M neurons per MLP connected to a D-dimensional residual stream could implement about ˜O(MDkmd) small networks, not just ˜O(Mm). Qualitatively speaking, our construction works using same basic trick as the one for boolean circuits in superposition. We just generalize it from boolean AND gates to any operations the neural network could implement.
Generalising to circuits
While our derivation here assumes T networks carrying out unrelated tasks in parallel, nothing in the construction stops us from instead chaining the small networks in series, with later small networks taking the outputs of earlier small networks as their inputs. Therefore, the construction in this post can be thought of as a framework for representing arbitrary circuits in superposition.
Some very tentative implications, maybe?
Real neural networks probably don’t work exactly the way this construction does. It's made to be easy for us to prove things about it, not to be efficient in real life. The finite width of real networks might make other constructions better. We're also not dealing with potential correlations between the activations of different circuits, which might change the optimal setup even more. And ultimately, we don't actually know whether the structure of real-world datasets is sparse in the right way to incentivise learning sparsely activating circuits.
Neverthless, there may be some useful takeaways about real networks, so long as we don't forget that they come with a heavy pinch of salt:
There is no superposition in parameter space: In this construction, we cannot compress more small networks into the large network than the large network has parameters. So, while a network can have more features than the dimension of its activation spaces, it can't implement more distinct operations[2] than the dimension of its parameter space[3].
Circuits don't have to follow the layer structure: This construction lines up the layers of the small networks with the layers of the large network, but that's just for our convenience. So long as the large network has more layers than the small networks, we can implement things all over the place. A single neuron in a small network could correspond to neurons across a range of layers in the big network. Thus, if somebody is looking at the residual stream activations in a layer of the big network, they might see a lot of half-computed nonsense that's hard to make sense of. You could call this cross-layer superposition.
Computation in superposition doesn't need one-dimensional 'features': Our construction doesn't assume that the T small networks internally work using one-dimensional variables represented as directions in activation space. Circuits may be embedded in the larger network as sparsely activating subspaces in the neurons and the residual stream, but within those spaces, their own representations don't have to be sparse or linear.
The total parameter vector could be decomposable into a sum of the parameter vectors dedicated to each small network: At least in this construction, the parameter vector of the large network θ is a sum of T vectors θt parametrizing the individual small networks: θ=∑Ti=1θt. If real networks share this property, then with the right optimization procedure, it might be possible to recover the individual small networks θt from θ by looking at the network's loss landscape. Apollo Research is trying out a way to do this at the moment.
Future work
Other architectures We think this construction can be straightforwardly extended to transformers and CNNs, without significantly changing any takeaways. We are investigating the error bounds for attention blocks at the moment.
Tracr extension Theoretically, this framework could allow people to create superposed circuits by hand. We'd be excited about someone writing a nore sophisticated version of Tracr based on these constructions, which could be used for building a more realistic interpretability benchmark akin to InterpBench. Note that the error bounds in this post are all formulated for the large network width limit — there is still some work to do to make this practical.
Training dynamics This post makes claims about the expressivity of neural networks, but in real life, the structures learned by neural networks depend greatly on the inductive biases of their training. We would like to build on this framework to explore if training actually incentivises the learning of sparse circuits. We have some ideas on this front, based on attempting to unify SLT ideas with the idea of the low-hanging fruit prior.
The Construction
Suppose we have T small neural networks. For simplicity we will assume that each small network consists of L layers, with m neurons in each layer with a fixed elementwise nonlinearity, and a fixed residual stream width d. We require that these small networks are at least somewhat robust to noise: there is some magnitude of random noise ϵmax>0 that we can apply to all the preactivations of any of the small networks' neurons without changing downstream layer activations by more than some small δ.[4]
Then we can create a large network that is also L layers deep, with a residual stream width D≫d, M≫m neurons in each layer and the same activation functions, which can leverage superposition to compute the outputs of all $T$ neural networks in parallel. This works even for D≪Td and M≪Tm, provided that only k≪T small neural networks are being passed a non-zero input vector on most forward passes. This large network will require on the order of N=˜O(kTn) parameters in total[5].
The core idea behind this construction is similar to that for computing many ANDs of binary inputs in superposition. There may be many other constructions that would also work, but we think that in the limit of very wide neural networks, all constructions would perform more or less the same, and yield the same fundamental limits for how many small networks can be superposed into a network with N parameters[6]. As with all constructions involving superposition, the key to the construction working out is in managing the size of the interference between separate small networks, and making sure that it does not become larger than the size of the signal — the correct output of each small network. In this construction, there are two sources of interference:
Read-in interference
Our T small networks have a combined Td≫D residual stream dimensions. So, activation vectors of different small networks in the large residual stream cannot be completely orthogonal. This means that when a particular small network is passed an input of 0 but other small networks are passed nonzero inputs, the value of the inputs that are read in by the weights that implement the first small network won't be exactly zero. In our construction, this read-in interference is what ends up dominating the constraints on how many small networks we can compute in a single large network.
At a high level, we manage read-in interference by making the residual stream width D larger so the overlap between small networks is smaller, and making the MLP width M larger so the read-in interference can be spread across more neurons.
Read-out interference
Our T small networks have a combined mT≫M neurons per layer. Naively, we could randomly assign every neuron in every small network to one neuron in the big network. But then, if two small networks that happened to share a neuron activated at the same time, that neuron would get conflicting inputs and misfire. So we could only carry out one of the T tasks at a time.
To make the small networks robust to these misfires, we introduce redundancy into the big network, representing each neuron in the small network with many neurons in the big network. This means that each neuron in the big network is assigned to even more small networks than if there was no redundancy, but this cost is worth it: we can now recover the value of any activation of any small network by averaging over the values of every neuron in the large neuron that represents it. If few enough small networks are active at once, then almost all neurons in the large network assigned to any particular small network's neuron will take on the correct value for that neuron, almost all of the time, and in the limit of M→∞, the difference between the value of a small network's neuron and the average of all the neurons in the large network that compute that small network will go to zero.
Maths
If you don't care about technical details, you can safely skip this section.
Let the input to the t-th small network be denoted by xt∈Rd and the activation vector of small network t in layer l for input xt by alt(xt) or simply alt. Similarly, denote the activation vector for the large network in layer l by Al. We also define a set of random matrices with orthonormal rows {Et∈RD×d}:
Et=(e1t⋯edt↓↓)
with eit∈RD satisfying eit⋅ejt=δij. Since the matrices are projection matrices to random d-dimensional subspaces of RD, their columns satisfy Et≠s(eit⋅ejs)2=O(1/D). These matrices define projections from the residual streams of each small network into a random subspace of the larger residual stream. What we want to prove is that if the number of xt that are nonzero is k≪T, then for all l=1,…,L, there exists terms δl satisfying ||δl||2≪||∑Tt=1Etalt||2, such that:
Al=∑Tt=1Etalt+δl.
We'll (sort-of) prove this using induction.
Embedding Matrix
The base case for the induction is just the embedding in layer 0. The input to the large network is the concatenated vector X=(x1,x2,…,xT)∈RTd. The embedding matrix[7]WE∈RD×Td is constructed by directly projecting each xt into the residual stream using Et, which we can do by stacking the projection matrices next to each other:
WE=(E1⋯ET).
Then, the residual stream activation vector at layer zero
A0:=WEX is equal to A0=∑TT=1Etxt as required.
Other layers
We'd now like to assume that Al=∑Tt=1Etalt+δl is satified in layer l−1, and demonstrate that it is satisfied in layer l. To do so, we need to work out what the matrices Wl,in,Wl,out should be.
Reading from the residual stream
To start, we need a way to compute the outputs of Wl,in1,…,Wl,inT∈Rd×n all at once with the larger matrix Wl,in∈RD×N. If we had D≥Td,N≥Tn we could do this by making Wl,in block diagonal, but we are looking for a construction with D≪Td,N≪Tn. To make progress, we start by noting that
where we have used that Et⊺Et=Id(d). We want the read-in interference
ϵl,int:=∑s≠tEt⊺Esal−1s
introduced to network t in layer l to be sufficiently small, staying below the ϵmax noise level we assume the subnetworks to be robust to. The justification for ϵl,int being small will be based based on the fact that for t≠s,Et⊺Es is approximately a matrix of gaussians with variance 1/D. Details are in Section Read-in interference.
Writing to the neurons
We can't just connect the outputs of this multiplication to neurons in layer l of the large network even if the interference is small. This is because mT≫M so we'd have to share neurons between many circuits and we wouldn't be able to tell if a neuron i fires due to circuit t activating, or some other circuit that connects to that neuron activating instead. Instead, we need to introduce some redundancy to the representations of the activations of each small network[8]. We do this by multiplying by a distributing matrixVl∈RmT×M. This matrix is defined as follows:
Start with the first m rows (each row is a vector in RM), which connect to small network 1. These are the rows of Vl which determine which neurons are involved in computing the lth layer of the first small network.
Then, pick a random partition of the neurons of the lth layer of the big network into `neuron sets' of size m. There are M/m many sets.
Let p=mMlogM. For each neuron set, consider the set of submatrices of Vl which consist of only the first m rows, and only the columns in that set, so each submatrix has shape (m×m). For each submatrix, with probability p set it equal to a random permutation of the identity matrix, and with probability 1−p, set it equal to the zero matrix.
Repeat for each set of m rows of Vl, corresponding to each small network. Each time, pick a different random partition of the neurons into neuron sets.
For the t-th small network, the neurons that are in sets which are assigned a permutation matrix are called connected to that small network, and the neurons that are in sets assigned the zero matrix are called unconnected. We denote the set of all sets of neurons in the large network that are connected to the tth small network in layer l by Slt (a subset of the powerset of {1,…,M}), and the set of all neurons in the large network that are connected to the ith neuron of the tth small network in layer l by Slt,i. Every small network will on average connect its weights Wl,int to r=E[|Slt|]=logM sets of m neurons in the big network. So, we set
Wl,in=∑tVltWl,intEt⊺.
Writing back to the residual stream
To write back to the residual stream from the neurons, first we can recover the value of the activations of each small network by averaging all the neurons in the large network that are connected to that small network neuron. We do this by multiplying the activations of the big network with 1|Slt|(Vlt)⊺:
Then we can apply each Wl,outt to recover al+1t, and then we can embed these activations back into the residual stream using Et:
Wl,out=∑t1|Slt|EtWl,outt(Vlt)⊺.
If ϵl,outt is small enough (which requires ϵl,in to be small as well, then we are done, and Al will have the correct form.
Error analysis
Let a,w∈R+ be upper bounds on the L2 norm of the small networks' activations in the residual stream, and operator norm of their MLP input matrices, respectively:
In the analysis below, we find that the L2 size of the total interference added to a subnet in an MLP layer will be
ϵ=O(wa√kTmdMDlogM).
For this noise to stay below the ϵmax we assumed the small networks to be robust to at every layer, our large network needs at least
N=O(w2a2ϵ2maxkTnlogM)
parameters in total. Any less than that, and the inteference will begin to overwhelm the signal. Assuming the noise ϵmax isn't larger than the maximum size of the small network's neuron activations, we'll have w2a2ϵ2max<1. So we need N=˜O(kTn) parameters in total.
Read-in interference
In this construction, we find that our total error term in dominated by read-in interference.
The noise from an activation vector als of a circuit s being multiplied by weight matrix Wint of a different circuit t will be
ϵl,int,s=WintEt⊺Esals.
The entries of the matrix Et⊺Es∈Rd×d will have approximate size O(1√D). Since the d entries of a row of Et⊺Es are randomly distributed, the entries of Et⊺Esals will then have average size O(√dD). So, the noise ϵl,int,s from activation als of small network s being partially projected into preactivations of neurons in small network t will be on the order of
ϵl,int,s=O(√dD||Win,lt||op||als||2).
On average, each neuron has Tp=TmMlogM weight rows of small networks connecting to it. Using ||als||≤a,||Win,lt||≤w, if there are k circuits active at a given time, the total read-in interference ϵl,int=∑s≠tϵl,int,s on the preactivation on any one neuron in any small network t will be bounded by
ϵl,int=O(wa√kTmdMDlogM)
because the noise sources are independent. This noise dominates the total error term.
Read-out interference
In our construction, we find that read-out interferenceϵl,outt from multiple circuits using the same neuron is subdominant and vanishes in the limit of large networks. For the read-out of a small network from the MLP of the large network to become inaccurate, some fraction of the logM neurons playing the role of one neuron in the original small network have to all `misfire', activating when they shouldn't, or with incorrect magnitude even when they do fire. Since we assumed that our activation functions are Lipschitz continuous, we can bound any `misfire' to be smaller than some bound K∈R.
We'll assume that there is some critical fraction 0<c<1 which is the maximum number of misfires we can tolerate, which is dependent on the error tolerance of our small networks: clog(T) misfires would give us an error ϵl,outt,i≤clog(T)K on the read-out of neuron i in small network t, which we require to be smaller than the maximum error tolerance of the small networks ϵmax.
One neuron: Consider a specific neuron i in small network s. This neuron is assigned a set Sls,i of size approximately logM of neurons to compute it in the large network.
k=1: Suppose that only small network t≠s is active on the current forward pass. The chance of any circuit t connecting to a given neuron is p=mMlog(M). So, if c≪1, the probability that there are clogM misfirings in the set Sls,i will follow a binomial distribution:
P(x misfirings in Sls,i)=(logMclogM)(mlogMM)clogM(1−mlogMM)(1−c)logM.
The last factor is approximately equal to 1 and can be ignored. k>1: Suppose there are k>1 small networks active at once. Each neuron in Sls,i can be used in multiple active networks. We can imagine a matrix with k rows and logM columns, with a 1 in the (i,j) position if the ith neuron in Sls,i is connected to the jth active small network, and a zero otherwise. The entries of this matrix are i.i.d Bernoulli random variables with probability p, and the number of nonzero entries in this matrix is the total number of misfirings in Sls,i. Again assuming c≪1, the probability Sls,i has clogM misfirings will be:
P(x misfirings in Sls,i)=(klogMclogM)(mlogMM)clogM.
Using Stirling's formula[9], we can write this as:
P(clogM misfirings in Sls,i)<(kmelogMMc)clogM.
We can approximate P(clogM+x misfirings in Sls,i) as a decaying geometric series in x, with initial value P0=P(clogM misfirings in Sls,i) and ratio r=Px+1Px≃klogMpclogM=kmlogMcM≪1.
Therefore, we have
P(at least clogM misfirings in Sls,i)=P01−r<(kmelogMMc)clogM.
One forward pass: We have Tm sets of neurons Sls,i. We want the chance of more than clogM misfirings for any of them on a forward pass to be vanishingly small for all c in the large width limit. That is, we want to scale M with the number of small networks T, the size of small networks m, and the number of active small networks k such that:
limM,T→∞Tm(eckmlogMM)clogM=0.
This condition is satisfied for any c≪1 so long as:
The neuron count of the large network grows as some fractional power of the neuron counts of the small networks combined: Tm=poly(M).
The combined number of active neurons in all the small networks on any one forward pass is small compared to the neuron count of the large network: km=o(M).
The read-in error already imposes MD=O(Tmkd), so the former condition is not an additional constraint, except in that it precludes making the residual stream exponentially wider than the MLP M. The latter condition is fulfilled if the small networks activate sparsely.
So, in the large width limit M→∞, ϵl,outt will vanish. Thus, the total error is dominated by ϵl,int.
Acknowledgements
Thanks to Dan Braun, Stefan Heimersheim, Lee Sharkey, and Bilal Chughtai for lots of discussions that shaped our thinking about this idea. Thanks also to Kaarel Hanni, Dmitry Vaintrob and Lawrence Chan for previous work that this idea builds on heavily, and for helping shape our thinking about this kind of thing.
This limit is already suggested by information theory: Every operation we want the network to implement takes some minimum number of bits in its parameters to specify. So, in general, the minimum description length of the large network in bits can't be smaller than the minimum description lengths of the small networks summed together.
The more imprecision we're willing to tolerate in the final result, the larger ϵmax will be. If small networks vary in how noise robust they are, we pick the ϵmax of the least robust one to be conservative.
These simplifications primarily serve to avoid obfuscating the ideas in the construction. We are pretty confident that the derivations go through if you allow the number of neurons, residual stream width, and number of layers per small network to vary. That is, suppose we are given a set of neural networks indexed by t=1,…T. For the t-th network, denote the number of neurons per layer as mt, residual stream width dt, and number of layers ℓt. Then, there exists a large residual neural network with depth L, number of neurons per layer M, and residual stream width D which satisfies∀t∈{1,…,T}:mt≪M,dt≪D,ℓt≤L, and ∑tmt≫M,∑tdt≫D, which can compute the outputs of all T circuits in parallel by leveraging superposition.
We think some additional tinkering might remove the log term, and constant prefactors could likely be improved, but we doubt anything will break the limit N≥∑Ttnt. We can't specify more operations than we have bits to specify them in.
Tl;dr: We generalize the mathematical framework for computation in superposition from compressing many boolean logic gates into a neural network, to compressing many small neural networks into a larger neural network. The number of small networks we can fit into the large network depends on the small networks' total parameter count, not their neuron count.
Work done at Apollo Research. The bottom half of this post is just maths that you do not need to read to get the gist.
Introduction
Background
Anthropic's toy model of superposition shows how to compress many sparsely activating variables into a low dimensional vector space and then read them out again. But it doesn't show how to carry out computations on the compressed variables in their native format. The mathematical framework for computation in superposition makes a first stab at closing that gap. It shows how to compute boolean circuits in superposition.
What we do
We show how a network can perform any computations whatsoever in superposition. Specifically, we show how T small residual neural networks, each with n parameters that perform arbitrary tasks can be compressed into a single larger residual network that performs all T tasks, provided that the large network is only evaluated on sparse combinations of tasks — any particular forward pass only asks for k≪T tasks to be carried out. In the limit of T,n going to infinity, this larger network will require N=˜O(kTn) parameters[1].
Crucially, this means that the total number of small networks the larger network can implement scales approximately linearly with the number of weights in the network, not the number of neurons, as would be the case without computation in superposition. For example, if each small network uses m neurons per MLP layer and d dimensions in the residual stream, a large network with M neurons per MLP connected to a D-dimensional residual stream could implement about ˜O(MDkmd) small networks, not just
˜O(Mm). Qualitatively speaking, our construction works using same basic trick as the one for boolean circuits in superposition. We just generalize it from boolean AND gates to any operations the neural network could implement.
Generalising to circuits
While our derivation here assumes T networks carrying out unrelated tasks in parallel, nothing in the construction stops us from instead chaining the small networks in series, with later small networks taking the outputs of earlier small networks as their inputs. Therefore, the construction in this post can be thought of as a framework for representing arbitrary circuits in superposition.
Some very tentative implications, maybe?
Real neural networks probably don’t work exactly the way this construction does. It's made to be easy for us to prove things about it, not to be efficient in real life. The finite width of real networks might make other constructions better. We're also not dealing with potential correlations between the activations of different circuits, which might change the optimal setup even more. And ultimately, we don't actually know whether the structure of real-world datasets is sparse in the right way to incentivise learning sparsely activating circuits.
Neverthless, there may be some useful takeaways about real networks, so long as we don't forget that they come with a heavy pinch of salt:
Future work
The Construction
Suppose we have T small neural networks. For simplicity we will assume that each small network consists of L layers, with m neurons in each layer with a fixed elementwise nonlinearity, and a fixed residual stream width d. We require that these small networks are at least somewhat robust to noise: there is some magnitude of random noise ϵmax>0 that we can apply to all the preactivations of any of the small networks' neurons without changing downstream layer activations by more than some small δ.[4]
Then we can create a large network that is also L layers deep, with a residual stream width D≫d, M≫m neurons in each layer and the same activation functions, which can leverage superposition to compute the outputs of all $T$ neural networks in parallel.
This works even for D≪Td and M≪Tm, provided that only k≪T small neural networks are being passed a non-zero input vector on most forward passes. This large network will require on the order of N=˜O(kTn) parameters in total[5].
The core idea behind this construction is similar to that for computing many ANDs of binary inputs in superposition. There may be many other constructions that would also work, but we think that in the limit of very wide neural networks, all constructions would perform more or less the same, and yield the same fundamental limits for how many small networks can be superposed into a network with N parameters[6]. As with all constructions involving superposition, the key to the construction working out is in managing the size of the interference between separate small networks, and making sure that it does not become larger than the size of the signal — the correct output of each small network. In this construction, there are two sources of interference:
Read-in interference
Our T small networks have a combined Td≫D residual stream dimensions. So, activation vectors of different small networks in the large residual stream cannot be completely orthogonal. This means that when a particular small network is passed an input of 0 but other small networks are passed nonzero inputs, the value of the inputs that are read in by the weights that implement the first small network won't be exactly zero. In our construction, this read-in interference is what ends up dominating the constraints on how many small networks we can compute in a single large network.
At a high level, we manage read-in interference by making the residual stream width D larger so the overlap between small networks is smaller, and making the MLP width M larger so the read-in interference can be spread across more neurons.
Read-out interference
Our T small networks have a combined mT≫M neurons per layer. Naively, we could randomly assign every neuron in every small network to one neuron in the big network. But then, if two small networks that happened to share a neuron activated at the same time, that neuron would get conflicting inputs and misfire. So we could only carry out one of the T tasks at a time.
To make the small networks robust to these misfires, we introduce redundancy into the big network, representing each neuron in the small network with many neurons in the big network. This means that each neuron in the big network is assigned to even more small networks than if there was no redundancy, but this cost is worth it: we can now recover the value of any activation of any small network by averaging over the values of every neuron in the large neuron that represents it. If few enough small networks are active at once, then almost all neurons in the large network assigned to any particular small network's neuron will take on the correct value for that neuron, almost all of the time, and in the limit of M→∞, the difference between the value of a small network's neuron and the average of all the neurons in the large network that compute that small network will go to zero.
Maths
If you don't care about technical details, you can safely skip this section.
Let the input to the t-th small network be denoted by xt∈Rd and the activation vector of small network t in layer l for input xt by alt(xt) or simply alt.
Similarly, denote the activation vector for the large network in layer l by Al.
We also define a set of random matrices with orthonormal rows {Et∈RD×d}:
Et=(e1t⋯edt↓↓)
with eit∈RD satisfying eit⋅ejt=δij. Since the matrices are projection matrices to random d-dimensional subspaces of RD, their columns satisfy Et≠s(eit⋅ejs)2=O(1/D). These matrices define projections from the residual streams of each small network into a random subspace of the larger residual stream. What we want to prove is that if the number of xt that are nonzero is k≪T, then for all l=1,…,L, there exists terms δl satisfying ||δl||2≪||∑Tt=1Etalt||2, such that:
Al=∑Tt=1Etalt+δl.
We'll (sort-of) prove this using induction.
Embedding Matrix
The base case for the induction is just the embedding in layer 0. The input to the large network is the concatenated vector X=(x1,x2,…,xT)∈RTd. The embedding matrix[7] WE∈RD×Td is constructed by directly projecting each xt into the residual stream using Et, which we can do by stacking the projection matrices next to each other:
WE=(E1⋯ET).
Then, the residual stream activation vector at layer zero
A0:=WEX is equal to A0=∑TT=1Etxt as required.
Other layers
We'd now like to assume that Al=∑Tt=1Etalt+δl is satified in layer l−1, and demonstrate that it is satisfied in layer l. To do so, we need to work out what the matrices Wl,in,Wl,out should be.
Reading from the residual stream
To start, we need a way to compute the outputs of Wl,in1,…,Wl,inT∈Rd×n all at once with the larger matrix Wl,in∈RD×N. If we had D≥Td,N≥Tn we could do this by making Wl,in block diagonal, but we are looking for a construction with D≪Td,N≪Tn. To make progress, we start by noting that
Wl,intEt⊺Al−1=Wl,intal−1t+Wl,intEt⊺δl−1+Wl,int∑s≠tEt⊺Esal−1s,
where we have used that Et⊺Et=Id(d). We want the read-in interference
ϵl,int:=∑s≠tEt⊺Esal−1s
introduced to network t in layer l to be sufficiently small, staying below the ϵmax noise level we assume the subnetworks to be robust to. The justification for ϵl,int being small will be based based on the fact that for t≠s,Et⊺Es is approximately a matrix of gaussians with variance 1/D. Details are in Section Read-in interference.
Writing to the neurons
We can't just connect the outputs of this multiplication to neurons in layer l of the large network even if the interference is small. This is because mT≫M so we'd have to share neurons between many circuits and we wouldn't be able to tell if a neuron i fires due to circuit t activating, or some other circuit that connects to that neuron activating instead. Instead, we need to introduce some redundancy to the representations of the activations of each small network[8]. We do this by multiplying by a distributing matrix Vl∈RmT×M. This matrix is defined as follows:
For the t-th small network, the neurons that are in sets which are assigned a permutation matrix are called connected to that small network, and the neurons that are in sets assigned the zero matrix are called unconnected. We denote the set of all sets of neurons in the large network that are connected to the tth small network in layer l by Slt (a subset of the powerset of {1,…,M}), and the set of all neurons in the large network that are connected to the ith neuron of the tth small network in layer l by Slt,i. Every small network will on average connect its weights Wl,int to r=E[|Slt|]=logM sets of m neurons in the big network. So, we set
Wl,in=∑tVltWl,intEt⊺.
Writing back to the residual stream
To write back to the residual stream from the neurons, first we can recover the value of the activations of each small network by averaging all the neurons in the large network that are connected to that small network neuron. We do this by multiplying the activations of the big network with 1|Slt|(Vlt)⊺:
1|Slt|(Vlt)⊺ReLU(Wl,inAl)=ReLU(Wl,intalt)+ϵl,outt.
Then we can apply each Wl,outt to recover al+1t, and then we can embed these activations back into the residual stream using Et:
Wl,out=∑t1|Slt|EtWl,outt(Vlt)⊺.
If ϵl,outt is small enough (which requires ϵl,in to be small as well, then we are done, and Al will have the correct form.
Error analysis
Let a,w∈R+ be upper bounds on the L2 norm of the small networks' activations in the residual stream, and operator norm of their MLP input matrices, respectively:
||alt||2≤a∀l,t∈(1,…,T), ||Win,lt||op≤w∀l,t∈(1,…,T).
In the analysis below, we find that the L2 size of the total interference added to a subnet in an MLP layer will be
ϵ=O(wa√kTmdMDlogM).
For this noise to stay below the ϵmax we assumed the small networks to be robust to at every layer, our large network needs at least
N=O(w2a2ϵ2maxkTnlogM)
parameters in total. Any less than that, and the inteference will begin to overwhelm the signal. Assuming the noise ϵmax isn't larger than the maximum size of the small network's neuron activations, we'll have w2a2ϵ2max<1. So we need N=˜O(kTn) parameters in total.
Read-in interference
In this construction, we find that our total error term in dominated by read-in interference.
The noise from an activation vector als of a circuit s being multiplied by weight matrix Wint of a different circuit t will be
ϵl,int,s=WintEt⊺Esals.
The entries of the matrix Et⊺Es∈Rd×d will have approximate size O(1√D). Since the d entries of a row of Et⊺Es are randomly distributed, the entries of Et⊺Esals will then have average size O(√dD). So, the noise ϵl,int,s from activation als of small network s being partially projected into preactivations of neurons in small network t will be on the order of
ϵl,int,s=O(√dD||Win,lt||op||als||2).
On average, each neuron has Tp=TmMlogM weight rows of small networks connecting to it. Using ||als||≤a,||Win,lt||≤w, if there are k circuits active at a given time, the total read-in interference ϵl,int=∑s≠tϵl,int,s on the preactivation on any one neuron in any small network t will be bounded by
ϵl,int=O(wa√kTmdMDlogM)
because the noise sources are independent. This noise dominates the total error term.
Read-out interference
In our construction, we find that read-out interference ϵl,outt from multiple circuits using the same neuron is subdominant and vanishes in the limit of large networks. For the read-out of a small network from the MLP of the large network to become inaccurate, some fraction of the logM neurons playing the role of one neuron in the original small network have to all `misfire', activating when they shouldn't, or with incorrect magnitude even when they do fire. Since we assumed that our activation functions are Lipschitz continuous, we can bound any `misfire' to be smaller than some bound K∈R.
We'll assume that there is some critical fraction 0<c<1 which is the maximum number of misfires we can tolerate, which is dependent on the error tolerance of our small networks: clog(T) misfires would give us an error ϵl,outt,i≤clog(T)K on the read-out of neuron i in small network t, which we require to be smaller than the maximum error tolerance of the small networks ϵmax.
One neuron: Consider a specific neuron i in small network s. This neuron is assigned a set Sls,i of size approximately logM of neurons to compute it in the large network.
k=1: Suppose that only small network t≠s is active on the current forward pass. The chance of any circuit t connecting to a given neuron is p=mMlog(M). So, if c≪1, the probability that there are clogM misfirings in the set Sls,i will follow a binomial distribution:
P(x misfirings in Sls,i)=(logMclogM)(mlogMM)clogM(1−mlogMM)(1−c)logM.
The last factor is approximately equal to 1 and can be ignored.
k>1: Suppose there are k>1 small networks active at once. Each neuron in Sls,i can be used in multiple active networks. We can imagine a matrix with k rows and logM columns, with a 1 in the (i,j) position if the ith neuron in Sls,i is connected to the jth active small network, and a zero otherwise. The entries of this matrix are i.i.d Bernoulli random variables with probability p, and the number of nonzero entries in this matrix is the total number of misfirings in Sls,i. Again assuming c≪1, the probability Sls,i has clogM misfirings will be:
P(x misfirings in Sls,i)=(klogMclogM)(mlogMM)clogM.
Using Stirling's formula[9], we can write this as:
P(clogM misfirings in Sls,i)<(kmelogMMc)clogM.
We can approximate P(clogM+x misfirings in Sls,i) as a decaying geometric series in x, with initial value P0=P(clogM misfirings in Sls,i) and ratio r=Px+1Px≃klogMpclogM=kmlogMcM≪1.
Therefore, we have
P(at least clogM misfirings in Sls,i)=P01−r<(kmelogMMc)clogM.
One forward pass: We have Tm sets of neurons Sls,i. We want the chance of more than clogM misfirings for any of them on a forward pass to be vanishingly small for all c in the large width limit. That is, we want to scale M with the number of small networks T, the size of small networks m, and the number of active small networks k such that:
limM,T→∞Tm(eckmlogMM)clogM=0.
This condition is satisfied for any c≪1 so long as:
The read-in error already imposes MD=O(Tmkd), so the former condition is not an additional constraint, except in that it precludes making the residual stream exponentially wider than the MLP M. The latter condition is fulfilled if the small networks activate sparsely.
So, in the large width limit M→∞, ϵl,outt will vanish. Thus, the total error is dominated by ϵl,int.
Acknowledgements
Thanks to Dan Braun, Stefan Heimersheim, Lee Sharkey, and Bilal Chughtai for lots of discussions that shaped our thinking about this idea. Thanks also to Kaarel Hanni, Dmitry Vaintrob and Lawrence Chan for previous work that this idea builds on heavily, and for helping shape our thinking about this kind of thing.
N=˜O(kTn) basically means 'N=O(kTn) up to log factors'.
Put differently, we can't have an overcomplete basis of task vectors.
This limit is already suggested by information theory: Every operation we want the network to implement takes some minimum number of bits in its parameters to specify. So, in general, the minimum description length of the large network in bits can't be smaller than the minimum description lengths of the small networks summed together.
The more imprecision we're willing to tolerate in the final result, the larger ϵmax will be. If small networks vary in how noise robust they are, we pick the ϵmax of the least robust one to be conservative.
These simplifications primarily serve to avoid obfuscating the ideas in the construction. We are pretty confident that the derivations go through if you allow the number of neurons, residual stream width, and number of layers per small network to vary. That is, suppose we are given a set of neural networks indexed by t=1,…T. For the t-th network, denote the number of neurons per layer as mt, residual stream width dt, and number of layers ℓt. Then, there exists a large residual neural network with depth L, number of neurons per layer M, and residual stream width D which satisfies∀t∈{1,…,T}:mt≪M,dt≪D,ℓt≤L, and ∑tmt≫M,∑tdt≫D, which can compute the outputs of all T circuits in parallel by leveraging superposition.
We think some additional tinkering might remove the log term, and constant prefactors could likely be improved, but we doubt anything will break the limit N≥∑Ttnt. We can't specify more operations than we have bits to specify them in.
Using the convention of left multiplication by matrices.
This is essentially the same idea that is referred to as superpositional codes in this essay.
Which applies because p≪1, and the expected number of misfirings is pklogM=mklog2MM≪clogM.