Circuit discovery has been restricted to the single-forward-pass setting, because the algorithms to attribute changes in behavior to particular neurons / SAE features need gradients, and you can't take a gradient through the sampled chain of thought. Or... can you?
It turns out taking gradients through random discrete actions is an essential part of reinforcement learning. We can estimate the gradients of an expectation over CoTs, with respect to the features, using the score function estimator. We can combine this with integrated gradients to produce a version of EAP-IG which works through the averages of chains of thought.
The task we attempt to do is circuit discovery, defined by Conmy et al. Formally, for each subgraph of a computational DAG which represents a neural network, we want to find which subgraph is responsible for a behavior. We do this by defining a 'task loss', which compares the performance of the subgraph to the performance of the whole network. Let that loss be and then be the clean and corrupted datapoints. The loss for a single pair of data points is:
.
the overall loss of a circuit is simply the average of this loss over all datapoints and corrupted datapoints .
To connect this to integrated gradients, we introduce variables , which control whether an edge is included in or not. The scalar controls whether the th edge (or node) is included in or not. That is, the value of the th edge is replaced by:
If we set the edge is not included, i.e., it has the value it would get under the corrupted input. If we set , then the edge has the value it gets from running the comptuational graph forward.
Our first ingredient can be any gradient-based method for circuit discovery. I've chosen to focus on EAP with Integrated Gradients because it's still the circuit discovery algorithm with the best balance of simplicity and performance. You could make a CoT version of Attribution Patching as well.
To attribute behavioral loss to some configuration of edges (concrete value of ), we compute the gradient of the task loss with respect to , which determines whether we include an edge or not: . In EAP-IG, we average this for z between 0 and 1, for all edges of the graph simultaneously. If we interpolate at points in between 0 and 1, the attribution for a single data point is:
for loss defined using the task loss, the full graph and the graph corrupted by z: . Notice that we average over , so we take , , and intermediate points.
Now for the second ingredient: policy gradients. Suppose I have a policy and some loss function , which depends on trajectories of actions from the policy. The policy is parameterized by some parameters . The expected loss over trajectories is:
We'd like to take gradients . These are tricky because the loss does not depend on directly. Instead, it depends on through the distribution over actions in the trajectory, which determines the expectation of .
The policy gradient theorem tells us that the gradient is the expected gradient of the log-probability of actions, weighed by how big is:
Let's take this formula as given. I explain it in these two posts, but one the keys to it is that we can swap integral and differential signs, and that by the chain rule.
It's worth expanding on what policy gradients are for, and why they're useful. Policy gradients give us the gradient of how the average outcome over many trajectories varies, when we vary the parameters . It's not for a particular rollout, it's for the whole distribution. As such, any gradients that we take include the effect of the CoT on the output.
The function can be a function of any number of steps in the trajectory. It can be just of the final step (if we're looking at e.g. full CoTs and a single token answer). It can be of many steps at the end (if we're considering a CoT + whether an answer matches the truth, as rated by some other model). It can be basically anything. That's why it's the workhorse of modern LLM RL: PPO, GRPO, etc. are all based on policy gradients.
So if we want to attribute behavior sampled through CoTs to parts of the network, we can just use both of these simultaneously.
We define a task loss that depends on the tokens until now, the output of the original model and the output of the new model. The behavior that we want to study (and find sub-circuits for) is thus the expectation when sampling from the corrupted subgraph:
We sample from autoregressively: we start with and corrupt it to find , that lets us compute and sample from it; then we corrupt to get , etc. I've abbreviated this in the expression above as .
Now we see how we can use both elements.
Integrated gradients: to attribute the behavior through the CoT to components z of the model, we simply need to take interpolated at various points for between 0 and 1. That is, we want:
We've removed the dependence of Attrib on the data points because we're sampling things from the model, presumably with some context. But we could average over some contexts, why not.
Policy gradients: The gradient is with respect to a probability distribution. To compute it, we need to use the policy gradient theorem:
To estimate this expectation, we sample a bunch of CoTs from and average their values of .
We can just plug this in into the previous equation, and there we have it: attribution to circuit components over chains of thought.
This method is very flexible, because it's just the old EAP-IG, except now we can also compute gradients over probability distributions. The can be assigned to neurons, attention heads, SAE components, anything.
They don't even have to be constant across time. We can have separate components for the gradients at a particular time-step to study the effect of a component at that time step. The same is possible if the 'time step' moves depending on where a token falls, but I think you're missing some of the effect in that case.
I haven't implemented this. It's tricky with open-source packages because you can't just interpolate between the original and corrupted inputs in vLLM, and Huggingface only has quadratic sampling. To make it really efficient, it's also nice to be able to compute the gradients w.r.t at every step using the same version of KV-cache attention that you computed.
I might fill this gap with open-source tooling myself, especially if I can get funding for a month to do it.