The notion of a preferred (linear) transformation for interpretability has been called a "privileged basis" in the mechanistic interpretability literature. See for example Softmax Linear Units, where the idea is discussed at length.
In practice, the typical reason to expect a privileged basis is in fact SGD – or more precisely, the choice of architecture. Specifically, activation functions such as ReLU often privilege the standard basis. I would not generally expect the data or the initialization to privilege any basis beyond the start of the network or the start of training. The data may itself have a privileged basis, but this should be lost as soon as the first linear layer is reached. The initialization is usually Gaussian and hence isotropic anyway, but if it did have a privileged basis I would also expect this to be quickly lost without some other reason to hold onto it.
Yeah, I'm familiar with privileged bases. Once we generalize to a whole privileged coordinate system, the RELUs are no longer enough.
Isotropy of the initialization distribution still applies, but the key is that we only get to pick one rotation for the parameters, and that same rotation has to be used for all data points. That constraint is baked in to the framing when thinking about privileged bases, but it has to be derived when thinking about privileged coordinate systems.
The data may itself have a privileged basis, but this should be lost as soon as the first linear layer is reached.
Not totally lost if the layer is e.g. a convolutional layer, because while the pixels within the convolutional window can get arbitrarily scrambled, it is not possible for a convolutional layer to scramble things across different windows in different parts of the picture.
Agreed. Likewise, in a transformer, the token dimension should maintain some relationship with the input and output tokens. This is sometimes taken for granted, but it is a good example of the data preferring a coordinate system. My remark that you quoted only really applies to the channel dimension, across which layers typically scramble everything.
I think we can get additional information from the topological representation. We can look at the relationship between the different level sets under different cumulative probabilities. Although this requires evaluating the model over the whole dataset.
Let's say we've trained a continuous normalizing flow model (which are equivalent to ordinary differential equations). These kinds of model require that the input and output dimensionality are the same, but we can narrow the model as the depth increases by directing many of those dimensions to isotropic gaussian noise. I haven't trained any of these models before, so I don't know if this works in practice.
Here is an example of the topology of an input space. The data may be knotted or tangled, and includes noise. The contours show level sets .
The model projects the data into a high dimensionality, then projects it back down into an arbitrary basis, but in the process untangling knots. (We can regularize the model to use the minimum number of dimensions by using an L1 activation loss
Lastly, we can view this topology as the Cartesian product of noise distributions and a hierarchical model. (I have some ideas for GAN losses that might be able to discover these directly)
We can use topological structures like these as anchors. If a model is strong enough, they will correspond to real relationships between natural classes. This means that very similar structures will be present in different models. If these structures are large enough or heterogeneous enough, they may be unique, in which case we can use them to find transformations between (subspaces of) the latent spaces of two different models trained on similar data.
Some interpretability work assigns meaning to activations of individual neurons or small groups of neurons. Some interpretability work assigns meaning to directions in activation-space. These are two different ontologies through which to view a net’s internals. Probably neither is really the “right” ontology, and there is at least one other ontology which would strictly outperform both of them in terms of yielding accurate interpretable structure.
One of the core problems of interpretability (I would argue the core problem) is that we don’t know what the “right” internal ontology is for a net - which internal structures we should assign meaning to. The goal of this post is to ask what things we could possibly assign meaning to under a maximally-general ontology constraint: coordinate freedom.
What Does Coordinate Freedom Mean?
Let’s think of a net as a sequence of activation-states xi, with the layer i → layer i+1 function given by xi+1=fi(xi).
We could use some other coordinate system to represent each xi. For instance, we could use (high dimensional) polar coordinates, with ri=||xi|| and ϕi a high-dimensional angle (e.g. all but one entry of a unit vector). Or, we could apply some fixed rotation to xi, e.g. in an attempt to find a basis which makes things sparse. In general, in order to represent xi in some other coordinate system, we apply a reversible transformation x′i=gi(xi), where x′i is the representation under the new coordinate system. In order to use these new coordinates while keeping the net the same overall, we transform the layer transition functions:
f′i−1=gi∘fi−1
f′i=fi∘g−1i
In English: we transform into the new coordinate system when calculating the layer state xi, and undo that transformation when computing the next layer state xi+1. That way, the overall behavior remains the same while using new coordinates in the middle
The basic idea of coordinate freedom is that our interpretability tools should not depend on which coordinate system we use for any of the internal states. We should be able to transform any layer to any coordinate system, and our interpretability procedure should still assign the same meaning to the same (transformed) internal structures.
What Kind Of Coordinate Free Internal Structure Is Even Possible?
Here’s one example of a coordinate free internal structure one could look for: maybe the layer i→i+1 function can be written as
fi(xi)=F(G(xi))
for some low-dimensional G. For instance, maybe xi and xi+1 are both 512-dimensional, but xi+1 can be calculated (to reasonable precision) from a 22-dimensional summary G(xi). We call this a low-dimensional “factorization” of fi.
(Side note: I’m assuming throughout this post that everything is differentiable. Begone, pedantic mathematicians; you know what you were thinking.)
This kind of structure is also easy to detect in practice: just calculate the singular vector decomposition of the jacobian dfidxi at a bunch of points, and see whether the jacobian is consistently (approximately) low rank. In other words, do the obvious thing which we were going to do anyway.
Why is this structure coordinate free? Well, no matter how we transform xi and xi+1, so long as the coordinate changes are reversible, the transformed function f′i will still factor through a low-dimensional summary. Indeed, it will factor through the same low-dimensional summary, up to isomorphism. We can also see the corresponding fact in the first-order approximation: we can multiply the jacobian on the left and right by any invertible matrix, its rank won’t change, and low-rank components will be transformed by the transformation matrices.
… and as far as local structure goes (i.e. first-order approximation near any given point), that completes the list of coordinate free internal structures. It all boils down to just that one (and things which can be derived/constructed from that one). Here’s the argument: by choosing our coordinate transformations, we can make the jacobian anything we please, so long as the rank and dimensions of the matrix stay the same. The rank is the only feature we can’t change.
But that’s only a local argument. Are there any other nonlocal coordinate free structures?
Are There Any Other Coordinate Free Internal Structures?
Let’s switch to the discrete case for a moment. Before we had fi(xi) mapping from a 512-dimensional space to a 512-dimensional space, but factoring through a 22-dimensional “summary”. A simple (and smaller) discrete analogue would be a function fi(xi) which maps the five possible values {1, 2, 3, 4, 5} to the same five values, but factors through a 2-value summary. For instance, maybe the function maps like this:
Coordinate freedom means we can relabel the 1, 2, 3, 4, 5 any way we please, on the input or output side. While maintaining coordinate freedom, we can still identify whether the function factors through some “smaller” intermediate set - in this case the set {“a”, “b”}. Are there any other coordinate free structures we can identify? Or, to put it differently: if two functions factor through the same intermediate sets, does that imply that there exists some reversible coordinate transformation between the two?
It turns out that we can find an additional structure. Here’s another function from {1, 2, 3, 4, 5} to itself, which factors through the same intermediate sets as our previous function, but is not equivalent under any reversible coordinate transformation:
Why is this not equivalent? Well, no matter how we transform the input set in the first function, we’ll always find that three input values map to one output value, and the other two input values map to another output value. The “level sets” - i.e. sets of inputs which map to the same output - have size 3 and 2, no matter what coordinates we use. Whereas, for the second function, the level sets have size 4 and 1.
Does that complete the list of coordinate free internal structures in the discrete case? Yes: if we have two functions with the same level set sizes, whose input and output spaces are the same size, then we can reversibly map between them. Just choose the coordinate transformation to match up level sets of the same size, and then match up the corresponding outputs.
Ok, so that’s the discrete case. Switching back to the continuous case (and bringing back the differentiable transformation constraint), what other coordinate free internal structure might exist in a net?
Well, in the continuous case, “size of the level set” isn’t really relevant, since e.g. we can reversibly map the unit interval to the real line. But, since our transformations need to be smooth, topology is relevant - for instance, if the set of inputs which map to 0 is 1 dimensional, is it topologically a circle? A line? Two circles and a line? A knot?
Indeed, “structure which is invariant under smooth reversible transformation” is kinda the whole point of topology! Insofar as we want our interpretability tools to be coordinate free, topological features are exactly the structures to which we can try to assign meaning.
Great, we’ve reinvented topology.
… So Now What?
There are some nontrivial things we can build up just from low-dimensional summaries between individual layers and topological features. But ultimately, I don’t expect to unlock most of interpretability this way. I’d guess that low-dimensional summaries of the particular form relevant here unlock a bit less than half of interpretability (i.e. all the low-rank stuff, along the lines of the Rome paper), and other topological structures add a nonzero but small chunk on top of that. (For those who are into topology, I strongly encourage you to prove me wrong!) What's missing? Well, for instance, one type of structure which should definitely play a big role in a working theory of interpretability is sparsity. With full coordinate freedom, we can always choose coordinates in which the layer functions are sparse, and therefore we gain no information by finding sparsity in a net.
So let’s assume we can’t get everything we want from pure coordinate free interpretability. Somehow, we need to restrict allowable transformations further. Next interesting question: where might a preferred coordinate system or an additional restriction on transformations come from?
One possible answer: the data. We’ve implicitly assumed that we can apply arbitrary coordinate transformations to the data, but that doesn’t necessarily make sense. Something like a stream of text or an image does have a bunch of meaningful structure in it (like e.g. nearby-ness of two pixels in an image) which would be lost under arbitrary transformations. So one natural next step is to allow coordinate preference to be inherited from the data. On the other hand, we’d be importing our own knowledge of structure in the data; really, we’d prefer to only use the knowledge learned by the net.
Another possible answer: the initialization distribution of the net parameters. For instance, there will always be some coordinate transformation which makes every layer sparse, but maybe that transformation is highly sensitive to the parameter values. That would indicate that any interpretation which relies on that coordinate system is not very robust; some small change in theta which leaves network behavior roughly the same could totally change the sparsifying coordinate system. To avoid that, we could restrict ourselves to transformations which are not very parameter-sensitive. I currently consider that the most promising direction.
The last answer I currently see is SGD. We could maybe argue that SGD introduces a preferred coordinate system, but then the right move is to probably look at the whole training process in a coordinate free way rather than just the trained net by itself. That does sound potentially useful, although my guess is that it mostly just reproduces the parameter-sensitivity thing.
Meta note: I’d be surprised if the stuff in this post hasn’t already been done; it’s one of those things where it’s easy and obvious enough that it’s faster to spend a day or two doing it than to find someone else who’s done it. If you know of a clean write-up somewhere, please do leave a link, I’d like to check whether I missed anything crucial.