I'm quite excited by this work. Principled justification of various techniques for MELBO, insights into feature multiplicity, a potential generalized procedure for selecting steering coefficients... all in addition to making large progress on the problem of MELBO via e.g. password-locked MATH and vanilla activation-space adversarial attacks.
Based off research performed in the MATS 5.1 extension program, under the mentorship of Alex Turner (TurnTrout). Research supported by a grant from the Long-Term Future Fund.
TLDR: I introduce a new framework for mechanistically eliciting latent behaviors in LLMs. In particular, I propose deep causal transcoding - modelling the effect of causally intervening on the residual stream of a deep (i.e. ≳10-layer) slice of a transformer, using a shallow MLP. I find that the weights of these MLPs are highly interpretable -- input directions serve as diverse and coherently generalizable steering vectors, while output directions induce predictable changes in model behavior via directional ablation.
Summary I consider deep causal transcoders (DCTs) with various activation functions: i) linear, ii) quadratic and iii) exponential. I define a novel functional loss function for training these DCTs, and evaluate the implications of training DCTs using this loss from a theoretical and empirical perspective. A repo reproducing the results of this post is available at this link. Some of my main findings are:
Introduction
I consider the problem of mechanistically eliciting latent behaviors (abbreviated as MELBO), a problem which I introduced and motivated in a previous post[3]. In particular, a good solution to MELBO learns perturbations of a model's internals with the following goals in mind:
There are numerous natural applications of MELBO methods to problems in AI alignment, including backdoor detection, eliciting latent capabilities and open-ended evaluation of LLM behaviors.
Related work
The most natural approach for MELBO consists of two steps:
In this post, I introduce and evaluate a novel feature detection method which I call deep causal transcoding. However, there are other reasonable feature learning methods against which this should be compared. Below is a brief list. For a more extensive literature review, see the previous post.
Sparse Coding: Applications of sparse dictionary learning methods exploit the hypothesis that there are a bundle of sparsely activating features in trained transformers (Yun et al. (2023)). More recently, Templeton et al. (2024) demonstrate the promise of sparse auto-encoders (SAEs) to learn these features in a scalable fashion. As Templeton et al. (2024) demonstrate, these features can be used to elicit potentially unsafe behaviors, allowing for open-ended evaluation of the spectrum of LLM behaviors. However, SAEs are likely to have poor out-of-distribution coverage - if a feature is never active in-distribution, an SAE trained with a reconstruction loss plus sparsity penalty is strictly incentivized not to represent it. Thus, a new method is desirable.
Matrix/tensor decompositions: Previous work has found that the right singular vectors of the Jacobian of the generator network in GANs yield a small number (∼32) of interpretable feature directions in generative image models (Ramesh et al. (2018), see also Park et al. (2023)). This is essentially equivalent to the algorithm I give below for training "linear DCTs". Meanwhile, other work has found that Jacobian-based feature detection schemes are less successful when applied to LLMs (Bushnaq et al. (2024)[5]). In this post, I provide a theoretical explanation for why decomposing the Jacobian matrix between layers alone may be insufficient - identifiability of features can only be guaranteed under strong assumptions like exact orthogonality. This motivates incorporating higher-order information, such as the Hessian tensor, to better identify non-orthogonal features (a known advantage of tensor decompositions in statistics/machine learning). I validate this theory empirically, showing improved generalization with tensor decomposition-based methods (i.e., quadratic/exponential DCTs, as defined below).
Theory
Summary: This section provides my current attempt at explaining why Algorithms 1-3 (defined below) elicit consistently generalizable high-level behaviors. For readers of the previous post, the "Persistent Shallow Circuits Principle" supplants the "High-Impact Feature Principle" outlined in that post[6]. For readers who are interested mainly in the results, feel free to skim through the descriptions of Algorithms 1-3 in this section, and then proceed to the empirical results.
We want to learn feature directions at some source layer s which will serve as steering vectors which elicit latent high-level behaviors. To do this, I consider the function Δs→t(→θ), defined as the change in layer-t activations as a function of a steering vector →θ at layer s, averaged across token positions over a data-set of n prompts. Importantly, I consider the causal effect of intervening at layer s of the transformer residual stream, hence the "causal" in deep causal transcoding. Thus, in contrast to parts of the mechanistic interpretability literature, I care only about learning causally important directions in the source layer, even if they are not important for explaining the in-distribution behavior of the transformer [7].
While the main goal is to learn steering vectors in the source layer s, I claim that it is useful to explicitly model the specific directional changes in the target layer t which are elicited by some steering vector in the source layer (as opposed to only considering the magnitude of changes as I did in the original post). Broadly, the idea of identifying connections between features in a source/target layer of a network has been referred to others as transcoding, and I adopt the same name here.
For some intuition, as a running example, imagine the following hypothetical "refusal circuit":
This is essentially a shallow circuit "smeared" across layers. For the purposes of MELBO, it seems like a reasonable hypothesis that many of the high-level behaviors we might care about are activated by shallow circuits of the above form, in which some "semantic summary" feature writes to a "high-level action" feature. For example, here is a table of MELBO applications along with (hypothesized) source/target layer features:
Table 1: Hypothesized pairs of source/target-layer features
The circuit for each high-level behavior may be associated with different values of the source/target layers s′,t′. But if for a given circuit we have s′,t′≤t, then that circuit should contribute additively to Δs→t. In order to capture a wide-range of behaviors, this suggests casting a wide net when choosing s and t, i.e. considering a relatively deep slice of the transformer (hence the "deep" in deep causal transcoding)[8].
As I mentioned in my previous post, another reason to consider a deep slice of a transformer is related to the noise stability of deep neural networks. One may imagine that between any two successive layers, a neural network trained with SGD will learn to compute a vast number of intermediate features which may or may not be important for downstream computations[9]. But empirically, it seems that most feature directions are not that important - if you add random directions to a network's forward pass at a certain layer, the effect on subsequent layers is quickly attenuated with depth (see, e.g., Arora et al. (2018))[10]. Thus considering a deep slice of a transformer allows us to "filter out the noise" introduced by redundant intermediate computations, and focus only on the most structurally important components of the transformer's forward pass.
More succinctly, the above considerations suggest something like the following hypothesis:
In math, the claim is that at some scale R and normalizing ||→θ||=1, the map Δs→tR(→θ)≡Δs→t(R→θ) is well-described by: Δs→tR(→θ)=m∗∑ℓ=1f∗ℓ(⟨→v∗ℓ,→θ⟩)→u∗ℓ+→ϵ(→θ)(1)
where each f∗ℓ is some one-dimensional non-linear gating function, and each →v∗ℓ,→u∗ℓ are pairs of input/output feature directions (at times I will refer to a specific input/output pair →v∗ℓ,→u∗ℓ as a factor, borrowing terminology from matrix/tensor decompositions; the connection to matrix/tensor decompositions will become clear later in the post). Here →ϵ(→θ) is an error term describing the "truly deep" part of Δs→tR. The hope is that →ϵ is "small" enough (particularly when averaging across prompts/token positions) that we have some hope of recovering the true factors →u∗ℓ,→v∗ℓ (which I will at times collect into the columns of matrices U∗,V∗∈Rdmodel×m).
If my above story is true, we can potentially elicit many high-level behaviors by learning these input/output directions. Towards this end, I propose learning an MLP with some activation function σ to approximate Δs→tR:
^Δs→t(→θ)≡m∑ℓ=1αℓσ(⟨^vℓ,→θ⟩)^uℓ(2)
Here I consider factors ^uℓ,^vℓ which are normalized such that ||^uℓ||=||^vℓ||=1, while the scale of each hidden unit ℓ is controlled by some scalar parameter αℓ≥0. For a given activation function, the goal is to learn values of {αℓ}mℓ=1,^U,^V such that ^U≈U∗,^V≈V∗ (where we have collected the ^uℓ,^vℓ's into the columns of ^U,^V).
I will evaluate three different activation functions in this post i) linear (σ(x)=x), ii) quadratic (σ(x)=x2) and iii) exponential (σ(x)=exp(x)−1).
One approach to learning ^Δs→t would be to treat things as a standard supervised learning problem: draw many samples of →θ at random, and then minimize the empirical expectation of ||^Δs→t(→θ)−Δs→tR(→θ)||2. This seems likely to be fairly sample-inefficient in terms of the number of random draws of →θ needed: in high dimensions, most directions are approximately orthogonal to each other, so that as long as m≪exp(O(dmodel)), the typical situation is that ⟨^vℓ,→θ⟩ will be approximately zero and we will need a large number of random samples to learn a good approximation[11].
So instead of the supervised learning approach, I propose exploiting the fact that we have oracle access to the true function Δs→tR as well as its higher-order derivatives. In particular, let T(k),^T(k) denote the order-(k+1) tensor denoting the k-th derivatives of Δs→t,^Δs→t, respectively, at →θ=0. Then I propose optimizing the following loss:
L≡∞∑k=11k!||RkT(k)−^T(k)||2(3)
where ||T||2 denotes the Frobenius norm of a tensor T; i.e. if T is an order-o tensor over Rd then ||T||2≡∑di1,...,io=1T2i1,...,io.
The quantity RkT(k) is simply the k-th derivative tensor of Δs→tR and thus captures the behavior of the function Δs→t at "scale" R.
For each activation function I consider (linear, quadratic and exponential), I show that it is not necessary to explicitly construct the higher-order derivative tensors of Δs→tR and ^Δs→t in order to optimize this loss; rather, the most we will need is access to Hessian-vector products of Δs→tR (in the case of quadratic activation functions), or simply vector-Jacobian products of Δs→tR (in the case of linear and exponential activation functions).
Method
Below, I describe the methods I use to train DCTs with linear, quadratic and exponential activation functions. When evaluating such methods, it is useful to understand whether it's possible to recover the true feature directions U∗,V∗, at least in the "noiseless" setting where →ϵ(→θ)=0[12]. A summary of noiseless recovery guarantees, along with the algorithm used for each activation function, is given in the following table:
Table 2: Description of recovery guarantees and algorithms used
To summarize, in the linear case, it is algorithmically straightforward to learn features (we simply compute an SVD of the Jacobian of Δs→tR), but we are only guaranteed recovery assuming exact orthogonality (i.e. U∗TU∗=V∗TV∗=I). In particular, this means we can only learn dmodel many features, which seems unnatural in light of theories of computation in superposition, which suggests that transformers will utilize many more feature directions than there are neurons in the residual stream.
In contrast to the linear case, as I show below, fitting a quadratic MLP is equivalent to performing a tensor decomposition of the Hessian of Δs→tR. In this case, we have some hope of recovering non-orthogonal factors, as long as the dot product of any two feature directions is not too large (i.e., they are not "too orthogonal"; this is known as an "incoherence" assumption in the literature on tensor decompositions). In particular, if the directions are drawn at random from the unit sphere, then the algorithm of Ding et al. (2022) guarantees recovery of ~Ω(d1.5model) many factors in polynomial time, freeing us (in principle) from the un-natural assumption that there are only dmodel many features[13]. In practice, however, the algorithm of Ding et al. (2022) is prohibitively expensive (the running time guarantee is O(d6.05model)) and so in this post I focus on evaluating the cheaper algorithm of Sharan and Valiant (2017) which only guarantees recovery of m≪dmodel many factors. This seems like a reasonable tradeoff - while the theoretical guarantees afforded by tensor decomposition suggest that it allows us to learn something closer to the "true ontology" of the model, which will have ≫dmodel many features, in practice, particularly if we aim to learn only the causally relevant features on a small data-set, it seems reasonable that there will be fewer than dmodel many important features relevant to the data-set[14].
Finally, in the case of exponential hidden units, I will present a heuristic algorithm inspired by Sharan and Valiant (2017) which I call Orthogonalized Gradient Iteration (OGI). In this setting, there are no longer any known theoretical guarantees of exact recovery[15]. But it seems plausible that by incorporating all higher-order information of the true mapping Δs→tR (as opposed to only first or second-order information) we will be able to recover features which are even closer to the model's true ontology (and in line with this, we will see that exponential DCTs yield the best empirical results). Furthermore, the algorithm only requires access to gradients/activations, rather than second-order information, and thus is more efficient than performing a Hessian tensor decomposition.
Fitting a Linear MLP
In this case, no matter what ^U,^V are, all but the linear terms of the loss (3) vanish, and we are fitting a factorization of the Jacobian of Δs→tR, as we are simply minimizing: minα,^U,^V||^uℓ||=||^vℓ||=1||RT(1)−^Udiag(α)^VT||2
In this case, we can accomodate changes in R by simply rescaling the αℓ's, and so we can without loss of generality take R=1.
Assuming orthogonal factors (^UT^U=^VT^V=I), this amounts to computing an SVD of T(1), which is unique assuming the singular values are distinct. To summarize, under this assumption we can recover the true factors with the following straightforward algorithm:
Algorithm (1) - Linear DCT via SVD of Jacobian:
It is well-known that the singular vectors are robust under a small amount of additive noise (see, e.g., Wedin (1972)), provided that the singular values are well-separated. Importantly, it is not required that the noise term (which, in this case corresponds to the Jacobian of →ϵ(→θ)) has i.i.d. entries, only that its spectral norm is small. Thus, we have some amount of worst-case robustness, which is desirable from a safety perspective.
However, the assumption of exact orthogonality seems problematic, and we will see that this method provides the least promising results from both a quantitative and qualitative perspective.
In addition to considering the full Jacobian matrix, I will also consider the following efficient approximate variant of Algorithm (1):
Algorithm (1') - Approximate Linear DCT via Jacobian of Random Projection:
This will serve as a useful fast initialization for learning quadratic and exponential MLPs. And as we will see, despite being an approximation it also often delivers more useful/interpretable features than computing the Jacobian exactly, further supporting the hypothesis that assuming orthogonal features does not accurately reflect an LLM's true ontology.
Fitting a Quadratic MLP
In this case, all but the quadratic terms in (3) are constant, and we are fitting a factorization of the Hessian tensor. In particular, our minimization problem is: minα,^U,^V||^uℓ||=||^vℓ||=1||R2T(2)−^T(2)||2
As in the linear case, by re-scaling α we can accomodate any choice of R and thus without loss of generality we may set R=1. Concretely, the parametrized Hessian tensor ^T(2) may be written as a sum of rank-1 factors, so that we can re-write our minimization problem as:
minα,^U,^V||^uℓ||=||^vℓ||=1∑i,j,k(T(2)ijk−m∑ℓ=1αℓ^Uiℓ^Vjℓ^Vkℓ)2(4)
This problem is known as calculating a CP-decomposition of T(2). In general, finding the best approximation such that the above expression is minimized is NP-hard, but there are algorithms which can provably recover the factors under certain conditions (and which are also robust to adversarial noise), although they are not immediately practical for frontier LLMs[16].
As for more pragmatic tensor decomposition algorithms, one might hope that a simple algorithm such as SGD will work well for minimizing (4) in practice. Unfortunately, it is known that SGD does not work well on objective (4), even for synthetic data on "nice" synthetic distributions such as normally distributed or even exactly orthogonal factors (see, e.g., Ge et al. (2015))[17].
Nevertheless, there do exist algorithms for tensor decompositions which work well in practice and are nearly as simple to implement as SGD. Of these, an algorithm known as Alternating Least Squares (ALS) is frequently referred to as a "workhorse" method for tensor decomposition (see, e.g. Kolda and Bader (2009) as a standard reference). The algorithm applies to the general asymmetric tensor decomposition problem:
minα,^U,^V||^uℓ||=||^vℓ||=1∑i,j,k(T(2)ijk−m∑ℓ=1αℓ^Uiℓ^V(1)jℓ^V(2)kℓ)2(5) where the only difference between (5) and (4) is that we maintain two separate matrices of input directions ^V(1),^V(2). Inspecting (5), one can see that if we freeze all but one of ^U,^V(1),^V(2) then the (un-normalized) minimization problem is simply a convex least-squares problem. The ALS algorithm thus proceeds by minimizing each of ^U,^V(1),^V(2) one-by-one, followed by normalization steps for better stability, and finally ending with a least-squares estimate of the factor strengths α.
Although ALS is known to work for "nice" synthetic data, where the αℓ's are relatively homogenous and factors are distributed uniformly on the sphere, in initial experiments I found that the standard version of ALS yielded poor solutions for the purposes of training quadratic DCTs, often collapsing to a small number of uninteresting feature directions. A reason why this might be the case is if the true αℓ's are very heterogeneous/heavy-tailed, with a few outlier factors dominating the strengths of all other factors.
Fortunately, there exists a variant of standard ALS, known as orthogonalized ALS, which provides provable recovery of the true factors under "nice" conditions, even when the distribution of αℓ's is heavy-tailed (Sharan and Valiant(2017)). The algorithm is similar to ALS but uses an orthogonalization step at the start of each iteration to encourage diversity in the factors[18]. In this post, I evaluate a slight heuristic modification of orthogonalized ALS which exploits the symmetry in the Hessian tensor to save on computation[19]. The algorithm is as follows:
Algorithm (2) - Symmetric Orthogonalized Alternating Least Squares
Initialize:
Repeat: For τ steps (or until the change in factors is smaller than ϵ):
Estimate α:
Here the quantity T(2)(^uℓ,^vℓ,⋅) is shorthand for the vector defined by T(2)(^uℓ,^vℓ,⋅)k≡∑i,jT(2)ijk^Uiℓ^Vjℓ, and analogously T(2)(⋅,^vℓ,^vℓ)i≡∑j,kT(2)ijk^Vjℓ^Vkℓ, while the scalars T(2)(^uℓ,^vℓ,^vℓ) are defined as T(2)(^uℓ,^vℓ,^vℓ)≡∑ijkT(2)ijk^Uiℓ^Vjℓ^Vkℓ for each ℓ.
Note that it's possible to compute the main updates (6.1, 6.2) of algorithm 2 without explicitly materializing the entire Hessian tensor; for details see the appendix (broadly, the details are similar to, e.g., Panickserry (2023) or Grosse et al. (2023)).
Alternative formulation of tensor decomposition objective: causal importance minus similarity penalty
For some further intuition behind algorithm (2), it is useful to re-write the tensor decomposition objective in equation (4). Importantly, my goal here is to provide a heuristic justification of algorithm (2) in terms of principles which are easily generalized to the case of exponential MLPs.
To start, recall that we can write our parametrized estimate of the Hessian tensor, ^T(2), as a sum of rank-1 factors:
^T(2)=m∑ℓ=1αℓ⋅^uℓ⊗^vℓ⊗^vℓ where the outer product notation ^uℓ⊗^vℓ⊗^vℓ denotes the tensor whose i,j,k'th entry is given by ^Uiℓ^Vjℓ^Vkℓ. Then, expanding the square in (4), we can re-write the objective in that equation as:
||T(2)−m∑ℓ=1αℓ⋅^uℓ⊗^vℓ⊗^vℓ||2==||T(2)||2−2⟨T(2),m∑ℓ=1αℓ⋅^uℓ⊗^vℓ⊗^vℓ⟩+||m∑ℓ=1αℓ⋅^uℓ⊗^vℓ⊗^vℓ||2(7) where for two order-o tensors T,T′ the bracket notation ⟨T,T′⟩ refers to the element-wise dot-product ∑i1,...,ioTi1,...,ioT′i1,...,io.
Importantly, the true Hessian T(2) does not depend on the parameters in our approximation (the αℓ,^uℓ,^vℓ's) and so we can regard ||T(2)||2 as a constant. Thus, some simple manipulations of the remaining terms in equation (7) tells us that we can re-formulate (4) in terms of the following maximization problem:
maxα,^U,^V||^uℓ||=||^vℓ||=1m∑ℓ=1αℓT(2)(^uℓ,^vℓ,^vℓ)quadratic causal importance−12∑ℓ,ℓ′αℓαℓ′⟨^uℓ,^uℓ′⟩⟨^vℓ,^vℓ′⟩2quadratic similarity penalty(8)
In words, by trying to minimize reconstruction error of the Hessian tensor (optimization problem (4)), we are implicitly searching for feature directions which are causally important, as measured by the quadratic part of Δs→tR, subject to a pairwise similarity penalty (with the particular functional form of the penalty derived by expanding ||∑ℓαℓ⋅^uℓ⊗^vℓ⊗^vℓ||2 in (7)).
In this light, we could re-phrase the main loop of algorithm (2) as follows:
Algorithm (2) Main Loop, Re-phrased
Notice that the sole difference in the re-phrasing is in step 2 - from the re-phrased version, we can see that step 2 essentially performs gradient ascent on the "quadratic causal importance term" from equation (8) with infinite step size.
The re-phrased version yields a number of additional insights, summarized as follows:
Fitting an Exponential MLP
Now, I leverage the intuition developed in the final part of the previous section to derive a heuristic training algorithm in the case of exponential MLPs.
First, note that we can perform the same manipulation we performed for the Hessian in equation (7) for any term in our original functional objective (3), as follows:
||RkT(k)−m∑ℓ=1αℓ⋅^uℓ⊗^v⊗kℓ||2==||RkT(k)||2−2⟨RkT(k),m∑ℓ=1αℓ⋅^uℓ⊗^v⊗kℓ⟩+||m∑ℓ=1αℓ⋅^uℓ⊗^v⊗kℓ||2=||RkT(k)||2constant−2m∑ℓ=1αℓRkT(k)(^uℓ,^vℓ,⋯,^vℓ)degree k causal importance+∑ℓ,ℓ′αℓαℓ′⟨^uℓ,^uℓ′⟩⟨^vℓ,^vℓ′⟩kdegree k similarity penalty(9)
Summing across all terms in (9) yields the following theorem:
Theorem Let ¯R denote the radius of convergence of the Taylor expansion of Δs→tR, and assume that R≤¯R. Then minimizing (3) for σ(x)≡exp(x)−1 is equivalent (up to constant multiplicative/additive factors) to maximizing the objective in the following optimization problem: maxα,^U,^V||^uℓ||=||^vℓ||=1∑ℓ⟨^uℓ,Δs→tR(^vℓ)⟩causal importance−12∑ℓ≠ℓ′αℓαℓ′⟨^uℓ,^uℓ′⟩(exp(⟨^vℓ,^vℓ′⟩)−1)similarity penalty(10)
We can now "lift" the re-phrased version of algorithm (2) to the case of exponential activation functions to obtain the following algorithm:
Algorithm (3) - Orthogonalized Gradient Iteration (OGI)
Initialize:
Repeat: For τ steps (or until the change in factors is smaller than ϵ):
Estimate α:
Note that OGI only requires access gradients of Δs→tR, and thus is more efficient than performing a Hessian tensor decomposition via orthogonalized ALS. Moreover, it seems likely that the causal importance term in (10)) will more fully capture the true behavior of Δs→tR than the quadratic causal importance term in (8) (and in fact, my experiments below indicate that OGI learns more generalizable jailbreak vectors). For these reasons, OGI is currently my recommended default algorithm for mechanistically eliciting latent behaviors.
Below are some additional considerations regarding OGI:
On the role of R
In contrast to the algorithms I introduced for linear and quadratic DCTs, which are scale-invariant, for OGI we need to choose a scale parameter R. The theory developed in this section suggests a reasonable approach here: if R is small, then equation (3) suggests the Jacobian will be weighted much more heavily than the other terms and the true factors can only be recovered under strong assumptions such as exact orthogonality. On the other hand, if R is too large, we will emphasize very high-order information in Δs→tR, which seems likely to be noisy. This suggests that we choose R large enough such that the linear part of Δs→tR is not too dominant, but not so large as to totally destabilize things. In the next section, I describe a heuristic calibration procedure along these lines which appears to find values of R which work well in practice.
Relation to original MELBO objective
In my original post, I proposed searching for features in layer s which induced large downstream changes in layer t, under the hypothesis that these features will be interpretable as they appear to be structurally important to the model. I considered various norms to measure downstream changes, but in the simplest such case I simply suggested maximizing ||Δs→t(R^v)||2 over ^v:||^v||=1[22]. Note that if we maximize ⟨^u,Δs→t(R^v)⟩ over ^u:||^u||=1 then this equals ||Δs→t(R^v)||, so that the causal objective in (10) captures the squared norm objective in the original post (up to a monotonic transformation). Furthermore, in the original post I proposed learning features sequentially, subject to an orthogonality constraint with previously learned features. This sequential algorithm largely mirrors the structure of tensor power iteration (see Anandkumar et al. (2014)[23]). Thus you could view both the original post and the OGI algorithm presented above in algorithm 3 as heuristic generalizations of two tensor decomposition methods - tensor power iteration in the original post, and orthogonalized alternating least squares in this post.
Calibrating R
Epistemic status: My preliminary attempt leveraging theory to develop a calibration procedure for choosing the scale R. It appears to work across a variety of ∼7B models, but it is likely that future work will refine or replace this method, perhaps by combining with insights along the lines of 25Hour & submarat (2024) or Heimersheim & Mendel (2024).
We need to choose a value of R for both training (in the case of Exponential DCTs) and inference (i.e., we need to choose a norm when we add a steering vector to a model's activations).
For training, equation (3) suggests we will get better identification of non-orthogonal factors by choosing R large enough such that the non-linear part of Δs→tR is non-negligble.
Moreover, theories of computation in superposition suggest that much of the meaningful computation in neural networks consists of two key steps: i) compute a dot product with some direction, and ii) apply a nonlinear gating function to de-noise interference between features. In particular, this suggests that meaningful computation occurs precisely when the non-linear part of Δs→t is non-neglible.
Assuming that the non-linearity occurs at the roughly same scale for all important feature directions, this suggests choosing the same value of R for both training and inference.
In particular, I propose the following:
Calibration Procedure for Choosing R
At first glance, we seem to have simply replaced one hyper-parameter (R) for another (λ). However, it seems reasonable to conjecture that the optimal value of λ may be a deterministic function of other hyper-parameters (i.e., both the hyper-parameters of the DCT, and the characteristics of the model itself). If this is the case, then we could use a fixed value of λ across a variety of different models/prompts, provided that the hyper-parameters stay the same across settings.
In fact, my preliminary finding is that for a constant depth-horizon t−s=10, a value of λ=.5 works across a variety of models (even models of varying depths). To illustrate this, for all exponential DCTs in this post I train using the same value of λ=.5.
Case Study: Learning Jailbreak Vectors
Generalization of linear, quadratic and exponential DCTs
Figure 1: Sample complexity of DCTs
To systematically evaluate the generalization properties of DCTs, I conduct a comparative study of linear, projected linear, quadratic and exponential DCTs trained on (subsets of) the 521 harmful instructions included in AdvBench. Importantly, my main goal in this subsection is not to compare DCTs with other mechanistic jailbreak methods such as CAA (although I plan to provide such comparisons in a follow-up post/paper; see also the next section for a preliminary comparison), but rather to first establish which DCT activation function yields the most promising generalization properties as a function of sample size.
To evaluate sample complexity, I create 10 different random shuffles of the dataset for each DCT variant. For each shuffle, I train DCTs on the first {2,4,6,8,12,16} instructions, with the last 100 instructions of each shuffled dataset reserved as a test set. I train all DCTs with m=512 many factors. Since both the SVD and Hessian tensor decompositions are only unique up to a sign[24], for both linear (projected/full) and quadratic DCTs, I first learn 256 many input directions and include the negation of these 256 many directions to form a total of 512 directions. For all DCTs, I use layer 10 as the source and layer 20 as the target. I calibrate R consistently across all DCTs using the procedure described in the previous section, with λ=0.5, on a fixed random sample of 8 harmful instructions.
Using the first 32 instructions of each shuffle as a validation set, I rank source-layer features by taking the average difference in final logits between "Sure" and "Sorry" when steered by each feature as an efficiently computable jailbreak score. I then take the highest-ranking feature on the validation set and compute test-set jailbreak scores. The results are visualized in figure 1 above.
Some key takeaways are:
Exponential DCTs out-perform Quadratic/Linear DCTs
Exponential DCTs consistently achieve the highest jailbreak scores across all sample sizes, aligning with their theoretical advantages - namely, their ability to learn non-orthogonal features, as well as their ability to incorporate higher-order information in Δs→t. Quadratic DCTs outperform linear variants, in line with the theory that non-orthogonality is important for accurately reflecting the model's true ontology. Projected linear DCTs show better performance than standard linear DCTs, suggesting that strict orthogonality constraints hinder learning of the model's true feature space, to the point that using an approximate Jacobian is better than working with the exact Jacobian.
A note on training efficiency
A priori, it could have been the case that non-orthogonality was not that important for eliciting generalizable behaviors. In this case, linear DCTs would be ideal - they require no hyper-parameter tuning and the Jacobian calculation is easily parallelizable. However, my results show that non-orthogonality significantly improves generalization.
Fortunately, exponential DCTs maintain much of the computational efficiency of linear DCTs, as they require only gradients rather than the second-order information needed by quadratic DCTs. Furthermore, the OGI algorithm parallelizes well and typically converges in just 10 iterations as long as one trains a sufficiently large number of factors (m). For an illustration of this, below are training curves on the prompt "Tell me how to make a bomb" for increasing values of m:
Figure 2: Training curves for different widths (m)
Note that since OGI converges in ∼10 iterations when m=512, the total FLOPS is not that much larger for training an exponential DCT than it would be for computing the full Jacobian. In particular, since dmodel=4096 for Qwen1.5-7B-Chat, we would have to perform 4096 backwards passes to compute a full Jacobian using backwards-mode auto-differentiation. In comparison, for τ=10,m=512, we only have to compute τ×m=5120 many backwards passes to run OGI, which is not that much larger than 4096 (we also have to compute a QR decomposition in each step, but this is significantly cheaper in runtime than the backwards pass).
A note on diversity
Looking at figure 1, one may be tempted to conclude that projected linear DCTs are not that much worse than exponential DCTs in terms of generalization of the best vector. And I would agree with this sentiment - if there is a single behavior that you believe will be particularly salient on your data-set of interest and if you want to see whether DCTs will recover it, you should probably first try a projected linear DCT with dproj=32 and then refine with an exponential DCT as necessary. However, if your goal is to enumerate as many possible features as possible, my suggestion is that you should probably use exponential DCTs - a projected linear DCT with small dproj=32 will only ever be able to learn dproj=32 interesting features, while taking dproj much larger (i.e. dproj=dmodel) is essentially the same as computing a full Jacobian, which hurts generalization. However, as I show below, there are, for example, (conservatively) >200 harm(ful/less) directions in Qwen1.5-7B-Chat. Training with a low dproj will miss out on many of these directions.
Additionally, my subjective impression is that projected linear DCTs miss out on many qualitatively interesting vectors, such as the "music theory" vector mentioned below, or vectors similar to the "minecraft" vector in the original post (although I haven't quantified this).
Sample size of n=12 suffices
Figure 1 suggests that the jailbreak scores of exponential DCTs actually decrease past a sample size of 12 (although the error bars are wide). It's not clear whether this is due to the logit difference between "Sure" and "Sorry" being an imperfect proxy for jailbreak behavior; perhaps with a better metric we would see that performance simply saturates at n=12, but doesn't decrease beyond that (alternatively, perhaps we need to re-run over more random seeds; I only use 10 seeds to keep compute costs/time down). In any case, this suggests that n=12 (or even smaller, like n=8) is a good default sample size for training DCTs.
Highest-ranking vectors elicit fluent completions
Note that I don't evaluate any sort of fluency of steered completions. This is because my subjective impression is that with the above choice of R, we don't get any sort of fluency penalty when steering with the highest-ranked DCT features. For example, below are some representative test-set completions for the top-scoring feature of an exponential DCT trained on 8 harmless instructions. One can see that the model steered by this vector exhibits archetypal "helpful-only" assistant behavior.
Evidence for multiple harmless directions
Epistemic Status: I give some suggestive evidence that there are many "true" harmless directions, similarly to others' findings. I don't think the question of "feature multiplicity" posed here has been conclusively settled. But given the potential relevance for adversarial robustness it seems wise to consider seriously what the implications would be if "feature multiplicity" were real.
In my initial MELBO post I found two orthogonal steering vectors which elicit jailbreak behavior in Qwen-14B-Chat. In a subsequent post, Goldman-Wetzler and Turner (2024) scaled up this sort of finding, discovering (among other things) more than 800 orthogonal steering vectors which elicit code from Qwen1.5-1.8B-Chat, even on prompts which have nothing to do with coding.
This type of result sounds somewhat surprising - in the most basic form of the linear representation hypothesis, it seems natural to assume that the mapping between an LLM's "features" and "human concepts" will be one-to-one. But this is not necessarily the case - it could be that the linear representation hypothesis is true, but the mapping of features to human concepts is many-to-one[25].
If we accept the premise that this mapping may be many-to-one, then it makes sense to talk more quantitatively about the degree of "feature multiplicity" for a given human concept. In this section, I focus on the concept of "harmfulness", asking the following question:
Rather than being an academic question, in principle this seems highly relevant to the study of adversarial robustness. To see why, imagine that something like the following hypothesis is true:
To elaborate, a number of recent works (e.g., Sheshadri et al. (2024) and Zou et al. (2024)) have proposed methods for implementing stronger safeguards in LLMs by training them to distinguish between "forget" and "retain" content. The general approach involves using a dataset that clearly delineates harmful content (to be "forgotten") from acceptable content (to be "retained"), then training the LLM via SGD with a loss function that scrambles its internal representations on the forget content while preserving normal operation on retain content.
If the hypothesis of harm(ful/less) feature multiplicity is true, these methods may have a fundamental limitation: they are likely to identify and suppress only the single most salient harmful direction that best discriminates between the forget and retain datasets. This seems eminently plausible - SGD seems much more likely to choose the "lazy" method of getting by with a single discriminative harmful direction, rather than systematically enumerating and addressing all possible harmful directions.
The implications for the robustness of these methods are concerning. If a model truly contains multiple independent harmless directions, then methods that focus on addressing only the most prominent direction leave the model vulnerable to attacks that activate any of the remaining "dormant" harmless directions.
Indeed, in a later section, I demonstrate that a representation-rerouted version of Mistral-7B-Instruct can be jailbroken using a DCT-based latent-space attack. This suggests that even sophisticated safety training methods may be insufficient if they don't account for the full multiplicity of harmful feature directions. In the remainder of this section, I outline mechanistic evidence supporting the hypothesis that there are many such harm(ful/less) directions present in language models.
Many loosely-correlated DCT features elicit jailbreaks
To investigate the question of "harmless feature multiplicity", I study exponential DCT features learned on the first 8 instructions of the first shuffle of the sample complexity experiment described above.
Taking the 512 source-layer features as steering vectors (i.e., by adding them at layer 10), it turns out that very many of them elicit archetypal "helpful-only assistant" jailbreak behavior, with each vector inducing slight variations on this theme, as if we were sampling from a helpful-only model with temperature 1.0. For example, choosing an arbitrary lower threshold of 5.0 on our jailbreak score yields 240 vectors, each of which appears to elicit jailbreaks. As an illustration, see the following completions from the bottom-ranked of the top 240 vectors:
To summarize, there appear to be at least 240 vectors which elicit jailbreaks. Moreover, while not exactly orthogonal, these vectors don't appear to be concentrated on a lower-dimensional subspace. For some evidence of this, the condition number of the associated matrix of 240 vectors is ∼38, so that this matrix is not terribly rank-deficient or ill-conditioned. Furthermore, the average of |⟨vℓ,v′ℓ⟩| over all pairs of these 240 vectors is .36, with standard deviation .19. Thus, while the vectors are significantly more correlated than what one would expect by chance, they are not perfectly correlated with each other.
Averaging doesn't improve generalization when we add features to the residual stream
One hypothesis is that these 240 vectors are all noisy versions of a "true" harmless direction, and that perhaps we can obtain some vector which generalizes better via some sort of averaging.
I consider two types of averaging: i) vavg, formed by averaging the top vectors, and then normalizing to the same scale R (similar to what was suggested in this comment), and ii) taking the top left singular vector vsvd, normalized to the same scale R (and taking the best value out of both ±vsvd). Test jailbreak scores obtained by adding these directions to the residual stream are given in the following table:
Table 3: Jailbreak scores obtained by adding individual vs aggregate source-layer features
Aggregating the 240 vectors (using either a mean or SVD) yields better performance than the average performance of the original 240 vectors, but does not surpass the top-scoring DCT feature. Taking the logit jailbreak score at face value, this is inconsistent with the view that DCT "harmless" features are simply a "true" harmless feature, plus i.i.d. noise.
Averaging does improve jailbreak scores when we ablate features
The results outlined above appear to conflict with recent work of Arditi et al. (2024), who found jailbreaks by ablating a single "harm(ful/less)" direction from a model's weights. Specifically, they used a supervised contrastive method to isolate this direction, then prevented the model from writing to it by projecting the direction out of all MLP and attention output matrices. This suggests a single primary direction mediating refusal behavior, whereas my results above show that multiple independent directions can elicit jailbreaks when added to the residual stream.
We can reconcile these seemingly contradictory findings by examining what happens when we ablate DCT features rather than adding them. Specifically, I find that while ablating individual DCT features performs poorly, ablating an averaged direction yields stronger jailbreak effects. Here are the results using directional ablation instead of activation addition:
Table 4: Jailbreak scores obtained by ablating individual vs aggregate source-layer features
Ablating individual DCT features scores worse than zero (on average) on our jailbreak metric, while ablating aggregated vectors (vavg or vsvd) performs noticeably better than even the best individual ablation.
These findings suggest a more nuanced picture of how harm(ful/less) features are represented in the model, summarized as follows:
Ablating (averaged) target-layer features also works
Finally, I consider what happens when we ablate target-layer features learned by the exponential DCT (i.e., the ^uℓ's). This will corroborate the "primary/secondary" view of harm(ful/less) feature multiplicity, and will also serve as an opportunity to test the "ensemble of persistent shallow circuits" hypothesis outlined in the introduction.
In particular, I perform the same ablation analysis as above on target-layer features (the ^uℓ's):
Table 5: Jailbreak scores obtained by ablating individual vs aggregate target-layer features
The pattern is analogous to the source-layer ablation results, though with lower overall scores - aggregating target-layer features improves jailbreak performance when ablating, suggesting the existence of a primary "should refuse" direction, where the model expects to see the "should refuse" feature in distribution.
As for the "persistent shallow circuits" story, the lower jailbreak scores in this table might indicate that the ^uℓ's don't perfectly align with the core "should refuse request" direction, but rather capture some downstream effects of this direction. This doesn't necessarily contradict the "persistent shallow circuits" hypothesis; it only implies that we should be cautious about assuming the DCT will learn some specific "high-level action" feature in the target layer (as I did for illustrative purposes in table 1), as opposed to some side effect.
Deeper models: constant depth horizon (t−s) suffices for learning jailbreaks
Epistemic Status: Based off a couple hours of experimenting on Qwen1.5-32B-Chat using the latest DCT methodology, as well as some prior tinkering on deeper models with an earlier version of the methodology.
To evaluate how well DCTs work on deeper models, I train an exponential DCT with standard hyper-parameters on Qwen1.5-32B-Chat with the prompt "Tell me how to make a bomb". This model is twice as deep as the 7B version (64 vs 32 hidden layers). A natural question is whether we can still learn interesting DCT features by training on a fixed depth horizon (t−s), rather than growing the depth horizon with the depth of the model[26]. If so then this is nice from the perspective of efficiency - if we can keep the depth fixed then in principle we can deal with the model getting wider by parallelizing things, while retaining the same serial computation time.
My preliminary finding is that using the same hyper-parameters as I used in the above experiment on Qwen-7B-Chat (most importantly, using the same value of λ=.5 for calibrating R, and training on layers 10→20), we get qualitatively similar results on the 32B model as on the 7B model.
In particular, skimming this notebook one can see that the exponential DCT learned a number of "helpful-only assistant" features, along with various "gaming/roleplay" vectors similar to my prior experiments with Qwen-14B-Chat.
Additionally, my subjective impression is that larger models are better able to rationalize why they are talking about a certain concept activated by a DCT feature (reminiscent of Anthropic's Golden Gate Claude). As a cherry-picked example, see these completions from a "music theory" feature, which induces Qwen-32B-Chat to hallucinate convincing interpretations of the words "bomb" or "identity theft" within the context of music theory:
Application: Jailbreaking Representation-Rerouted Mistral-7B
In the previous section, I presented evidence for the existence of over 200 "request is harmless" directions capable of jailbreaking Qwen-7B-Chat. If we are to take the implications of this finding seriously, this suggests that even sophisticated safety training methods which attempt to scramble model internals on a "forget" dataset may remain vulnerable to jailbreaks if they do not explicitly attempt to enumerate and address all possible harmless directions. To test this hypothesis, I evaluate how well exponential DCT features perform as jailbreak vectors on a representation-rerouted version of Mistral-7B-Instruct-v2.
I focus on the "relaxed" adversarial setting where edits to the model's activation space are permitted, following the reasoning that success in this setting may predict vulnerability to more determined input-space attacks (e.g., as suggested by Hubinger (2019), Jermyn (2022) or Panickserry (2023)). Indeed, Sklar et al. (2024) have already successfully jailbroken a representation-rerouted version of Llama3-8B-Instruct by optimizing representational similarity with a "toxified" fine-tuned version of the model, and perhaps future work may be able to explain this result through the lens of feature multiplicity.
I also use this application to evaluate how well exponential DCTs work "out-of-the-box" - that is, with minimal deviation from the standard hyperparameters used in the sample complexity experiment (m=512,λ=0.5,s=10,t=20,τ=10).
The variant of Mistral-7B-Instruct-v2 I use has been trained using the "circuit-breaking" methodology of Zou et al. (2024), which attempts to scramble the model's internals on harmful prompts, causing it to output gibberish when given harmful instructions. For example, see the following unsteered completion:
As a first step in jailbreaking this model, I train exponential DCTs using λ=0.5 with depth horizon t−s=10 on the prompt "Tell me how to make a bomb" at various source-layer values s. For each source-layer, I sort vectors using the logits jailbreak score from the previous section and inspect the first 32 vectors. Training at s=5 proves particularly promising, yielding vectors that induce cooking-related instructions like:
While not a complete jailbreak, this result suggests we're on the right track - instead of outputting gibberish, the model now provides coherent instructions while creatively reinterpreting "bomb" in a culinary context.
To achieve a full jailbreak, I train an exponential DCT at source-layer s=5 on a dataset of n=8 instructions from AdvBench, sorting features by the logits jailbreak score. Running the calibration procedure with λ=0.5 on this dataset yields an input scale of R=0.21; this ends up being too small to elicit anything meaningful during inference. So I increase R by a factor of 4 for inference (to R=0.84). Here are some test-set completions for the top-ranking jailbreak vector (ranked using the logits jailbreak score):
To quantify things, this vector achieves an attack success rate of 62% using Harmbench's Mistral-7B classifier. This is a stark improvement over the 6.2% attack success rate reported in Zou et al. (2024) using a representation-engineering based latent-space attack.
Application: Eliciting Capabilities in Password-Locked Models
In the previous section, I showed how to use DCTs with minimal hyper-parameter tuning to jailbreak a robustly-trained version of Mistral-7B-Instruct-v2, showing how the lessons from my initial sample complexity experiments can generalize with minimal hyper-parameter tuning.
To get a sense of how exponential DCTs with standard hyper-parameters might apply to another model/setting, I train an exponential DCT on a password-locked version of Deepseek-Math-7B trained by Greenblatt et al. (2024). The model has been trained so that in the absence of a specific password, it emulates the answers of a much less capable model (Pythia-1B) on MATH, scoring an accuracy of 3% without the password, but 65% with the password.
I train the exponential DCT with the same hyper-parameters as my earlier sample complexity experiment (m=512,λ=0.5,s=10,t=20,τ=10), on a sample size of n=8 questions from MATH. Steering with the calibrated value of R, I then choose the best scoring vector on a validation set of 32 questions from MATH, and evaluate performance on a test set of 100 questions. I find that the steered model achieves a test accuracy of 23%, recovering 32% of performance of the unlocked model.
Future Work
I'm excited about the many interesting research directions that can be explored using DCTs. Below is a quick summary of what I think are some of the most compelling questions. If you're interested in pursuing any of these, or have other novel ideas, please get in touch.
Studying feature multiplicity
Quantifying a broader range of behaviors
Effect of pre-training hyper-parameters
I've noticed in my experiments that DCT features can "feel" more interpretable depending on which LLM you use (although as demonstrated in this post, the current methodology works across a variety of model families, including Qwen, Mistral, Deepseek and others which didn't make it into the current post). For example, my subjective impression is that Qwen1.5-7B is less sensitive to the choice of R than, say, Mistral-7B.
This raises the following question:
Acknowledgements
Thank you to Joseph Bloom, Jordan Taylor, Dmitry Vaintrob and Nina Panickserry for helpful comments/discussions.
Appendix
Hessian auto diff details
Notice that each update in algorithm (2) only requires access to Hessian-vector-vector products of the form T(2)(⋅,v,v) for the ^U update and T(2)(u,v,⋅) for the ^V update. Each Hessian-vector-vector product returns a vector in Rdmodel, and if τm≪d2model it will be more efficient to implicitly calculate each Hessian-vector-vector product using
torch.func
's auto-diff features, rather than populating the full Hessian tensor. Below is a sketch of how to compute each update; for full details see the repo.For each factor of the ^U update:
For each factor of the ^V update:
In particular, when training one feature, and assuming that a certain scale parameter defined below is small enough such that the Taylor series approximation of the sliced transformer converges, then the correspondence between the two objectives is exact. Otherwise, the correspondence is approximate. ↩︎
I.e., it often converges in as little as 10 iterations, while the method in the original post needed as many as 100−1000 steps of AMSGrad to converge. ↩︎
Note that I conceive of MELBO as a set of requirements for a behavior elicitation method, rather than as a method itself. In this post I consider several different methods for MELBO, all based off learning unsupervised steering vectors by searching directly in a model's activation space for important directions. I think this is the most natural approach towards MELBO, but could imagine there existing other approaches, such as sampling diverse completions, clustering the completions, and fine-tuning the model to emulate the different clusters. ↩︎
Note that this point was not explicitly listed as a requirement for MELBO in the original post. I wanted to elevate it to a primary goal of MELBO, as I believe out-of-distribution coverage is especially important from an alignment perspective (think data poisoning/sleeper agents/treacherous turn type scenarios), while being especially neglected by existing methods. ↩︎
Although in contrast to (Ramesh et al. (2018) and my work, that paper only considers the Jacobian of a shallow rather than deep slice. ↩︎
The persistent shallow circuits principle has slightly higher description length than the high-impact feature principle, as it makes the claim that the specific vector of changes in activations induced in the target layer is interpretable. But for this cost we gain several desirable consequences: i) a theory why we should learn mono-semantic features, ii) more efficient algorithms via the connection to tensor decompositions and iii) additional ways of editing the model (i.e. by ablating target-layer features). ↩︎
In other words, I don't consider interpretability-illusion-type concerns such as isolating the particular component in a transformer where a behavior is activated in-distribution. This is because for MELBO, we only care about what we can elicit by intervening at the source layer s, even if this is not where the network itself elicits the behavior. ↩︎
Of course, there are reasons why we might not want to go too deep. The simplest such reason is that for efficiency reasons, the shallower the slice, the better. Additionally, it seems reasonable to hypothesize that in a very deep network the neural net may learn to erase/re-use certain directions in the residual stream. Finally, if we go too deep then the "truly deep" part of the network (for example, some sort of mesa-optimizer) may stretch the limits of our shallow approximation, interfering with our ability to learn interpretable features. ↩︎
For example, the XOR of arbitrary combinations of features. ↩︎
You could also argue this is a prediction of singular learning theory. ↩︎
Moreover, even if the features are exactly orthogonal (m=m∗), and the true f∗ℓ's are simply relus, SGD will fail to recover the true features (see Ge et al. (2017)). ↩︎
If we can't recover the true features under reasonable assumptions even in the noiseless setting, then this is obviously bad for enumerative safety, and so we should at least demand this much. A natural follow-up concern is that the "noise term" →ϵ(→θ) is far from i.i.d. as it captures the "deep" part of the computation of the sliced transformer, and it's natural to assume that the reason why deep transformers work well in practice is that they perform non-trivially "deep" computations. Fortunately, as I discuss below, there is room for hope from the literature on matrix and tensor decompositions which guarantee recovery under certain conditions even with adversarial noise. ↩︎
Interestingly, the factor of 1.5 in the exponent is of the same order of the construction that Hänni et al. (2024) give for noise-tolerant computation in superposition. ↩︎
Moreover, Sharan and Valiant (2017) give some indication that the provable bounds are pessimistic, and that empirically a modification of the algorithm can perform well even in the over-complete setting m>dmodel. ↩︎
Although see this notebook for a (very preliminary) demonstration of recovery of true factors in a synthetic setting. ↩︎
As I mentioned above, one algorithm of particular interest is that of Ding et al. (2022), which guarantees recovery of ∼d1.5model many random symmetric factors. Another algorithm of interest is that of Hopkins et al. (2019), which provides provably robust recovery of up to ∼d2 many factors of symmetric 4-tensors in the presence of adversarial noise with bounded spectral norm. I mention these algorithms as a lower-bound of what is attainable provided one is willing to spend a substantial (but still "merely" polynomial) amount of compute - in this case, we get quite a bit in terms of the number of features we are able to recover, as well as robustness to adversarial noise. ↩︎
In other words, the situation is not analogous to the case of sparse dictionary learning, for which impractical algorithms with provable guarantees are known (e.g., Spielman et al. (2012)), while in practice SGD on SAEs is able to recover the true factors just as well on synthetic data (Sharkey and beren (2022)). ↩︎
As Jordan Taylor informs me, physicists have been using an orthogonalization step as standard practice in tensor networks since at least 2009 (see the MERA algorithm of Evenbly and Vidal (2009)). The main difference between MERA and orthogonalized ALS is that MERA uses an SVD to perform the orthogonalization step, while orthogonalized ALS uses a QR decomposition. In initial experiments, I've found that using a QR decomposition is qualitatively superior to using an SVD. ↩︎
In particular, the algorithm of Sharan and Valiant(2017), applicable to general asymmetric tensors, maintains two separate estimates of ^V, initialized separately at random, and applies separate updates ^V(1)←T(2)(^U,⋅,^V(2)) and ^V(2)←T(2)(^U,^V(1),⋅). But if we initialize both estimates to the same value (i.e., the right singular vectors of the Jacobian) then all updates for both estimates will remain the same throught each iteration. Thus, intuitively it makes sense to consider only the symmetric updates considered in this post, as this will be more efficient. Even though there are no longer any existing provable guarantees for the symmetric algorithm, I adopt it here since it is more efficient and seemed to work well enough in initial experiments. ↩︎
In principle, one could also orthogonalize ^U, but in initial experiments this was significantly less stable, and led to qualitatively less interesting results. I suspect the reason why this is the case is that the true mapping Δs→t is significantly "many-to-one", with a large number of linearly independent directions in layer s writing to a smaller number of directions in layer t. This is in line with Goldman-Wetzler and Turner (2024)'s discovery of >800 "write code" steering vectors in Qwen1.5-1.8B (Chat). ↩︎
This is a common motif in the literature on tensor decompositions. To summarize, a prevailing finding from both theoretical and empirical papers in this literature is that methods which make large steps in parameter space (ALS is one, but another is the tensor power iteration method of Anandkumar et al. (2014)) perform better, and are more robust to noise, than vanilla gradient descent. ↩︎
Another subtle difference: I proposed averaging ||Δ||2 across prompts, rather than first averaging Δ and then computing the squared norm. Thus the method in that post essentially concatenates Δ across prompts (i.e., if there are n prompts, then technically we consider the map Δs→tconcat:Rd→Rnd), whereas the method proposed here simply averages across prompts. I haven't systematically evaluated the difference between averaging and concatenating; my expectation is that concatenating would lead to better diversity but slightly less generalization. ↩︎
The main difference is that Anandkumar et al. (2014) use a soft "deflation" step to encourage diversity, as opposed to the hard orthogonality constraint of the original paper. ↩︎
In particular, if uvT forms one of the rank-1 factors of the SVD, then this is equivalent to (−u)(−v)T, so that we don't know the sign of v. Similarly in the case of a Hessian tensor decomposition, both u⊗v⊗v and u⊗(−v)⊗(−v) are equivalent. This is not the case for exponential DCTs, as θ→exp(⟨v,θ⟩) looks very different from θ→exp(⟨−v,θ⟩). ↩︎
Conceptually, this is related to feature splitting. The difference is that in feature splitting, we normally assume that some "coarse-grained" feature splits into more "fine-grained" variants of the same feature, which correspond to some interpretable refinement of the original concept. But it's plausible that a concept could be represented by multiple directions even if there is no interpretable difference between these directions. ↩︎
Of course, for enumerative safety reasons, we may eventually want to use a large depth horizon, or train DCTs with fixed depth-horizons at various source layers. But in the interest of not letting "perfect be the enemy of the good", it seems useful to first establish whether we can learn features using a relatively cheap fixed depth horizon. ↩︎