Knowledge Neurons in Pretrained Transformers

by Evan Hubinger2 min read17th May 20214 comments

41

Transparency / Interpretability (ML & AI)AI
Frontpage
This is a linkpost for https://arxiv.org/abs/2104.08696

This is a link post for the Dai et al. paper “Knowledge Neurons in Pretrained Transformers” that was published on the arXiv last month. I think this paper is probably the most exciting machine learning paper I've read so far this year and I'd highly recommend others check it out as well.

To start with, here are some of the basic things that the paper demonstrates:

  • BERT has specific neurons, which the authors call “knowledge neurons,” in its feed-forward layers that store relational facts (e.g. “the capital of Azerbaijan is Baku”) such that controlling knowledge neuron activations up-weights/down-weights the correct answer in relational knowledge prompts (e.g. “Baku” in “the capital of Azerbaijan is <mask>”) even when the syntax of the prompt is changed—and the prompts that most activate the knowledge neuron all contain the relevant relational fact.
  • Knowledge neurons can reliably be identified via a well-justified integrated gradients attribution method (see also “Self-Attention Attribution”).
  • In general, the feed-forward layers of transformer models can be thought of as key-value stores that memorize relevant information, sometimes semantic and sometimes syntactic (see also “Transformer Feed-Forward Layers Are Key-Value Memories”) such that knowledge neurons are composed of a “key” (the first layer, prior to the activation function) and the “value” (the second layer, after the activation function).

The paper's key results—at least as I see it, however—are the following:

  • Taking knowledge neurons that encode “the of is ” and literally just adding to the value neurons (where are just the embeddings of ) actually changes the knowledge encoded in the network such that it now responds to “the of is <mask>” (and other semantically equivalent prompts) with instead of .
  • For a given relation (e.g. “place of birth”), if all knowledge neurons encoding that relation (which ends up being a relatively small number, e.g. 5 - 30) have their value neurons effectively erased, the model loses the ability to predict the majority of relational knowledge involving that relation (e.g. 40 - 60%).

I think that particularly the first of these two results is pretty mind-blowing, in that it demonstrates an extremely simple and straightforward procedure for directly modifying the learned knowledge of transformer-based language models. That being said, it's the second result that probably has the most concrete safety applications—if it can actually be scaled up to remove all the relevant knowledge—since something like that could eventually be used to ensure that a microscope AI isn't modeling humans or ensure that an agent is myopic in the sense that it isn't modeling the future.

Furthermore, the specific procedure used suggests that transformer-based language models might be a lot less inscrutable than previously thought: if we can really just think about the feed-forward layers as encoding simple key-value knowledge pairs literally in the language of the original embedding layer (as I think is also independently suggested by “interpreting GPT: the logit lens”), that provides an extremely useful and structured picture of how transformer-based language models work internally.

41

4 comments, sorted by Highlighting new comments since Today at 10:33 AM
New Comment

I'm inclined to be more skeptical of these results.

I agree that this paper demonstrates that it's possible to interfere with a small number of neurons in order to mess up retrieval of a particular fact (roughly 6 out of the 40k mlp neurons if I understand correctly), which definitely tells you something about what the model is doing.

