This is a linkpost for Sparse Autoencoders Find Highly Interpretable Directions in Language Models

We use a scalable and unsupervised method called Sparse Autoencoders to find interpretable, monosemantic features in real LLMs (Pythia-70M/410M) for both residual stream and MLPs. We showcase monosemantic features, feature replacement for Indirect Object Identification (IOI), and use OpenAI's automatic interpretation protocol to demonstrate a significant improvement in interpretability.

Paper Overview

Sparse Autoencoders & Superposition

To reverse engineer a neural network, we'd like to first break it down into smaller units (features) that can be analysed in isolation. Using individual neurons as these units can be useful but neurons are often polysemantic, activating for several unrelated types of feature so just looking at neurons is insufficient. Also, for some types of network activations, like the residual stream of a transformer, there is little reason to expect features to align with the neuron basis so we don't even have a good place to start.

Overview of the methodology.

Toy Models of Superposition investigates why polysemanticity might arise and hypothesise that it may result from models learning more distinct features than there are dimensions in the layer, taking advantage of the fact that features are sparse, each one only being active a small proportion of the time. This suggests that we may be able to recover the network's features by finding a set of directions in activation space such that each activation vector can be reconstructed from a sparse linear combinations of these directions.

We attempt to reconstruct these hypothesised network features by training linear autoencoders on model activation vectors. We use a sparsity penalty on the embedding, and tied weights between the encoder and decoder, training the models on 10M to 50M activation vectors each. For more detail on the methods used, see the paper.

Automatic Interpretation

We use the same automatic interpretation technique that OpenAI used to interpret the neurons in GPT2 to analyse our features, as well as alternative methods of decomposition. This was demonstrated in a previous post but we now extend these results across the all 6 layers in Pythia-70M, showing a clear improvement over all baselines in all but the final layers. Case studies later in the paper suggest that the features are still meaningful in these later layers but that automatic interpretation struggles to perform well.

Automatic interpretability score of learned features and baselines in the residual stream for different layers of Pythia70M

IOI Feature Identification

We are able to use less-than-rank one ablations to precisely edit activations to restore uncorrupted behaviour on the IOI task. With normal activation patching, patches occur at a module-wide level, while here we perform interventions of the form

where  is the embedding of the corrupted datapoint,  is the set of patched features, and  and  are the activations of feature  on the clean and corrupted datapoint respectively.

We show that our features are able to better able to precisely reconstruct the data than other activation decomposition methods (like PCA), and moreover that the finegrainedness of our edits increases with dictionary sparsity. Unfortunately, as our autoencoders are not able to perfectly reconstruct the data, they have a positive minumum KL-divergence from the base model, while PCA does not.

The tradeoff curve between number of interventions and amount of reconstruction of clean performance.

Dictionary Features are Highly Monosemantic & Causal

(Left) Histogram of activations for a specific dictionary feature. The majority of activations are for apostrophe (in blue), where the y-axis the is number of datapoints that activate in that bin. (Right) Histogram of the drop in logits (ie how much the LLM predicts a specific token) when ablating this dictionary feature direction.

This is in contrast to the residual stream basis:
 

which appears highly polysemantic (ie many semantic meanings). More examples can be found in Appendix E.  We've found many context-neurons (e.g. [medical/Biology/Stack Exchange/German]-context), with some shown in a previous post, so this is an existence proof against concerns that this method only finds token-level features.

Automatic Circuit Discovery

The previous section was on a dictionary's feature relationship to the input tokens and it's effect on the logits. We can also see the relationship between features themselves.

Each feature is labeled as "layer_feature-index". Darker arrow means that previous feature is more important. Importance is measured by finding activating examples from e.g. 5_2079, ablating each feature from layer 4 while rerunning those activating examples, and showing the top-3 features that cause the biggest difference. This is then recursively applied to found features.

Layer 5 is the last layer in Pythia-70M, and this feature directly unembeds into various forms of the closing parenthesis. We can view the previous layers as calculating "What are all the reasons one might predict a closing parenthesis?".

Conclusion

Sparse autoencoders are a scalable, unsupervised approach to disentangling language model network features from superposition. We have demonstrated that the dictionary features they learn are more interpretable by autointerpretability, are better for performing precise model steering, and are more monosemantic than comparable methods. 

The ability to find these dictionary features gives us a new, fully unsupervised tool to investigate model behaviour, allows us to make targeted edits, and can be trained using a manageable amount of computing power. 

An ambitious dream in the field of interpretability is enumerative safety: the ability to understand the full set of computations that a model applies. If this were achieved, it could allow us to create models for which we have strong guarantees that the model is not able to perform certain dangerous actions, such as deception or advanced bioengineering. While this is still remote, dictionary learning hopefully marks a small step towards making it possible. 

In summary, sparse autoencoders bring a new tool to the interpretability and editing of language models, which we hope others can build upon. The potential for innovations and applications is vast, and we’re excited to see what happens next.

Bonus Section: Did We Find All the Features?

No.

In general, we get a reconstruction loss, and if that's 0, than we've perfectly reconstructed e.g. Layer 4 with our sparse autoencoder. But what does a reconstruction loss of 0.01 mean compared to 0.0001? 

We can ground this out to the difference in perplexity (a measure of prediction loss) on some dataset. This will better measure the functional equivalence (ie they have the same loss on the same data). As non-released, preliminary results, with GPT2 (small) on layer 4 on a subset of OpenWebText:

Each dictionary has a sparsity of ~60 features/datapoint, which is ~10% of the residual stream dimension of 768 for GPT2 small.

A difference in perplexity of 2.6 for training directly on KL-divergence[1] is quite small, especially for 4 months of effort between 3 main researchers. The two possibilities are 

  1. People better at maths/ML/sparse dictionary learning than us can get it to ~0-perplexity difference
  2. A subset of features aren't linearly-represented.

