Rohin Shah

Research Scientist at DeepMind. Creator of the Alignment Newsletter. http://rohinshah.com/

Sequences

Value Learning
Alignment Newsletter

Wiki Contributions

Comments

Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out.

Yes, that's what I mean.

I do agree that it's useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.

Do you have any plans to do this?

No, we're moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn't that unexpected, from the start we expected "science of deep learning" to be more hits-based, or to require significant progress before it actually became useful for practical proposals).

How much time do you think it would take?

Actually running the experiments should be pretty straightforward, I'd expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I'd still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like "under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases".

The hard part is then interpreting those results and turning them into something more generalizable -- including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses -- this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don't know how long it would take if you want to include that; it could be quite a while (e.g. months or years).

And do you have any predictions for what should happen in these cases?

Not really. I've learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.

For example, in our minimal model in Section 3 we have a hyperparameter  that determines how param norm and logits scale together -- initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).

(I say "seemed to falsify" because it's still possible that we're just failing to deal with confounders in some way, or measuring something that isn't exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels -- maybe there's a relevant difference between these.)

Which of these theories [...] can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.

I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:

  1. Difficulty of representation learning: Shrugs at our prediction about  /  efficiencies, anti-predicts ungrokking (since in that case the representation has already been learned), shrugs at semi-grokking.
  2. Scale of parameters at initialisation: Shrugs at all of our predictions. If you interpret it as making a strong claim that scale of parameters at initialisation is the crucial thing (i.e. other things mostly don't matter) then it anti-predicts semi-grokking.
  3. Spikes in loss / slingshots: Shrugs at all of our predictions.
  4. Random walks among optimal solutions: Shrugs at our prediction about  /  efficiencies. I'm not sure what this theory says about what happens after you hit the generalising solution -- can you then randomly walk away from the generalising solution? If yes, then it predicts that if you train for a long enough time without changing the dataset, a grokked network will ungrok (false in our experiments, and we often trained for much longer than time to grok); if no then it anti-predicts ungrokking and semi-grokking.
  5. Simplicity of the generalising solution: This is our explanation. Our paper is basically an elaboration, formalization, and confirmation of Nanda et al's theory, as we allude to in the next sentence after the one you quoted.

how does this theory explain other grokking related pheonmena e.g. Omni-Grok?

My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.

Happy to speculate on other grokking phenomena as well (though I don't think there are many others?)

And how do things change as you increase parameter count?

We haven't investigated this, but I'd pretty strongly predict that there mostly aren't major qualitative changes. (The one exception is semi-grokking; there's a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)

I expect there would be quantitative changes (e.g. maybe the value of  changes, maybe the time taken to learn  changes). Sufficiently big changes in  might mean you don't see the phenomena on modular addition any more, but I'd still expect to see them in more complicated tasks that exhibit grokking.

I'd be interested in investigations that got into these quantitative questions (in addition to the above, there's also things like "quantitatively, how does the strength of weight decay affect the time for  to be learned?", and many more).