But beyond that I think the inferences are dicier:

  • Knowledge neurons don't seem to include all of the model's knowledge about a given question. Cutting them out only decreases the probability on the correct answer by 40%. This makes it sound like accuracy is still quite high for these relational facts even after removing the 6 knowledge neurons, and certainly the model still knows a lot about them. (It's hard to tell because this is a weird way to report deterioration of the model's knowledge and I actually don't know exactly how they defined it.) And of course the correct operation of a model can be fragile, so even if knocking out 6 neurons did mess up knowledge about topic X that doesn't really show you've localized knowledge about X.
  • I don't think there's evidence that these knowledge neurons don't do a bunch of other stuff. After removing about 0.02% of neurons they found that the mean probability on other correct answers decreased by 0.4%. They describe this as "almost unchanged" but it seems like it's larger than I'd expect for a model trained with dropout for knocking out random neurons (if you extrapolate that to knocking out 10% of the mlp neurons, as done during training, you'd have reduced the correct probability by 50x, whereas the model should still operate OK with 10% dropout). So I think this shows that the knowledge neurons are more important for the knowledge in question than for typical inputs, rather than e.g. the model being brittle so that knocking out 6 neurons messes up everything, but definitely doesn't tell us much about whether the knowledge neurons do only this fact.
  • For groups of neurons, they instead report accuracy and find it falls by about half. Again, this does show that they've found 20 neurons that are particularly important for e.g. messing up facts about capitals (though clearly not containing all the info about capitals). I agree this is more impressive than the individual knowledge neurons, but man there are lots of ways to get this kind of result even if the mechanical story of what the neuron is doing is 100% wrong.
  • Looking at that again, it seems potentially relevant that instead of zeroing those neurons they added the embedding of the [UNK] token. I don't know what the effect of that is (or why they changed), you could easily imagine it kind of clobbering the model. I think I'd feel more comfortable if they had showed that the model still worked OK. But at this point they switch to a different way of measuring performance on other prompts (perplexity). It's not clear whether they mean perplexity just on the masked token or on all the tokens, and I haven't thought about these numbers in detail, but I think they might imply that the [UNK] replacement did quite a lot of damage.
  • I doubt that the attribution method matters much for identifying these neurons (and I don't think any evidence is given that it does). I suspect anything reasonable would work fine. Moreover, I'm not sure it's more impressive to get this result using this attribution method rather than by directly searching for neurons whose removal causes bad behavior (which would probably produce even better results). It seems like the attribution method can just act as a slightly noisy way of doing that.
  • When you say "In general, the feed-forward layers can be thought of..." I agree that it's possible to think of them that way, but don't see much evidence that this is a good way to think about how the feed-forward layers work, or even how these neurons work. A priori if a network did work this way, it's unclear why individual neurons would correspond to individual lookups rather than using a distributed representation (and they probably wouldn't, given sparsity---that's a crazy inefficient thing to do and if anything seems harder for SGD to learn) so I'm not sure that this perspective even helps explain the observation that a small number of neurons can have a big effect on particular prompts.
  • I think the evidence about replacement is kind of weak. In general, adding t - t' to the activations at any layer should increase the probability the model says t and decrease the probability it says t'. So adding it to the weights of any activated neuron should tend to have this effect, if only in a super crude way (as well as probably messing up a bunch of other stuff). But they don't give any evidence that the transformation had a reliable effect, or that it didn't mess up other stuff, or that they couldn't have a similar effect by targeting other neurons.
  • Actually looking at the replacement stuff in detail it seems even weaker than that. Unless I'm missing something it looks like they only present 3 cherry-picked examples with no quantitative evaluation at all? It's possible that they just didn't care about exploring this effect experimentally, but I'd guess that they tried some simple stuff and found the effect to be super brittle and so didn't report it. And in the cases they report, they are changing the model from remembering an incorrect fact to a correct one---that seems important because probably the model put significant probability on the correct thing already.
  • When you say "the prompts that most activate the knowledge neuron all contain the relevant relational fact." I think the more precise statement is "Amongst 20 random prompts that mention both Azerbaijan and Baku, the ones that most activate the knowledge neuron also include the word 'capital'." I think this is relatively unsurprising---if you select a neuron to respond to various permutations of (Azerbaijan, Baku, capital) then it makes sense that removing one of the three words would decrease the activation of the neuron relative to having all three, regardless of whether the neuron is a knowledge neuron or anything else that happens to be activated on those prompts.
  • I think the fact that amplifying knowledge neurons helps accuracy doesn't really help the story, and if anything slightly undercuts it. This is exactly what you'd expect if the neurons just happened to have a big gradient on those inputs. But if the model is actually doing knowledge retrieval, then it seems like it would be somewhat less likely you could make things so much better by just increasing the magnitude of the values.

Knowledge neurons don't seem to include all of the model's knowledge about a given question. Cutting them out only decreases the probability on the correct answer by 40%.

Yeah, agreed—though I would still say that finding the first ~40% of where knowledge of a particular fact is stored counts as progress (though I'm not saying they have necessarily done that).

I don't think there's evidence that these knowledge neurons don't do a bunch of other stuff. After removing about 0.02% of neurons they found that the mean probability on other correct answers decreased by 0.4%. They describe this as "almost unchanged" but it seems like it's larger than I'd expect for a model trained with dropout for knocking out random neurons (if you extrapolate that to knocking out 10% of the mlp neurons, as done during training, you'd have reduced the correct probability by 50x, whereas the model should still operate OK with 10% dropout).

That's a good point—I didn't look super carefully at their number there, but I agree that looking more carefully it does seem rather large.

Looking at that again, it seems potentially relevant that instead of zeroing those neurons they added the embedding of the [UNK] token.

I also thought this was somewhat strange and am not sure what to make of it.

A priori if a network did work this way, it's unclear why individual neurons would correspond to individual lookups rather than using a distributed representation (and they probably wouldn't, given sparsity---that's a crazy inefficient thing to do and if anything seems harder for SGD to learn) so I'm not sure that this perspective even helps explain the observation that a small number of neurons can have a big effect on particular prompts.

I was also surprised that they used individual neurons rather than NMF factors or something—though the fact that it still worked while just using the neuron basis seems like more evidence that the effect is real rather than less.

But they don't give any evidence that the transformation had a reliable effect, or that it didn't mess up other stuff, or that they couldn't have a similar effect by targeting other neurons.

Actually looking at the replacement stuff in detail it seems even weaker than that. Unless I'm missing something it looks like they only present 3 cherry-picked examples with no quantitative evaluation at all? It's possible that they just didn't care about exploring this effect experimentally, but I'd guess that they tried some simple stuff and found the effect to be super brittle and so didn't report it. And in the cases they report, they are changing the model from remembering an incorrect fact to a correct one---that seems important because probably the model put significant probability on the correct thing already.

Perhaps I'm too trusting—I agree that everything you're describing seems possible given just the evidence in the paper. All of this is testable, though, and suggests obvious future directions that seem worth exploring.

I think that particularly the first of these two results is pretty mind-blowing, in that it demonstrates an extremely simple and straightforward procedure for directly modifying the learned knowledge of transformer-based language models. That being said, it's the second result that probably has the most concrete safety applications—if it can actually be scaled up to remove all the relevant knowledge—since something like that could eventually be used to ensure that a microscope AI isn't modeling humans or ensure that an agent is myopic in the sense that it isn't modeling the future.

Despite agreeing that the results are impressive, I'm less optimistic that you are for this path to microscope and/or myopia. Doing so would require an exhaustive listing of what we don't want the model to know (like human modeling or human manipulation) and a way of deleting that knowledge that doesn't break the whole network. The first requirement seems a deal-breaker to me, and I'm not convinced this work actual provide much evidence that more advanced knowledge can be removed that way.

Furthermore, the specific procedure used suggests that transformer-based language models might be a lot less inscrutable than previously thought: if we can really just think about the feed-forward layers as encoding simple key-value knowledge pairs literally in the language of the original embedding layer (as I think is also independently suggested by “interpreting GPT: the logit lens”), that provides an extremely useful and structured picture of how transformer-based language models work internally.

Here too, I agree with the sentiment, but I'm not convinced that this is the whole story. This looks like how structured facts are learned, but I see no way as of now to generate the range of stuff GPT-3 and other LMs can do from just key-value knowledge pairs.

Thanks for the link. This has been on my reading list for a little bit and your recco tipped me over.

Mostly I agree with Paul's concerns about this paper. 

However, I did find the "Transformer Feed-Forward Layers Are Key-Value Memories" paper they reference more interesting -- it's more mechanistic, and their results are pretty encouraging. I would personally highlight that one more, as it's IMO stronger evidence for the hypothesis, although not conclusive by any means.

Some experiments they show:

  • Top-k activations of individual 'keys' do seem like coherent patterns in prefixes, and as we move up the layers, these patterns become less shallow and more semantics-driven. (Granted, it's not clear how good the methodology there is, as to qualify as a pattern, it needs to occur in 3 out of top-25 prefixes. There are 3.6 patterns on average in each key. But this is curious enough to keep looking into.)
  • The 'value' distributions corresponding to the keys are in fact somewhat predictive of the actual next word for those top-k prefixes, and exhibit a kind of 'calibration': while the distributions themselves aren't actually calibrated, they are more correct when they assign a higher probability.

I also find it very intriguing that you can just decode the value distributions using the embedding matrix a la Logit Lens.