This is a linkpost for https://arxiv.org/abs/2404.16014

Improving Dictionary Learning with Gated Sparse Autoencoders

12Sam Marks

6Sam Marks

2Arthur Conmy

1Sam Marks

4Rohin Shah

1Senthooran Rajamanoharan

8Senthooran Rajamanoharan

3Sam Marks

3Rohin Shah

1Sam Marks

3Senthooran Rajamanoharan

2Neel Nanda

1Sam Marks

5Neel Nanda

1Sam Marks

3leogao

4Arthur Conmy

1leogao

3Neel Nanda

1leogao

1Arthur Conmy

1leogao

1Arthur Conmy

2Senthooran Rajamanoharan

1leogao

3Charlie Steiner

1leogao

4Neel Nanda

1Arthur Conmy

2leogao

-1jacob_drori

2Rohin Shah

New Comment

Great work! Obviously the results here speak for themselves, but I especially wanted to complement the authors on the writing. I thought this paper was a pleasure to read, and easily a top 5% exemplar of clear technical writing. Thanks for putting in the effort on that.

I'll post a few questions as children to this comment.

I believe that equation (10) giving the analytical solution to the optimization problem defining the relative reconstruction bias is incorrect. I believe the correct expression should be .

You could compute this by differentiating equation (9), setting it equal to 0 and solving for . But here's a more geometrical argument.

By definition, is the multiple of closest to . Equivalently, this closest such vector can be described as the projection . Setting these equal, we get the claimed expression for .

As a sanity check, when our vectors are 1-dimensional, , and , we my expression gives (which is correct), but equation (10) in the paper gives .

Oh oops, thanks so much. We'll update the paper accordingly. Nit: it's actually

(it's just minimizing a quadratic)

ETA: the reason we have complicated equations is that we didn't compute during training (this quantity is kinda weird). However, you can compute from quantities that are usually tracked in SAE training. Specifically, and all terms here are clearly helpful to track in SAE training.

Oh, one other issue relating to this: in the paper it's claimed that if is the argmin of then is the argmin of . However, this is not actually true: the argmin of the latter expression is . To get an intuition here, consider the case where and are very nearly perpendicular, with the angle between them just slightly less than . Then you should be able to convince yourself that the best factor to scale either or by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.

ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I'll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.

ETA2: In fact, now I'm a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn't perfect. I guess the answer must somehow be "equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the 'corrected' version of equation (10)." That's pretty cool and surprising, because I don't really have a great intuition for what equation (10) is actually capturing.

Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it.

This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio . Looking at the that would optimize the reconstruction loss ensures that we're capturing only bias from the L1 regularization, and *not* capturing the "inherent" need to shrink the vector given these nonzero angles. (In particular, if we computed for Gated SAEs, I expect that would be below 1.)

I think the main thing we got wrong is that we accidentally treated as though it were . To the extent that was the main mistake, I think it explains why our results still look how we expected them to -- usually is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.

We're going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it's probably worth waiting for that -- I expect we'll provide much more detailed derivations that make everything a lot clearer.

Hey Sam, thanks - you're right. The definition of reconstruction bias is actually the argmin of

which I'd (incorrectly) rearranged as the expression in the paper. As a result, the optimum is

That being said, the derivation we gave was not quite right, as I'd incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn't pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.

UPDATE: we've corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I've also attached a revised figure 6, showing that this doesn't change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs' widths), likely some point next week. Thanks again Sam for pointing this out!

Updated equations (draft):

Updated figure 6 (shrinkage comparison for GELU-1L):

I'm a bit perplexed by the choice of loss function for training GSAEs (given by equation (8) in the paper). The intuitive (to me) thing to do here would be would be to have the and terms, but not the term, since the point of is to tell you which features should be active, not to itself provide good feature coefficients for reconstructing . I can sort of see how not including this term might result in the coordinates of all being extremely small (but barely positive when it's appropriate to use a feature), such that the sparsity term doesn't contribute much to the loss. Is that what goes wrong? Are there ablation experiments you can report for this? If so, including this term still currently seems to me like a pretty unprincipled way to deal with this -- can the authors provide any flavor here?

Here are two ways that I've come up with for thinking about this loss function -- let me know if either of these are on the right track. Let denote the gated encoder, but with a ReLU activation instead of Heaviside. Note then that is just the standard SAE encoder from *Towards Monosemanticity.*

Perspective 1: The usual loss from *Towards Monosemanticity* for training SAEs is (this is the same as your and up to the detaching thing). But now you have this magnitude network which needs to get a gradient signal. Let's do that by adding an additional term -- your . So under this perspective, it's the reconstruction term which is new, with the sparsity and auxiliary terms being carried over from the usual way of doing things.

Perspective 2 (h/t Jannik Brinkmann): let's just add together the usual *Towards Monosemanticity* loss function for both the usual architecture and the new modified archiecture: .

However, the gradients with respect to the second term in this sum vanish because of the use of the Heaviside, so the gradient with respect to this loss is the same as the gradient with respect to the loss you actually used.

Possibly I'm missing something, but if you don't have , then the only gradients to and come from (the binarizing Heaviside activation function kills gradients from ), and so would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is "none of the features are active".)

(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with from the beginning of Section 3.2.2.)

Ah thanks, you're totally right -- that mostly resolves my confusion. I'm still a little bit dissatisfied, though, because the term is optimizing for something that we don't especially want (i.e. for to do a good job of reconstructing ). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.

Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE's encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It's totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we've not explored.

