(Thanks to Adam Scherlis, Kshitij Sachan, Buck Shlegeris, Chris Olah, and Nicholas Schiefer for conversations that informed this post. Thanks to Aryan Bhatt for catching an error in the loss minimization.)

Anthropic recently published a paper on toy models of superposition [Elhage+22]. One of the most striking results is that, when features are sparse, feature embeddings can become divided into disjoint subspaces with just a few vectors per subspace. This type of decomposition is known as a tegum product.

This post aims to give some intuition for why tegum products are a natural feature of the minima of certain loss functions.

Setup

Task

Suppose we’ve got  embedding dimensions and  features. We want to embed features into dimensions in a way that minimizes overlap between their embedding vectors. In the simplest case this could be because we’re building an autoencoder and want to compress the features into a low-dimensional space.

One approach is to encode all of the features in all of the dimensions. With this approach there is some interference between every pair of features (i.e. no pair is embedded in a fully orthogonal way), but we have a lot of degrees of freedom that we can use to minimize this interference.

Another approach is to split the  dimensions into  orthogonal subspaces of  dimensions. This has the advantage of making most pairs of vectors exactly orthogonal, but at the cost that some vectors are packed more closely together. In the limit where  this reduces to the first approach.

Our aim is to figure out the  that minimizes the loss on this task.

Loss

Suppose our loss has the following properties:

  1. . That is, the loss decomposes into a sum of terms involving the cosine similarities of feature vectors, and all features are equally important.
  2. . The loss vanishes for orthogonal vectors.
  3. . The loss is greater the more the vectors overlap.

Using these properties, we find that the loss is roughly

where  is the typical cosine similarity between vectors in a subspace.

Loss-Minimizing Subspaces

Edit: The formula for  below is a simplification in the limit of small . This simplification turns out to matter, and affects the subsequent loss minimization. I've struck through the affected sections below, explained the correct optimization in this comment below and reproduced the relevant results below the original here. None of the subsequent interpretation is affected.

The Johnson-Lindenstrauss lemma says that we can pack  nearly-orthogonal vectors into  dimensions, with mutual angles satisfying

where

and  is a constant. Setting  and  gives

Assuming we pick our vectors optimally to saturate the Johnson-Lindenstrauss bound, we can substitute this for  in the loss and differentiate with respect to  to find

There are three possible cases: either the minimum occurs at  (the greatest value it can take), or at  (the smallest value it can take) or at some point in between where  vanishes.

The derivative vanishes if

 which gives

where

When  there is no place where the derivative vanishes, and the optimum is . Otherwise there is an optimum at

so long as this is less than . If it reaches , the optimum sticks to .

The Johnson-Lindenstrauss lemma says that we can pack  nearly-orthogonal vectors into  dimensions, with mutual angles satisfying

where  (per Scikit and references therein). The cubic term matters because it makes the interference grow faster than the quadratic alone would imply (especially in the vicinity of ).

With this correction it's not feasible to do the optimization analytically, but we can still do things numerically. Setting ,  , and  gives:

The top panel shows the normalized loss for a few different , and the lower shows the loss derivative with respect to . Note that the range of  is set by the real roots of : for larger  there are no real roots, which corresponds to the interference  crossing unity. In practice this bound applies well before . Intuitively, if there are more vectors than dimensions then the interference becomes order-unity (so there is no information left!) well before the subspace dimension falls to unity.

Anyway, all of these curves have global minima in the interior of the domain (if just barely for ), and the minima move to the left as  rises. That is, for  we care increasingly about higher moments as we increase  and so we want fewer subspaces.

What happens for ?

The global minima disappear! Now the optimum is always . In fact though the transition is no longer at  but a little higher:

Interpretation

We can think of  as the sensitivity of the loss to interference. Specifically, which moment of the interference distribution do we care about?

When  is large, we care more about decreasing higher moments, and in the limit of infinite  what matters is just the maximum interference between vectors. Hence when  is large we want to have fewer subspaces, each with more vectors but smaller cosine similarities.

