Same person as nostalgebraist2point0, but now I have my account back.
Elsewhere:
This post introduces a model, and shows that it behaves sort of like a noisy version of gradient descent.
However, the term "stochastic gradient descent" does not just mean "gradient descent with noise." It refers more specifically to mini-batch gradient descent. (See e.g. Wikipedia.)
In mini-batch gradient descent, the "true" fitness[1] function is the expectation of some function over a data distribution . But you never have access to this function or its gradient. Instead, you draw a finite sample from , compute the mean of over the sample, and take a step in this direction. The noise comes from the variance of the finite-sample mean as an estimator of the expectation.
The model here is quite different. There is no "data distribution," and the true fitness function is not an expectation value which we could noisily estimate with sampling. The noise here comes not from a noisy estimate of the gradient, but from a prescribed stochastic relationship () between the true gradient and the next step.
I don't think the model in this post behaves like mini-batch gradient descent. Consider a case where we're doing SGD on a vector , and two of its components have the following properties:
If you like, you can think of the per-example gradient as sampling a single number from a distribution with mean 0, and setting the and components to and respectively, for some positive constants .
When we sample a mini-batch and average over it, these components are simply and , where is the average of over the mini-batch. So the perfect correlation carries over to the mini-batch gradient, and thus to the SGD step. If SGD increases , it will always increase alongside it (etc.)
However, applying the model from this post to the same case:
Thus, the the and components of the step will be uncorrelated, rather than perfectly correlated as in SGD.
Some other comments:
I'm using this term here for consistency with the post, though I call it into question later on. "Loss function" or "cost function" would be more standard in SGD.
There is no such thing as a per-example gradient in the model. I'm assuming the "true gradient" from SGD corresponds to in the model, since the intended analogy seems to be "the model steps look like ascending plus noise, just like SGD steps look like descending the true loss function plus noise."
Very interesting! Some thoughts:
Is there a clear motivation for choosing the MLP activations as the autoencoder target? There are other choices of target that seem more intuitive to me (as I'll explain below), namely:
In principle, we could also imagine using the "logit versions" of each of these as the target:
(In practice, the "logit versions" might be prohibitively expensive because the vocab is larger than other dimensions in the problem. But it's worth thinking through what might happen if we did autoencode these quantities.)
At the outset, our goal is something like "understand what the MLP is doing." But that could really mean one of 2 things:
The feature decomposition in the paper provides a potentially satisfying answer for (1). If someone runs the network on a particular input, and asks you to explain what the MLP was doing during the forward pass, you can say something like:
Here is a list of features that were activated by the input. Each of these features is active because of a particular, intuitive/"interpretable" property of the input.
Each of these features has an effect on the logits (its logit weights), which is intuitive/"interpretable" on the basis of the input properties that cause it to be active.
The net effect of the MLP on the network's output (i.e. the logits) is approximately[2] a weighted sum over these effects, weighted by how active the features were. So if you understand the list of features, you understand the effect of the MLP on the output.
However, if this person now asks you to explain what MLP neuron A/neurons/472 was doing during the forward pass, you may not be able to provide a satisfying answer, even with the feature decomposition in hand.
The story above appealed to the interpetability of each feature's logit weights. To explain individual neuron activations in the same manner, we'd need the dictionary weights to be similarly interpretable. The paper doesn't directly address this question (I think?), but I expect that the matrix of dictionary weights is fairly dense[3] and thus difficult to interpret, with each neuron being a long and complicated sum over many apparently unrelated features. So, even if we understand all the features, we still don't understand how they combine to "cause" any particular neuron's activation.
Is this a bad thing? I don't think so!
An MLP sub-block in a transformer only affects the function computed by the transformer through the update it adds to the residual stream. If we understand this update, then we fully understand "what the MLP is doing" as a component of that larger computation. The activations are a sort of "epiphenomenon" or "implementation detail"; any information in the activations that is not in the update is inaccessible the rest of the network, and has no effect on the function it computes[4].
From this perspective, the activations don't seem like the right target for a feature decomposition. The residual stream update seems more appropriate, since it's what the rest of the network can actually see[5].
In the paper, the MLP that is decomposed into features is the last sub-block in the network.
Because this MLP is the last sub-block, the "residual stream update" is really just an update to the logits. There are no indirect paths going through later layers, only the direct path.
Note also that MLP activations are have a much more direct relationship with this logit update than they do with the inputs. If we ignore the nonlinear part of the layernorm, the logit update is just a (low-rank) linear transformation of the activations. The input, on the other hand, is related to the activations in a much more complex and distant manner, involving several nonlinearities and indeed most of the network.
With this in mind, consider a feature like A/1/2357. Is it...
The paper implicitly the former view: the features are fundamentally a sparse and interpretable decomposition of the inputs, which also have interpretable effects on the logits as a derived consequence of the relationship between inputs and correct language-modeling predictions.
(For instance, although the automated interpretability experiments involved both input and logit information[6], the presentation of these results in the paper and the web app (e.g. the "Autointerp" and its score) focuses on the relationship between features and inputs, not features and outputs.)
Yet, the second view -- in which features are fundamentally directions in logit-update space -- seems closer to the way the autoencoder works mechanistically.
The features are a decomposition of activations, and activations in the final MLP are approximately equivalent to logit updates. So, the features found by the autoencoder are
To illustrate the value of this perspective, consider the token-in-context features. When viewed as detectors for specific kinds of inputs, these can seem mysterious or surprising:
But why do we see hundreds of different features for "the" (such as "the" in Physics, as distinct from "the" in mathematics)? We also observe this for other common words (e.g. "a", "of"), and for punctuation like periods. These features are not what we expected to find when we set out to investigate one-layer models!
An example of such a feature is A/1/1078, which Claude glosses as
The [feature] fires on the word "the", especially in materials science writing.
This is, indeed, a weird-sounding category to delineate in the space of inputs.
But now consider this feature as a direction in logit-update space, whose properties as a "detector" in input space derive from its logit weights -- it "detects" exactly those inputs on which the MLP wants to move the logits in this particular, rarely-deployed direction.
The question "when is this feature active?" has a simple, non-mysterious answer in terms of the logit updates it causes: "this feature is active when the MLP wants to increase the logit for the particular tokens ' magnetic', ' coupling', 'electron', ' scattering' (etc.)"
Which inputs correspond to logit updates in this direction? One can imagine multiple scenarios in which this update would be appropriate. But we go looking for inputs on which the update was actually deployed, our search will be weighted by
The Pile contains all of Arxiv, so it contains a lot of materials science papers. And these papers contain a lot of "materials science noun phrases": phrases that start with "the," followed by a word like "magnetic" or "coupling," and possibly more words.
This is not necessarily the only input pattern "detected" by this feature[8] -- because it is not necessarily the only case where this update direction is appropriate -- but it is an especially common one, so it appears at a glance to be "the thing the feature is 'detecting.' " Further inspection of the activation might complicate this story, making the feature seem like a "detector" of an even weirder and more non-obvious category -- and thus even more mysterious from the "detector" perspective. Yet these traits are non-mysterious, and perhaps even predictable in advance, from the "direction in logit-update space" perspective.
That's a lot of words. What does it all imply? Does it matter?
I'm not sure.
The fact that other teams have gotten similar-looking results, while (1) interpreting inner layers from real, deep LMs and (2) interpreting the residual stream rather than the MLP activations, suggests that these results are not a quirk of the experimental setup in the paper.
But in deep networks, eventually the idea that "features are just logit directions" has to break down somewhere, because inner MLPs are not only working through the direct path. Maybe there is some principled way to get the autoencoder to split things up into "direct-path features" (with interpretable logit weights) and "indirect-path features" (with non-interpretable logit weights)? But IDK if that's even desirable.
We could compute this exactly, or we could use a linear approximation that ignores the layer norm rescaling. I'm not sure one choice is better motivated than the other, and the difference is presumably small.
because of the (hopefully small) nonlinear effect of the layer norm
There's a figure in the paper showing dictionary weights from one feature (A/1/3450) to all neurons. It has many large values, both positive and negative. I'm imagining that this case is typical, so that the matrix of dictionary vectors looks like a bunch of these dense vectors stacked together.
It's possible that slicing this matrix along the other axis (i.e. weights from all features to a single neuron) might reveal more readily interpretable structure -- and I'm curious to know whether that's the case! -- but it seems a priori unlikely based on the evidence available in the paper.
However, while the "implementation details" of the MLP don't affect the function computed during inference, they do affect the training dynamics. Cf. the distinctive training dynamics of deep linear networks, even though they are equivalent to single linear layers during inference.
If the MLP is wider than the residual stream, as it is in real transformers, then the MLP output weights have a nontrivial null space, and thus some of the information in the activation vector gets discarded when the update is computed.
A feature decomposition of the activations has to explain this "irrelevant" structure along with the "relevant" stuff that gets handed onwards.
Claude was given logit information when asked to describe inputs on which a feature is active; also, in a separate experiment, it was asked to predict parts of the logit update.
Caveat: L2 reconstruction loss on logits updates != L2 reconstruction loss on activations, and one might not even be a close approximation to the other.
That said, I have a hunch they will give similar results in practice, based on a vague intuition that the training loss will tend encourage the neurons to have approximately equal "importance" in terms of average impacts on the logits.
At a glance, it seems to also activate sometimes on tokens like " each" or " with" in similar contexts.
Nice catch, thank you!
I re-ran some of the models with a prompt ending in I believe the best answer is (
, rather than just (
as before.
Some of the numbers change a little bit. But only a little, and the magnitude and direction of the change is inconsistent across models even at the same size. For instance:
davinci
's rate of agreement w/ the user is now 56.7% (CI 56.0% - 57.5%), up slightly from the original 53.7% (CI 51.2% - 56.4%)davinci-002
's rate of agreement w/ the user is now 52.6% (CI 52.3% - 53.0%), the original 53.5% (CI 51.3% - 55.8%)Oh, interesting! You are right that I measured the average probability -- that seemed closer to "how often will the model exhibit the behavior during sampling," which is what we care about.
I updated the colab with some code to measure
% of cases where the probability on the sycophantic answer exceeds the probability of the non-sycophantic answer
(you can turn this on by passing example_statistic='matching_more_likely'
to various functions).
And I added a new appendix showing results using this statistic instead.
The bottom line: results with this statistic are very similar to those I originally obtained with average probabilities. So, this doesn't explain the difference.
(Edited to remove an image that failed to embed.)
To check this, you'd want to look at a model trained with untied embeddings. Sadly, all the ones I'm aware of (Eleuther's Pythia, and my interpretability friendly models) were trained on the GPT-NeoX tokenizer or variants, whcih doesn't seem to have stupid tokens in the same way.
GPT-J uses the GPT-2 tokenizer and has untied embeddings.
This post provides a valuable reframing of a common question in futurology: "here's an effect I'm interested in -- what sorts of things could cause it?"
That style of reasoning ends by postulating causes. But causes have a life of their own: they don't just cause the one effect you're interested in, through the one causal pathway you were thinking about. They do all kinds of things.
In the case of AI and compute, it's common to ask
But once we have an answer to this question, we can always ask
If you've asked the first question, you ought to ask the second one, too.
The first question includes a hidden assumption: that the imagined technology is a reasonable use of the resources it would take to build. This isn't always true: given those resources, there may be easier ways to accomplish the same thing, or better versions of that thing that are equally feasible. These facts are much easier to see when you fix a given resource level, and ask yourself what kinds of things you could do with it.
This high-level point seems like an important contribution to the AI forecasting conversation. The impetus to ask "what does future compute enable?" rather than "how much compute might TAI require?" influenced my own view of Bio Anchors, an influence that's visible in the contrarian summary at the start of this post.
I find the specific examples much less convincing than the higher-level point.
For the most part, the examples don't demonstrate that you could accomplish any particular outcome applying more compute. Instead, they simply restate the idea that more compute is being used.
They describe inputs, not outcomes. The reader is expected to supply the missing inference: "wow, I guess if we put those big numbers in, we'd probably get magical results out." But this inference is exactly what the examples ought to be illustrating. We already know we're putting in +12 OOMs; the question is what we get out, in return.
This is easiest to see with Skunkworks, which amounts to: "using 12 OOMs more compute in engineering simulations, with 6 OOMs allocated to the simulations themselves, and the other 6 to evolutionary search." Okay -- and then what? What outcomes does this unlock?
We could replace the entire Skunkworks example with the sentence "+12 OOMs would be useful for engineering simulations, presumably?" We don't even need to mention that evolutionary search might be involved, since (as the text notes) evolutionary search is one of the tools subsumed under the category "engineering simulations."
Amp suffers from the same problem. It includes two sequential phases:
Each of the two steps takes about 1e34 FLOP, so we don't get the second step "for free" by spending extra compute that went unused in the first. We're simply training a big model, and then doing a second big project that takes the same amount of compute as training the model.
We could also do the same evolutionary search project in our world, with GPT-3. Why haven't we? It would be smaller-scale, of course, just as GPT-3 is smaller scale than "GPT-7" (but GPT-3 was worth doing!).
With GPT-3's budget of 3.14e23 FLOP, we could to do a GPT-3 variant of AMP with, for example,
100,000,000 evaluations per run (Amp) sure sounds like a lot, but then, so does 10000 (above). Is 1600 steps "not enough"? Not enough for what? (For that matter, is 50000 steps even "enough" for whatever outcome we are interested in?)
The numbers sound intuitively big, but they have no sense of scale, because we don't know how they relate to outcomes. What do we get in return for doing 50000 steps instead of 1600, or 1e8 function evaluations instead of 1e5? What capabilities do we expect out of Amp? How does the compute investment cause those capabilities?
The question "What could you do with +12 OOMs of Compute?" is an important one, and this post deserves credit for raising it.
The concrete examples of "fun" are too fun for their own good. They're focused on sounding cool and big, not on accomplishing anything. Little would be lost if they were replaced with the sentence "we could dramatically scale up LMs, game-playing RL, artificial life, engineering simulations, and brain simulations."
Answering the question in a less "fun," more outcomes-focused manner sounds like a valuable exercise, and I'd love to read a post like that.
uses about six FLOP per parameter per token
Shouldn't this be 2 FLOP per parameter per token, since our evolutionary search is not doing backward passes?
On the other hand, the calculation in the footnote seems to assume that 1 function call = 1 token, which is clearly an unrealistic lower bound.
A "lowest-level" function (one that only uses a single context window) will use somewhere between 1 and tokens. Functions defined by composition over "lowest-level" functions, as described two paragraphs above, will of course require more tokens per call than their constituents.
An operational definition which I find helpful for thinking about memorization is Zhang et al's counterfactual memorization.
The counterfactual memorization of a document is (roughly) the amount that the model's loss on degrades when you remove from its training dataset.
More precisely, it's the difference in expected loss on between models trained on data distribution samples that happen to include , and models trained on data distribution samples that happen not to include .
This will be lower for documents that are easy for the LM to predict using general features learned elsewhere, and higher for documents that the LM can't predict well except by memorizing them. For example (these are intuitive guesses, not experimental results!):
Note that the true likelihood under the data distribution only matters through its effect on the likelihood predicted by the LM. On average, likely texts will be easier than unlikely ones, but when these two things come apart, easy-vs-hard is what matters. is more plausible as natural text than , but it's harder for the LM to predict, so it has higher counterfactual memorization.
On the other hand, if we put many near duplicates of a document in the dataset -- say, many copies with a random edit to a single token -- then every individual near-duplicate will have low counterfactual memorization.
This is not very satisfying, since it feels like something is getting memorized here, even if it's not localized in a single document.
To fix the problem, we might imagine broadening the concept of "whether a document is in the training set." For example, instead of keeping or removing an literal document, we might keep/remove every document that includes a specific substring like a Bible quote.
But if we keep doing this, for increasingly abstract and distant notions of "near duplication" (e.g. "remove all documents that are about frogs, even if they don't contain the word 'frog'") -- then we're eventually just talking about generalization!
Perhaps we could define memorization in a more general way in terms of distances along this spectrum. If we can select examples for removal using a very simple function, and removing the selected examples from the training set destroys the model's performance on them, then it was memorizing them. But if the "document selection function" grows more complex, and starts to do generalization internally, we then say the model is generalizing as opposed to memorizing.
(ETA: though we also need some sort of restriction on the total number of documents removed. "Remove all documents containing some common word" and "remove all but the first document" are simple rules with very damaging effects, but obviously they don't tell us anything about whether those subsets were memorized.)
Hmm, this comment ended up more involved than I originally intended ... mostly I wanted to drop a reference to counterfactual memorization. Hope this was of some interest anyway.
My hunch about the ultra-rare features is that they're trying to become fully dead features, but haven't gotten there yet. Some reasons to believe this:
To test this hypothesis, I guess we could watch how density evolves for rare features over training, up until the point where they are re-initialized? Maybe choose a random subset of them to not re-initialize, and then watch them?
I'd expect these features to get steadily rarer over time, and to never reach some "equilibrium rarity" at which they stop getting rarer. (On this hypothesis, the actual log-density we observe for an ultra-rare feature is an artifact of the training step -- it's not useful for autoencoding that this feature activates on exactly one in 1e-6 tokens or whatever, it's simply that we have not waited long enough for the density to become 1e-7, then 1e-8, etc.)
Intuitively, when such a "useless" feature fires in training, the W_enc gradient is dominated by the L1 term and tries to get the feature to stop firing, while the W_dec gradient is trying to stop the feature from interfering with the useful ones if it does fire. There's no obvious reason these should have similar directions.
Although it's conceivable that the ultra-rare features are "conspiring" to do useful work collectively, in a very different way from how the high-density features do useful work.