Great work! Obviously the results here speak for themselves, but I especially wanted to complement the authors on the writing. I thought this paper was a pleasure to read, and easily a top 5% exemplar of clear technical writing. Thanks for putting in the effort on that.

<3 Thanks so much, that's extremely kind. Credit entirely goes to Sen and Arthur, which is even more impressive given that they somehow took this from a blog post to a paper in a two week sprint! (including re-running all the experiments!!)

(The question in this comment is more narrow and probably not interesting to most people.)

The limitations section includes this paragraph:

One worry about increasing the expressivity of sparse autoencoders is that they will overfit when

reconstructing activations (Olah et al., 2023, Dictionary Learning Worries), since the underlying

model only uses simple MLPs and attention heads, and in particular lacks discontinuities such as step

functions. Overall we do not see evidence for this. Our evaluations use held-out test data and we

check for interpretability manually. But these evaluations are not totally comprehensive: for example,

they do not test that the dictionaries learned contain causally meaningful intermediate variables in the

model’s computation. The discontinuity in particular introduces issues with methods like integrated

gradients (Sundararajan et al., 2017) that discretely approximate a path integral, as applied to SAEs

by Marks et al. (2024).

I'm not sure I understand the point about integrated gradients here. I understand this sentence as meaning: since model outputs are a discontinuous function of feature activations, integrated gradients will do a bad job of estimating the effect of patching feature activations to counterfactual values.

If that interpretation is correct, then I guess I'm confused because I think IG actually handles this sort of thing pretty gracefully. As long as the number of intermediate points you're using is large enough that you're sampling points pretty close to the discontinuity on both sides, then your error won't be too large. This is in contrast to attribution patching which will have a pretty rough time here (but not really that much worse than with the normal ReLU encoders, I guess). (And maybe you also meant for this point to apply to attribution patching?)

I haven't fully worked through the maths, but I think both IG and attribution patching break down here? The fundamental problem is that the discontinuity is invisible to IG because it only takes derivatives. Eg the ReLU and Jump ReLU below look identical from the perspective of IG, but not from the perspective of activation patching, I think.

Great paper! The gating approach is an interesting way to learn the JumpReLU threshold and it's exciting that it works well. We've been working on some related directions at OpenAI based on similar intuitions about feature shrinking.

Some questions:

- Is b_mag still necessary in the gated autoencoder?
- Did you sweep learning rates for the baseline and your approach?
- How large is the dictionary of the autoencoder?

We use learning rate 0.0003 for all Gated SAE experiments, and also the GELU-1L baseline experiment. We swept for optimal baseline learning rates on GELU-1L for the *baseline* SAE to generate this value.

For the Pythia-2.8B and Gemma-7B *baseline* SAE experiments, we divided the L2 loss by , motivated by wanting better hyperparameter transfer, and so changed learning rate to 0.001 or 0.00075 for all the runs (currently in Figure 1, only attention output pre-linear uses 0.00075. In the rerelease we'll state all the values used). We didn't see noticable difference in the Pareto frontier changing between 0.001 and 0.00075 so did not sweep the baseline hyperparameter further than this.

Re dictionary width, 2**17 (~131K) for most Gated SAEs, 3*(2**16) for baseline SAEs, except for the (Pythia-2.8B, Residual Stream) sites we used 2**15 for Gated and 3*(2**14) for baseline since early runs of these had lots of feature death. (This'll be added to the paper soon, sorry!). I'll leave the other Qs for my co-authors

Got it - do you think with a bit more tuning the feature death at larger scale could be eliminated, or would it be tough to manage with the reinitialization approach?

I'm not sure what you mean by "the reinitialization approach" but feature death doesn't seem to be a major issue at the moment. At all sites besides L27, our Gemma-7B SAEs didn't have much feature death at all (stats at https://arxiv.org/pdf/2404.16014v2 up in a few hours), and also the Anthropic update suggests even in small models the problem can be addressed.

Sorry I meant the Anthropiclike neuron resampling procedure.

I think I misread Neel's comment, I thought he was saying that 131k was chosen because larger autoencoders would have too many dead latents (as opposed to this only being for Pythia residual).

Ah yeah, Neel's comment makes no claims about feature death beyond Pythia 2.8B residual streams. I trained 524K width Pythia-2.8B MLP SAEs with <5% feature death (not in paper), and Anthropic's work gets to >1M live features (with no claims about interpretability) which together would make me surprised if 131K was near the max of possible numbers of live features even in small models.

Authors: Senthooran Rajamanoharan*, Arthur Conmy*, Lewis Smith, Tom Lieberum, Vikrant Varma, János Kramár, Rohin Shah, Neel NandaA new paper from the Google DeepMind mech interp team: Improving Dictionary Learning with Gated Sparse Autoencoders!

Gated SAEs are a new Sparse Autoencoder architecture that seems to be a significant Pareto-improvement over normal SAEs, verified on models up to Gemma 7B. They are now our team's preferred way to train sparse autoencoders, and we'd love to see them adopted by the community! (Or to be convinced that it would be a bad idea for them to be adopted by the community!)

They achieve similar reconstruction with about half as many firing features, and while being either comparably or more interpretable (confidence interval for the increase is 0%-13%).

See Sen's Twitter summary, my Twitter summary, and the paper!