TL;DR

We achieve better SAE performance by:

  1. Removing the lowest activating features
  2. Replacing the L1(feature_activations) penalty function with L1(sqrt(feature_activations))

with 'better' meaning: we can reconstruct the original LLM activations w/ lower MSE & with fewer features/datapoint.

As a sneak peak (the graph should make more sense as we build up to it, don't worry!):

The L1(sqrt()) graph (ie dotted one) is farther to the lower-left corner (this is good!) of low Cross Entropy loss with low features/datapoint (ie L0 Norm). 

Now in more details: 

Sparse Autoencoders (SAEs) reconstruct each datapoint in [layer 3's residual stream activations of Pythia-70M-deduped] using a certain amount of features (this is the L0-norm of the hidden activation in the SAE). Typically the higher activations are interpretable & the lowest of activations non-interpretable. 

 

This is a histogram of the specific tokens that activated a specific feature. If you run 1M tokens in an LLM, 1000 of those activate this feature. Of those 1000, there are ~30 that activate between 5.53 & 6.15 which are all the token apostrophe (aka '). 

Here is a feature that activates mostly on apostrophe (removing it also makes it worse at predicting "s"). The lower activations are conceptually similar, but then we have a huge amount of tokens that are something else. 

From a datapoint viewpoint, there's a similar story: given a specific datapoint, the top activation features make a lot of sense, but the lowest ones don't (ie if 20 features activate that reconstruct a specific datapoint, the top ~5 features make a decent amount of sense & the lower 15 make less and less sense)

Are these low-activating features actually important for downstream performance (eg CE)? Or are they modeling noise in the underlying LLM (which is why we see conceptually similar datapoints in lower activation points)?

Ablating Lowest Features

There are a few different ways to remove the "lowest" feature activations.

Dataset View:

  1. Lowest k-features per datapoint

Feature View: Features have different activation values. Some are an OOM larger than others on average, so we can set feature specific thresholds.

  1. Percentage of max activation - remove all feature activations that are < [10%] of max activation for that feature
  2. Quantile - Remove all features in the [10th] percentile activations for each feature
  3. Global Threshold - Let's treat all features the same. Set all feature activations less than [0.1] to 0.

It turns out that the simple global threshold performs the best:

MSE is red on the left & CE is Blue on the right. The bottom right corner is default performance removing no features. Top-left is removing all features. The original model's CE is 4.02 on this dataset.

[Note: "CE" refers to the CE when you replace [layer 3 residual stream]'s activations with the reconstruction from the SAE. Ultimately we want the original model's CE with the smallest amount of feature's per datapoint (L0 norm).]

You can halve the L0 w/ a small (~0.08) increase in CE. Sadly, there is an increase in both MSE & CE. If MSE was higher & CE stayed the same, then that supports the hypothesis that the SAE is modeling noise at lower activations (ie noise that's important for MSE/reconstruction but not for CE/downstream performance). But these lower activations are important for both MSE & CE similarly. 

For completion sake, here's a messy graph w/ all 4 methods:

 

[Note: this was run on a different SAE than the other images]

There may be a more sophisticated methods that take into account feature-information (such as whether it's an outlier feature or feature frequency), but we'll be sticking w/ the global threshold for the rest of the post. 

Sweeping Across SAE's with Different L0's

You can get widly different L0's by just sweeping the weight on the L1 penalty term where increasing the L0 increases reconstruction but at the cost of more, potentially polysemantic, features per datapoint. Does the above phenomona extend to SAE's w/ different L0's? 

We plot 15 SAEs trained on different L1's. We then sweep to remove the lowest features for each. The right-most point is the original CE of the SAE reconstructed.

Looks like it does & the models seems to follow a pareto frontier. 

Using L1(sqrt(feature_activation))

@Lucia Quirke trained SAE's with L1(sqrt(feature_activations)) (this punishes smaller activations more & larger activations less) and anecdotally noticed less of these smaller, unintepretable features. Maybe this method can get us better L0 for the same CE? 

In fact it does! Whoooo! Though, becomes pareto worse at around SAE's with ~15 L0 norm.

But in what way is this helping? We can break the loss into a couple components:
1. Cos sim: when reconstructing an activation, how well does the SAE approximate the direction? ie a cos-sim of 1 means the reconstruction is pointing in the same direction, but might not be the same magnitude

2. L2 Ratio: How well does the SAE reconstruct the magnitude of the original activation?

Most strikingly, it helps with the L2 Ratio significantly:

This makes sense because higher activating feature activations are punished less for the sqrt() than otherwise. If you have higher-activating features, they can reconstruct the norm better even with the decoder normalization constraint.  Looking at Cos sim & MSE we see similar improvements. 

What about L1(sqrt(sqrt()))?

If sqrt-ing helped, why not do more? 

It got worse than just one sqrt, lol.

Looking at the graph of these loss terms:

Sqrt-ing does decrease low (<1.0) feature-activations more & higher (>1.0) feature-activations less relative to the usual L1 penalty. The point of 1 & the slope of sqrt are arbitrary with respect to the feature activations (ie other points/slopes may perform better)

For example, look at a few datapoint activations:

We plot a hisogram of the first 4000 datapoint's feature activations for every feature (ie after encoding with the SAE). 

We do see a right shift in activations for L1(sqrt()) as expected. But the median activation is 0.28 & 0.43 for L1 & L1(sqrt()) respectively. Lets say we want to punish values around 0.01 more, but once it becomes 0.1, we don't want to punish as much. We could then do sqrt(0.1*x). 

We could also sweep this [0.1] value as a parameter (while trying to maintain a similar L0 which is tricky) as well as trying slopes between x & x^0.25 (ie sqrt(sqrt())). 

Feature Frequency/Density Plots

Anthropic noticed that very low frequency feature activations were usually uninterpretable (note: they say low "density" where I say low "frequency").

We do see more low-frequency features for L1(sqrt()) here. 

[Note: I chose 1e-5 as the threshold based off the feature frequency histograms in the appendix]

Future Work

It'd be good to investigate a better L1 penalty than L1(sqrt(x)). This can be done empirically by throwing lots of L1 loss terms at the wall, or there may be a more analytical solution. Let me know if you have any ideas! Comments and dms are welcome. 

We also give intermediate results & discussion in the EleutherAI Sparse-coding channel; feel welcomed to join the discussion!

Invite to the server  and link to our project thread

Appendix

  • Feature statistics were calculated over 100k tokens
  • Model is Pythia-70M-deduped
  • SAE is trained on 400M tokens of the Pile, has 4x as many features as datapoints. We did not use ghost grads nor neuron resampling. 
  • CE is high (ie 5.0 instead of 3.5) because CE of Pythia-70M-deduped has worse CE on text data that starts in the middle of text, which is what the majority of our text is formatted as. 

Feature Freq histograms

New Comment