Interesting post, thanks for writing it!
I think that the QK section somewhat under-emphasises the importance of the softmax. My intuition is that models rarely care about as precise a task as counting the number of pairs of matching query-key features at each pair of token positions, and that instead softmax is more of an "argmax-like" function that finds a handful of important token positions (though I have not empirically tested this, and would love to be proven wrong!). This enables much cheaper and more efficient solutions, since you just need the correct answer to be the argmax-ish.
For example, ignoring floating point precision, you can implement a duplicate token head with and arbitrarily high . If there are vocab elements, map the th query and key to the point of the way round the unit circle. The dot product is maximised when they are equal.
If you further want the head to look at a resting position unless the duplicate token is there, you can increase , and have a dedicated BOS dimension with a score of , so you only get a higher score for a perfect match. And then make the softmax temperature super low so it's an argmax.
Thanks for the comment!
In more detail:
In our discussion of softmax (buried in part 1 of section 4), we argue that our story makes the most sense precisely when the temperature is very low, in which case we only attend to the key(s) that satisfy the most skip feature-bigrams. Also, when features are very sparse, the number of skip feature bigrams present in one query-key pair is almost always 0 or 1, and we aren't trying to super precisely track whether its, say, 34 or 35.
I agree that if softmax is just being an argmax, then one implication is that we don't need error terms to be , instead, they can just be somewhat less than 1. However, at least in our general framework, this doesn't help us beyond changing the log factor in the tilde inside ). There still will be some log factor because we require the average error to be to prevent the worst-case error being greater than 1. Also, we may want to be able to accept 'ties' in which a small number of token positions are attended to together. To achieve this (assuming that at most one SFB is present for each QK pair for simplicity) we'd want the variation in the values which should be 1 to be much smaller than the gap between the smallest value which should be 1 and the largest value which should be 0.
A few comments about your toy example:
To tell a general story, I'd like to replace the word 'token' with 'feature' in your construction. In particular, I might want to express what the attention head does using the same features as the MLP. The choice of using tokens in your example is special, because the set of features {this is token 1, this is token 2, ...} are mutually exclusive, but once I allow for the possibility that multiple features can be present (for example if I want to talk in terms of features involved in MLP computation), your construction breaks. To avoid this problem, I want the maximum dot product between f-vectors to be at most 1/(the maximum number of features that can be present at once). If I allow several features to be present at once, this starts to look like an -orthogonal basis again. I guess you could imagine a case where the residual stream is divided into subspaces, and inside each subspace is a set of mutually exclusive features (à la tegum products of TMS). In your picture, there would need to be a 2d subspace allocated to the 'which token' features anyway. This tegum geometry would have to be specifically learned — these orthogonal subspaces do not happen generically, and we don't see a good reason to think that they are likely to be learned by default for reasons not to do with the attention head that uses them, even in the case that there are these sets of mutually exclusive features.
It takes us more than 2 dimensions, but in our framework, it is possible to do a similar construction to yours in dimensions assuming random token vectors (ie without the need for any specific learned structure in the embeddings for this task): simply replace the rescaled projection matrix with where is and is a projection matrix to a -dimensional subspace. Now, with high probability, each vector has a larger dot product with its own projection than another vector's projection (we need to be this large to ensure that projected vectors all have a similar length). Then use the same construction as in our post, and turn the softmax temperature down to zero.
Someone suggested this comment was inscrutable so here's a summary:
I don't think that how argmax-y softmax is being is a crux between us - we think our picture makes the most sense when softmax acts like argmax or top-k so we hope you're right that softmax is argmax-ish. Instead, I think the property that enables your efficient solution is that the set of features 'this token is token (i)' is mutually exclusive, ie. only one of these features can activate on an input at once. That means that in your example you don't have to worry about how to recover feature values when multiple features are present at once. For more general tasks implemented by an attention head, we do need to worry about what happens when multiple features are present at the same time, and then we need the f-vectors to form a nearly orthogonal basis and your construction becomes a special case of ours I think.
Haven't read everything yet, but that seems like excellent work. In particular, I think this general research avenue is extremely well-motivated.
Figuring out how to efficiently implement computations on the substrate of NNs had always seemed like a neglected interpretability approach to me. Intuitively, there are likely some methods of encoding programs into matrix multiplication which are strictly ground-truth better than any other encoding methods. Hence, inasmuch as what the SGD is doing is writing efficient programs on the NN substrate, it is likely doing so by making use of those better methods. And so nailing down the "principles of good programming" on the NN substrate should yield major insights regarding how the naturally-grown NN circuits are shaped as well.
This post seems to be a solid step in that direction!
(I haven't had the chance to read part 3 in detail, and I also haven't checked the proofs except insofar as they seem reasonable on first viewing. Will probably have a lot more thoughts after I've had more time to digest.)
This is very cool work! I like the choice of U-AND task, which seems way more amenable to theoretical study (and is also a much more interesting task) than the absolute value task studied in Anthropic's Toy Model of Superposition (hereafter TMS). It's also nice to study this toy task with asymptotic theoretical analysis as opposed to the standard empirical analysis, thereby allowing you to use a different set of tools than usual.
The most interesting part of the results was the discussion on the universality of universal calculation -- it reminds me of the interpretations of the lottery ticket hypothesis that claim some parts of the network happen to be randomly initialized to have useful features at the start.
Some examples that are likely to be boolean-interpretable are bigram-finding circuits and induction heads. However, it's possible that most computations are continuous rather than boolean[31].
My guess is that most computations are indeed closer to continuous than to boolean. While it's possible to construct boolean interpretations of bigram circuits or induction heads, my impression (having not looked at either in detail on real models) is that neither of these cleanly occur inside LMs. For example, induction heads demonstrate a wide variety of other behavior, and even on induction-like tasks, often seem to be implementing induction heuristics that involve some degree of semantic content.
Consequently, I'd be especially interested in exploring either the universality of universal calculation, or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition).
Some nitpicks:
The post would probably be a lot more readable if it were chunked into 4. The 88 minute read time is pretty scary, and I'd like to comment only on the parts I've read.
Section 2:
Two reasons why this loss function might be principled are
- If there is reason to think of the model as a Gaussian probability model
- If we would like our loss function to be basis independent
A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can't gradient descent cleanly through eps-accuracy).
Section 5:
4 How relevant are our results to real models?
This should be labeled as section 5.
Appendix to the Appendix:
Here, $f_i$ always denotes the vector.
[..]
with \[\sigma_1\leq n\) with
(TeX compilation failure)
Thanks for the kind feedback!
I'd be especially interested in exploring either the universality of universal calculation
Do you mean the thing we call genericity in the further work section? If so, we have some preliminary theoretical and experimental evidence that genericity of U-AND is true. We trained networks on the U-AND task and the analogous U-XOR task, with a narrow 1-layer MLP and looked at the size of the interference terms after training with a suitable loss function. Then, we reinitialised and froze the first layer of weights and biases, allowing the network only to learn the linear readoff directions, and found that the error terms were comparably small in both cases.
This figure is the size of the errors for (which is pretty small) for readoffs which should be zero in blue and one in yellow (we want all these errors to be close to zero).
This suggests that the AND/XOR directions were -linearly readoffable at initialisation, but the evidence at this stage is weak because we don't have a good sense yet of what a reasonable value of is for considering the task to have been learned correctly: to answer this we want to fiddle around with loss functions and training for longer. For context, an affine readoff (linear + bias) directly on the inputs can read off with , which has an error of . This is larger than all but the largest errors here, and you can’t do anything like this for XOR with affine readoff.
After we did this, Kaarel came up with an argument that networks randomly initialised with weights from a standard Gaussian and zero bias solve U-AND with inputs not in superposition (although it probably can be generalised to the superposition case) for suitable readoffs. To sketch the idea:
Let be the vector of weights from the th input to the neurons. Then consider the linear readoff vector with th component given by:
where is the indicator function. There are 4 free parameters here, which are set by 4 constraints given by requiring that the expectation of this vector dotted with the activation vector has the correct value in the 4 cases . In the limit of large the value of the dot product will be very close to its expectation and we are done. There are a bunch of details to work out here and, as with the experiments, we aren't 100% sure the details all work out, but we wanted to share these new results since you asked.
A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can't gradient descent cleanly through eps-accuracy).
We've suggested that perhaps it would be more principled to use something like loss for larger than 2, as this is closer to -accuracy. It's worth mentioning that we are currently finding that the best loss function for the task seems to be something like with extra weighting on the target values that should be . We do this to avoid the problem that if the inputs are sparse, then the ANDs are sparse too, and the model can get good loss on (for low ) by sending all inputs to the zero vector. Once we weight the ones appropriately, we find that lower values of may be better for training dynamics.
or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition)
We agree and are keen to look into that!
(TeX compilation failure)
Thanks - fixed.
How are you setting when ? I might be totally misunderstanding something but at - feels like you need to push up towards like 2k to get something reasonable? (and the argument in 1.4 for using clearly doesn't hold here because it's not greater than for this range of values).