Hey Joseph (and coauthors),
Your directions are really fantastic. I hope you don't mind, but I generated the activation data for the first 3000+ directions for each of the 12 layers and uploaded your directions to Neuronpedia:
https://www.neuronpedia.org/gpt2-small/res-jb
Your directions are also linked on the home page and the model page.
They're also accessible by layer (sorted by top activation), eg layer 6: https://neuronpedia.org/gpt2-small/6-res-jb
I added the "Anthropic dashboard" to Neuronpedia for your dataset.
Explanations, comments, and autointerp scoring are also working - anyone can do this:
I plan to do some autointerp explaining on a batch of these directions too.
Btw - your directions are so good that it's easy to find super interesting stuff. 5-RES-JB:5 is about astronomy:
I'm aware that you're going to do some library updates to get even better directions, and I'm excited for that - will re-generate/upload all layers after the new changes come in.
Things that I'm still working on and hope to get working in the next few days:
Again, your directions look fantastic - congrats. I hope this is useful/interesting for you and anyone trying to browse/explain them. Also, I didn't know how to provide a citation/reference to you (and your team?) so I just used RES-JB = Residuals by Joseph Bloom and included links to all relevant sources on your directions page.
If there's anything you'd like me to modify about this, or any feature you'd like me to add to make it better, please do not hesitate to let me know.
I tried replicating your statistics using my own evaluation code (in evaluation.py here). I pseudo-randomly chose layer 1 and layer 7. Sadly, my results look rather different from yours:
Layer | MSE Loss | % Variance Explained | L1 | L0 | % Alive | CE Reconstructed |
---|---|---|---|---|---|---|
1 | 0.11 | 92 | 44 | 17.5 | 54 | 5.95 |
7 | 1.1 | 82 | 137 | 65.4 | 95 | 4.29 |
Places where our metrics agree: L1 and L0.
Places where our metrics disagree, but probably for a relatively benign reason:
Our metrics disagree strongly on CE reconstructed, and this is a bit alarming. It means that either you have a bug which significantly underestimates reconstructed CE loss, or I have a bug which significantly overestimates it. I think I'm 50/50 on which it is. Note that according to my stats, your MSE loss is kinda bad, which would suggest that you should also have high CE reconstructed (especially when working with residual stream dictionaries! (in contrast to e.g. MLP dictionaries which are much more forgiving)).
Spitballing a possible cause: when computing CE loss, did you exclude padding tokens? If not, then it's possible that many of the tokens on which you're computing CE are padding tokens, which is artificially making your CE look extremely good.
Here is my code. You'll need to pip install nnsight
before running it. Many thanks to Caden Juang for implementing the UnifiedTransformer functionality in nnsight, which is a crazy Frankenstein marriage of nnsight and transformer_lens; it would have been very hard for me to attempt this replication without this feature.
Why do you scale your MSE by 1/(x_centred**2).sum(dim=-1, keepdim=True).sqrt()
? In particular, I'm confused about why you have the square root. Shouldn't it just be 1/(x_centred**2).sum(dim=-1, keepdim=True)
?
Browse these SAE Features on Neuronpedia!
Update 1: Since we posted this last night, someone pointed out that our implementation of ghost grads has a non-trivial error (which makes the results a-priori quite surprising). We computed the ghost grad forward pass using Exp(Relu(W_enc(x)[dead_neuron_mask])) rather than Exp((W_enc(x)[dead_neuron_mask])). I'm running some ablation experiments now to get to the bottom of this.
Update 2: I've since investigated this further and run some ablation studies with the following results. Ghost grads weren't working as intended due to the Exp(Relu(x)) bug but the resulting SAE's were still quite good (later layers had few dead neurons simply because when we dropped the number of features, you get less dead neurons. I've found that with a correct ghost grads implementation, you can get less dead neurons and will update the library shortly. (I will make edits to the rest of this post to reflect my current views). Sorry for the confusion.
This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, under mentorship from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants.
This is intended to be a fairly informal post sharing a set of Sparse Autoencoders trained on the residual stream of GPT2-small which achieve fairly good reconstruction performance and contain fairly sparse / interpretable features. More importantly, advice from Anthropic and community members has enabled us to train these fairly more efficiently / faster than before. The specific methods that were most useful were: ghost gradients, learning rate warmup, and initializing the decoder bias with the geometric median. We discuss each of these in more detail below.
5 Minute Summary
We’re publishing a set of 12 Sparse AutoEncoders for the GPT2 Small residual stream.
Readers can access the Sparse Autoencoder weights in this HuggingFace Repo. Training code and code for loading the weights / model and data loaders can be found in this Github Repository. Training curves and feature dashboards can also be found in this wandb report. Users can download all 25k feature dashboards generated for layer 2 and 10 SAEs and the first 5000 of the layer 5 SAE features here (note the left hand of column of the dashboards should currently be ignored).
Layer
Variance Explained
L1 loss
L0*
% Alive Features
Reconstruction
CE Log Loss
0
99.15%
4.58
12.24
80.0%
3.32
1
98.37%
41.04
14.68
83.4%
3.33
2
98.07%
51.88
18.80
80.0%
3.37
3
96.97%
74.96
25.75
86.3%
3.48
4
95.77%
90.23
33.14
97.7%
3.44
5
94.90%
108.59
43.61
99.7%
3.45
6
93.90%
136.07
49.68
100%
3.44
7
93.08%
138.05
57.29
100%
3.45
8
92.57%
167.35
65.47
100%
3.45
9
92.05%
198.42
71.10
100%
3.45
10
91.12%
215.11
53.79
100%
3.52
11
93.30%
270.13
59.16
100%
3.57
Original Model
3.3
Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.
Training SAEs that we were happy with used to take much longer than it is taking us now. Last week, it took me 20 hours to train a 50k feature SAE on 1 billion tokens and over the weekend it took 3 hours for us to train 25k SAE on 300M tokens with similar variance explained, L0 and CE loss recovered.
We attribute the improvement to having implemented various pieces of advice that have made our lives a lot easier:
While we haven’t tested our code extensively since implementing these improvements, we suspect that hyperparameter tuning may be easier in the future since these method improvements make the process generally less sensitive.
To demonstrate the interpretability of these SAEs, we share screenshots of feature dashboards we produced using a reproduction of Anthropic’s dashboard developed by Callum McDougall.
Finally, we end by discussing how readers can access these SAE’s, some experiments which they could perform to upskill with SAEs and possible research directions to pursue.
Introduction
What are Sparse AutoEncoders and why should we care about them?
Sparse autoencoders (SAEs) are an unsupervised technique to take a model's activations and decompose it into interpretable feature vectors. We highly recommend this tutorial on SAEs for those interested. Recent papers on the topic can be found here and here.
I’m particularly excited about Sparse Autoencoders for two reasons:
General Advice for Training SAEs
Why can training Sparse AutoEncoders be difficult?
Sparse autoencoders are an unsupervised method which attempts to trade off reconstruction accuracy against interpretability, which we achieve by inducing activation sparsity. Since we don’t have good metrics for interpretability / reconstruction quality, it’s hard to know when we are actually optimizing what we care about. On top of this, we’re trying to pick a good point on the pareto frontier between interpretability and reconstruction quality which is a hard thing to assess well.
The main objective is to have your Sparse Autoencoder learn a population of sparse features (which are likely to be interpretable) without having some dense features (features which activate all the time and are likely uninterpretable) or too many dead features (features which never fire). As discussed in the 5 minute summary, we went from training GPT2 small residual streams in 12+ hours to ~ 3 hours (so 4x faster / cheaper). Though the L0 and CE loss were somewhat similar, the feature density histograms also suggested we avoided dead / dense features way more effectively after using ghost gradients.
Let’s dig into the challenges associated with dead / dense features a bit more a bit more:
Which tricks help the most?
In terms of solutions, Anthropic published useful advice (especially ghost gradients) and the research community is building consensus on how to train SAEs well (what your loss curves and final statistics should look like for example). I found Arthur Conmy’s post very useful.
The top 3 changes that I made which led to my ability to train these SAEs cheaply/quickly were:
Sparse AutoEncoders for the GPT2 Residual Stream
Why GPT2 small? Why the residual stream?
GPT2 small has been extensively studied by the mechanistic interpretability community and whilst not an incredibly performant model, it certainly has some kind of “prototypical object of study” property. We chose the residual stream because this enables us to analyze “the total sum of previous output” in a manner not dissimilar to the logit lens approach. This may be useful for understanding how features are constructed from earlier features as well as studying how the distribution of features over time changes in a model.
Architecture and Hyperparameters
We trained 12 Sparse Autoencoders on the Residual Stream of GPT2-small.
Were I to be training SAEs on another model or part of the same model, I wouldn’t change any of the architectural choices (except maybe expansion factor). Other parameters like learning rate, l1 coefficient, number of tokens to train on all likely need to be tuned in practice. It also seems plausible we’ll continue to see methodological advances in the future which I’m excited about!
What do we think about when choosing hyper-parameters / evaluating SAEs?
We train against:
However, what we actually care about is whether we reconstructed information required for model performance and how useful the features are for interpretability. Better proxies for these desiderata are:
Though L0 is a pretty good proxy for interpretability, in practice the feature density histogram (the distribution of how frequent features are) turns out to be one of the most important things we need to get right when tuning hyperparameters.
How good are these Sparse AutoEncoders?
At a glance, the summary metrics seem fairly good. I’ll make a number of comments:
Georg Lang pointed out to me that the L2 loss grows quadratically with the norm which increases with layer whilst the L1 coefficient grows linearly. This means that since I didn’t vary the L1 coefficient when training these, we’re effectively pushing less hard for sparsity in later layers (which would explain the trend in L0 / L1 and the feature density histograms). Interestingly, the variance explained still gets worse with layers.
Layer
Variance Explained
L1 loss
L0*
% Alive Features
Reconstruction
CE Log Loss
0
99.15%
4.58
12.24
80.0%
3.32
1
98.37%
41.04
14.68
83.4%
3.33
2
98.07%
51.88
18.80
80.0%
3.37
3
96.97%
74.96
25.75
86.3%
3.48
4
95.77%
90.23
33.14
97.7%
3.44
5
94.90%
108.59
43.61
99.7%
3.45
6
93.90%
136.07
49.68
100%
3.44
7
93.08%
138.05
57.29
100%
3.45
8
92.57%
167.35
65.47
100%
3.45
9
92.05%
198.42
71.10
100%
3.45
10
91.12%
215.11
53.79
100%
3.52
11
93.30%
270.13
59.16
100%
3.57
Original Model
3.3
Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.
Log Feature Sparsity Histograms for each the residual stream SAEs of GPT2-small
How interpretable are the features in each layer?
Feature interpretability is far from a settled science, but feature dashboards sure do automate a huge chunk of the work. We use a reproduction of Anthropic’s dashboard developed by Callum McDougall.
We’re still working on making our dashboard generating code more efficient (to keep up with the improvements in our ability to train sparse autoencoders!). In the meantime, we’ve collected some anecdotal examples of features at layers 2, 5 and 10 to give examples of the kinds of features we can detect in GPT2 small.
Though we share some features below, you can look through more features at the bottom of the dashboard here.
Layer 2: The President Feature
For example, below we show a “President” feature which promotes the first names of presidents.
Layer 2: The “c” subword token feature
Another example of a fairly typical layer 2 feature is this feature which fires on “c” due to tokenization which splits a word that starts with c.
Layer 5: The what you are saying thanks OR sorry for feature
For example, this feature appears to fire for short stretches of text involving thanks or apologies.
Layer 5: The Force is strong with this Feature.
Though there are plenty of features that seem interesting about layer 5 SAE, some are just way stronger in the force than others.
Layer 10: Violence / Conflict Feature.
How to get involved
I want to look at more dashboards!
You can download all 25k for layer 10 and 2 and the first 5k for layer 5 here.
How can I download and analyze these SAEs?
For those who would like to play around with these sparse autoencoders, my codebase is pretty crazy right now but you can mostly ignore it once you have the SAE. The codebase has:
In order to speed up my own analysis, I creates a “SessionLoader” class which takes a path to the saved SAE and then instantiates the model it was trained on, the sparse autoencoder and the activations_loader (which gets your tokens/activations). Between these three artifacts, you start analyzing an SAE very quickly post training.
After cloning the repo and installing the requirements.txt, users can simply run the following commands:
What kinds of analysis can we do with Residual Stream Sparse Autoencoders?
Without getting into entire research directions, it’s worth discussing briefly the kinds of experiments that can be run with Sparse Autoencoders. These are projects that might enable you to get a taste of working with SAEs and decide if you’re excited to work with them more seriously.
Some example upskilling projects could be:
For more ideas, posts published by researchers currently working on SAEs can be a great source of inspiration:
What research directions could you pursue with SAEs?
It’s likely that there are a bunch of low hanging fruit with SAEs right now. Logan Riggs posted a bunch of ideas here and Anthropic list some directions for future work which are worth reading here.
Two direction, I’m excited by are:
Appendix
Thanks
I’d like to thank Neel Nanda and Arthur Conmy for their support and feedback while I’ve been working on this and other SAE related work. I also appreciate feedback and support from members of the Mechanistic Interpretability Stream in the MATS 5.0 cohort, especially Ben Wu and Andy Arditi.
I’d also like to thank the interpretability team at Anthropic for continually sharing their advice on how to train sparse autoencoders, and to Callum McDougall for his awesome SAE visualizer (replication of Anthropic’s dashboard).
Funding Note
This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, with support from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants.
Related Work
How to Cite