We started out with the question: How does GPT-2 know when to use the word "an" over "a"? The choice depends on whether the word that comes after starts with a vowel or not, but GPT-2 can only output one word at a time.
We still don’t have a full answer, but we did find a single MLP neuron in GPT-2 Large that is crucial for predicting the token " an". And we also found that the weights of this neuron correspond with the embedding of the " an" token, which led us to find other neurons that predict a specific token.
It was surprisingly hard to think of a prompt where GPT-2 would output “ an” (the leading space is part of the token) as the top prediction. Eventually we gave up with GPT-2_small and switched to GPT-2_large. As we’ll see later, even GPT-2_large systematically under-predicts the token “ an”. This may be because smaller language models lean on the higher frequency of " a" to make a best guess. The prompt we finally found that gave a high (64%) probability for “ an” was:
“I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked”
The first sentence was necessary to push the model towards an indefinite article — without it the model would make other predictions such as “[picked] up”.
Before we proceed, here’s a quick overview of the transformer architecture. Each attention block and MLP takes input and adds output to the residual stream.
Using the logit lens technique, we took the logits from the residual stream between each layer and plotted the difference between logit(‘ an’) and logit(‘ a’). We found a big spike after Layer 31’s MLP.
Activation patching is a technique introduced by Meng et. al. (2022) to analyze the significance of a single layer in a transformer. First, we saved the activation of each layer when running the original prompt through the model — the “clean activation”.
We then ran a corrupted prompt through the model:
“I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked”
By replacing the word "apple" with "lemon", we induce the model to predict the token " a" instead of " an". With the model predicting " a" over " an", we can replace a layer’s corrupted activation with its clean activation to see how much the model shifts towards the " an" token, which indicates that layer’s significance to predicting " an". We repeat this process over all the layers of the model.
We're mostly going to ignore attention for the rest of this post, but these results indicate that Layer 26 is where " picked" starts thinking a lot about " apple", which is obviously required to predict " an".
Note: the scale on these patching graphs is the relative logit difference recovery:
(ie. "what proportion of logit(" an") - logit(" a') in the clean prompt did this patch recover?").
logit(" an") - logit(" a')
The two MLP layers that stand out are Layer 0 and Layer 31. We already know that Layer 0’s MLP is generally important for GPT-2 to function (although we're not sure why attention in Layer 0 is important). The effect of Layer 31 is more interesting. Our results suggests that Layer 31’s MLP plays a significant role in predicting the " an" token. (See this comment if you're confused how this result fits with the logit lens above.)
Activation patching has been used to investigate transformers by the layer, but can we push this technique further and apply it to individual neurons? Since each MLP in a transformer only has one hidden layer, each neuron’s activation does not affect any other neuron in the MLP. So we should be able to patch individual neurons, because they are independent from each other in the same sense that the attention heads in a single layer are independent from each other.
We run neuron-wise activation patching for Layer 31’s MLP in a similar fashion to the layer-wise patching above. We reintroduce the clean activation of each neuron in the MLP when running the corrupted prompt through the model, and look at how much restoring each neuron contributes to the logit difference between " a" and " an".
We see that patching Neuron 892 recovers 50% of the clean prompt's logit difference, while patching whole layer actually does worse at 49%.
Neuroscope is an online tool that shows the top activating examples in a large dataset for each neuron in GPT-2. When we look at Layer 31 Neuron 892, we see that the neuron maximally activates on tokens where the subsequent token is " an".
But Neuroscope only shows us the top 20 most activating examples. Would there be a correlation for a wider range of activations?
To check this, we ran the pile-10k dataset through the model. This is a diverse set of about 10 million tokens taken from The Pile, split into prompts of 1,024 tokens. We plotted the proportion of " an" predictions across the range of neuron activations:
We see that the " an" predictions increase as the neuron’s activation increases, to the point where " an" is always the top prediction. The trend is somewhat noisy, which suggests that there might be other mechanisms in the model that also contribute towards the " an" prediction. Or maybe when the " an" logit increases, other logits increase at the same time.
Note that the model only predicted " an" 1,500 times, even though it actually occurred 12,000 times in the dataset. No wonder it was so hard to find a good prompt!
How does the neuron influence the model’s output? Well, the neuron’s output weights have a high dot product with the embedding for the token “ an”. We call this the congruence of the neuron with the token. Compared to other random tokens like " any" and " had", the neuron’s congruence with " an" is very high:
In fact, when we calculate the neuron’s congruence with all of the tokens, there are a few clear outliers:
It seems like the neuron basically adds the embedding of “ an” to the residual stream, which increases the output probability for “ an” since the unembedding step consists of taking the dot product of the final residual with each token.
Are there other neurons that are also congruent to “ an”? To find out, we plotted the congruence of all neurons with the " an" token:
Our neuron is way above the rest, but there are some other neurons with a fairly high congruence. These other neurons could be part of the reason why the correlation between the " an" neuron’s activation and the prediction of the " an" token isn’t perfect: there may be prompts where " an" is predicted, but the model uses these other neurons to do it.
If this is the case, could we use congruence to find a neuron that is perfectly correlated with a single token prediction?
We can try to find a neuron that is associated with a specific token by running the following search:
With this search, we wanted to find neurons that were uniquely responsible for a token. Our conjecture was that these neurons' activations would be more correlated with their tokens' prediction, since any prediction of that token would “rely” on that neuron.
Let’s try running the “ though” neuron — Layer 28 Neuron 1921 — through the dataset and see whether we get a cleaner graph.
Woah, that is much messier than the graph for the " an" neuron. What is going on?
Looking at Neuroscope’s data for the neuron reveals that it predicts both the tokens “ though” and “ however”. This complicates things — it seems that this neuron is correlated with a group of semantically similar tokens (conjunctive adverbs).
When we calculate the neuron’s congruence for all tokens, we find that the same tokens pop up as outliers:
In our large dataset correlation graph above, instances where the neuron activates and " however" is predicted over " though" would be counted as negative examples, since " though" was not the top prediction. This could also explain some of the noise in the " an" correlation, where the neuron is also congruent with "An", " An" and "an".
Can we find a simpler neuron to look at — preferably a neuron that only predicts for one token?
For a neuron to be ‘cleanly associated’ with a token, their congruence with each other should be mutually exclusive, meaning:
(Remember, 'congruence' is just our term for the dot product.)
Both criteria help to simplify the relationship between the neuron and its token. If a neuron’s congruence with a token is a representation of how much it contributes to that token’s prediction, the first criteria can be seen as making sure that only this neuron is responsible for predicting that token, while the second criteria can be seen as making sure that this neuron is responsible for predicting only that token.
Our search then is as follows:
For GPT-2_large, Layer 33 Neuron 4142 paired with "i" scores the highest on this metric. Looking at Neuroscope confirms the connection:
And when we plot the graph of top prediction proportion over activation for the top 5 highest scorers:
We do indeed see strong correlations for each pair!
Does the congruence of a neuron with a token actually measure the extent to which the neuron predicts that token? We don't know. There could be several reasons why even token-neuron pairs with high mutual exclusive congruence may not always correlate:
However, we’ve found that the token neuron pairs with the top 5 highest mutual exclusive congruence do in fact have a strong correlation.
The code to reproduce our results can be found here.
This is a write-up and extension of our winning submission to Apart Research's Mechanistic Interpretability Hackathon. Thanks to the London EA Hub for letting us use their co-working space, Fazl Barez for his comments and Neel Nanda for his feedback and for creating Neuroscope, the pile-10k dataset and TransformerLens.
Neel Nanda’s take on MLP 0:"It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.I only have suggestive evidence of this, and would love to see someone look into this properly!”
What else could it have done? It might have suppressed the logit for " a" which would have had the same impact on the logit difference. Or it might have added some completely different direction to the residual which would cause a neuron in a later layer to increase the " an" logit.
Note that the " though" neuron is congruent to a group of semantically similar tokens, while the " an" neuron is correlated with a group of syntactically similar tokens (eg. " an" and " Ancients").
Why does " an" have a cleaner correlation despite the other congruent tokens? We're not sure. One possible explanation is that "An" and " An" are simply much less common tokens so they make little impact on the correlation, while "an" has a significantly lower congruence with the neuron than the top 3.
In general, we expect that neurons found by only looking at the top 2 neuron difference for each token will not often have clean correlations with their respective tokens because these neurons may be congruent with multiple tokens.
When we look at the most congruent neuron for each token, we see some familiar troublemakers showing up with very high congruence:
At first, it looks like these 'forbidden tokens' are all associated with a 'forbidden neuron' (Layer 35 Neuron 3354) which they are all very congruent with. But actually if we plot the most congruent tokens of many other neurons we also see some of these weird tokens near the top. Our tentative hypothesis is that this has something to do with the hubness effect.
Neuroscope data wasn't available for this neuron, so we took the max activating dataset examples from the pile-10k dataset. Texts 1, 2, 3 are prompts 1755, 8528 and 6375 respectively.
Note that one of the top 5 tokens is "an", but this is different from " an" that we were talking about earlier, and it will rarely be used as the start of a word or a word on its own. Similarly the neuron with which it is paired, Layer 34 Neuron 4549, is not the " an" neuron named earlier.