Part 7 of 12 in the Engineer’s Interpretability Sequence.
Thanks to Neel Nanda. I used some very nicely-written code of his from here. And thanks to both Chris Olah and Neel Nanda for briefly discussing this challenge with me.
MI = “mechanistic interpretability”
In the last post, I argued that existing works in MI focus on solving problems that are too easy. Here, I am posing a challenge for mechanists that is still a toy problem but one that is quite a bit less convenient than studying a simple model or circuit implementing a trivial, known task. To the best of my knowledge:
Unlike prior work on MI from the AI safety interpretability community, beating this challenge would be the first example of mechanistically explaining a network’s solution to a task that was not cherrypicked by the researcher(s) doing so.
Gaining a mechanistic understanding of the models in this challenge may be difficult, but it will probably be much less difficult than mechanistically interpreting highly intelligent systems in high stakes settings in the real world. So if an approach can’t solve the type of challenge posed here, it may not be very promising for doing much heavy lifting with AI safety work.
This post comes with a GitHub repository. Check it out here. The challenge is actually two challenges in one, and the basic idea is similar to some ideas presented in Lindner et al. (2023).
I made up a nonlinear labeling function that labels approximately half of all MNIST images as 0’s and the other half as 1’s. Then I trained a small CNN on these labels, and it got 96% testing accuracy. The challenge is to use MI tools on the network to recover that labeling function.
Hint 1: The labels are binary.
Hint 2: The network gets 95.58% accuracy on the test set.
Hint 3: The labeling function can be described in words in one sentence.
Hint 4: This image may be helpful.
I made up a labeling function that takes in two integers from 0 to 113 and outputs either a 0 or 1. Then, using a lot of code from Neel Nanda’s grokking work, I trained a 1-layer transformer on half of the data. It then got 97% accuracy on the test half. As before, the challenge is to use MI tools to recover the labeling function.
Hint 2: The network is trained on 50% of examples and gets 97.27% accuracy on the test half.
Hint 3: Here are the ground truth and learned labels. Notice how the mistakes the network makes are all near curvy parts of the decision boundary...
If you are the first person to send me the labeling function's program and a justified mechanistic explanation for either challenge (e.g. in the form of a colab notebook), I will sing your praises on my Twitter, and I would be happy to help you write a post about how you solved a problem I thought would be very difficult. Neel Nanda and I are also offering a cash prize. (Thanks to Neel for offering to contribute to the pool!) Neel will donate $250, and I will donate $500 to a high-impact charity of choice for the first person to solve each challenge. That makes the total donation prize pool $1,500.
For this challenge, I intentionally designed the labeling functions to not be overly simple. But I will not be too surprised if someone reverse-engineers them with MI tools, and if so, I will be extremely interested in how.
Neither of the models perfectly label the validation set. One may object that this will make the problem unfairly difficult because if there is no convergence on the same behavior as the actual labeling function, then how is one supposed to find that function inside the model? This is kind of the point though. Real models that real engineers have to work with don’t tend to conveniently grok onto a simple, elegant, programmatic solution to a problem.
Here’s the GitHub link again.
Edit: played around with the models, it seems like the transformer only gets 99.7% train accuracy and 97.5% test accuracy!
The MNIST CNN was trained only on the 50k training examples.
I did not guarantee that the models had perfect train accuracy. I don't believe they did.
I think that any interpretability tools are allowed. Saliency maps are fine. But to 'win,' a submission needs to come with a mechanistic explanation and sufficient evidence for it. It is possible to beat this challenge by using non mechanistic techniques to figure out the labeling function and then using that knowledge to find mechanisms by which the networks classify the data.
At the end of the day, I (and possibly Neel) will have the final say in things.
Does the “ground truth” shows the correct label function on 100% of the training and test data? If so, what’s the relevance of the transformer which imperfectly implements the label function?
Yes, it does show the ground truth.
The goal of the challenge is not to find the labels, but to find the program that explains them using MI tools. In the post, when I say labeling "function", I really mean labeling "program" in this case.