By contrast, when  is small, we care more about decreasing smaller moments, and in the limit as  what matters is the fraction of vectors that interfere at all. Hence when  is small we want to have more subspaces, each with fewer vectors but larger cosine similarities.

So tegum products are preferred when we can tolerate larger “peak” interference and want fewer instances of interference, whereas a single large subspace is preferred when we can tolerate lots of instances of interference and want to minimize the worst cases.

Relation to Anthropic’s Results

In Anthropic’s Toy Model 2, the dimension of the subspaces increases the sparser the features get, meaning that  falls. We can make sense of this by expanding the loss as they do in powers of the sparsity :

where  is the loss associated with -sparse vectors. In the sparse limit  so

The term  is a penalty on positive biases and the term  is the loss on 1-sparse vectors. In this limit, the biases are negative (to minimize ), and this has the effect of wiping out the contribution of small interference terms in . So the model is pushed to decrease the worst case interference (which might overcome the bias) rather than minimize the average, corresponding to our large- limit.

On the other hand, in the dense limit  so

The term  is the loss on dense vectors, which means there is interference between all pairs of vectors. This makes minimizing the average interference more important than minimizing the really bad cases (i.e. higher moments), so the model prefers lots of small subspaces, corresponding to our small- limit.

Just as the relevant limit varies with sparsity for a given toy model, we can also get different phenomenology for different models. This makes sense given that different setups can have different loss functions.

Summary

There are many ways for models to embed features. A surprising observation from Elhage+22 is that sometimes the optimal choice is one which divides the embedding space into many orthogonal subspaces (i.e. a tegum product). We can understand this roughly as coming from a tradeoff between minimizing higher moments of the feature interference (e.g. worst-case) and minimizing lower moments (e.g. average-case interference).

Smaller subspaces minimize the lower moments by making most pairs of vectors exactly orthogonal. The cost of this is that there is less freedom to choose vector pairs in each subspace, so there is worse interference between the pairs that do interfere.

Larger subspaces have the reverse tradeoff: they experience interference between more pairs of vectors, but it tends to be milder because larger-dimensional spaces support packing more nearly-orthogonal vectors, even at a fixed ratio of vectors-to-dimension.

New Comment
7 comments, sorted by Click to highlight new comments since:

Thanks for writing this! I found this a really helpful post for clarifying my own intuitions. Trying to operationalise what confused me before, and what now feels clear:

Confusion: Why does the model want to split vectors into these orthogonal subspaces? This seems somewhat unnatural and wasteful - it loses a lot of degrees of freedom, and surely it wants to spread out and minimise interference as much as possible?

Implicitly, I was imagining something like L2 loss where the model wants to minimise the sum of squared dot products.

New intuition: There is no inherently correct solution to this problem! It all depends on the precise loss function (or, the impact of each pairwise interference on the loss function). If the model has 100 dimensions and needs to fit in 1000 vectors, it can do this by packing 1000 spread out across all 100 dimensions, or by packing 500 into the first 50, and 500 into the second 50. The second approach immediately gives it 500^2 dot products to be 0, at the cost of increasing the dot products within each partition of 500.

Intuitively, there's going to be some kind of conservation property affecting the total amount of interference, but the model can choose to allocate that towards minimising the number of significant interferences or the maximum interference. Smearing it across all dimensions minimises the maximum, forming a partition minimises the number. So the choice depends on the model's exact loss function.

In practice, the model's loss function will be really complicated - for any pair of features, cost of interference goes up if they're correlated, up if either is important, down if either is sparse, and down if the model can allocate some parameters to denoising the interference. Importantly, for the ones to do with correlation, interference between correlated features will be way worse, so the model wants to finds ways to minimise the max interference, and is happy to tolerate a lot of interference between uncorrelated features. Which means the optimal packing probably involves tegum products, because it's a nice hack to efficiently get lots of the interference terms to zero.

