Behold the dogit lens. Patch-level logit attribution is an emergent segmentation map.

Join our Discord here.

This article was written by Sonia Joseph, in collaboration with Neel Nanda, and incubated in Blake Richards’s lab at Mila and in the MATS community. Thank you to the Prisma core contributors, including Praneet Suresh, Rob Graham, and Yash Vadi. 

Full acknowledgements of contributors are at the end. I am grateful to my collaborators for their guidance and feedback.

Outline

  • Part One: Introduction and Motivation
  • Part Two: Tutorial Notebooks
  • Part Three: Brief ViT Overview
  • Part Four: Demo of Prisma’s Functionality
    • Key features, including logit attribution, attention head visualization, and activation patching.
    • Preliminary research results obtained using Prisma, including emergent segmentation maps and canonical attention heads.
  • Part Five: FAQ, including Key Differences between Vision and Language Mechanistic Interpretability
  • Part Six: Getting Started with Vision Mechanistic Interpretability
  • Part Seven: How to Get Involved
  • Part Eight: Open Problems in Vision Mechanistic Interpretability

Introducing the Prisma Library for Multimodal Mechanistic Interpretability

I am excited to share with the mechanistic interpretability and alignment communities a project I’ve been working on for the last few months. Prisma is a multimodal mechanistic interpretability library based on TransformerLenscurrently supporting vanilla vision transformers (ViTs) and their vision-text counterparts CLIP.

With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.

The goal of Prisma is to make research in mechanistic interpretability for multimodal models both easy and fun. We are also building a strong and collaborative open source research community around Prisma. You can join our Discord here.

This post includes a brief overview of the library, fleshes out some concrete problems, and gives steps for people to get started.

Prisma Goals

  1. Build shared infrastructure (Prisma) to make it easy to run standard language mechanistic interpretability techniques on non-language modalities, starting with vision.
  2. Build shared conceptual foundation for multimodal mechanistic interpretability.
  3. Shape and execute on research agenda for multimodal mechanistic interpretability.
  4. Build an amazing multimodal mechanistic interpretability subcommunity, inspired by current efforts in language.
    1. Set the cultural norms of this subcommunity to be highly collaborative, curious, inventive, friendly, respectful, prolific, and safety/alignment-conscious.
    2. Encourage sharing of early/scrappy research results on Discord/Less Wrong.
    3. Co-create a web of high-quality research.

Tutorial Notebooks

To get started, you can check out three tutorial notebooks that show how Prisma works.

  1. Main ViT Demo - Overview of main mechanistic interpretability technique on a ViT, including direct logit attribution, attention head visualization, and activation patching. The activation patching switches the net's prediction from tabby cat to Border collie with a minimum ablation.
  2. Emoji Logit Lens - Deeper dive into layer- and patch-level predictions with interactive plots.
  3. Interactive Attention Head Tour - Deeper dive into the various types of attention heads a ViT contains with interactive JavaScript.
     

Brief ViT Overview

From "An Image is Worth 16x16 words: Transformers for Image Recognition at Scale."

