Review

I'm writing this post to discuss solutions to the September challenge, and present the challenge for this October. 

If you've not read the first post in this sequence, I'd recommend starting there - it outlines the purpose behind these challenges, and recommended prerequisite material.

October Problem

The problem for this month is interpreting a model which has been trained to sort a list. The model is fed sequences like:

[11, 2, 5, 0, 3, 9, SEP, 0, 2, 3, 5, 9, 11]

and has been trained to predict each element in the sorted list (in other words, the output at the SEP token should be a prediction of 0, the output at 0 should be a prediction of 2, etc).

The model is attention-only, with 1 layer, and 2 attention heads per layer. It was trained with layernorm, weight decay, and an Adam optimizer with linearly decaying learning rate. This model should be much easier to interpret than previous models in the sequence, so I'm expecting solutions to have entirely reverse-engineered this model, as well as presenting adversarial examples and explaining how & why they work.

You can find more details on the Streamlit page. Feel free to reach out if you have any questions!

 

September Problem - Solutions

You can read full solutions on the Streamlit page, or on my personal website. Both of these sources host interactive charts (Plotly and Circuitsvis) so are more suitable than a LessWrong/AF post to discuss the solutions in depth. However, I've presented a much shorter version of the solution below. If you're interested in these problems, I'd recommend having a try before you read on!


To calculate each digit Ci, we require 2 components - the sum and the carry. The formula for Ci is (sum + int(carry == True)) % 10, where sum is the sum of digits Ai + Bi, and carry is whether A(i+1) + B(i+1) >= 10. (This ignores issues of carrying digits multiple times, which I won't discuss in this solution.)

We calculate the carry by using the hierarchy . An attention head in layer 0 will attend to the first number in this hierarchy that it sees, and if that number is  then that means the digit will be carried. There are also some layer 0 attention heads which store the sum information in certain sequence positions - either by attending uniformly to both digits, or by following the reverse hierarchy so it can additively combine with something that follows the hierarchy. Below is a visualisation of the QK circuits for the layer 0 attention heads at the positions which are storing this "carry" information, to show how they're implementing the hierarchy.

At the end of layer 0, the sum information is stored in the residual stream as points around a circle traced out by two vectors, parameterized by an angle . The carry information is stored in the residual stream as a single direction.

The model manages to store the sum of the two digits modulo 10 in a circular way by the end of layer 0 (although it's not stored in exactly the same way it will be at the end of the model). We might guess the model takes advantage of some trig identities to do this, although I didn't have time to verify this conclusively.

The heads in layer 1 mostly process this information by self-attending. They don't seem as important as heads 0.1 and 0.2 (measured in terms of loss after ablation), and it seems likely they're mainly clearing up some of the representations learned by the layer 0 heads (and dealing with logic about when to carry digits multiple times).

By the end of layer 1, the residual stream is parameterized by a single value: the angle . The digits from 0-9 are evenly spaced around the unit circle, and the model's prediction depends on which angle they're closest to. Two visualisations of this are shown below: (1) the singular value decomposition of the unembedding matrix, and (2) the residual stream projected onto these first two singular directions.

Best Submissions

There were a few strong solutions to this problem, although fewer than for last month (I attribute this to the problem being quite challenging and difficult to get traction on, which as mentioned I hope we'll fix going forwards). The best solution came from Andy Arditi, who did some great work with Fourier transforms to identify the most important directions in various weight matrices in the model. Honorable mention to Cary Jin, who provided the best explanation for how the model decides when to carry a digit.

Best of luck for this and future challenges!

New Comment