Research Scientist at DeepMind. Creator of the Alignment Newsletter. http://rohinshah.com/
Which of these theories [...] can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:
how does this theory explain other grokking related pheonmena e.g. Omni-Grok?
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.
Happy to speculate on other grokking phenomena as well (though I don't think there are many others?)
And how do things change as you increase parameter count?
We haven't investigated this, but I'd pretty strongly predict that there mostly aren't major qualitative changes. (The one exception is semi-grokking; there's a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)
I expect there would be quantitative changes (e.g. maybe the value of changes, maybe the time taken to learn changes). Sufficiently big changes in might mean you don't see the phenomena on modular addition any more, but I'd still expect to see them in more complicated tasks that exhibit grokking.
I'd be interested in investigations that got into these quantitative questions (in addition to the above, there's also things like "quantitatively, how does the strength of weight decay affect the time for to be learned?", and many more).
From page 6 of the paper:
Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once , if we vary we predict that there will be a sharp transition from very strong to near-random test accuracy (around ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay.
(All of these predictions are then confirmed in the experimental section.)
I think that post has a lot of good ideas, e.g. the idea that generalizing circuits get reinforced by SGD more than memorizing circuits at least rhymes with what we claim is actually going on (that generalizing circuits are more efficient at producing strong logits with small param norm). We probably should have cited it, I forgot that it existed.
But it is ultimately a different take and one that I think ends up being wrong (e.g. I think it would struggle to explain semi-grokking).
I also think my early explanation, which that post compares to, is basically as good or better in hindsight, e.g.:
I think I would particularly critique DeepMind and OpenAI's interpretability works, as I don't see how this reduces risks more than other works that they could be doing, and I'd appreciate a written plan of what they expect to achieve.
I can't speak on behalf of Google DeepMind or even just the interpretability team (individual researchers have pretty different views), but I personally think of our interpretability work as primarily a bet on creating new affordances upon which new alignment techniques can be built, or existing alignment techniques can be enhanced. For example:
I don't usually try to backchain too hard from these theories of change to work done today; I think it's going to be very difficult to predict in advance what kind of affordances we might build in the future with years' more work (similarly to Richard's comment, though I'm focused more on affordances than principled understanding of deep learning; I like principled understanding of deep learning but wouldn't be doing basic research on interpretability if that was my goal).
My attitude is much more that we should be pushing on the boundaries of what interp can do, and as we do so we can keep looking out for new affordances that we can build. As an example of how I reason about what projects to do, I'm now somewhat less excited about projects that do manual circuit analysis of an algorithmic task. They do still teach us new stylized facts about LLMs like "there are often multiple algorithms at different 'strengths' spread across the model" that can help with future mech interp, but overall it feels like these projects aren't pushing the boundaries as much as seems possible, because we're using the same, relatively-well-vetted techniques for all of these projects.
I'm also more keen on applying interpretability to downstream tasks (e.g. fixing issues in a model, generating adversarial examples), but not necessarily because I think it will be better than alternative methods today, but rather because I think the downstream task keeps you honest (if you don't actually understand what's going on, you'll fail at the task) and because I think practice with downstream tasks will help us notice which problems are important to solve vs. which can be set aside. This is an area where other people disagree with me (and I'm somewhat sympathetic to their views, e.g. that the work that best targets a downstream task won't tackle fundamental interp challenges like superposition as well as work that is directly trying to tackle those fundamental challenges).
(EDIT: I mostly agree with Ryan's comment, and I'll note that I am considering a much wider category of work than he is, which is part of why I usually say "interpretability" rather than "mechanistic interpretability".)
Separately, you say:
I don't see how this reduces risks more than other works that they could be doing
I'm not actually sure why you believe this. I think on the views you've expressed in this post (which, to be clear, I often disagree with), I feel like you should think that most of our work is just as bad as interpretability.
In particular we're typically in the business of building aligned models. As far as I can tell, you think that interpretability can't be used for this because (1) it is dual use, and (2) if you optimize against it, you are in part optimizing for the AI system to trick your interpretability tools. But these two points seem to apply to any alignment technique that is aiming to build aligned models. So I'm not sure what other work (within the "build aligned models" category) you think we could be doing that is better than interpretability.
(Similarly, based on the work you express excitement about in your post, it seems like you are targeting an endgame of "indefinite, or at least very long, pause on AI progress". If that's your position I wish you would have instead written a post that was instead titled "against almost every theory of impact of alignment" or something like that.)
Yeah, that seems like a reasonable operationalization of "capable of doing X". So my understanding is that (1), (3), (6) and (7) would not falsify the hypothesis under your operationalization, (5) would falsify it, (2) depends on details, and (4) is kinda ambiguous but I tend to think it would falsify it.
Which of (1)-(7) above would falsify the hypothesis if observed? Or if there isn't enough information, what additional information do you need to tell whether the hypothesis has been falsified or not?
The “no sandbagging on checkable tasks” hypothesis: With rare exceptions, if a not-wildly-superhuman ML model is capable of doing some task X, and you can check whether it has done X, then you can get it to do X using already-available training techniques (e.g., fine-tuning it using gradient descent).[1]
I think as phrased this is either not true, or tautological, or otherwise imprecisely specified (in particular I'm not sure what it means for a model to be "capable of" doing some task X -- so far papers define that to be "can you quickly finetune the model to do X"; if you use that definition then it's tautological).
Here are some hypotheticals, all of which seem plausible to me, that I think are useful test cases for your hypothesis (and would likely falsify a reasonable reading of it):
I think you probably want your hypothesis to instead say something like "if given full autonomy, the AI system could not perform better at task X than what we can elicit with currently-available techniques". (You'd also want to postulate that the AI system only gets to use a similar amount of resources for finetuning as we use.)
(At some point AI will be better at eliciting capabilities from other AI systems than unaided humans; so at that point "currently-available techniques" would have to include ones that leverage AI systems for eliciting capabilities if we wanted the statement to continue to hold. This is a feature of the definition, not a bug.)
I expect a delay even in the infinite data case, I think?
Although I'm not quite sure what you mean by "infinite data" here -- if the argument is that every data point will have been seen during training, then I agree that there won't be any delay. But yes training on the test set (even via "we train on everything so there is no possible test set") counts as cheating for this purpose.
Honestly I'd be surprised if you could achieve (2) even with explicit regularization, specifically on the modular addition task.
(You can achieve it by initializing the token embeddings to those of a grokked network so that the representations are appropriately structured; I'm not allowing things like that.)
EDIT: Actually, Omnigrok does this by constraining the parameter norm. I suspect this is mostly making it very difficult for the network to strongly memorize the data -- given the weight decay parameter the network "tries" to learn a high-param norm memorizing solution, but then repeatedly runs into the parameter norm constraint -- and so creates a very strong reason for the network to learn the generalizing algorithm. But that should still count as normal regularization.
Yes, that's what I mean.
I do agree that it's useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.
No, we're moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn't that unexpected, from the start we expected "science of deep learning" to be more hits-based, or to require significant progress before it actually became useful for practical proposals).
Actually running the experiments should be pretty straightforward, I'd expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I'd still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like "under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases".
The hard part is then interpreting those results and turning them into something more generalizable -- including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses -- this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don't know how long it would take if you want to include that; it could be quite a while (e.g. months or years).
Not really. I've learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.
For example, in our minimal model in Section 3 we have a hyperparameter κ that determines how param norm and logits scale together -- initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).
(I say "seemed to falsify" because it's still possible that we're just failing to deal with confounders in some way, or measuring something that isn't exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels -- maybe there's a relevant difference between these.)