I think this is a valuable contribution. I used to think that Demix-like techniques would dominate in this space because in principle they could achieve close-to-zero alignment tax, but actually absorption is probably crucial, especially in large pre-training runs where models might learn with very limited mislabeled data.
I am unsure whether techniques like gradient routing can ever impose a <10x alignment tax, but I think like a lot can be done here (e.g. by combining Demix and gradient routing, or maybe by doing something more clean, though I don't know what that would look like), and I would not be shocked if techniques that descend from gradient routing became essential components of 2030-safety.
I am confused about Table 1's interpretation.
Ablating the target region of the network increases loss greatly on both datasets. We then fine-tune the model on a train split of FineWeb-Edu for 32 steps to restore some performance. Finally, we retrain for twenty steps on a separate split of two WMDP-bio forget set datapoints, as in Sheshadri et al. (2024), and report the lowest loss on the validation split of the WMDP-bio forget set. The results are striking: even after retraining on virology data, loss increases much more on the WMDP-bio forget set (+0.182) than on FineWeb-Edu (+0.032), demonstrating successful localization and robust removal of virology capabilities.
To recover performance on the retain set, you fine-tune on 32 unique examples of FineWeb-Edu, whereas when assessing loss after retraining on the forget set, you fine-tune on the same 2 examples 10 times. This makes it hard to conclude that retraining on WMDP is harder than retraining on FineWeb-Edu, as the retraining intervention attempted for WMDP is much weaker (fewer unique examples, more repetition).
Thanks for pointing this out! Our original motivation for doing it that way was that we thought of the fine-tuning on FineWeb-Edu as a "coherence" step designed to restore the model's performance after ablation, which damaged it a lot. We noticed that this "coherence" step helped validation loss on both forget and retain. However, your criticism is valid, so we have updated the paper so that we retrain on the training distribution (which contains some of the WMDP-bio forget set). We still see that while the loss on FineWeb-Edu decreases to almost its value before ablation, the loss on the WMDP-bio forget set is around 0.1 nats above its value before ablation, showing that it is harder to retrain virology after ablation than just FineWeb-Edu data. Since we re-train on the training distribution (N=12 times with different data), we would expect that both losses would be retrainable at roughly the same rate, but this is not the case, showing that localization and then ablation has an effect.
We present gradient routing, a way of controlling where learning happens in neural networks. Gradient routing applies masks to limit the flow of gradients during backpropagation. By supplying different masks for different data points, the user can induce specialized subcomponents within a model. We think gradient routing has the potential to train safer AI systems, for example, by making them more transparent, or by enabling the removal or monitoring of sensitive capabilities.
In this post, we:
If you’re interested in further discussion or details, check out the paper and its extensive appendices, or the code for gradient routing.
Gradient routing
Gradient routing allows the user to configure what data (at the level of tokens, documents, or any other feature of the data) causes learning updates where in a neural network (parameters, activations, modules). In full generality, this configuration is achieved by assigning weights to every edge in the computational graph, for every data point. These weights are then multiplied by the gradients that get backpropagated through these edges. This is formalized in the paper.
In practice, we implement gradient routing by applying stop-gradient masks selectively in order to stop the flow of gradients during backprop:
Code: The user specifies the
gradient_masks
corresponding to each batch of datax
.Note: We say “route X to Y” to mean “limit gradient updates on data X to region Y of the network.”
MNIST latent space splitting
We train an MLP-based autoencoder to encode images of handwritten digits into vectors with 32 elements, then decode them back into full images. Our goal is to “split” the latent space so that half of it corresponds to one subset of digits, and the other half corresponds to others, such that it is not possible to decode digits from the “wrong” half. This task is difficult: an autoencoder trained only on a subset of digits learns a latent space from which other digits can be decoded accurately (a form of zero-shot generalization). It is a non-linear kind of concept erasure.
To achieve splitting, we route digits 0-4 through the top half of the encoding and digits 5-9 through the bottom half of the encoding. We apply L1 regularization to the encoding to encourage specialization. The result: a latent space which represents 0-4 in the bottom dimensions and 5-9 in the top dimensions!
Localizing capabilities in language models
Steering scalar
Much interpretability work (most notably, on SAEs) seeks to identify meaningful directions in the space of a model’s internal activations. What if we could specify some of those dimensions at training time, instead of having to search for them afterward? We did this by routing the token
_California
to the 0th dimension of the residual stream. Interestingly, the entries of the Transformer unembedding matrix closest to the_California
token were all highly related:California
,_Californ
,_Oregon
,_Colorado
,_Texas
,_Florida
,_Arizona
,_Sacramento
, and_Los
, etc, indicating that our localization had a broader effect on the model’s training than that single token.Robust unlearning
Our most extensive experiments are on the removal of capabilities in language models when data labels are limited.
We want the model to be able to predict some data (the “retain” data) but not other data (the “forget” data). The key idea: if we route forget data to particular regions of the network, then delete those parts of the network, we must have robustly removed those capabilities. One scheme for achieving this is called ERA (Expand-Route-Ablate).
We compare gradient routing to three baselines on a made-up unlearning problem based on GPT-generated children’s stories:
We measure performance at different proportions of random forget data labeling. Unlabeled forget data is treated as retain data for training purposes. The idea is to simulate frontier AI applications where perfectly labeling all training data is infeasible.
When labels are limited, we observe that Expand, Route, Ablate outperforms other methods, including data filtering. By influencing where the model learns features, gradient routing allows limited labels to scale to unlabeled examples, despite the training loss selecting for models that perform well on the unlabeled data. In contrast, DEMix also localizes learning updates (to MLP expert submodules), but because only one expert (per layer) participates in each forward pass, the features learned based on the labeled forget samples are not able to “absorb” those from the unlabeled forget samples.
Unlearning virology
We apply the same idea to unlearn virology concepts in a larger (0.7B parameter) model by routing gradients on 20 tokens:
COVID
,_COVID
,RNA
,_infections
,DNA
,_genome
,_virus
,_gene
,_viruses
,_mutations
,_antibodies
,_influenza
,_bacteria
,PCR
,_cell
,_herpes
,_bacterial
,_pathogens
,_tumor
, and_vaccine
.After retraining on a small number of virology examples, ERA causes a 0.18-nat increase in CE loss on virology data while only incurring a 0.03-nat increase in CE on non-virology data. This loss increase applies even when considering only virology tokens that we did not route on, demonstrating that ERA is not merely down-weighting tokens.
Scalable oversight via localization
Scalable oversight means effectively training a model despite limited ability to supervise its behavior (e.g., with data labels or a reward function).
We use gradient routing to train a policy to reach two types of (terminal) squares in a gridworld,
diamond
andghost
. In our gridworld, some squares have oversight and some don’t. When a square has oversight, we (the supervisor) can observe its type. When a square doesn’t have oversight, we can observe that it was reached, but not its type. The aim is to produce a policy that navigates todiamond
and notghost
, even when oversight is limited to a small portion of episodes. Naturally, this is challenging, because it’s difficult to reinforce the desired behavior when we can’t observe it!In real-world problems, label availability may be correlated with task type or difficulty, and a policy might not generalize correctly because of this fact. To model this fundamental difficulty, our gridworld environment allows the policy to condition on the presence or absence of oversight. In terms of alignment threat models, this conditioning could represent an AGI with situational awareness that changes its behavior based on its predictions about human overseers’ ability to understand its behavior. Or, in more prosaic terms, it could represent a model that has learned to exploit systematic biases in training data labels. As a consequence of our problem setup, there exists no reward function which induces optimal diamond-seeking behavior.
We circumvent this difficulty by using gradient routing to supervise the model mechanistically: the policy is trained to seek any terminal square, but as we train, we use gradient routing to induce specialized submodules.
The first layer in our policy network is a mixture of experts (MoE) layer with two experts. During training, we aim to make each expert responsible for reaching one of the aforementioned squares. We do this by routing policy gradients through the expert that corresponds to the type of square the policy reached in the given episode. When we have no oversight, we propagate the gradients through both experts.
By changing the gate value, we are able to steer the model surprisingly effectively.
This allows us to get much higher returns than other methods when access to oversight is severely limited. This is achieved by first training the mixture policy with gradient routing, then intervening to activate the
diamond
module during evaluation.At low oversight levels, the gradient-routed MoE dramatically outperforms two baselines:
ghost
, in which case -1 is used.Notes: besides being data-inefficient, data filtering is not guaranteed to converge to optimal behavior. This is because the presence of oversight is part of the environment, rather than being independent of it (as in random labeling). Indeed, we see the lack of convergence in the learning curves at 10% oversight given here. The naive training policy learns to exploit the limits of oversight, going to
ghost
100% of the time when it is the closer destination and not under oversight. By contrast, the gradient-routed policy only walks into theghost
“by accident” when on the way to thediamond
.Key takeaways
Absorption
We find evidence that gradient routing applied to limited data is sufficient to localize capabilities relevant to a broader superset of data:
_California
token localized California- and state-related features into a similar direction.Absorption means that gradient routing provides a qualitatively different kind of supervision than loss-based methods. For example, in an LLM, intervening on the loss for the single token
_California
would likely have negligible effects on other tokens. However, routing_California
to a location induces the model to learn other features there as well, allowing all of them to be intervened on. This effect grants gradient routing unique affordances which we hope will enable novel alignment or control methods.Localization avoids Goodharting
Goodharting happens when imperfect labels are used to modify the training objective in an attempt to produce desirable behavior; but, instead of desirable behavior, a model learns to exploit the limits of the labels; so, the model performs better at the training objective but in an undesired way. See this list of examples or this blogpost for more.
Gradient routing provides a principled way to avoid Goodharting. By using imperfect labels (possibly, based on a non-robust specification) to shape model internals, gradient routing leaves the behavioral objective unchanged. In doing so, it avoids the possibility of the labels being exploited. Instead of attempting to suppress useful capabilities, we let the model learn them, but localize where that learning happens. After training, that component can be monitored or intervened on (e.g. deleted).
Key limitations
We still aren’t sure about best practices for applying gradient routing. In our unlearning experiments, careful hyperparameter tuning was needed to achieve localization without incurring a large hit to retain loss. There is a lot to tune: which tokens to route on, how much of the network to route to, what learning rates to use (e.g. whether to use negative learning rates), and regularization. This kind of tuning might be too costly to attempt for larger models. Furthermore, despite this tuning, we still see a meaningful hit to retain set performance when applying ERA. We think this hints at a flaw in our application of the method to unlearning, and are exploring improvements.
Another challenge is that some capabilities are entangled, in the sense that there may be a strong inductive bias for a model to “bundle” their learning together. So, attempting to separate particular capabilities into separate submodules means fighting an uphill battle that manifests in an increased alignment tax. We saw this in MNIST (and to a lesser extent in our brief follow-up experiments on CIFAR classification), where inducing split representations for digits 0-4 vs. 5-9 required a heavily L1 penalty applied to the encoding. This isn’t a limitation of gradient routing per se. Rather, it is the unsurprising fact that certain kinds of structure in neural nets are both (a) preferable to us and (b) unnatural with respect to neural net inductive biases, and hence costly to induce by any means. For example, it is not possible to induce a specialized encoding in an MNIST autoencoder merely by filtering the training data (see MNIST ablations, table 2, setting 8).
Alignment implications
Robust removal of harmful capabilities
Conventional unlearning methods are more about suppressing behavior than unlearning information or internal circuitry related to that behavior (Deeb & Roger, 2024; Sheshadri et al., 2024; Łucki et al., 2024). Gradient routing offers a way around this problem by training models with specialized subcomponents that can be ablated for capability removal.[1]
Scalable oversight
By exploiting the absorption property, perhaps we can purposefully allow “bad shards / motivational circuits” to form during training, only to later ablate them. That’s how we think of our toy RL results, at least — don’t try to stop the model from going to ghost, just localize the tendency and ablate it! This provides a simplistic first example of how localization can scale limited labels to get good behavior. This is only the first step, though. We are excited to explore the implications of training methods that can sidestep Goodharting. In terms of our proposed technique, we wonder about the:
Specialized AI
One way to avoid existential risk is to not “build god.” As an alternative to building god, we might tailor general AI systems towards specific tasks by removing unnecessary capabilities or knowledge. We imagine:
AI systems could be deployed using a “principle of least capability”. For each AI application or end user, we ask: What “risky” capabilities are required? We then ablate the unnecessary ones. Furthermore, if we can localize dangerous capabilities, we can demonstrate that the model cannot reliably and inconspicuously perform certain harmful behaviors (like domination of humans). For example, such incapacities could be demonstrated via adversarial fine-tuning attacks.
Conclusion
Gradient routing enables data-driven supervision of neural net internals. This supervision works even when data labeling is imperfect, a property that seems relevant to hard problems in AI safety. If it works, we can imagine many possible applications.
We think the most likely failure mode of the gradient routing agenda is that the alignment tax of inducing useful structure in neural nets is too high to be competitive with conventional training methods. This tax could be because the desired structure is "unnatural" with respect to neural net inductive biases. Or, the tax could be because gradient routing itself is an ineffective way of inducing useful structure. We expect to get a better sense of this soon by improving on ERA for unlearning and developing our ideas about RL applications.
Optimistically, gradient routing might enable a new era of controllable model internals-- a shift away from the black box paradigm. Neural networks need not be random-seeming programs which happen to generalize well! Instead, perhaps gradient routing can provide a “bittersweet” lesson: that while it may be impractical to design white-box AI systems, the high-level organization of capabilities in neural nets can be supervised effectively.
Team Shard has a strong track record, and we’re always looking for enthusiastic new scholars. Since 2023, we’ve introduced steering vectors, gradient routing, retargeted the search of an RL policy, and introduced an unsupervised method to elicit latent capabilities from a model. If you want to work on Team Shard in MATS 8.0 (next summer), apply in spring 2025.
This work was conducted as part of MATS 6 and would not have been possible without the program's support. Bryce Woodworth was especially helpful with planning, team dynamics, and feedback on the paper. Please see the paper for further acknowledgments.
Gradient routing expands on work like SISA. Gradient routing is more sample-efficient due to parameter sharing and is applicable under partial labeling due to absorption.