This is very neat. I definitely agree that I find the discontinuity from the first transformer block surprising. One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model. You could also see if the resulting matrix has any relationship to the embedding matrix (e.g. are the two matrices farther apart or closer together than would be expected by chance?). One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model's guess is given that input or whether it's just being stored in parts of the embedding space that aren't very relevant to the output (if it's the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model.
That's a great idea!
One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model's guess is given that input or whether it's just being stored in parts of the embedding space that aren't very relevant to the output (if it's the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
Hmm... I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they're near-impossible so you don't need to track them closely), and then smoothly subtracted out at the end. There's probably a way to check if that's happening.
That's a great idea!
Thanks! I'd be quite excited to know what you find if you end up trying it.
Hmm... I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they're near-impossible so you don't need to track them closely), and then smoothly subtracted out at the end. There's probably a way to check if that's happening.
I wasn't thinking you would do this with the natural component basis—though it's probably worth trying that also—but rather doing some sort of matrix decomposition on the embedding matrix to get a basis ordered by importance (e.g. using PCA or NMF—PCA is simpler though I know NMF is what OpenAI Clarity usually uses when they're trying to extract interpretable basis elements from neural network activations) and then seeing what the linear model looks like in that basis. You could even just do something like what you're saying and find some sort of basis ordered by the frequency of the tokens that each basis element corresponds to (though I'm not sure exactly what the right way would be to generate such a basis).
I also thought of PCA/SVD, but I imagine matrix decompositions like these would be misleading here.
What matters here (I think) is not some basis of N_emb orthogonal vectors in embedding space, but some much larger set of ~exp(N_emb) almost orthogonal vectors. We only have 1600 degrees of freedom to tune, but they're continuous degrees of freedom, and this lets us express >>1600 distinct vectors in vocab space as long as we accept some small amount of reconstruction error.
I expect GPT and many other neural models are effectively working in such space of nearly orthogonal vectors, and picking/combining elements of it. A decomposition into orthogonal vectors won't really illuminate this. I wish I knew more about this topic -- are there standard techniques?
You might want to look into NMF, which, unlike PCA/SVD, doesn't aim to create an orthogonal projection. It works well for interpretability because its components cannot cancel each other out, which makes its features more intuitive to reason about. I think it is essentially what you want, although I don't think it will allow you to find directly the 'larger set of almost orthogonal vectors' you're looking for.
I think this might suggest there is some fundamentally better way to do sampling from GPT models? I'm having trouble writing out the intuition clearly, so I'll leave it for later posts.
Unroll the sampling process: hook up all the individual GPT instances into a single long model, bypass the discretizing/embedding layers to make it differentiable end-to-end, and do gradient ascent to find the sequence which maximizes likelihood conditional on the fixed input.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output -- in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
My sampling idea was something like "let's replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low." (This is for sampling specifically because it'd be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there's no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it -- indeed the connection structure feels weirdly adverse to such operations -- but apparently it does. So it's probably premature to assume it can't do this well, and attempt to "help it out" with extra tricks.
It doesn't sound hard at all. The things Gwern is describing are the same sort of thing that people do for interpretability where they, eg, find an image that maximizes the probability of the network predicting a target class.
Of course, you need access to the model, so only OpenAI could do it for GPT-3 right now.
Doing it with GPT-3 would be quite challenging just for compute requirements like RAM. You'd want to test this out on GPT-2-117M first, definitely. If the approach works at all, it should work well for the smallest models too.
Hey I'm not finished reading this yet but I noticed something off about what you said.
At the end, the final 1600-dimensional vector is multiplied by W's transpose to project back into vocab space.
This isn't quite right. They don't multiply by W's transpose at the end. Rather there is a completely new matrix at the end, whose shape is the same as the transpose of W.
You can see this in huggingface's code for GPT2. In the class GPT2LMHeadModel the final matrix multiplication is performed by the matrix called "lm_head", where as the matrix you call W which is used to map 50,257 dimensional vectors into 1600 dimensional space is called "wte" (found in the GPT2Model class). You can see from the code that wte has shape "Vocab size x Embed Size" while lm_head has shape "Embed Size x Vocab size" so lm_head does have the same shape as W transpose but doesn't have the same numbers.
Edit: I could be wrong here, though. Maybe lm_head was set to be equal to wte transpose? I'm looking through the GPT-2 paper but don't see anything like that mentioned.
Maybe lm_head was set to be equal to wte transpose?
Yes, this is the case in GPT-2. Perhaps the huggingface implementation supports making these two matrices different, but they are the same in the official GPT-2.
Edit: I think the reason this is obscured in the huggingface implementation is that they always distinguish the internal layers of a transformer from the "head" used to convert the final layer outputs into predictions. The intent is easy swapping between different "heads" with the same "body" beneath.
This forces their code to allow for heads that differ from the input embedding matrix, even when they implement models like GPT-2 where the official specification says they are the same.
Edit2: might as well say explicitly that I find the OpenAI tensorflow code much more readable than the huggingface code. This isn't a critique of the latter; it's trying to support every transformer out there in a unified framework. But if you only care about GPT, this introduces a lot of distracting abstraction.
This post relates an observation I've made in my work with GPT-2, which I have not seen made elsewhere.
IMO, this observation sheds a good deal of light on how the GPT-2/3/etc models (hereafter just "GPT") work internally.
There is an accompanying Colab notebook which will let you interactively explore the phenomenon I describe here.
[Edit: updated with another section on comparing to the inputs, rather than the outputs. This arguably resolves some of my confusion at the end. Thanks to algon33 and Gurkenglas for relevant suggestions here.]
[Edit 5/17/21: I've recently written a new Colab notebook which extends this post in various ways:
]
overview
background on GPT's structure
You can skip or skim this if you already know it.
the logit lens
As described above, GPT schematically looks like
We have a "dictionary," W, that lets us convert between vocab space and embedding space at any point. We know that some vectors in embedding space make sense when converted into vocab space:
What about the 1600-dim vectors produced in the middle of the network, say the output of the 12th layer or the 33rd? If we convert them to vocab space, do the results make sense? The answer is yes.
logits
For example: the plots below show the logit lens on GPT-2 as it predicts a segment of the abstract of the GPT-3 paper. (This is a segment in the middle of the abstract; it can see all the preceding text, but I'm not visualizing the activations for it.)
For readability, I've made two plots showing two consecutive stretches of 10 tokens. Notes on how to read them:
There are various amusing and interesting things one can glimpse in these plots. The "early guesses" are generally wrong but often sensible enough in some way:
ranks
The view above focuses only on the top-1 guess at each layer, which is a reductive window on the full distributions.
Another way to look at things: we still reduces the final output to the top-1 guess, but we compare other distributions to the final one by looking at the rank of the final top-1 guess.
Even if the middle of the model hasn't yet converged to the final answer, maybe it's got that answer somewhere in its top 3, top 10, etc. That's a lot better than "top 50257."
Here's the same activations as ranks. (Remember: these are ranks of the model's final top-1 prediction, not the true token.)
In most cases, network's uncertainty has drastically reduced by the middle layers. The order of the top candidates may not be right, and the probabilities may not be perfectly calibrated, but it's got the gist already.
KL divergence and input discarding
Another way of comparing the similarity of two probability distributions is the KL divergence. Taking the KL divergence of the intermediate probabilities w/r/t the final probabilities, we get a more continuous view of how the distributions smoothly converge to the model's output.
Because KL divergence is a more holistic measure of the similarity between two distributions than the ones I've used above, it's also my preferred metric for making the point that nothing looks like the input.
In the plots above, I've skipped the input layer (i.e. the input tokens in embedding space). Why? Because they're so different from everything else, they distract the eye!
In the plots below, where color is KL divergence, I include the input as well. If we trust that KL divergence is a decent holistic way to compare two distributions (I've seen the same pattern with other metrics), then:
other examples
I show several other examples in the Colab notebook. I'll breeze through a few of them here.
copying a rare token
Sometimes it's clear that the next token should be a "copy" of an earlier token: whatever arbitrary thing was in that slot, spit it out again.
If this is a token with relatively low prior probability, one would think it would be useful to "keep it around" from the input so later positions can look at it and copy it. But as we saw, the input is never "kept around"!
What happens instead? I tried this text:
As shown below (truncated to the last few tokens for visibility), the model correctly predicts "plasma" at the last position, but only figures it out in the very last layers.
Apparently it is keeping around a representation of the token "plasma" with enough resolution to copy it . . . but it only retrieves this representation at the end! (In the rank view, the rank of plasma is quite low until the very end.)
This is surprising to me. The repetition is directly visible in the input: "when people say" is copied verbatim. If you just applied the rule "if input seems to be repeating, keep repeating it," you'd be good. Instead, the model scrambles away the pattern, then recovers it later through some other computational route.
extreme repetition
We've all seen GPT sampling get into a loop where text repeats itself exactly, over and over. When text is repeating like this, where is the pattern "noticed"?
At least in the following example, it's noticed in the upper half of the network, while the lower half can't see it even after several rounds of repetition.
why? / is this surprising?
First, some words about why this trick can even work at all.
One can imagine models that perform the exact same computation as GPT-2, for which this trick would not work. For instance, each layer could perform some arbitrary vector rotation of the previous one before doing anything else to it. This would preserve all the information, but the change of basis would prevent the vectors from making sense when multiplied by W^T.
Why doesn't the model do this? Two relevant facts:
1. Transformers are residual networks. Every connection in them looks like x + f(x) where f is the learned part. So the identity is very easy to learn.
This tends to keep things in the same basis across different layers, unless there's some reason to switch.
2. Transformers are usually trained with weight decay, which is almost the same thing as L2 regularization. This encourages learned weights to have small L2 norm.
That means the model will try to "spread out" a computation across as many layers as possible (since the sum-of-squares is less than the square-of-sums). Given the task of turning an input into an output, the model will generally prefer changing the input a little, then a little more, then a little more, bit by bit.
1+2 are a good story if you want to explain why the same vector basis is used across the network, and why things change smoothly. This story would render the whole thing unsurprising . . . except that the input is discarded in such a discontinuous way!
I would have expected a U-shaped pattern, where the early layers mostly look like the input, the late layers mostly look like the output, and there's a gradual "flip" in the middle between the two perspectives. Instead, the input space immediately vanishes, and we're in output space the whole way.
Maybe there is some math fact I'm missing here.
Or, maybe there's some sort of "hidden" invertible relationship between
so that a token like "plasma" is kept around from the input -- but not in the form "the output is plasma," instead in the form "the output is [the kind of word that comes after plasma]."
However, I'm not convinced by that story as stated. For one thing, GPT layers don't share their weights, so the mapping between these two spaces would have to be separately memorized by each layer, which seems costly. Additionally, if this were true, we'd expect the very early activations to look like naive context-less guesses for the next token. Often they are, but just as often they're weird nonsense like "Garland."
addendum: more on "input discarding"
In comments, Gurkenglas noted that the plots showing KL(final || layer) don't tend the whole story.
The KL divergence is not a metric: it is not symmetric and does not obey the triangle inequality. Hence my intuitive picture of the distribution "jumping" from the input to the first layer, then smoothly converging to the final layer, is misleading: it implies we are measuring distances along a path through some space, but KL divergence does not measure distance in any space.
Gurkenglas and algon33 suggested plotting the KL divergences of everything w/r/t the input rather than the output: KL(input || layer).
Note that the input is close to a distribution that just assigns probability 1 to the input token ("close" because W * W^T is not invertible), so this is similar to asking "how probable is the input token, according to each layer?" That's a question which is also natural to answer by plotting ranks: what rank is assigned to the input token by each layer?
Below, I show both: KL(input || layer), and the rank of the input token according to later layers.
It's possible that the relatively high ranks -- in the 100s or 1000s, but not the 10000s -- of input tokens in many cases is (related to) the mechanism by which the model "keeps around" rarer tokens in order to copy them later.
As some evidence for this, I will show plots like the above for the plasma example. Here, I show a segment including the first instance of "plasma," rather than the second which copies it.
The preservation of "plasma" here is striking.
My intuitive guess is that the rarity, or (in some sense) "surprisingness," of the token causes early layers to preserve it: this would provide a mechanism for providing raw access to rare tokens in the later layers, which otherwise only be looking at more plausible tokens that GPT had guessed for the corresponding positions.
On the other hand, this story has trouble explaining why "G" and "PT" are not better preserved in the GPT3 abstract plots just above. This is the first instance of "GPT" in the full passage, so the model can't rely on copies of these at earlier positions. That said, my sense of scale for "well-preservedness" is a wild guess, and these particular metrics may not be ideal for capturing it anyway.