If (2) is the case, then we'll now have a dataset of datapoints that aren't linearly represented which we can study![2] This would show that superposition only explains a subset of features, and provide concrete counterexamples to the linear-part of the hypothesis.

We would like to give two big caveats though:

  1. We don't have a perfect monosemanticity metric, so even if we have 0-reconstruction loss, we can't claim each feature is monosemantic, although a lower sparsity is partial evidence for that.
  2. What if every 1000 features decreases the remaining reconstruction loss by half, so we're really infinity features away from perfect reconstruction?

Come Work With Us

We are currently discussing research in the #unsupervised-interp channel (under Interpretabilty) in the EleutherAI Discord server. If you're a researcher and have directions you'd like to apply sparse auteoncoders to, feel free to message Logan on Discord (loganriggs) or LW & we can chat!

For specific questions on sections (we're all on discord as well):
1. Hoagy- autoninterp & MLP results

2. Aidan - IOI Feature Identification

3. Logan - Monosemantic features & Auto-circuits

  1. ^

    KL-divergence is calculated by getting the original LLM's output, then reconstructing e.g. layer 4 w/ the autoencoder to get a different output, then finding the KL-div between these two outputs. In practice, we found training on KL-div & reconstruction (and sparsity) to converge to lower perplexity.

  2. ^

    These datapoints can be found by finding datapoints with the highest perplexity-difference.

New Comment
7 comments, sorted by Click to highlight new comments since: Today at 11:04 PM

Did you try searching for similar ideas to your work in the broader academic literature? There seems to be lots of closely related work that you'd find interesting. For example:

Elite BackProp: Training Sparse Interpretable Neurons. They train CNNs to have "class-wise activation sparsity." They claim their method achieves "high degrees of activation sparsity with no accuracy loss" and "can assist in understanding the reasoning behind a CNN."

Accelerating Convolutional Neural Networks via Activation Map Compression. They "propose a three-stage compression and acceleration pipeline that sparsifies, quantizes, and entropy encodes activation maps of Convolutional Neural Networks." The sparsification step adds an L1 penalty to the activations in the network, which they do at finetuning time. The work just examines accuracy, not interpretability.

Enhancing Adversarial Defense by -Winners-Take-All. Proposes the -Winners-Take-All activation function, which keeps only the  largest activations and sets all other activations to 0. This is a drop-in replacement during neural network training, and they find it improves adversarial robustness in image classification. How Can We Be So Dense? The Benefits of Using Highly Sparse Representations also uses the -Winners-Take-All activation function, among other sparsification techniques.

The Neural LASSO: Local Linear Sparsity for Interpretable Explanations. Adds an L1 penalty to the gradient wrt the input. The intuition is to make the final output have a "sparse local explanation" (where "local explanation" = input gradient)

Adaptively Sparse Transformers. They replace softmax with -entmax, "a differentiable generalization of softmax that allows low-scoring words to receive precisely zero weight." They claim "improve[d] interpretability and [attention] head diversity" and also that "at no cost in accuracy, sparsity in attention heads helps to uncover different head specializations."

Interpretable Neural Predictions with Differentiable Binary Variables. They train two neural networks. One "selects a rationale (i.e. a short and informative part of the input text)", and the other "classifies... from the words in the rationale alone."

I ask because your paper doesn't seem to have a related works section, and most of your citations in the intro are from other safety research teams (eg Anthropic, OpenAI, CAIS, and Redwood.)

Hi Scott, thanks for this!

Yes I did do a fair bit of literature searching (though maybe not enough tbf) but very focused on sparse coding and approaches to learning decompositions of model activation spaces rather than approaches to learning models which are monosemantic by default which I've never had much confidence in, and it seems that there's not a huge amount beyond Yun et al's work, at least as far as I've seen.

Still though, I've not seen almost any of these which suggests a big hole in my knowledge, and in the paper I'll go through and add a lot more background to attempts to make more interpretable models.

Awesome work! I like the autoencoder approach a lot.

Cool work! I really like the ACDC on the parenthesis feature part, I'd love to see more work like that, and work digging into exactly how things compose with each other in terms of the weights.

I've had trouble figuring out a weight-based approach due to the non-linearity and would appreciate your thoughts actually.

We can learn a dictionary of features at the residual stream (R_d) & another mid-MLP (MLP_d), but you can't straightfowardly multiply the features from R_d with W_in, and find the matching features in MLP_d due to the nonlinearity, AFAIK.

I do think you could find Residual features that are sufficient to activate the MLP features[1], but not all linear combinations from just the weights.

Using a dataset-based method, you could find causal features in practice (the ACDC portion of the paper was a first attempt at that), and would be interested in an activation*gradient method here (though I'm largely ignorant). 

 

  1. ^

    Specifically, I think you should scale the residual stream activations by their in-distribution max-activating examples.

Did you ever try out independent component analysis? There's a scikit-learn implementation even. If you haven't, I'm strongly tempted to throw an undergrad at it (in a RL setting where it makes sense to look for features that are coherent across time).

EDIT: Nevermind, it's in the paper. And also I guess in the figure if I was paying closer attention :P

Hi Charlie, yep it's in the paper - but I should say that we did not find a working CUDA-compatible version and used the scikit version you mention. This meant that the data volumes used are somewhat limited - still on the order of a million examples but 10-50x less than went into the autoencoders.

It's not clear whether the extra data would provide much signal since it can't learn an overcomplete basis and so has no way of learning rare features but it might be able to outperform our ICA baseline presented here, so if you wanted to give someone a project of making that available, I'd be interested to see it!