Self-Review: After a while of being insecure about it, I'm now pretty fucking proud of this paper, and think it's one of the coolest pieces of research I've personally done. (I'm going to both review this post, and the subsequent paper). Though, as discussed below, I think people often overrate it.
Impact The main impact IMO is proving that mechanistic interpretability is actually possible, that we can take a trained neural network and reverse-engineer non-trivial and unexpected algorithms from it. In particular, I think by focusing on grokking I (semi-accidentally) did a great job of picking a problem that people cared about for non-interp reasons, where mech interp was unusually easy (as it was a small model, on a clean algorithmic task), and that I was able to find real insights about grokking as a direct result of doing the mechanistic interpretability. Real models are fucking complicated (and even this model has some details we didn't fully understand), but I feel great about the field having something that's genuinely detailed, rigorous and successfully reverse-engineered, and this seems an important proof of concept. IMO the other contributions are the specific algorithm I found, and the specific insights about how and why grokking happens. but in my opinion these are much less interesting.
Field-Building Another large amount of impact is that this was a major win for mech interp field-building. This is hard to measure, but some data:
Is field-building good? It's plausible to me that, on the margin, less alignment effort should be going into mech interp, and more should go into other agendas. But I'm still excited about mech interp field building, think this is a solid and high-value thing, and that field-building is often positive sum - people often have strong personal fits to different areas, and many people are drawn to it from non-alignment motivations. And though there's concern over capabilities externalities, my guess is that good interp work is strongly net good.
Is it worth publishing in academic venues Submitting this to ICLR was my first serious experience with peer review. I'm not sure what updates I've made re whether this was worth it. I think some, but probably <50% of the field-building benefit came from this, and that going Twitter viral was much more important for ensuring people became aware of the work. I think it made the work more legitimate-seeming, more citable, more respectable, etc. On an object level, it resulted in the writing, ideas and experiments becoming much better and clearer (though led to some of the more interesting speculation being relegated to appendix E :'( ) though this was largely due to /u/lawrencec 's skills. I definitely find peer review/conforming to academic conventions very tedious and irritating, and am very grateful to Lawrence for doing almost all of the work.
Personal benefit It's hard to measure, but I think this turned out to be substantially good for my career, reputation and influence. It's often been what people know me for, and I think has helped me snowball into other career successes.
Ways it's overrated As noted, I do think there's ways people overrate the results/ideas here, and that there's too much interest in the object level results (the Fourier Multiplication algorithm, and understanding grokking). Some thoughts:
[Edit Jan 19, 2023: I no longer think the below is accurate. My argument rests on an unstated assumption: that when weight decay kicks in, the counter-pressure against it is stronger for the 101th weight (the "bias/generalizer") than the other weights (the "memorizers") since the gradient is stronger in that direction. In fact, this mostly isn't true, for the same reason Adam(W) moved towards the solution to begin with before weight decay strongly kicked in: each dimension of the gradient is normalized relative to its typical magnitudes in the past. Hence the counter-pressure is the same on all coordinates.
A caveat: this assumes we're doing full batch AdamW, as opposed to randomized minibatches. In the latter case, the increased noise in the "memorizer" weights will in fact cause Adam to be less confident about those weights, and thus assign less magnitude to them. But this happens essentially right from the start, so it doesn't really explain grokking.
Here's an example of this, taking randomized minibatches of size 10 (out of 100 total) on each step, optimizing with AdamW (learning rate = 0.001, weight decay = 0.01). I show the first three "memorizer" weights (out of 100 total) plus the bias:

As you can see, it does place less magnitude on the memorizers due to the increased noise, but this happens right from the get-go; it never "groks".
If we do full batch AdamW, then the bias is treated indistinguishably:

For small weight decay settings and zero gradient noise, AdamW is doing something like finding the minimum-norm solution, but in space, not space.]
Here's a straightforward argument that phase changes are an artifact of AdamW and won't be seen with SGD (or SGD with momentum).
Suppose we have 101 weights all initialized to 0 in a linear model, and two possible ways to fit the training data:
The first is . (It sets the first 100 weights to 1, and the last one to 0.)
The second is . (It sets the first 100 weights to 0, and the last one to 1.)
(Any combination with will also fit the training data.)
Intuitively, the first solution memorizes the training data: we can imagine that each of the first 100 weights corresponds to storing the value of one of the 100 samples in our training set. The second solution is a simple, generalized algorithm for solving all instances from the underlying data distribution, whether in the training set or not.
has an norm which is ten times as large as . SGD, since it follows the gradient directly, will mostly move directly toward as it's the direction of steeper descent. It will ultimately converge on the minimum norm solution . (Momentum won't change this picture much, since it's just smearing out each SGD step over multiple updates, and each individual SGD step goes in the direction of steepest descent.)
AdamW, on the other hand, is basically the same as Adam at first, since weight decay doesn't do much when the weights are small. Since Adam is a scale-invariant, coordinate-wise adaptive learning rate algorithm, it will move at the same speed for each of the 101 coordinates in the direction which reduces loss, moving towards the solution , i.e. with heavy weight on the memorization solution. Weight decay will start to kick in a bit before this point, and over time AdamW will converge to (close to) the same minimum-norm solution as SGD. This is the phase transition from memorization to generalization.
Thanks, this is awesome! Especially cool that you were able to reverse-engineer the learned algorithm! And the theory/analysis seems great too.
[Just thinking aloud my confusion, not important] --Not sure I understand the difference between random walk hypothesis and lottery ticket inspired hypothesis. In both cases there's a phase transition caused by accelerating gradients once weak versions of all the components of the nonmodular circuit are in place. --The central piece of evidence you bring up is what exactly... that the effectiveness of the nonmodular circuit rises before the phase transition begins. I guess I'm a bit confused about how this works. If I imagine a circuit consisting of a modular memorized answer part and another part, the generalizing bit, that actually computes the answer... if the second part is really crappy/weak/barely there, I don't see how making it stronger is incentivised. Making it stronger means more weights so the regularization should push against it, UNLESS you can simultaneously delete or dampen weights from the memorized answer part, right? But because the generalizing part is so weak, if you delete or dampen weights from the memorized answer part, won't loss go up? Because the generalizing part isn't ready yet to take over the job of the memorized answers?
Thanks! I agree that they're pretty hard to distinguish, and evidence between them is fairly weak - it's hard to distinguish between a winning lottery ticket at initialisation vs one stumbled upon within the first 200 steps, say.
My favourite piece of evidence is [this video from Eric Michaud](https://twitter.com/ericjmichaud_/status/1559305105521926144) - we know that the first 2 principle components of the embedding form a circle at the end of training. But if we fix the axes at the end of training, and project the embedding at the start of training, it's pretty circle-y
Making it stronger means more weights so the regularization should push against it, UNLESS you can simultaneously delete or dampen weights from the memorized answer part, right?
I think this does happen (and is very surprising to me!). If you look at the excluded loss section, I ablate the model's ability to use one component of the generalising algorithm, in a way that shouldn't affect the memorising algorithm (much), and see the damage of this ablation rise smoothly over training. I hypothesise that it's dampening memorisation weights simulatenously, though haven't dug deep enough to be confident. Regardless, it clearly seems to be doing some kind of interpolation - I have a lot of progress measures that all (both qualitatively and quantitatively) show clear progress towards generalisation pre grokking.
You suggest that first the model quickly memorises a chunk of the data set, then (slowly) learns rules afterward to simplify itself. (You mention that noticing x + y = y + x could allow it to half its memorised library with no loss).
I am interested in another way this process could go. We could imagine that the early models are doing something like "Check for the input in my library of known results, if its not in there just guess something." So the model has two parts, [memorised] and [guess]. With some kind of fork to decide between them.
At first the guess is useless, and improves very slowly. So making sure the memorised information is correct and compact is the majority of the improvement. Eventually the memory saturates, and the only way to progress is refining the guess so the model does slightly-less awful in the cases it hasn't memorised. So here the guess steadily improves, but these improvements don't show up in the performance because the guess is still a fallback option that is rarely used. In this picture the phase change is presumably the point where the "guess" becomes so good the model "realises" it no longer needs its memorised library or the "if" tree.
These two pictures are quite different. In one information about the problem makes the library more efficient, leading eventually to an algorithm in the other the two strategies (memory, algorithm) are essentially independent and one eventually makes the other redundant.
Just the random thoughts of someone interested with no knowledge of the topic.
Interesting hypothesis, thanks!
My guess is that memorisation isn't really a discrete thing - it's not that the model has either memorised a data point or not, it's more that it's fitting a messed up curve to approximated all training data as well as it can. And more parameters means it memorises all data points a bit better, not that it has excellent loss on some data and terrible loss on others, and gradually the excellent set expands.
I haven't properly tested this though! And I'm using full batch training, which probably messes this up a fair bit.
Possibly also of interest: https://cprimozic.net/blog/reverse-engineering-a-small-neural-network/
You don't talk about human analogs of grokking, and that makes sense for a technical paper like this. Nonetheless, grokking also seems to happen in humans, and everybody has had "Aha!" moments before. Can you maybe comment a bit on the relation to human learning? It seems clear that human grokking is not a process that purely depends on the number of training samples seen but also on the availability of hypotheses. People grok faster if you provide them with symbolic descriptions of what goes on. What are your thoughts on the representation and transfer of the resulting structure, e.g., via language/token streams?
Hmm. So firstly, I don't think ML grokking and human grokking having the same name is that relevant - it could just be a vague analogy. And I definitely don't claim to understand neuroscience!
That said, I'd guess there's something relevant about phase changes? Internally, I know that I initially feel very confused, then have some intuition of 'I half see some structure but it's super fuzzy', and then eventually things magically click into place. And maybe there's some similar structure around how phase changes happen - useful explanations get reinforced, and as explanations become more useful they become reinforced much faster, leading to a feeling of 'clicking'
It seems less obvious to me that human grokking looks like 'stare at the same data points a bunch of times until things click'. It also seems plausible that there's some transfer learning going on - here I train models from scratch, but when I personally grok things it feels like I'm fitting the new problem into existing structure in my mind - maybe analogous to how induction heads get repurposed for few shot learning
This is honestly some of the most significant alignment work I've seen in recent years (for reasons I plan to post on shortly), thank you for going to all this length!
Typo: "Thoughout this process test loss remains low - even a partial memorising solution still performs extremely badly on unseen data!", 'low' should be 'high' (and 'throughout' is misspelled too).
Thanks, I really appreciate it. Though I wouldn't personally consider this among the significant alignment work, and would love to hear about why you do!
I had an idea when reading it that I think is pretty interesting. You mention that both the grokking of a small amount of data repeated many times, and models trained on a great deal of data are highly general. Repeated data during training is also mentioned as a significant negative for large models. These are very much in tension.
My idea is this. Split the training data into two parts, one vastly larger than the other. First, train a model on a small amount of data many times in a way designed to make it grok the task, such as weight decay. Second, train it on the rest of the very large amount of data. Third, compare it to the a model trained on both parts without repeats. See how they compare. (I don't know if people have done this. I'm definitely a layman when it comes to such things.)
Models are often 'fine-tuned' on things later. You could see this as sort of the opposite, where we tune it for the task first, and then train it.
Repeated data during training is also mentioned as a significant negative for large models. These are very much in tension.
To be clear, the paper I cite on data quality focuses on how repeated data is bad for generalisation. From the model's perspective, the only thing it cares about is train loss (and maybe simplicity), and repeated data is great for train loss! The model doesn't care whether it generalises, only whether generalisation is a "more efficient" solution. Grokking happens when the amount of data is such that the model marginally prefers the correct solution, but there's no reason to expect that repeated data screwing over models is exactly the amount of data such that the correct solution is better.
Though the fact that larger models are messed up by fewer repeated data points is fascinating - I don't know if this is a problem with my hypothesis, or just a statement about the relative complexity of different circuits in larger vs smaller models.
Your experiment idea is interesting, I'm not sure what I'd expect to happen! I'd love to see someone try it, and am not aware of anyone who has (the paper I cite is vaguely similar - there they train the model on the repeated data and unrepeated data shuffled together, and compare it to a model trained on just the unrepeated data).
Though I do think that if this is a real task there wouldn't be an amount of data that leads to general grokking, rather than amount of data to grok varies heavily between different circuits.
Good stuff. A few thoughts:
1. Assuming a model has memorized the training data, and still have enough "spare capacity" to play lottery ticket hypothesis to find generalizing solutions to a subset of the memorized data, you'll eventually end up with a number of partial solutions that generalize to a subset of the memorized data (obviously assuming some form of regularization towards simplicity). So this may be where the "underparametrized" regime of ML of the past went wrong: That approach tried to force the model into generalization without memorization, but by being stingy with parameters, forced the model to first and foremost memorize -- there was no spare capacity to "play / experiment with possibly generalizing solutions" left. This then led to memorization-only models, to which researchers reacted by restricting parameters more ...
2. Occam's razor favors simpler models (for some definition of simplicity) over more complex models, given equal predictive power. The best definition of "model simplicity" that we have may in fact be Kolmogorov complexity of the weight matrices. This would mean that if we want a model to apply Occam's razor, we should see if we can use a measure of Kolmogorov complexity of the weights as regularization. The "best" approximation we currently have for Kolmogorov complexity is ... compression, which in itself is a prediction problem. So perhaps the way to encourage good generalization in models is to measure how good the weights can be predicted by another model (?). Apologies if this may sound like a crackpot idea.
3. It might be worth experimenting with shifting the regularization term during training, initially encouraging wide connectivity, and then shifting to either sparsity or low Kolmogorov complexity. There's an intriguing parallel to synaptic pruning in childhood.
So perhaps the way to encourage good generalization in models is to measure how good the weights can be predicted by another model (?). Apologies if this may sound like a crackpot idea.
Interesting idea. The natural next question is: how would you use that second model to determine the kolmogorov complexity (or a metric similar to kolmogorov complexity) of the first model's weights? Let's say you want to use the complexity of the second model, assuming that it is the simplest possible model that can predict the first models weights, to help you determine that. But in order to satisfy that assumption, you could use a third model in a similar way to minimize the complexity of the second. And so on. Eventually you need to determine the complexity of the weights without training another model, using some metric (whether its weight norms, performance after pruning, or [insert clever method from the future]). Why not just apply this metric to the first model and not train additional ones?
That said, I could be overlooking something and empirical results could suggest otherwise, so it could still be worth testing the idea out.
Some potentially naive thoughts/questions:
Idk, it might be related to double descent? I'm not that convinced.
Firstly, IMO, the most interesting part of deep double descent is the model size wise/data wise descent, which totally don't apply here.
They did also find epoch wise (different from data wise, because it's trained on the same data a bunch), which is more related, but looks like test loss going down, then going up again, then going down. You could argue that grokking has test loss going up, but since it starts at uniform test loss I think this doesn't count.
My guess is that the descent part of deep double descent illustrates some underlying competition between different circuits in the model, where some do memorisation and others do generalisation, and there's some similar competition and switching. Which is cool and interesting and somewhat related! And it totally wouldn't surprise me if there's some similar tension between hard to reach but simple and easy to reach but complex.
Re single basin, idk, I actually think it's a clear disproof of the single basin hypothesis (in this specific case - it could still easily be mostly true for other problems). Here, there's a solution to modular addition for each of the 56 frequencies! These solutions are analogous, but they are fairly far apart in model space, and definitely can't be bridged by just permuting the weights. (eg, the embedding picking up on vs is a totally different solution and set of weights, and would require significant non-linear shuffling around of parameters)