Probably my biggest remaining confusion is why tegum products are the best way to get a lot of interference terms to zero, rather than just some clever packing smeared across all dimensions.

That's good to hear! And I agree with your new intuition.

I think if you want interference terms to actually be zero you have to end up with tegum products, because that means you want orthogonal vectors and that implies disjoint subspaces. Right?

I don't think so? If you have eg 8 vectors arranged evenly in a 2D plane (so at 45 degrees to each other) there's a lot of orthogonality, but no tegum product. I think the key weirdness of a tegum product is that it's a partition, where every pair in different bits of the partition is orthogonal. I could totally imagine that eg the best way to fit 2n vectors is n dimensional space is two sets of n orthogonal vectors, but at some arbitrary angle to each other.

I can believe that tegum products are the right way to maximise the number of orthogonal pairs, though that still feels a bit weird to me. (technically, I think that the optimal way to fit kn vectors in R^n is to have n orthogonal directions and k vectors along each direction, maybe with different magnitudes - which is a tegum product. It forming 2D-3D subspaces feels odd though).

Oh yes you're totally right.

I think partitions can get you more orthogonality than your specific example of overlapping orthogonal sets. Take n vectors and pack them into d dimensions in two ways:

  1. A tegum product with k subspaces, giving (n/k) vectors per subspace and n^2*(1-1/k)orthogonal pairs.
  2. (n/d) sets of vectors, each internally orthogonal but each overlapping with the others, giving n*d orthogonal pairs.

If d < n*(1-1/k) the tegum product buys you more orthogonal pairs. If n > d then picking large k (so low-dimensional spaces) makes the tegum product preferred.

This doesn't mean there isn't some other arrangement that does better though...

Yeah, agreed that's not an optimal arrangement, that was just a proof of concept for 'non tegum things can get a lot of orthogonality

Thanks for the great post! I have a question, if it's not too much trouble: 

Sorry for my confusion about something so silly, but shouldn't the following be "when "? 

When  there is no place where the derivative vanishes

I'm also a bit confused about why we can think of  as representing "which moment of the interference distribution we care about."

Perhaps some of my confusion here stems from the fact that it seems to me that the optimal number of subspaces, , is an increasing function of , which doesn't seem to line up with the following:

Hence when  is large we want to have fewer subspaces

What am I missing here?

Sorry for my confusion about something so silly, but shouldn't the following be "when 

Oh you're totally right. And k=1 should be k=d there. I'll edit in a fix.

I'm also a bit confused about why we can think of  as representing "which moment of the interference distribution we care about."

It's not precisely which moment, but as we vary  the moment(s) of interest vary monotonically.

Perhaps some of my confusion here stems from the fact that it seems to me that the optimal number of subspaces, , is an increasing function of , which doesn't seem to line up with the following:

This comment turned into a fascinating rabbit hole for me, so thank you!

It turns out that there is another term in the Johnson-Lindenstrauss expression that's important. Specifically, the relation between , and  should be  (per Scikit and references therein). The numerical constants aren't important, but the cubic term is, because it means the interference grows rather faster as  grows (especially in the vicinity of ).

With this correction it's no longer feasible to do things analytically, but we can still do things numerically. The plots below are made with :

The top panel shows the normalized loss for a few different , and the lower shows the loss derivative with respect to . Note that the range of  is set by the real roots of : for larger  there are no real roots, which corresponds to the interference  crossing unity. In practice this bound applies well before . Intuitively, if there are more vectors than dimensions then the interference becomes order-unity (so there is no information left!) well before the subspace dimension falls to unity.

Anyway, all of these curves have global minima in the interior of the domain (if just barely for ), and the minima move to the left as  rises. That is, for  we care increasingly about higher moments as we increase  and so we want fewer subspaces.

What happens for ?

The global minima disappear! Now the optimum is always . In fact though the transition is no longer at  but a little higher:

So the basic story still holds, but none of the math involved in finding the optimum applies!

I'll edit the post to make this clear.