From page 6 of the paper:

Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once , if we vary  we predict that there will be a sharp transition from very strong to near-random test accuracy (around ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since  does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay.

(All of these predictions are then confirmed in the experimental section.)

I think that post has a lot of good ideas, e.g. the idea that generalizing circuits get reinforced by SGD more than memorizing circuits at least rhymes with what we claim is actually going on (that generalizing circuits are more efficient at producing strong logits with small param norm). We probably should have cited it, I forgot that it existed.

But it is ultimately a different take and one that I think ends up being wrong (e.g. I think it would struggle to explain semi-grokking).

I also think my early explanation, which that post compares to, is basically as good or better in hindsight, e.g.:

  1. My early explanation says that memorization produces smaller probabilities than generalization. Quintin's post explicitly calls this out as a difference, I continued to endorse my position in the comments. In hindsight my explanation was right and Quintin's was wrong, at least if you believe our new paper.
  2. My early explanation relies on the assumption that there are no "intermediate" circuits, only a pure memorization and pure generalization circuit that you can interpolate between. Again, this is called out as a difference by Quintin's post, and I continued to endorse my position in the comments. Again, I think in hindsight my explanation was right and Quintin's was wrong (though this is less clearly implied by our paper, and I could still imagine future evidence overturning that conclusion, though I really don't expect that to happen).
  3. On the other hand, my early explanation involves a random walk to the generalizing circuit, whereas in reality it develops smoothly over time. In hindsight, my explanation was wrong, and Quintin's is correct.

I think I would particularly critique DeepMind and OpenAI's interpretability works, as I don't see how this reduces risks more than other works that they could be doing, and I'd appreciate a written plan of what they expect to achieve.

I can't speak on behalf of Google DeepMind or even just the interpretability team (individual researchers have pretty different views), but I personally think of our interpretability work as primarily a bet on creating new affordances upon which new alignment techniques can be built, or existing alignment techniques can be enhanced. For example:

  • It is possible to automatically make and verify claims about what topics a model is internally "thinking about" when answering a question. This is integrated into debate, and allows debaters to critique each other's internal reasoning, not just the arguments they externally make.
    • (It's unclear how much this buys you on top of cross-examination.)
  • It is possible to automatically identify "cruxes" for the model's outputs, making it easier for adversaries to design situations that flip the crux without flipping the overall correct decision.
    • Redwood's adversarial training project is roughly in this category, where the interpretability technique is saliency, specifically magnitude of gradient of the classifier output w.r.t. the token embedding.
    • (Yes, typical mech interp directions are far more detailed than saliency. The hope is that they would produce affordances significantly more helpful and robust than saliency.)
    • A different theory of change for the same affordance is to use it to analyze warning shots, to understand the underlying cause of the warning shot (was it deceptive alignment? specification gaming? mistake from not knowing a relevant fact? etc).

I don't usually try to backchain too hard from these theories of change to work done today; I think it's going to be very difficult to predict in advance what kind of affordances we might build in the future with years' more work (similarly to Richard's comment, though I'm focused more on affordances than principled understanding of deep learning; I like principled understanding of deep learning but wouldn't be doing basic research on interpretability if that was my goal).

My attitude is much more that we should be pushing on the boundaries of what interp can do, and as we do so we can keep looking out for new affordances that we can build. As an example of how I reason about what projects to do, I'm now somewhat less excited about projects that do manual circuit analysis of an algorithmic task. They do still teach us new stylized facts about LLMs like "there are often multiple algorithms at different 'strengths' spread across the model" that can help with future mech interp, but overall it feels like these projects aren't pushing the boundaries as much as seems possible, because we're using the same, relatively-well-vetted techniques for all of these projects.

I'm also more keen on applying interpretability to downstream tasks (e.g. fixing issues in a model, generating adversarial examples), but not necessarily because I think it will be better than alternative methods today, but rather because I think the downstream task keeps you honest (if you don't actually understand what's going on, you'll fail at the task) and because I think practice with downstream tasks will help us notice which problems are important to solve vs. which can be set aside. This is an area where other people disagree with me (and I'm somewhat sympathetic to their views, e.g. that the work that best targets a downstream task won't tackle fundamental interp challenges like superposition as well as work that is directly trying to tackle those fundamental challenges).

(EDIT: I mostly agree with Ryan's comment, and I'll note that I am considering a much wider category of work than he is, which is part of why I usually say "interpretability" rather than "mechanistic interpretability".)


Separately, you say:

I don't see how this reduces risks more than other works that they could be doing

I'm not actually sure why you believe this. I think on the views you've expressed in this post (which, to be clear, I often disagree with), I feel like you should think that most of our work is just as bad as interpretability.

In particular we're typically in the business of building aligned models. As far as I can tell, you think that interpretability can't be used for this because (1) it is dual use, and (2) if you optimize against it, you are in part optimizing for the AI system to trick your interpretability tools. But these two points seem to apply to any alignment technique that is aiming to build aligned models. So I'm not sure what other work (within the "build aligned models" category) you think we could be doing that is better than interpretability.

(Similarly, based on the work you express excitement about in your post, it seems like you are targeting an endgame of "indefinite, or at least very long, pause on AI progress". If that's your position I wish you would have instead written a post that was instead titled "against almost every theory of impact of alignment" or something like that.)

Yeah, that seems like a reasonable operationalization of "capable of doing X". So my understanding is that (1), (3), (6) and (7) would not falsify the hypothesis under your operationalization, (5) would falsify it, (2) depends on details, and (4) is kinda ambiguous but I tend to think it would falsify it.

Which of (1)-(7) above would falsify the hypothesis if observed? Or if there isn't enough information, what additional information do you need to tell whether the hypothesis has been falsified or not?

The “no sandbagging on checkable tasks” hypothesis: With rare exceptions, if a not-wildly-superhuman ML model is capable of doing some task X, and you can check whether it has done X, then you can get it to do X using already-available training techniques (e.g., fine-tuning it using gradient descent).[1]

I think as phrased this is either not true, or tautological, or otherwise imprecisely specified (in particular I'm not sure what it means for a model to be "capable of" doing some task X -- so far papers define that to be "can you quickly finetune the model to do X"; if you use that definition then it's tautological).

Here are some hypotheticals, all of which seem plausible to me, that I think are useful test cases for your hypothesis (and would likely falsify a reasonable reading of it):

  1. You spend T time trying to prompt a model to solve a task X, and fail to do so, and declare that the model can't do X. Later someone else spends T time trying to prompt the same model to solve X, and succeeds, because they thought of a better prompt than you did.
  2. Like (1), but both you and the other person tried lots of current techniques (prompting, finetuning, chain of thought, etc).
  3. You spend $100 million pretraining a model, and then spend $1,000 of compute to finetune it, and observe it can only get a 50% success rate, so you declare it incapable of doing task X. Later you spend $1 million of compute to finetune it (with a correspondingly bigger dataset), and observe it can now get a 95% accuracy on the task.
  4. Like (3), but later you still spend $1,000 of compute to finetune it, but with a much more curated and high-quality dataset, which gets you from 50% to 95%.
  5. You evaluate GPT-4 using existing techniques and observe that it can't do task X. In 2033, somebody goes back and reevaluates GPT-4 using 2033 techniques (with the same data and similar compute, let's say) and now it does well on task X.
  6. You evaluate a model using existing techniques and observe that it can't do task X. A domain expert comes in and looks at the transcripts of the models, figures out the key things the model is struggling with, writes up a set of guidelines, and puts those in the prompt. The model can now do task X.
  7. Like (6), but the domain expert is a different AI system. (Perhaps one with a larger base model, or perhaps one that was engineered with a lot of domain-specific schlep.)

I think you probably want your hypothesis to instead say something like "if given full autonomy, the AI system could not perform better at task X than what we can elicit with currently-available techniques". (You'd also want to postulate that the AI system only gets to use a similar amount of resources for finetuning as we use.)

(At some point AI will be better at eliciting capabilities from other AI systems than unaided humans; so at that point "currently-available techniques" would have to include ones that leverage AI systems for eliciting capabilities if we wanted the statement to continue to hold. This is a feature of the definition, not a bug.)

I expect a delay even in the infinite data case, I think?

Although I'm not quite sure what you mean by "infinite data" here -- if the argument is that every data point will have been seen during training, then I agree that there won't be any delay. But yes training on the test set (even via "we train on everything so there is no possible test set") counts as cheating for this purpose.

Honestly I'd be surprised if you could achieve (2) even with explicit regularization, specifically on the modular addition task.

(You can achieve it by initializing the token embeddings to those of a grokked network so that the representations are appropriately structured; I'm not allowing things like that.)

EDIT: Actually, Omnigrok does this by constraining the parameter norm. I suspect this is mostly making it very difficult for the network to strongly memorize the data -- given the weight decay parameter the network "tries" to learn a high-param norm memorizing solution, but then repeatedly runs into the parameter norm constraint -- and so creates a very strong reason for the network to learn the generalizing algorithm. But that should still count as normal regularization.

Load More