This post relates an observation I've made in my work with GPT-2, which I have not seen made elsewhere.
IMO, this observation sheds a good deal of light on how the GPT-2/3/etc models (hereafter just "GPT") work internally.
There is an accompanying Colab notebook which will let you interactively explore the phenomenon I describe here.
[Edit: updated with another section on comparing to the inputs, rather than the outputs. This arguably resolves some of my confusion at the end. Thanks to algon33 and Gurkenglas for relevant suggestions here.]
[Edit 5/17/21: I've recently written a new Colab notebook which extends this post in various ways:
- trying the "lens" on various models from 125M to 2.7B parameters, including GPT-Neo and CTRL
- exploring the contributions of the attention and MLP sub-blocks within transformer blocks/layers
- trying out a variant of the "decoder" used in this post, which dramatically helps with interpreting some models
- GPT's probabilistic predictions are a linear function of the activations in its final layer. If one applies the same function to the activations of intermediate GPT layers, the resulting distributions make intuitive sense.
- This "logit lens" provides a simple (if partial) interpretability lens for GPT's internals.
- Other work on interpreting transformer internals has focused mostly on what the attention is looking at. The logit lens focuses on what GPT "believes" after each step of processing, rather than how it updates that belief inside the step.
- These distributions gradually converge to the final distribution over the layers of the network, often getting close to that distribution long before the end.
- At some point in the middle, GPT will have formed a "pretty good guess" as to the next token, and the later layers seem to be refining these guesses in light of one another.
- The general trend, as one moves from earlier to later layers, is
- "nonsense / not interpretable" (sometimes, in very early layers) -->
- "shallow guesses (words that are the right part of speech / register / etc)" -->
- "better guesses"
- ...though some of those phases are sometimes absent.
- On the other hand, only the inputs look like the input tokens.
- In the logit lens, the early layers sometimes look like nonsense, and sometimes look like very simple guesses about the output. They almost never look like the input.
- Apparently, the model does not "keep the inputs around" for a while and gradually process them into some intermediate representation, then into a prediction.
- Instead, the inputs are immediately converted to a very different representation, which is smoothly refined into the final prediction.
- This is reminiscent of the perspective in Universal Transformers which sees transformers as iteratively refining a guess.
- However, Universal Transformers have both an encoder and decoder, while GPT is only a decoder. This means GPT faces a tradeoff between keeping around the input tokens, and producing the next tokens.
- Eventually it has to spit out the next token, so the longer it spends (in depth terms) processing something that looks like token i, the less time it has to convert it into token i+1. GPT has a deadline, and the clock is ticking.
- More speculatively, this suggests that GPT mostly "thinks in predictive space," immediately converting inputs to predicted outputs, then refining guesses in light of other guesses that are themselves being refined.
- I think this might suggest there is some fundamentally better way to do sampling from GPT models? I'm having trouble writing out the intuition clearly, so I'll leave it for later posts.
- Caveat: I call this a "lens" because it is one way of extracting information from GPT's internal activations. I imagine there is other information present in the activations that cannot be understood by looking at logits over tokens. The logit lens show us some of what is going on, not all of it.
background on GPT's structure
You can skip or skim this if you already know it.
- Input and output
- As input, GPT takes a sequence of tokens. Each token is a single item from a vocabulary of N_v=50257 byte pairs (mostly English words).
- As output, GPT returns a probability distribution over the vocabulary. It is trained so this distribution predicts the next token.
- That is, the model's outputs are shifted forward by one position relative to the inputs. The token at position i should, after flowing through the layers of the model, turn into the token at position i+1. (More accurately, a distribution over the token at position i+1.)
- Vocab and embedding spaces
- The vocab has size N_v=50257, but GPT works internally in a smaller "embedding" vector space, of dimension N_e.
- For example, in the GPT-2 1558M model size, N_e=1600. (Below, I'll often assume we're talking about GPT-2 1558M for concreteness.)
- There is an N_v-by-N_e embedding matrix W which is used to project the vocab space into the embedding space and vice versa.
- In, blocks, out
- The first thing that happens to the inputs is a multiplication by W, which projects them into the embedding space. 
- The resulting 1600-dimensional vector then passes through many neural network blocks, each of which returns another 1600-dimensional vector.
- At the end, the final 1600-dimensional vector is multiplied by W's transpose to project back into vocab space.
- The resulting 50257-dim vectors are treated as logits. Applying the softmax function to them gives you the output probability distribution.
the logit lens
As described above, GPT schematically looks like
- Project the input tokens from vocab space into the 1600-dim embedding space
- Modify this 1600-dim vector many times
- Project the final 1600-dim vector back into vocab space
We have a "dictionary," W, that lets us convert between vocab space and embedding space at any point. We know that some vectors in embedding space make sense when converted into vocab space:
- The very first embedding vectors are just the input tokens (in embedding space)
- The very last embedding vectors are just the output logits (in embedding space)
What about the 1600-dim vectors produced in the middle of the network, say the output of the 12th layer or the 33rd? If we convert them to vocab space, do the results make sense? The answer is yes.
For example: the plots below show the logit lens on GPT-2 as it predicts a segment of the abstract of the GPT-3 paper. (This is a segment in the middle of the abstract; it can see all the preceding text, but I'm not visualizing the activations for it.)
For readability, I've made two plots showing two consecutive stretches of 10 tokens. Notes on how to read them:
- The input tokens are shown as 45-degree tilted axis labels at the bottom.
- The correct output (i.e. the input shifted by one) is likewise shown at the top.
- A (*) is added in these labels when the model's top guess matched the correct output.
- The vertical axis indexes the layers (or "blocks"), zero-indexed from 0 to 47. To make the plots less huge I skip every other intermediate layer. The Colab notebook lets you control this skipping as you like.
- The top guess for each token, according to the model's activations at a given layer, is printed in each cell.
- The colors show the logit associated with the top guess. These tend to increase steadily as the model converges on a "good guess," then get refined in the last layers.
- Cells are outlined when their top guess matches the final top guess.
- For transformer experts: the "activations" here are the block outputs after layer norm, but before the learned point-wise transformation.
There are various amusing and interesting things one can glimpse in these plots. The "early guesses" are generally wrong but often sensible enough in some way:
- "We train GPT-3..." 000? (someday!)
- "GPT-3, an..." enormous? massive? (not wrong!)
- "We train GPT-3, an aut..." oreceptor? (later converges to the correct oregressive)
- "model with 175..." million? (later converges to a comma, not the correct billion)
The view above focuses only on the top-1 guess at each layer, which is a reductive window on the full distributions.
Another way to look at things: we still reduces the final output to the top-1 guess, but we compare other distributions to the final one by looking at the rank of the final top-1 guess.
Even if the middle of the model hasn't yet converged to the final answer, maybe it's got that answer somewhere in its top 3, top 10, etc. That's a lot better than "top 50257."
Here's the same activations as ranks. (Remember: these are ranks of the model's final top-1 prediction, not the true token.)
In most cases, network's uncertainty has drastically reduced by the middle layers. The order of the top candidates may not be right, and the probabilities may not be perfectly calibrated, but it's got the gist already.
KL divergence and input discarding
Another way of comparing the similarity of two probability distributions is the KL divergence. Taking the KL divergence of the intermediate probabilities w/r/t the final probabilities, we get a more continuous view of how the distributions smoothly converge to the model's output.
Because KL divergence is a more holistic measure of the similarity between two distributions than the ones I've used above, it's also my preferred metric for making the point that nothing looks like the input.
In the plots above, I've skipped the input layer (i.e. the input tokens in embedding space). Why? Because they're so different from everything else, they distract the eye!
In the plots below, where color is KL divergence, I include the input as well. If we trust that KL divergence is a decent holistic way to compare two distributions (I've seen the same pattern with other metrics), then:
- Immediately, after the very first layer, the input has been transformed into something that looks more like the final output (47 layers layer) than it does like the input.
- After this one discontinuous jump, the distribution progresses in a much more smooth way to the final output distribution.
I show several other examples in the Colab notebook. I'll breeze through a few of them here.
copying a rare token
Sometimes it's clear that the next token should be a "copy" of an earlier token: whatever arbitrary thing was in that slot, spit it out again.
If this is a token with relatively low prior probability, one would think it would be useful to "keep it around" from the input so later positions can look at it and copy it. But as we saw, the input is never "kept around"!
What happens instead? I tried this text:
Sometimes, when people say plasma, they mean a state of matter. Other times, when people say plasma
As shown below (truncated to the last few tokens for visibility), the model correctly predicts "plasma" at the last position, but only figures it out in the very last layers.
Apparently it is keeping around a representation of the token "plasma" with enough resolution to copy it . . . but it only retrieves this representation at the end! (In the rank view, the rank of plasma is quite low until the very end.)
This is surprising to me. The repetition is directly visible in the input: "when people say" is copied verbatim. If you just applied the rule "if input seems to be repeating, keep repeating it," you'd be good. Instead, the model scrambles away the pattern, then recovers it later through some other computational route.
We've all seen GPT sampling get into a loop where text repeats itself exactly, over and over. When text is repeating like this, where is the pattern "noticed"?
At least in the following example, it's noticed in the upper half of the network, while the lower half can't see it even after several rounds of repetition.
why? / is this surprising?
First, some words about why this trick can even work at all.
One can imagine models that perform the exact same computation as GPT-2, for which this trick would not work. For instance, each layer could perform some arbitrary vector rotation of the previous one before doing anything else to it. This would preserve all the information, but the change of basis would prevent the vectors from making sense when multiplied by W^T.
Why doesn't the model do this? Two relevant facts:
1. Transformers are residual networks. Every connection in them looks like x + f(x) where f is the learned part. So the identity is very easy to learn.
This tends to keep things in the same basis across different layers, unless there's some reason to switch.
2. Transformers are usually trained with weight decay, which is almost the same thing as L2 regularization. This encourages learned weights to have small L2 norm.
That means the model will try to "spread out" a computation across as many layers as possible (since the sum-of-squares is less than the square-of-sums). Given the task of turning an input into an output, the model will generally prefer changing the input a little, then a little more, then a little more, bit by bit.
1+2 are a good story if you want to explain why the same vector basis is used across the network, and why things change smoothly. This story would render the whole thing unsurprising . . . except that the input is discarded in such a discontinuous way!
I would have expected a U-shaped pattern, where the early layers mostly look like the input, the late layers mostly look like the output, and there's a gradual "flip" in the middle between the two perspectives. Instead, the input space immediately vanishes, and we're in output space the whole way.
Maybe there is some math fact I'm missing here.
Or, maybe there's some sort of "hidden" invertible relationship between
- the embedding of a given token, and
- the model's prior for what token comes after it (given no other information)
so that a token like "plasma" is kept around from the input -- but not in the form "the output is plasma," instead in the form "the output is [the kind of word that comes after plasma]."
However, I'm not convinced by that story as stated. For one thing, GPT layers don't share their weights, so the mapping between these two spaces would have to be separately memorized by each layer, which seems costly. Additionally, if this were true, we'd expect the very early activations to look like naive context-less guesses for the next token. Often they are, but just as often they're weird nonsense like "Garland."
addendum: more on "input discarding"
In comments, Gurkenglas noted that the plots showing KL(final || layer) don't tend the whole story.
The KL divergence is not a metric: it is not symmetric and does not obey the triangle inequality. Hence my intuitive picture of the distribution "jumping" from the input to the first layer, then smoothly converging to the final layer, is misleading: it implies we are measuring distances along a path through some space, but KL divergence does not measure distance in any space.
Gurkenglas and algon33 suggested plotting the KL divergences of everything w/r/t the input rather than the output: KL(input || layer).
Note that the input is close to a distribution that just assigns probability 1 to the input token ("close" because W * W^T is not invertible), so this is similar to asking "how probable is the input token, according to each layer?" That's a question which is also natural to answer by plotting ranks: what rank is assigned to the input token by each layer?
Below, I show both: KL(input || layer), and the rank of the input token according to later layers.
- For KL(input || layer), I use the same color scale as in the plots for KL(final || layer), so the two are comparable.
- For the ranks, I do not use the same color scale: I have the colors bottom out at rank 1000 instead of rank 100. This gives more visual insight into where the model could be preserving input information.
- There is still a fast jump in KL(input || layer) after the input.
- However, it's far smaller than the jump in KL(output || layer) at the same point.
- Note that the darkest color, meaning KL=30 does not appear on the plot of KL(input || layer).
- On the plot of KL(output || layer), however, the maximum values were in fact much greater than 30; I cut off the color scale at 30 so other distinctions were perceptible at all.
- Likewise, while ranks jump quickly after the input, they often stay relatively high in the context of a ~50K vocab.
- I am curious about the differences here: some tokens are "preserved" much more in this sense than others.
- This is apparently contextual, not just based on the token itself. Note the stark differences between the rank trajectories of the first, second, and third commas in the passage.
It's possible that the relatively high ranks -- in the 100s or 1000s, but not the 10000s -- of input tokens in many cases is (related to) the mechanism by which the model "keeps around" rarer tokens in order to copy them later.
As some evidence for this, I will show plots like the above for the plasma example. Here, I show a segment including the first instance of "plasma," rather than the second which copies it.
The preservation of "plasma" here is striking.
My intuitive guess is that the rarity, or (in some sense) "surprisingness," of the token causes early layers to preserve it: this would provide a mechanism for providing raw access to rare tokens in the later layers, which otherwise only be looking at more plausible tokens that GPT had guessed for the corresponding positions.
On the other hand, this story has trouble explaining why "G" and "PT" are not better preserved in the GPT3 abstract plots just above. This is the first instance of "GPT" in the full passage, so the model can't rely on copies of these at earlier positions. That said, my sense of scale for "well-preservedness" is a wild guess, and these particular metrics may not be ideal for capturing it anyway.
- Right after this, positional embeddings are added. I'm ignoring positional embeddings in the post, but mention them in this footnote for accuracy. ↩︎
This is very neat. I definitely agree that I find the discontinuity from the first transformer block surprising. One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model. You could also see if the resulting matrix has any relationship to the embedding matrix (e.g. are the two matrices farther apart or closer together than would be expected by chance?). One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model's guess is given that input or whether it's just being stored in parts of the embedding space that aren't very relevant to the output (if it's the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
That's a great idea!
Hmm... I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they're near-impossible so you don't need to track them closely), and then smoothly subtracted out at the end. There's probably a way to check if that's happening.
Thanks! I'd be quite excited to know what you find if you end up trying it.
I wasn't thinking you would do this with the natural component basis—though it's probably worth trying that also—but rather doing some sort of matrix decomposition on the embedding matrix to get a basis ordered by importance (e.g. using PCA or NMF—PCA is simpler though I know NMF is what OpenAI Clarity usually uses when they're trying to extract interpretable basis elements from neural network activations) and then seeing what the linear model looks like in that basis. You could even just do something like what you're saying and find some sort of basis ordered by the frequency of the tokens that each basis element corresponds to (though I'm not sure exactly what the right way would be to generate such a basis).
I also thought of PCA/SVD, but I imagine matrix decompositions like these would be misleading here.
What matters here (I think) is not some basis of N_emb orthogonal vectors in embedding space, but some much larger set of ~exp(N_emb) almost orthogonal vectors. We only have 1600 degrees of freedom to tune, but they're continuous degrees of freedom, and this lets us express >>1600 distinct vectors in vocab space as long as we accept some small amount of reconstruction error.
I expect GPT and many other neural models are effectively working in such space of nearly orthogonal vectors, and picking/combining elements of it. A decomposition into orthogonal vectors won't really illuminate this. I wish I knew more about this topic -- are there standard techniques?
You might want to look into NMF, which, unlike PCA/SVD, doesn't aim to create an orthogonal projection. It works well for interpretability because its components cannot cancel each other out, which makes its features more intuitive to reason about. I think it is essentially what you want, although I don't think it will allow you to find directly the 'larger set of almost orthogonal vectors' you're looking for.
Unroll the sampling process: hook up all the individual GPT instances into a single long model, bypass the discretizing/embedding layers to make it differentiable end-to-end, and do gradient ascent to find the sequence which maximizes likelihood conditional on the fixed input.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output -- in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
My sampling idea was something like "let's replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low." (This is for sampling specifically because it'd be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there's no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it -- indeed the connection structure feels weirdly adverse to such operations -- but apparently it does. So it's probably premature to assume it can't do this well, and attempt to "help it out" with extra tricks.
How far away is this from being implementable?
It doesn't sound hard at all. The things Gwern is describing are the same sort of thing that people do for interpretability where they, eg, find an image that maximizes the probability of the network predicting a target class.
Of course, you need access to the model, so only OpenAI could do it for GPT-3 right now.
Doing it with GPT-3 would be quite challenging just for compute requirements like RAM. You'd want to test this out on GPT-2-117M first, definitely. If the approach works at all, it should work well for the smallest models too.
Related layer visualizations: "Looking for Grammar in All The Right Places".
Hey I'm not finished reading this yet but I noticed something off about what you said.
This isn't quite right. They don't multiply by W's transpose at the end. Rather there is a completely new matrix at the end, whose shape is the same as the transpose of W.
You can see this in huggingface's code for GPT2. In the class GPT2LMHeadModel the final matrix multiplication is performed by the matrix called "lm_head", where as the matrix you call W which is used to map 50,257 dimensional vectors into 1600 dimensional space is called "wte" (found in the GPT2Model class). You can see from the code that wte has shape "Vocab size x Embed Size" while lm_head has shape "Embed Size x Vocab size" so lm_head does have the same shape as W transpose but doesn't have the same numbers.
Edit: I could be wrong here, though. Maybe lm_head was set to be equal to wte transpose? I'm looking through the GPT-2 paper but don't see anything like that mentioned.
Yes, this is the case in GPT-2. Perhaps the huggingface implementation supports making these two matrices different, but they are the same in the official GPT-2.
Edit: I think the reason this is obscured in the huggingface implementation is that they always distinguish the internal layers of a transformer from the "head" used to convert the final layer outputs into predictions. The intent is easy swapping between different "heads" with the same "body" beneath.
This forces their code to allow for heads that differ from the input embedding matrix, even when they implement models like GPT-2 where the official specification says they are the same.
Edit2: might as well say explicitly that I find the OpenAI tensorflow code much more readable than the huggingface code. This isn't a critique of the latter; it's trying to support every transformer out there in a unified framework. But if you only care about GPT, this introduces a lot of distracting abstraction.
Thanks for the info.
This was a great read, very informative.