This is an interim research report on role embeddings, an approach to make language models more robust to many-shot jailbreaks and prompt injections by adding role information at every token position in the context rather than just at special token delimiters. We credit Cem Anil for originally proposing this idea.
In our initial experiments on Llama 3, we find that role embeddings mitigate many-shot jailbreaks more effectively than fine-tuning alone without degrading general model capabilities, which demonstrates that this technique may be a viable way to increase LLM robustness. However, more work should to be done to find the optimal set of hyperparameters and fully understand any side-effects of our proposed approach.
By default, chat LLMs are trained (during instruction fine-tuning and RLHF) using a particular prompt format that distinguishes different message "roles". Almost all chat LLMs accept some version of system, user, and assistant. A separate role may also be used to indicate tool outputs for tool-use enabled models.
The prompt format plays an important role in LLM post-training. The model learns to interpret text from different roles differently. In particular:
(There is also the related concept of data-instruction separation—an LLM should be able to tell which part of its context is "data" it should operate on but not necessarily follow, and which part of its context contains the actual "instructions". The concept of roles discussed in this post can apply similarly in this situation, where a "role" could distinguish instructions from data.)
Notably, by using the prompt format in non-standard ways, it's possible to circumvent safety training. A particularly effective jailbreak is when the previous context appears to demonstrate the assistant role doing an undesired behavior many times. Updating on in-context evidence is an important LLM capability that is generally rewarded by most training tasks—if the in-context evidence that the assistant is exhibiting trait x is strong enough, you'll observe the model continuing to exhibit trait x.
This is the phenomenon of many-shot jailbreaking (first described by Anil et al). Given enough in-context demonstrations of harmful behavior, the model will continue producing harmful behavior.
What happens if you try to prevent prompt format misuse? A naive approach is simple to implement: only allow users to input tokens from a specific set while reserving a few special tokens for the prompt format.
This is how the Llama prompt format works. Role tags are enclosed within special tokens, e.g. <|start_header_id|>user<|end_header_id|>
, where <|start_header_id|>
, <|end_header_id|>
are token IDs that never appear in natural text. In addition, each role message ends with <|eot_id|>
.
You can imagine a version of Llama behind an API that ensures that no user input will be encoded to a special token. You could hope that this way the user will be unable to make their messages look like they came from the assistant role.
But your hope would be misplaced. Instead, many properties of text will cause that text to appear as if it came from the assistant role, even if the standard prompt format is not being applied. LLMs are good enough at generalization that they will not ignore alternatively presented evidence. For example, you can embed an alternative format within the user message and effectively teach the model a new prompt format in context, which it will interpret in a similar way to its standard format.
You could also hope that simply training the model on (successful responses to) examples of such attacks would mitigate them. However, this is only partially the case. Supervised fine-tuning and reinforcement learning on examples that contain instances of many-shot jailbreaks (MSJs) only change the intercept and not the slope of the power-law relationship between number of demonstrations and undesired response likelihood.
What if there was a more robust way to indicate text origin than special-token formats? Unlike standard prompt formats, role embeddings aim to add role information at every token position.
The basic version of this idea is simply a new embedding component. Besides semantic and positional information, we also add a vector that indicates the role associated with that token. In addition, we consider a more "intrusive" variant where this information is added at multiple layers of the residual stream, aiming to make it even more salient.
We will refer to this vector addition process as "coloring"[1] in the sense of "coloring in the tokens to indicate what role they come from". This is meant to distinguish this technique from activation steering, where the intervention vector is selected from a rich space of linear semantic representations. For role embeddings, we instead use a simple and small discrete set of (usually orthogonal) "coloring" vectors that the model is trained to interpret as role signal.
We focus on the many-shot jailbreak attack testbed. Being able to mitigate the power-law slope is a sign we're particularly interested in because standard fine-tuning approaches have not been able to achieve this.
Our dataset consists of:
Many-shot jailbreaks
Harmless conversations
Like in Anil et al., we see a roughly linear trend in log-log space between number of MSJ shots and NLL of jailbreak response. The NLL of the recovery responses stays roughly constant.
We add a “user” vector to the token embeddings at every user token and an “assistant” vector at every assistant token. The magnitude of the added vector is scaled to be proportional to the embedding norm at that token position (this scale factor is a hyperparameter).
As an initial attempt, we try scale factor = 1, user vector = embedding(“user”)
, assistant vector = embedding(“assistant”)
. By embedding()
here we mean the literal embedding matrix entry for that token.
These are the harmful and mean MSJ jailbreak slopes before and after the intervention without any fine-tuning:
Regular conversations and MSJ recoveries:
As we can see above, the interventions:
Next, we try fine-tuning (with LORA) on the training set under the coloring intervention, and then repeat the evals above. As a control, we also try fine-tuning on the same training data without the coloring intervention.
We find that given fine-tuning, we can preserve the benefits of the pure inference-time intervention without incurring any of the costs.
Both embedding coloring and residual stream coloring help flatten the MSJ power law more than control fine-tuning. Residual stream coloring is more effective than embedding coloring.
Intervention:
embed(“user”)
to the residual stream at every layer (besides the final layer, which we skip) and every user-written token position (with a relative multiplier of 0.5) while projecting out embed(“assistant”)
, andembed(“assistant”)
to the residual stream at every layer and every assistant-written token position while projecting out embed(“user”)
.Mathematically:
Where:
This intervention successfully reduces the MSJ slope (and raises the absolute NLL values, as expected). In contrast, control fine-tuning sometimes makes the MSJ performance worse (in the case of the mean MSJs[2]).
By including regular training data, we are able to preserve performance compared to the baseline. In fact, NLLs actually go down on harmless responses (albeit less than with the control fine-tuning), which can be explained by fitting to the idiosyncrasies of the fine-tuning data distribution. However, for the numerical sequence prediction task, we see worse performance compared to the control FT.
Intervention:
embed(“user”)
to every user-written token embedding, andembed(“assistant”)
to every assistant-written token embeddingEmbedding-only coloring is less effective than the residual-stream intervention, but is also able to reduce the slopes somewhat:
However it also has less of an effect on the harmless numerical sequence prediction task:
As expected, NLLs on recovery responses go down:
(For both role-embedding interventions, we also qualitatively assess free-text samples from the model and don't find a degradation in general quality.)
Although our implementation has some undesired side effects (the NLL slopes for the numerical sequence prediction task are also flatter compared to the control fine-tuning baseline), we think this could be because we're only introducing the intervention after the bulk of post-training is over. With a small dataset of 2000 samples and fine-tuning with rank-8 LORA, we are using far less compute than Llama's full post-training run. Therefore, it's hard to achieve perfect generalization. In production, we would propose using role embeddings from the start of instruction fine-tuning, so the model will learn to process the role vectors from the beginning, plausibly resulting in better generalization across tasks.
We plan to test our interventions on a broader range of jailbreak and general-capability evaluations and perform more comprehensive hyperparameter sweeps to determine what variant of role embeddings has the best cost/benefit trade-off. We hope embedding-only coloring can be improved via some tweaks to get it closer to the effect we're seeing with the residual-stream coloring.
There are a number of subtle implementation details when testing variants of role embeddings, many of which make some difference to the results (although we consistently observe the directional effect that role embeddings increase robustness to MSJs). These include decisions such as:
detach()
before using the norm so that the model finds it slightly harder to adapt to compensate for the coloring)This research was performed as part of the SPAR program. The main ideas were initially proposed by Cem Anil.
Credit to Cem Anil for suggesting this term.
We think this is because a significant proportion of the training data is generated by Claude causing the model to fit to Claude's writing style. The mean MSJ responses are also generated by Claude so probably share some similar surface-level characteristics.