I made a series of mech interp puzzles for my MATS scholars, which seemed well received, so I thought I'd share them more widely! I'll be posting a sequence of puzzles, approx one a week - these are real questions about models which I think are interesting, and where thinking about them should teach you some real principle about how to think about mech interp. Here's a short one to start:
Mech Interp Puzzle 1:
This is a histogram of the pairwise cosine similarity of the embedding of tokens in GPT-Neo (125M language model). Note that the mean is very high! (>0.9) Is this surprising? Why does this happen?
Bonus question: Here's the same histogram for GPT-2 Small, with a mean closer to 0.3. Is this surprising? What, if anything, can you infer from the fact that they differ?
!pip install transformer_lens plotly
from transformer_lens import HookedTransformer
import plotly.express as px
model = HookedTransformer.from_pretrained("gpt-neo-small")
subsample = torch.randperm(model.cfg.d_vocab)[:5000].to(model.cfg.device)
W_E = model.W_E[subsample] # Take a random subset of 5,000 for memory reasons
W_E_normed = W_E / W_E.norm(dim=-1, keepdim=True) # [d_vocab, d_model]
cosine_sims = W_E_normed @ W_E_normed.T # [d_vocab, d_vocab]
px.histogram(cosine_sims.flatten().detach().cpu().numpy(), title="Pairwise cosine sims of embedding")
Answer: (decode with rot13)
Gur zrna irpgbe bs TCG-Arb vf whfg ernyyl ovt - gur zbqny pbfvar fvz jvgu nal gbxra rzorq naq gur zrna vf nobhg 95% (frr orybj).
Gur pbaprcghny yrffba oruvaq guvf vf znxr fher lbhe mreb cbvag vf zrnavatshy. Zrgevpf yvxr pbfvar fvz naq qbg cebqhpgf vaureragyl cevivyrtr gur mreb cbvag bs lbhe qngn. Lbh jnag gb or pnershy gung vg'f zrnavatshy, naq gur mreb rzorqqvat inyhr vf abg vaureragyl zrnavatshy. (Mreb noyngvbaf ner nyfb bsgra abg cevapvcyrq!) V trarenyyl zrna prager zl qngn - guvf vf abg n havirefny ehyr, ohg gur uvtu-yriry cbvag vf gb or pnershy naq gubhtugshy nobhg jurer lbhe bevtva vf!
V qba'g unir n terng fgbel sbe jul, be jul bgure zbqryf ner fb qvssrerag, V onfvpnyyl whfg guvax bs vg nf n ovnf grez gung yvxryl freirf fbzr checbfr sbe pbagebyyvat gur YnlreAbez fpnyr. Vg'f onfvpnyyl n serr inevnoyr bgurejvfr, fvapr zbqryf pna nyfb serryl yrnea ovnfrf sbe rnpu ynlre ernqvat sebz gur erfvqhny fgernz (naq rnpu ynlre qverpgyl nqqf n ovnf gb gur erfvqhny fgernz naljnl). Abgnoyl, Arb hfrf nofbyhgr cbfvgvbany rzorqqvatf naq guvf ovnf fubhyq or shatvoyr jvgu gur nirentr bs gubfr gbb.
Please share your thoughts in the comments!
What background knowledge do you think this requires? If I know a bit about how ML and language models work in general, should I be able to reason this out from first principles (or from following a fairly obvious trail of "look up relevant terms and quickly get up to speed on the domain?"). Or does it require some amount of pre-existing ML taste?
Also, do you have a rough sense of how long it took for MATS scholars?
Great questions, thanks!
Background: You don't need to know anything beyond "a language model is a stack of matrix multiplications and non-linearities. The input is a series of tokens (words and sub-words) which get converted to vectors by a massive lookup table called the embedding (the vectors are called token embeddings). These vectors have really high cosine sim in GPT-Neo".
Re how long it took for scholars, hmm, maybe an hour? Not sure, I expect it varied a ton. I gave this in their first or second week, I think.