Lawrence Chan

I do AI Alignment research. Currently at ARC Evals, though I still dabble in grantmaking and interpretability in my spare time. 

I'm also currently on leave from my PhD at UC Berkeley's CHAI. 

Obligatory research billboard website: https://chanlawrence.me/

Sequences

(Lawrence's) Reflections on Research
[Redwood Research] Causal Scrubbing

Wiki Contributions

Comments

I mean, yeah, as your footnote says:

Another simpler but less illuminating way to put this is that higher serial reasoning depth can't be parallelized.[1]

Transformers do get more computation per token on longer sequences, but they also don't get more serial depth, so I'm not sure if this is actually an issue in practice?

 

  1. ^

    [C]ompactly represent  (f composed with g) in a way that makes computing  more efficient for general choices of  and .

    As an aside, I actually can't think of any class of interesting functions with this property -- when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)

     

I finally got around to reading the Mamba paper. H/t Ryan Greenblatt and Vivek Hebbar for helpful comments that got me unstuck. 

TL;DR: authors propose a new deep learning architecture for sequence modeling with scaling laws that match transformers while being much more efficient to sample from.

A brief historical digression

As of ~2017, the three primary ways people had for doing sequence modeling were RNNs, Conv Nets, and Transformers, each with a unique “trick” for handling sequence data: recurrence, 1d convolutions, and self-attention.
 

  • RNNs are easy to sample from — to compute the logit for x_t+1, you only need the most recent hidden state h_t and the last token x_t, which means it’s both fast and memory efficient. RNNs generate a sequence of length L with O(1) memory and O(L) time. However, they’re super hard to train, because  you need to sequentially generate all the hidden states and then (reverse) sequentially calculate the gradients. The way you actually did this is called backpropogation through time — you basically unroll the RNN over time — which requires constructing a graph of depth equal to the sequence length. Not only was this slow, but the graph being so deep caused vanishing/exploding gradients without careful normalization. The strategy that people used was to train on short sequences and finetune on longer ones. That being said, in practice, this meant you couldn’t train on long sequences (>a few hundred tokens) at all. The best LSTMs for modeling raw audio could only handle being trained on ~5s of speech, if you chunk up the data into 25ms segments.
  • Conv Nets had a fixed receptive field size and pattern, so weren’t that suited for long sequence  modeling. Also, generating each token takes O(L) time, assuming the receptive field is about the same size as the sequence. But they had significantly more stability (the depth was small, and could be as low as O(log(L))), which meant you could train them a lot easier. (Also, you could use FFT to efficiently compute the conv, meaning it trains one sequence in O(L log(L)) time.) That being said, you still couldn’t make them that big. The most impressive example was DeepMind’s WaveNet, conv net used to model human speech, and could handle up sequences up to 4800 samples … which was 0.3s of actual speech at 16k samples/second (note that most audio is sampled at 44k samples/second…), and even to to get to that amount, they had to really gimp the model’s ability to focus on particular inputs.
  • Transformers are easy to train, can handle variable length sequences, and also allow the model to “decide” which tokens it should pay attention to. In addition to both being parallelizable and having relatively shallow computation graphs (like conv nets), you could do the RNN trick of pretraining on short sequences and then finetune on longer sequences to save even more compute. Transformers could be trained with comparable sequence length to conv nets but get much better performance; for example, OpenAI’s musenet was trained on sequence length 4096 sequences of MIDI files. But as we all know, transformers have the unfortunate downside of being expensive to sample from — it takes O(L) time and O(L) memory to generate a single token (!).

The better performance of transformers over conv nets and their ability to handle variable length data let them win out.

That being said, people have been trying to get around the O(L) time and memory requirements for transformers since basically their inception. For a while, people were super into sparse or linear attention of various kinds, which could reduce the per-token compute/memory requirements to O(log(L)) or O(1).

The what and why of Mamba

If the input -> hidden and hidden -> hidden map for RNNs were linear (h_t+1 = A h_t + B x_t), then it’d be possible to train an entire sequence in parallel — this is because you can just … compose the transformation with itself (computing A^k for k in 2…L-1) a bunch, and effectively unroll the graph with the convolutional kernel defined by A B, A^2 B, A^3 B, … A^{L-1} B. Not only can you FFT during training to get the O(L log (L)) time of a conv net forward/backward pass (as opposed to O(L^2) for the transformer), you still keep the O(1) sampling time/memory of the RNN!

The problem is that linear hidden state dynamics are kinda boring. For example, you can’t even learn to update your existing hidden state in a different way if you see particular tokens! And indeed, previous results gave scaling laws that were much worse than transformers in terms of performance/training compute. 

In Mamba, you basically learn a time varying A and B. The parameterization is a bit wonky here, because of historical reasons, but it goes something like: A_t is exp(-\delta(x_t) * exp(A)), B_t = \delta(x_t) B x_t, where \delta(x_t) = softplus ( W_\delta x_t). Also note that in Mamba, they also constrain A to be diagonal and W_\delta to be low rank, for computational reasons  

Since exp(A) is diagonal and has only positive entries, we can interpret the model as follows: \delta controls how much to “learn” from the current example — with high \delta, A_t approaches 0 and B_t is large, causing h_t+1 ~= B_t x_t, while with \delta approaching 0, A_t approaches 1 and B_t approaches 0, meaning h_t+1 ~= h_t.

Now, you can’t exactly unroll the hidden state as a convolution with a predefined convolution kernel anymore, but you can still efficiently compute the implied “convolution” using parallel scanning.

Despite being much cheaper to sample from, Mamba matches the pretraining flops efficiency of modern transformers (Transformer++ = the current SOTA open source Transformer with RMSNorm, a better learning rate schedule, and corrected AdamW hyperparameters, etc.). And on a toy induction task, it generalizes to much longer sequences than it was trained on.

So, about those capability externalities from mech interp...

Yes, those are the same induction heads from the Anthropic ICL paper! 

Like the previous Hippo and Hyena papers they cite mech interp as one of their inspirations, in that it inspired them to think about what the linear hidden state model could not model and how to fix that. I still don’t think mech interp has that much Shapley here (the idea of studying how models perform toy tasks is not new, and the authors don't even use induction metric or RRT task from the Olsson et al paper), but I'm not super sure on this.

IMO, this is line of work is the strongest argument for mech interp (or maybe interp in general) having concrete capabilities externalities. In addition, I think the previous argument Neel and I gave of "these advances are extremely unlikely to improve frontier models" feels substantially weaker now.

Is this a big deal?

I don't know, tbh.

Right, the step I missed on was that P(X|Y) = P(X|Z) for all y, z implies P(X|Z) = P(X). Thanks!

Hm, it sounds like you're claiming that if each pair of x, y, z are pairwise independent conditioned on the third variable, and p(x, y, z) =/= 0 for all x, y, z with nonzero p(x), p(y), p(z), then ?

I tried for a bit to show this but couldn't prove it, let alone the general case without strong invariance. My guess is I'm probably missing something really obvious. 

Probabilities of zero are extremely load-bearing for natural latents in the exact case, and probabilities near zero are load-bearing in the approximate case; if the distribution is zero nowhere, then it can only have a natural latent if the ’s are all independent (in which case the trivial variable is a natural latent).

I'm a bit confused why this is the case. It seems like in the theorems, the only thing "near zero" is that D_KL (joint, factorized) < epsilon ~= 0 . But you. can satisfy this quite easily even with all probabilities > 0. 

E.g. the trivial case where all variables are completely independents satisfies all the conditions of your theorem, but can clearly have every pair of probabilities > 0. Even in nontrivial cases, this is pretty easy (e.g. by mixing in irreducible noise with every variable).

After having spent a few hours playing with Opus, I think "slightly better than best public gpt-4" seems qualitatively correct -- both models tend to get tripped up on the same kinds of tasks, but Opus can inconsistently solve some tasks in my workflow that gpt-4 cannot. 

And yeah, it seems likely that I will also swap to Claude over ChatGPT. 

(I haven't had the chance to read part 3 in detail, and I also haven't checked the proofs except insofar as they seem reasonable on first viewing. Will probably have a lot more thoughts after I've had more time to digest.)

This is very cool work! I like the choice of U-AND task, which seems way more amenable to theoretical study (and is also a much more interesting task) than the absolute value task studied in Anthropic's Toy Model of Superposition (hereafter TMS). It's also nice to study this toy task with asymptotic theoretical analysis as opposed to the standard empirical analysis, thereby allowing you to use a different set of tools than usual. 

The most interesting part of the results was the discussion on the universality of universal calculation -- it reminds me of the interpretations of the lottery ticket hypothesis that claim some parts of the network happen to be randomly initialized to have useful features at the start. 

Some examples that are likely to be boolean-interpretable are bigram-finding circuits and induction heads. However, it's possible that most computations are continuous rather than boolean[31].

My guess is that most computations are indeed closer to continuous than to boolean. While it's possible to construct boolean interpretations of bigram circuits or induction heads, my impression (having not looked at either in detail on real models) is that neither of these cleanly occur inside LMs. For example, induction heads demonstrate a wide variety of other behavior, and even on induction-like tasks, often seem to be implementing induction heuristics that involve some degree of semantic content. 

Consequently, I'd be especially interested in exploring either the universality of universal calculation, or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition).


Some nitpicks:

The post would probably be a lot more readable if it were chunked into 4.  The 88 minute read time is pretty scary, and I'd like to comment only on the parts I've read. 

Section 2:

Two reasons why this loss function might be principled are

  1. If there is reason to think of the model as a Gaussian probability model
  2. If we would like our loss function to be basis independent

A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can't gradient descent cleanly through eps-accuracy). 

 

Section 5:

4 How relevant are our results to real models?

This should be labeled as section 5. 

 

Appendix to the Appendix:

Here, $f_i$ always denotes the vector.

[..]

with \[\sigma_1\leq n\) with

(TeX compilation failure)

As with the CCS post, I'm reviewing both the paper and the post, though the majority of the review is on the paper. Writing this quickly (total time on review: ~1.5h), but I expect to be willing to defend the points being made --

There's a lot of reasons I like the work. It's an example of:

  1. Actually poking inside a real model. A lot of the mech interp work in early-mid 2022 was focused on getting a deep understanding of toy models trained on algorithmic tasks (at least in this community).[1] There was some effort at Redwood to do neuron-by-neuron replacement, and Nix completed his work on the parentheses balancer concurrently to the IOI results, but insofar as there was mech interp work being done, most of it was on simple models such as the ones featured in Toy Models of Superposition or Modular Arithmetic Grokking (with the primary exception being the Induction Head results from Anthropic, which are substantively weaker outside of very small transformers). 

    This work was one of the first attempts to explain a particular, nontrivial behavior inside of a small but real LM (GPT-2-small). 
  2. Demonstrating the feasibility of patching and circuit-based analysis on language models. I think it's notable that this work doesn't just mechanistically study behavior inside of a language model, it finds a circuit (a small subgraph) implementing the behavior. This is valuable both as a confirmation that patching could be used to find circuits in "real" models, but also as evidence that we can find these circuits at all. In turn, this has led to a veritable explosion of "poking LLMs with various kinds of patching/scrubbing to identify subgraphs of particular behaviors", which I think has been pretty valuable on net. 

    Also, as Neel says below, it's important for pedagogical reasons. 
  3. Field-building via example. As with the Modular Arithmetic work by Neel, this was published in ICLR '23 as the joint first mech interp work to be published in a top conference. This helped build a substantial amount of legitimacy and academic interest for the field of mech interp (and broadly, ai x-risk flavored interp in general). 
  4. Demonstrating failure modes and limitations of mech interp techniques. As stated in this post, an earlier version of this work used mean ablation in a way that preserved "information that helped compute the task", which incorrectly suggested that parts of the circuit were unimportant for performance. It's a concrete example of why important to think about what exactly you're ablating, and how your ablation serves as a valid test of your hypothesis. 

    This work also directly inspired Causal Scrubbing , which was an attempt to more completely remove information that helps complete a task. 
  5. Validating interp via adversarial inputs. I appreciate the use of adversarial example discovery as a downstream use case of the interp. 

But there's also some reservations I have:

  1. Some of the presentation was misleading. Originally, the paper defined the IOI task as something along the lines of:
    '... sentences like “When John and Mary went to the store, John gave a drink to” should be completed with “Mary”.'
    That is, it did not make it clear that IOI was about assigning a higher logit to "Mary" than to "John", and not about assigning an (absolutely) high logit to "Mary". IIRC, this was only clarified near the end thanks to the effort of one of the critical ICLR reviewers.[2] There were also other strong claims that were significantly ameliorated by the ICLR review process.[3]
  2. The circuit is likely overfit to the metric. I think that the mean logit difference is indeed the correct metric to look at, both because of how the task was defined and also for many use cases in general.[4] However, it's worth noting that this circuit does not hold up well if we replace the mean logit difference with other superficially similar metrics. E.g. if you replace the metric with mean absolute logit difference (i.e. E[|logit diff_model - logit diff_subgraph|]). 
  3. The circuit is likely incomplete. Running Causal Scrubbing on the hypothesis suggests that it is importantly incomplete, see for example Alexandre's comment below. The incompleteness of the circuit also suggests some limitations of node-based causal interventions (i.e. activation patching in this case), as previously discussed. That being said, this wasn't something that could've really been known when the experiments were being done for this paper, as Causal Scrubbing was inspired by these results (and thus could not have been used to generate them). 

And there's two big points that I'm very, very torn on (it's less to do with the work itself than general approaches to/issues with mech interp):

  • Using an algorithmic task (IOI). As this post says, it's an example of "streetlight interpretability" -- looking where at cases that are easy, as opposed to where it's useful or realistic. I think it's valuable to do some amount of streetlight interpretability, and it's especially understandable in the case of this work (as one of the earliest mech interp pieces) but I do think that this is a weakness of the work. I also think that fact that seminal works in mech interp used algorithmic tasks may have contributed to a lack of attention paid to soft heuristics/memorization/n-gram statistic-style behavior inside of models, which I think are quite neglected. 
  • Low percent performance recovered. While the headline numbers for completeness/faithfulness are pretty high in terms of percentage, this actually is quite bad in terms of downstream performance.This isn't specific to this work. But, to use causal scrubbing as an example, if random performance on a task is 10 nats of log loss and the model's performance is 2.1 nats, recovering up to 2.6 nats might give the impressive number of 93.7% loss recovered. But in practice, 2.6 nats might be the performance of a model 1/100 or 1/1000 the size of the model we're trying to explain. If the behavior you're trying to explain is present in the most capable models but not in models a generation or two back, this work does not provide significant evidence that it's possible. Again, this isn't specific to this work, but to circuit-style mech interp on real models in general. 

I think the post itself is pretty good though not exceptional -- I appreciate the explanation of how the task and approach were chosen, as well as the key takeaway that causal interventions can be powerful for mech interp, if they are performed appropriately, but doing them appropriately is challenging.

All said, I'm giving this a 4 on the annual review.

  1. ^

    Note that there was plenty of non-mechanistic interp work that looked at real models and tasks; in fact, the majority of interp has always been on non-toy models and tasks. But mech interp was focused on toy tasks.

  2. ^

    I helped out with rebuttals on this paper, and was honestly impressed by the two critical reviews posted by reviewer jy1a (official review, response to author rebuttal), who among (imo correct) issues correctly pointed out that the post was using this incorrect definition of IOI. Notice also how in the rebuttal response, they also point out the issue of using mean logit difference versus mean absolute logit difference. I think that (alongisde the RR AT paper) this was one of the reasons I updated to be more in favor of the existing academic peer review system. 

  3. ^

    See e.g. this comment from the Program Chairs:

    The major concerns from the reviewers are the current limited limitation section and a few not-well supported/overstated claims in the paper. Request to the authors: Please update the paper to have a stronger and more critical limitation discussion, as well as substantially change the writing to justify all claims/assumptions (or not to overstate claims) in order to reflect reviewers’ comments.

  4. ^

    The main reason is that we don't really care about 'noise' when explaining good performance, e.g. from the Causal Scrubbing appendix:

    Suppose that one of the drivers of the model’s behavior is noise: trying to capture the full distribution would require us to explain what causes the noise. For example, you’d have to explain the behavior of a randomly initialized model despite the model doing ‘nothing interesting’.

    That being said, this claim depends greatly on the implied downstream use case of interp. E.g. if the goal is to understand failure modes, then explaining just the success is insufficient. 

Load More