vision transformer (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.

Unlike language models, vision transformers do not have a dictionary-style embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, akin to word embeddings in language models. A learnable class token (CLS token) is prepended at the start of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.

The patch embeddings then pass through the transformer blocks (each block consists of a LayerNorm, an Attention layer, another LayerNorm, and an MLP layer). The output of each block is added back to the previous input. The sum of the block’s output and its previous input is called the residual stream.

The final layer of this vision transformer is a classification head with 1000 logit values for ImageNet's 1000 classes. The CLS token is fed into the final layer for 1000-way classification. Adapting TransformerLens, we designed HookedViT to easily capture intermediate activations with custom hook functions, instead of dealing with PyTorch's normal hook functionality.

Prisma Functionality

We’ll demonstrate the functionality with some preliminary research results. The plots are all interactive but the LW site does not let me render HTML. See the original post for interactive graphs.

Emoji Logit Lens 

The emoji logit lens is a convenient way to visualize patch-level predictions for each layer of the net. 

We treat every patch like the CLS token, and feed it into the ViT’s 1000-way classification head that’s pre-trained on ImageNet, without fine-tuning. This is the equivalent to deleting all layers between the layer of your choice and the output classification head. 

For convenience, we represent the ImageNet prediction of that patch with its corresponding emoji, drawing from our ImageNet-Emoji Dictionary.

Below are the patch-level predictions of the final layer of a ViT for an image of a cat sitting inside a toilet. The yellow means that the logit prediction was high and blue means the logit prediction was low (see the Emoji Logit Lens notebook for more details).

Emergent Segmentation

One of my favorite findings so far is that the patch-level logit lens on the image basically acts as a segmentation map. For the image above, the cat patches get classified as cat, and the toilet patches get classified as a toilet! 

This is not an obvious result, as vision transformers are optimized to predict a single class with the CLS token, and not segment the image. The segmentation is an emergent property. See the Emoji Logit Lens Notebook for more details and an interactive visualization. 

Similar emergent segmentation capabilities were recently reported by Gandelsman, Efros, and Steinhardt (2024), who found that decomposing CLIP's image representation across spatial locations allowed obtaining zero-shot semantic segmentation masks that outperformed prior methods. Our results extend this finding to vanilla vision transformers and provide an intuitive visualization using the emoji logit lens. 

We can see similar results on other images. 

(Note: For visualization purposes, I’ve changed the coloring to be by emoji class instead of logit value like above; see Emoji Logit Lens notebook for details.)

Image and patch-level logit lens of a cheetah chasing a gazelle.
Image and patch-level logit lens of two children with a basketball playing near baby lions.

Interestingly, the net has some biased predictions (“abaya” for the children, perhaps due to their ethnicity), one consequence of only having a 1000-class vocabulary to span concept-space.


 

Funnily, the net thinks that the center of the green apple (above image, bottom left) is a bagel.


 

When we do a layer-by-layer logit lens, we see the net’s evolving predictions:

Interestingly, the net picks up on the “animal” at 9_pre (the residual stream before the 9th transformer block) but classifies the cat as a dog. The net only catches onto the cat at 10_pre.

This layer-wise analysis builds upon the work of Gandelsman, Efros, and Steinhardt (2024), who used mean ablations to identify which layers in CLIP have the most significant direct effect on the final representation. Our emoji logit lens provides a complementary view, visualizing how the patch-level predictions evolve across the model's depth.

Interactive code here.

We can also visualize the evolving per-patch predictions for the above cat/toilet image for all the layers at once:

Direct Logit Attribution

The library supports direct logit attribution, including at the layer-level and attention-level.

Below, the net starts making a distinction between tabby/collie and banana at the eighth layer. See the ViT Prisma Main Demo for the interactive graph.

Attention Heads

I wrote an interactive JavaScript visualizer so we can see what each vision attention head is attending to on the image. 

The x and y axes of the attention head are the flattened image. The image is 50 patches in total, including the CLS token, which means the total attention head is a 50x50 square.

Upon initial inspection, the first layer’s attention heads are extremely geometric.

Corner Head, Edges Head, and Modulus Head

We can see attention heads’ scores specializing for specific patterns in the data, including what we call a Corner Head, an Edges Head, and a Modulus Head. This is fascinating because the flattened image does not explicitly contain corner, edge, or row/column information; detecting these patterns is emergent from training.

These findings echo the recent work of Gandelsman, Efros, and Steinhardt (2024) who identified property-specific attention heads in CLIP that specialize in concepts like colors, locations, and shapes. Our results suggest that such specialization is a more general property of vision transformer architectures, including vanilla models trained solely on image classification, and includes even more basic geometric properties like the coordinates of the image.

Interactive code here.


 

Video of the Corner Head

Activation Patching

Prisma has the activation patching functionality of TransformerLens. 

The Cat-Dog Switch

I found a single attention head (Layer 11, Head 4) wherein patching the CLS token of the z-matrix flips the computation from tabby cat to Border Collie. The CLS token in that z-matrix aggregates patch-level cat ear/face information from the attention pattern.

Our activation patching results demonstrate that this technique can be used to flip the model's prediction by targeting specific heads, providing a powerful tool for understanding and manipulating the model's decision-making process

This result resonates with Gandelsman, Efros, and Steinhardt (2024), who showed that knowledge of head-specific roles in CLIP can be used to manually intervene in the model's computation, such as removing heads associated with spurious cues.

Interactive code here. 

Toy Vision Transformers 

We are releasing nine tiny ViTs for testing (equivalent to TransformerLens’ gelu-1l) to better isolate behavior. These tiny ViTs were trained by Yash Vadi and Praneet Suresh.

The repo also contains training code to quickly train custom toy ViTs.

HookedViT

We currently support timm’s vanilla ViTs, TinyCLIP, the video vision transformer, and our own custom tiny transformers. More models will come soon based on demand! 

FAQ

Is multimodal mechanistic interpretability really that different from language?

Yes and no. Vision mech interpretability is like language mechanistic interpretability, but in a fun-house mirror. Both architectures are transformers, so many LLM techniques carry over. However, there are a few twists:

  • The typical ViT is not doing unidirectional sequence modeling. ViTs use bidirectional attention and predict a global CLS token, rather than predicting the next token in an autoregressive manner. (Note: There are autoregressive vision transformers with basically the same architecture as language, such as Image GPT and Parti, which do next-token image generation. However, as of February 2024, autoregressive vision transformers are not frequently used in the wild.)
  • Bidirectional attention vs causal attention. Language transformers have causal (unidirectional) attention-- i.e. there is an upper triangular mask on the attention, so that earlier tokens cannot attend to tokens in the future. However, the classical ViT has bidirectional attention. Thus, the ViT does not have the same concept of “time” as language transformers, and some of the original language mechanistic interpretability techniques break. It can be unclear which direction information is flowing. Induction heads  if they are present in vision, would look different from those in language to account for bidirectional attention.
  • CLS token instead of next token prediction/ autoregressive loss. For ViTs, a learnable CLS token, which is prepended to the input, gets fed into the classification head instead of the final token as in language. The CLS token accrues global information from the other patches through self-attention as all the patches pass through the net.
  • No canonical dictionary matrix. Vision is more ambiguous and lacks the standard dictionary matrix like the 50k one for language. For instance, a yellow patch on a goldfinch might represent "yellow," "wing," "goldfinch," "bird," or "animal," depending on the granularity, showing hierarchical ambiguity. An animal might be identified specifically as a "Border collie" or more generally as a "dog." Beyond hierarchy, ambiguity in vision also stems from cultural interpretations and the imprecision of language. Practically, ImageNet's 1000 classes serve as a makeshift "dictionary," but it falls short of fully encompassing visual concepts.
  • Additional hyperparameters. Patch size is a vision-specific hyperparameter, determining the size of the patches into which an image is divided. Using smaller patches increases accuracy but also computational load, because attention scales quadratically with patch number.
  • There is a zoo of vision transformers. Similar to language, vision transformers come in many forms. The most relevant are the vanilla ViT; CLIP, which is co-trained with text using contrastive loss; and DINO, which uses unlabeled data. There is also a gallery of loss functions used, including classification loss and masked autoencoder loss. Different losses may lead to different emergent symmetries in the model, although this is an open question. For a review of important ViT architectures, check out this survey.

If there is demand, I may write up a post giving a deeper and more theoretical take on the differences on language vs non-language mechanistic interpretability.

Why start with vision transformers?

Vision transformers have an extremely similar architecture to language transformers, so many of the existing techniques transfer over cleanly.

Diffusion models are the next obvious frontier, but there will be a larger conceptual leap in designing mechanistic interpretability techniques, largely due to their iterative denoising process. I’d be happy to collaborate on this with anyone who is serious about building strong conceptual foundations here.

Getting Started with Vision Mechanistic Interpretability 

  1. It is wise to first have a strong foundation in language mechanistic interpretability, so check out the loads of resources already on this forum. The ARENA curriculum is a good place to start.
  2. Check out these tutorial notebooks:
    1. Main ViT Demo - Includes direct logit attribution, activation patching, and other standard mech interp techniques. The notebook has a section that switches a ViT’s prediction from tabby cat to Border Collie with a minimum ablation.
    2. Emoji Logit Lens
    3. Interactive Attention Head Tour
    4. Optional:
      1. Mindreader. A viewer for maximally activating images on TinyCLIP, which gives a better intuition of how the model hierarchically processes information.
  3. Check out my brief paper list on Vision Transformers and Mechanistic Interpretability for general context.

How to get involved

  1. Use and contribute to the Prisma repoCheck out our open Issues
    1. Check out our Spring 2024 Roadmap
  2. Community. Join the Prisma Discord.
  3. Collaboration. Work on the Open Problems below.
    1. Post more Open Problems in the comments, or discuss what excites you in particular! 
  4. Get mentorship. I would be happy to mentor people on any of the Open Problems above, or a new one that you propose. However, please ensure your proposal is well-thought-out and includes preliminary results on any of the medium to hard problems. Feel free to reach out to me on the LW forum or the Prisma Discord.
  5. Funding for open source alignment and interpretability. As an open-source project, we value the support of our community and sponsors. Open source funding helps us cover expenses and invest in the project's growth. Feel free to reach out if you'd like to discuss potential funding opportunities! 

Open Problems in Vision Mechanistic Interpretability

Here are some Open Problems to get started. If inspired, you are encouraged to post your own in the comments, or comment on the ideas that most grab your attention.

Easy and Exploratory

  1. Explore interpretable neurons in Mindreader:
  2. Play with the attention head JavaScript visualizers to visually inspect attention heads.
    • Do any heads seem to pick up on the same patterns across images? 
    • Find 1-3 attention heads that appear to capture interesting features.
  3. Generate maximally activating images for a ViT/CLIP. The maximally activating images were used to make the Mindreader. I will provide sample code to identify the maximally activating images for each neuron on request.
    • Experiment with different models and datasets for finding maximally activating images.
    • Plot histograms of neuron activations to assess false positives and negatives.
  4. Run the logit lens on various images in the Main Demo notebook.
    • Are certain ImageNet classes identified earlier in the net than others?
  5. Load in a toy ViT into Prisma and examine the attention heads. Do you notice any invariant attention scores/patterns for certain types of images? 
    • Look at the attention heads of the 3-layer toy ViT.
    • Find canonical circuit patterns in the 1-4 layer ViTs. (Note these ViTs are patch size 16 whose attention is currently not feasible to render in the JavaScript attention head viewer.)

Expanding Techniques to New Architectures and Datasets

  1. Swap CLIP / Vanilla ViTs. Run the techniques above but swap the architecture.
    • Run the Main Demo on TinyCLIP instead of a vanilla ViT (which is pretrained only on 1k-ImageNet, instead of TinyCLIP’s image/text).
    • For the maximally activating images task (Problem 3), use a vanilla ViT instead of TinyCLIP.
  2. Video Vision Transformer. Run the techniques on a video vision transformer, which you can specify in the HookedViTConfig here (thanks Rob Graham for the idea and adapting the model to Prisma). 
    • How do you account for time as a new dimension? How is the model representing time? 
  3. Patch-level labels. Use patch-level labels for your dataset.
    • ImageNet is currently only labeled at the class level. Often, this is too coarse-grained. For example, a man eating a burrito, which is labeled “burrito” for ImageNet, has many components: the man, the burrito, his shirt, and the cutlery on the table. Patch-level labels were created by Rob Graham using SAM on ImageNet Images and give a boolean mask for every object in the image. See our Huggingface link for more details on patch-level labels and the SAM pipeline.
    • Generating maximally activating images (Problem 3) is a good candidate as a task for patch-level labels, because now your neurons’ labels will be much more fine-grained than merely using ImageNet class-level labels.

Deeper Investigations

  1. Finish off the cat/dog circuit from the Main Demo Notebook.
    • Run linear probes on the notebook/scratchpad token to the right of the Border Collie. How much cat vs dog information does the patch contain at each layer of the net?
    • How general is Attention Head 4, Layer 11 (the “Cat-Dog Decider Head”), which appears to be pushing the net’s decision from Border Collie to tabby cat? Does the attention head make the same decision for other images containing both cats and dogs? 
    • What is the full circuit for cats and dogs, according to the rigorous definition of a circuit?
  2. Attention patterns vs attention scores.
    • In the Interactive Attention Head Tour, the attention scores (pre-softmax) sometimes look more visually interpretable than the attention patterns (post-softmax). How do we connect our observations about the attention scores to our observations about the attention patterns?
  3. Tuned Lens. Use the Tuned Lens instead of the classical logit lens for the Emoji Logit Lens (train a probe for each block). Does the Tuned Lens improve interpretability? 
    • Recreate the “Layer-Level Logit Lens” plots in the Emoji Logit Lens notebook using Tuned Lenses instead of the current vanilla logit lens. Do the results corroborate each other? 
  4. Attention Ablations: Change the ViT’s bidirectional attention to unidirectional (like language models). How does this affect segmentation maps and information flow?
  5. Adding registers. Try adding registers as in Darcet et al (2023) and see what happens to the segmentation map.
  6. Superposition. Find superpositioned ViT neurons and disentangle the layer with SAEs
  7. Circuit Identification. Do full circuit identification for simple naturalistic data like MNIST.
  8. Disentanglement datasets. Run the disentanglement dataset dSprites through the model. Does the internal representation of the net show disentanglement? We have some pre-trained dSprites transformers by Yash Vadi here.
  9. Are there induction heads in vision? 
    1. Induction heads emerge with two or more layers in language. Is there an analog in vision (or the emergence of some other useful symmetry)?
    2. The symmetries may depend a lot on loss function (e.g. masked autoencoder losens may yield different symmetries than classification loss, although this is an open question).
    3. Vision attention is bidirectional, so it’s less obvious what “induction” means here. The canonical definition from language breaks down.
  10. Reverse engineer textual inversion - Add a new, made-up word to the CLIP text encoder and fine-tune with 4-6 corresponding images (e.g. you can finetune the model to label your face with your name). How do the model’s internals change?

Advanced Investigations

  1. Vision training dynamics and phase transitions.
    1. Detect canonical phase transitions in ViT training loss curves (analogous to induction head loss bumps).
    2. Praneet Suresh found that reconstruction loss, and visualizing the reconstructed image throughout training, is a convenient way to detect phase transitions.
  2. Compare the interpretability of CLIP vs a vanilla ViT.
    1. There is an unproven intuition that CLIP is more interpretable than a vanilla ViT. CLIP has better labels than a vanilla ViT. CLIP co-optimized with captions, which are more granular labels. For example, “the tabby cat sat on the window” (a CLIP-style label) is more precise than the plain ImageNet class “tabby cat.” The higher-quality labels may result in “better-factored” internal representations.
    2. How interpretable is a ViT in comparison to CLIP? You explore this question by checking maximally activating images on both models, and running the Logit Lens notebook on TinyCLIP instead of a ViT. Brainstorm your own techniques to compare the internal representations of the models.
  3. Patch information.
    1. How does local information in spatial patches propagate to the global CLS token? Could we get a circuit-level breakdown?
    2. Why do the patches in the upper layer of CLIP store global information? How are they computed from local patches?
  4. Reverse-Engineering Vision Transformer Registers
    1. Vision transformer registers (Darcet et. al (2023)) were a recent phenomenon in the vision community where adding blank tokens dramatically improved the attention maps of a ViT
    2. On a low-level, mechanistic basis, how does local information from patch tokens propagate to the global tokens, including the CLS token?
    3. Try removing a few register tokens, see what breaks. Are there heads that are specialized to attend to register tokens?
    4. Train an SAE on register tokens and see if you can disentangle what they store. Also run linear probes on register tokens.
  5. Visual reasoning. Use a dataset like CLEVR to see if CLIP does visual reasoning. Does CLIP have visual reasoning “circuits”? Note: CLIP may be bad at CLEVR, which is a complex dataset. Try creating your own simpler visual reasoning tasks (e.g. 2 apples + 2 apples = 4 apples), as a baseline.
    1. Create an open source dataset with very simple reasoning tasks for this purpose. This would be a service to the broader research community!
  6. Replicate the results of Gandelsman et al. Do you notice attention heads specializing for certain semantic groups? 
  7. Other architectures
    1. Explore models like Flamingo.
    2. Lay groundwork to analyze diffusion models.

Acknowledgements

Thank you to this most excellent mosaic of communities.

Thank you to Praneet Suresh, Rob Graham, and Yash Vadi, and the other core contributors to the Prisma Repo. Thank you to my PI, Blake Richards, and the rest of our lab at Mila for their support and feedback.

Thank you to Neel Nanda for guidance in bringing mechanistic interpretability to another modality, to Joseph Bloom for your advice on building a repo, to Arthur Conmy for coining the term “dogit lens,” and to the rest of the MATS community for your feedback.

Thank you to the Prisma group at Mila for your feedback, including Santoshi Ravichandran, Ali Kuwajerwala, Mats L. Richter, and Luca Scimeca; members of LiNCLab, including Arna Ghosh and Dan Levenstein; and members of CERC-AAI lab, including Irina Rish and Ethan Caballero. Thank you to Karolis Ramanauskas, Noah MacCallum, Rob Graham, and Romeo Valentin for your feedback on the tutorial notebooks. 

Finally, thank you to the South Park Commons community for your support, including Ker Lee Yap, Abhay Kashyap, Jonathan Brebner, and Ruchi Sanghvi and Aditya Agarwal.

This research was generously supported by Blake Richards’s lab, which was funded by the Bank of Montreal; NSERC (Discovery Grant: RGPIN-2020-05105; Discovery Accelerator Supplement: RGPAS-2020-00031; Arthur B. McDonald Fellowship: 566355-2022); CIFAR (Canada AI Chair; Learning in Machine and Brains Fellowship); and a Canada Excellence Research Chair Award to Prof. Irina Rish; and by South Park Commons. This research was enabled in part by support provided by Calcul Québec and the Digital Research Alliance of Canada. We acknowledge the material support of NVIDIA in the form of computational resources.



 

New Comment
2 comments, sorted by Click to highlight new comments since:

I think working on mechanistic intepretability in a variety of domains, architectures, and modalities seems like a reasonable research diversification bet.

However, it feels pretty odd to me to describe branching out into other modalities as crucial when we haven't yet really done anything useful with mechanistic interpretability in any domain or for any task.

You say:

With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.

And on X/twitter:

Frontier models are multimodal, and it's increasingly clear that mechanistic interpretability can't only study language models.

But, I feel like the situation is relatively analogous to:

Fusion power plants will need to be built in many countries, and it's increasing clear that fusion power plant construction can't only study building fusion power in the US.

Like yeah, you'll eventually need to handle non-language modalities and you should probably sanity check that they aren't key additional blockers with the methodology, but also why would there be key methodologies that mean it can solve our problems in the language case but note the vision/multimodal case? And the main obstacle is demonstrating basic technical feasibility, not branching out?

Again, I'd like to stress that studying a variety of cases with mech interp seems like a reasonable research diversification bet.

(And I don't want to be the language police here, just pushing back a bit on the implicit vibes.)

There is another argument that could be made for working on other modalities now: there could be insights which generalize across modalities, but which are easier to discover when working on some modalities vs. others.

I've actually been thinking, for a while now, that people should do more image model interprebility for this sort of reason.  I never got around to posting this opinion, but FWIW it is the main reason I'm personally excited by the sort of work reported here.  (I have mostly been thinking about generative or autoencoding image models here, rather than classifiers, but the OP says they're building toward that.)


Why would we expect there to be transferable insights that are easier to discover in visual domains than  textual domains?  I have two thoughts in mind:

First thought:

The tradeoff curve between "model does something impressive/useful that we want to understand" and "model is conveniently small/simple/etc." looks more appealing in the image domain.

Most obviously: if you pick a generative image model and an LLM which do "comparably impressive" things in their respective domains, the image model is going to be way smaller (cf.).  So there are, in a very literal way, fewer things we have to interpret -- and a smaller gap between the smallest toy models we can make and the impressive models which are our holy grails. 

Like, Stable Diffusion is definitely not a toy model, and does lots of humanlike things very well.  Yet it's pretty tiny by LLM standards.  Moreover, the SD autoencoder is really tiny, and yet it would be a huge deal if we could come to understand it pretty well.

Beyond mere parameter count, image models have another advantage, which is the relative ease of constructing non-toy input data for which we know the optimal output.  For example, this is true of:

  • Image autoencoders (for obvious reasons).
  • "Coordinate-based MLP" models (like NeRFs) that encode specific objects/scenes in their weights.  We can construct arbitrarily complex objects/scenes using 3D modeling software, train neural nets on renders of them, and easily check the ground-truth output for any input by just inspecting our 3D model at the input coordinates.

By contrast, in language modeling and classification, we really have no idea what the optimal logits are.  So we are limited to making coarse qualitative judgments of logit effects ("it makes this token more likely, which makes sense"), ignoring the important fine-grained quantitative stuff that the model is doing.

None of that is intrinsically about the image domain, I suppose; for instance, one can make text autoencoders too (and people do).  But in the image domain, these nice properties come for free with some of the "real" / impressive models we ultimately want to interpret.  We don't have to compromise on the realism/relevance of the models we choose for ease of interpretation; sometimes the realistic/relevant models are already convenient for interpretability, as a happy accident.  The capabilities people just make them that way, for their own reasons.

The hope, I guess, is that if we came pretty close to "fully understanding" one of these more convenient models, we'd learn a lot of stuff a long the way about how to interpret models in general, and that would transfer back to the language domain.  Stuff like "we don't know what the logits should be" would no longer be a blocker to making progress on other fronts, even if we do eventually have to surmount that challenge to interpret LLMs.  (If we had a much better understanding of everything else, a challenge like that might be more tractable in isolation.)

Second thought:

I have a hunch that the apparent intuitive transparency of language (and tasks expressed in language) might be holding back LLM interpretability.

If we force ourselves to do interpretability in a domain which doesn't have so much pre-existing taxonomical/terminological baggage -- a domain where we no longer feel it's intuitively clear what the "right" concepts are, or even what any breakdown into concepts could look like -- we may learn useful lessons about how to make sense of LLMs when they aren't "merely" breaking language and the world down into conceptual blocks we find familiar and immediately legible.

When I say that "apparent intuitive transparency" affects LLM interpretability work, I'm thinking of choices like:

  • In circuit work, researchers select a familiar concept from a pre-existing human "map" of language / the world, and then try to find a circuit for it.
    • For example, we ask "what's the circuit for indirect object identification?", not "what's the circuit for frobnoloid identification?" -- where "a frobnoloid" is some hypothetical type-of-thing we don't have a standard term for, but which LMs identify because it's useful for language modeling.
    • (To be clear, this is not a critique of the IOI work, I'm just talking about a limit to how far this kind of work can go in the long view.)
  • In SAE work, researchers try to identify "interpretable features."
    • It's not clear to me what exactly we mean by "interpretable" here, but being part of a pre-existing "map" (as above) seems to be a large part of the idea.
    • "Frobnoloid"-type features that have recognizable patterns, but are weird and unfamiliar, are "less interpretable" under prevailing use of the term, I think.

In both of these lines of work, there's a temptation to try to parse out the LLM computation into operations on parts we already have names for -- and, in cases where this doesn't work, to chalk it up either to our methods failing, or to the LLM doing something "bizarre" or "inhuman" or "heuristic / unsystematic."

But I expect that much of what LLMs do will not be parseable in this way.  I expect that the edge that LLMs have over pre-DL AI is not just about more accurate extractors for familiar, "interpretable" features; it's about inventing a decomposition of language/reality into features that is richer, better than anything humans have come up with.  Such a decomposition will contain lots of valuable-but-unfamiliar "frobnoloid"-type stuff, and we'll have to cope with it.

To loop back to images: relative to text, with images we have very little in the way of pre-conceived ideas about how the domain should be broken down conceptually.

Like, what even is an "interpretable image feature"?

Maybe this question has some obvious answers when we're talking about image classifiers, where we expect features related to the (familiar-by-design) class taxonomy -- cf. the "floppy ear detectors" and so forth in the original Circuits work.

But once we move to generative / autoencoding / etc. models, we have a relative dearth of pre-conceived concepts.  Insofar as these models are doing tasks that humans also do, they are doing tasks which humans have not extensively "theorized" and parsed into concept taxonomies, unlike language and math/code and so on.  Some of this conceptual work has been done by visual artists, or photographers, or lighting experts, or scientists who study the visual system ... but those separate expert vocabularies don't live on any single familiar map, and I expect that they cover relatively little of the full territory.

When I prompt a generative image model, and inspect the results, I become immediately aware of a large gap between the amount of structure I recognize and the amount of structure I have names for. I find myself wanting to say, over and over, "ooh, it knows how to do that, and that!" -- while knowing that, if someone were to ask, I would not be able to spell out what I mean by each of these "that"s.

Maybe I am just showing my own ignorance of art, and optics, and so forth, here; maybe a person with the right background would look at the "features" I notice in these images, and find them as familiar and easy to name as the standout interpretable features from a recent LM SAE.  But I doubt that's the whole of the story.  I think image tasks really do involve a larger fraction of nameless-but-useful, frobnoloid-style concepts.  And the sooner we learn how to deal with those concepts -- as represented and used within NNs -- the better.