I think working on mechanistic intepretability in a variety of domains, architectures, and modalities seems like a reasonable research diversification bet.
However, it feels pretty odd to me to describe branching out into other modalities as crucial when we haven't yet really done anything useful with mechanistic interpretability in any domain or for any task.
You say:
With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.
And on X/twitter:
Frontier models are multimodal, and it's increasingly clear that mechanistic interpretability can't only study language models.
But, I feel like the situation is relatively analogous to:
Fusion power plants will need to be built in many countries, and it's increasing clear that fusion power plant construction can't only study building fusion power in the US.
Like yeah, you'll eventually need to handle non-language modalities and you should probably sanity check that they aren't key additional blockers with the methodology, but also why would there be key methodologies that mean it can solve our problems in the language case but note the vision/multimodal case? And the main obstacle is demonstrating basic technical feasibility, not branching out?
Again, I'd like to stress that studying a variety of cases with mech interp seems like a reasonable research diversification bet.
(And I don't want to be the language police here, just pushing back a bit on the implicit vibes.)
There is another argument that could be made for working on other modalities now: there could be insights which generalize across modalities, but which are easier to discover when working on some modalities vs. others.
I've actually been thinking, for a while now, that people should do more image model interprebility for this sort of reason. I never got around to posting this opinion, but FWIW it is the main reason I'm personally excited by the sort of work reported here. (I have mostly been thinking about generative or autoencoding image models here, rather than classifiers, but the OP says they're building toward that.)
Why would we expect there to be transferable insights that are easier to discover in visual domains than textual domains? I have two thoughts in mind:
First thought:
The tradeoff curve between "model does something impressive/useful that we want to understand" and "model is conveniently small/simple/etc." looks more appealing in the image domain.
Most obviously: if you pick a generative image model and an LLM which do "comparably impressive" things in their respective domains, the image model is going to be way smaller (cf.). So there are, in a very literal way, fewer things we have to interpret -- and a smaller gap between the smallest toy models we can make and the impressive models which are our holy grails.
Like, Stable Diffusion is definitely not a toy model, and does lots of humanlike things very well. Yet it's pretty tiny by LLM standards. Moreover, the SD autoencoder is really tiny, and yet it would be a huge deal if we could come to understand it pretty well.
Beyond mere parameter count, image models have another advantage, which is the relative ease of constructing non-toy input data for which we know the optimal output. For example, this is true of:
By contrast, in language modeling and classification, we really have no idea what the optimal logits are. So we are limited to making coarse qualitative judgments of logit effects ("it makes this token more likely, which makes sense"), ignoring the important fine-grained quantitative stuff that the model is doing.
None of that is intrinsically about the image domain, I suppose; for instance, one can make text autoencoders too (and people do). But in the image domain, these nice properties come for free with some of the "real" / impressive models we ultimately want to interpret. We don't have to compromise on the realism/relevance of the models we choose for ease of interpretation; sometimes the realistic/relevant models are already convenient for interpretability, as a happy accident. The capabilities people just make them that way, for their own reasons.
The hope, I guess, is that if we came pretty close to "fully understanding" one of these more convenient models, we'd learn a lot of stuff a long the way about how to interpret models in general, and that would transfer back to the language domain. Stuff like "we don't know what the logits should be" would no longer be a blocker to making progress on other fronts, even if we do eventually have to surmount that challenge to interpret LLMs. (If we had a much better understanding of everything else, a challenge like that might be more tractable in isolation.)
Second thought:
I have a hunch that the apparent intuitive transparency of language (and tasks expressed in language) might be holding back LLM interpretability.
If we force ourselves to do interpretability in a domain which doesn't have so much pre-existing taxonomical/terminological baggage -- a domain where we no longer feel it's intuitively clear what the "right" concepts are, or even what any breakdown into concepts could look like -- we may learn useful lessons about how to make sense of LLMs when they aren't "merely" breaking language and the world down into conceptual blocks we find familiar and immediately legible.
When I say that "apparent intuitive transparency" affects LLM interpretability work, I'm thinking of choices like:
In both of these lines of work, there's a temptation to try to parse out the LLM computation into operations on parts we already have names for -- and, in cases where this doesn't work, to chalk it up either to our methods failing, or to the LLM doing something "bizarre" or "inhuman" or "heuristic / unsystematic."
But I expect that much of what LLMs do will not be parseable in this way. I expect that the edge that LLMs have over pre-DL AI is not just about more accurate extractors for familiar, "interpretable" features; it's about inventing a decomposition of language/reality into features that is richer, better than anything humans have come up with. Such a decomposition will contain lots of valuable-but-unfamiliar "frobnoloid"-type stuff, and we'll have to cope with it.
To loop back to images: relative to text, with images we have very little in the way of pre-conceived ideas about how the domain should be broken down conceptually.
Like, what even is an "interpretable image feature"?
Maybe this question has some obvious answers when we're talking about image classifiers, where we expect features related to the (familiar-by-design) class taxonomy -- cf. the "floppy ear detectors" and so forth in the original Circuits work.
But once we move to generative / autoencoding / etc. models, we have a relative dearth of pre-conceived concepts. Insofar as these models are doing tasks that humans also do, they are doing tasks which humans have not extensively "theorized" and parsed into concept taxonomies, unlike language and math/code and so on. Some of this conceptual work has been done by visual artists, or photographers, or lighting experts, or scientists who study the visual system ... but those separate expert vocabularies don't live on any single familiar map, and I expect that they cover relatively little of the full territory.
When I prompt a generative image model, and inspect the results, I become immediately aware of a large gap between the amount of structure I recognize and the amount of structure I have names for. I find myself wanting to say, over and over, "ooh, it knows how to do that, and that!" -- while knowing that, if someone were to ask, I would not be able to spell out what I mean by each of these "that"s.
Maybe I am just showing my own ignorance of art, and optics, and so forth, here; maybe a person with the right background would look at the "features" I notice in these images, and find them as familiar and easy to name as the standout interpretable features from a recent LM SAE. But I doubt that's the whole of the story. I think image tasks really do involve a larger fraction of nameless-but-useful, frobnoloid-style concepts. And the sooner we learn how to deal with those concepts -- as represented and used within NNs -- the better.
Join our Discord here.
This article was written by Sonia Joseph, in collaboration with Neel Nanda, and incubated in Blake Richards’s lab at Mila and in the MATS community. Thank you to the Prisma core contributors, including Praneet Suresh, Rob Graham, and Yash Vadi.
Full acknowledgements of contributors are at the end. I am grateful to my collaborators for their guidance and feedback.
Outline
Introducing the Prisma Library for Multimodal Mechanistic Interpretability
I am excited to share with the mechanistic interpretability and alignment communities a project I’ve been working on for the last few months. Prisma is a multimodal mechanistic interpretability library based on TransformerLens, currently supporting vanilla vision transformers (ViTs) and their vision-text counterparts CLIP.
With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.
The goal of Prisma is to make research in mechanistic interpretability for multimodal models both easy and fun. We are also building a strong and collaborative open source research community around Prisma. You can join our Discord here.
This post includes a brief overview of the library, fleshes out some concrete problems, and gives steps for people to get started.
Prisma Goals
Tutorial Notebooks
To get started, you can check out three tutorial notebooks that show how Prisma works.
Brief ViT Overview
A vision transformer (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.
Unlike language models, vision transformers do not have a dictionary-style embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, akin to word embeddings in language models. A learnable class token (CLS token) is prepended at the start of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.
The patch embeddings then pass through the transformer blocks (each block consists of a LayerNorm, an Attention layer, another LayerNorm, and an MLP layer). The output of each block is added back to the previous input. The sum of the block’s output and its previous input is called the residual stream.
The final layer of this vision transformer is a classification head with 1000 logit values for ImageNet's 1000 classes. The CLS token is fed into the final layer for 1000-way classification. Adapting TransformerLens, we designed HookedViT to easily capture intermediate activations with custom hook functions, instead of dealing with PyTorch's normal hook functionality.
Prisma Functionality
We’ll demonstrate the functionality with some preliminary research results. The plots are all interactive but the LW site does not let me render HTML. See the original post for interactive graphs.
Emoji Logit Lens
The emoji logit lens is a convenient way to visualize patch-level predictions for each layer of the net.
We treat every patch like the CLS token, and feed it into the ViT’s 1000-way classification head that’s pre-trained on ImageNet, without fine-tuning. This is the equivalent to deleting all layers between the layer of your choice and the output classification head.
For convenience, we represent the ImageNet prediction of that patch with its corresponding emoji, drawing from our ImageNet-Emoji Dictionary.
Below are the patch-level predictions of the final layer of a ViT for an image of a cat sitting inside a toilet. The yellow means that the logit prediction was high and blue means the logit prediction was low (see the Emoji Logit Lens notebook for more details).
Emergent Segmentation
One of my favorite findings so far is that the patch-level logit lens on the image basically acts as a segmentation map. For the image above, the cat patches get classified as cat, and the toilet patches get classified as a toilet!
This is not an obvious result, as vision transformers are optimized to predict a single class with the CLS token, and not segment the image. The segmentation is an emergent property. See the Emoji Logit Lens Notebook for more details and an interactive visualization.
Similar emergent segmentation capabilities were recently reported by Gandelsman, Efros, and Steinhardt (2024), who found that decomposing CLIP's image representation across spatial locations allowed obtaining zero-shot semantic segmentation masks that outperformed prior methods. Our results extend this finding to vanilla vision transformers and provide an intuitive visualization using the emoji logit lens.
We can see similar results on other images.
(Note: For visualization purposes, I’ve changed the coloring to be by emoji class instead of logit value like above; see Emoji Logit Lens notebook for details.)
Interestingly, the net has some biased predictions (“abaya” for the children, perhaps due to their ethnicity), one consequence of only having a 1000-class vocabulary to span concept-space.
Funnily, the net thinks that the center of the green apple (above image, bottom left) is a bagel.
When we do a layer-by-layer logit lens, we see the net’s evolving predictions:
Interestingly, the net picks up on the “animal” at 9_pre (the residual stream before the 9th transformer block) but classifies the cat as a dog. The net only catches onto the cat at 10_pre.
This layer-wise analysis builds upon the work of Gandelsman, Efros, and Steinhardt (2024), who used mean ablations to identify which layers in CLIP have the most significant direct effect on the final representation. Our emoji logit lens provides a complementary view, visualizing how the patch-level predictions evolve across the model's depth.
Interactive code here.
We can also visualize the evolving per-patch predictions for the above cat/toilet image for all the layers at once:
Direct Logit Attribution
The library supports direct logit attribution, including at the layer-level and attention-level.
Below, the net starts making a distinction between tabby/collie and banana at the eighth layer. See the ViT Prisma Main Demo for the interactive graph.
Attention Heads
I wrote an interactive JavaScript visualizer so we can see what each vision attention head is attending to on the image.
The x and y axes of the attention head are the flattened image. The image is 50 patches in total, including the CLS token, which means the total attention head is a 50x50 square.
Upon initial inspection, the first layer’s attention heads are extremely geometric.
Corner Head, Edges Head, and Modulus Head
We can see attention heads’ scores specializing for specific patterns in the data, including what we call a Corner Head, an Edges Head, and a Modulus Head. This is fascinating because the flattened image does not explicitly contain corner, edge, or row/column information; detecting these patterns is emergent from training.
These findings echo the recent work of Gandelsman, Efros, and Steinhardt (2024) who identified property-specific attention heads in CLIP that specialize in concepts like colors, locations, and shapes. Our results suggest that such specialization is a more general property of vision transformer architectures, including vanilla models trained solely on image classification, and includes even more basic geometric properties like the coordinates of the image.
Interactive code here.
Video of the Corner Head
Activation Patching
Prisma has the activation patching functionality of TransformerLens.
The Cat-Dog Switch
I found a single attention head (Layer 11, Head 4) wherein patching the CLS token of the z-matrix flips the computation from tabby cat to Border Collie. The CLS token in that z-matrix aggregates patch-level cat ear/face information from the attention pattern.
Our activation patching results demonstrate that this technique can be used to flip the model's prediction by targeting specific heads, providing a powerful tool for understanding and manipulating the model's decision-making process
This result resonates with Gandelsman, Efros, and Steinhardt (2024), who showed that knowledge of head-specific roles in CLIP can be used to manually intervene in the model's computation, such as removing heads associated with spurious cues.
Interactive code here.
Toy Vision Transformers
We are releasing nine tiny ViTs for testing (equivalent to TransformerLens’ gelu-1l) to better isolate behavior. These tiny ViTs were trained by Yash Vadi and Praneet Suresh.
The repo also contains training code to quickly train custom toy ViTs.
HookedViT
We currently support timm’s vanilla ViTs, TinyCLIP, the video vision transformer, and our own custom tiny transformers. More models will come soon based on demand!
FAQ
Is multimodal mechanistic interpretability really that different from language?
Yes and no. Vision mech interpretability is like language mechanistic interpretability, but in a fun-house mirror. Both architectures are transformers, so many LLM techniques carry over. However, there are a few twists:
If there is demand, I may write up a post giving a deeper and more theoretical take on the differences on language vs non-language mechanistic interpretability.
Why start with vision transformers?
Vision transformers have an extremely similar architecture to language transformers, so many of the existing techniques transfer over cleanly.
Diffusion models are the next obvious frontier, but there will be a larger conceptual leap in designing mechanistic interpretability techniques, largely due to their iterative denoising process. I’d be happy to collaborate on this with anyone who is serious about building strong conceptual foundations here.
Getting Started with Vision Mechanistic Interpretability
How to get involved
Open Problems in Vision Mechanistic Interpretability
Here are some Open Problems to get started. If inspired, you are encouraged to post your own in the comments, or comment on the ideas that most grab your attention.
Easy and Exploratory
Expanding Techniques to New Architectures and Datasets
Deeper Investigations
Advanced Investigations
Acknowledgements
Thank you to this most excellent mosaic of communities.
Thank you to Praneet Suresh, Rob Graham, and Yash Vadi, and the other core contributors to the Prisma Repo. Thank you to my PI, Blake Richards, and the rest of our lab at Mila for their support and feedback.
Thank you to Neel Nanda for guidance in bringing mechanistic interpretability to another modality, to Joseph Bloom for your advice on building a repo, to Arthur Conmy for coining the term “dogit lens,” and to the rest of the MATS community for your feedback.
Thank you to the Prisma group at Mila for your feedback, including Santoshi Ravichandran, Ali Kuwajerwala, Mats L. Richter, and Luca Scimeca; members of LiNCLab, including Arna Ghosh and Dan Levenstein; and members of CERC-AAI lab, including Irina Rish and Ethan Caballero. Thank you to Karolis Ramanauskas, Noah MacCallum, Rob Graham, and Romeo Valentin for your feedback on the tutorial notebooks.
Finally, thank you to the South Park Commons community for your support, including Ker Lee Yap, Abhay Kashyap, Jonathan Brebner, and Ruchi Sanghvi and Aditya Agarwal.
This research was generously supported by Blake Richards’s lab, which was funded by the Bank of Montreal; NSERC (Discovery Grant: RGPIN-2020-05105; Discovery Accelerator Supplement: RGPAS-2020-00031; Arthur B. McDonald Fellowship: 566355-2022); CIFAR (Canada AI Chair; Learning in Machine and Brains Fellowship); and a Canada Excellence Research Chair Award to Prof. Irina Rish; and by South Park Commons. This research was enabled in part by support provided by Calcul Québec and the Digital Research Alliance of Canada. We acknowledge the material support of NVIDIA in the form of computational resources.