Browse these SAE Features on Neuronpedia

Update 1: Since we posted this last night, someone pointed out that our implementation of ghost grads has a non-trivial error (which makes the results a-priori quite surprising). We computed the ghost grad forward pass using Exp(Relu(W_enc(x)[dead_neuron_mask])) rather than Exp((W_enc(x)[dead_neuron_mask])). I'm running some ablation experiments now to get to the bottom of this.     

Update 2: I've since investigated this further and run some ablation studies with the following results. Ghost grads weren't working as intended due to the Exp(Relu(x)) bug but the resulting SAE's were still quite good (later layers had few dead neurons simply because when we dropped the number of features, you get less dead neurons. I've found that with a correct ghost grads implementation, you can get less dead neurons and will update the library shortly. (I will make edits to the rest of this post to reflect my current views). Sorry for the confusion

This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, under mentorship from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants. 

This is intended to be a fairly informal post sharing a set of Sparse Autoencoders trained on the residual stream of GPT2-small which achieve fairly good reconstruction performance and contain fairly sparse / interpretable features. More importantly, advice from Anthropic and community members has enabled us to train these fairly more efficiently / faster than before. The specific methods that were most useful were: ghost gradients, learning rate warmup, and initializing the decoder bias with the geometric median. We discuss each of these in more detail below. 

Feature found in Residual Stream (Pre) Layer 10 of GPT2 Small. The force is strong with this one. 

5 Minute Summary

We’re publishing a set of 12 Sparse AutoEncoders for the GPT2 Small residual stream. 

  • These dictionaries have approximately 25,000 features each, with very few dead features (mainly in the early layers) and high quality reconstruction (log loss when the activations are replaced with the output is 3.3 - 3.6 as compared with 3.3 normally). 
  • The L0’s range from 5 in the first layer to 70 in the 9th SAE (increasing by about 5-10 per layer and dropping in the last two layers. 
  • By choosing a fixed dictionary size, we can see how statistics like the number of dead features or reconstruction cross entropy loss change with layer giving some indication of how properties of the feature distribution change with layer depth. 
  • We haven’t yet extensively analyzed these dictionaries, but will share automatically generated dashboards we’ve generated. 

Readers can access the Sparse Autoencoder weights in this  HuggingFace Repo. Training code and code for loading the weights / model and data loaders can be found in this Github Repository.  Training curves and feature dashboards can also be found in this wandb report. Users can download all 25k feature dashboards generated for layer 2 and 10 SAEs and the first 5000 of the layer 5 SAE features here (note the left hand of column of the dashboards should currently be ignored).

 Layer 

 Variance Explained 

L1 loss 

L0* 

% Alive Features 

Reconstruction 

CE Log Loss 

0

99.15%

4.58

12.24

80.0%

3.32

1

98.37%

41.04

14.68

83.4%

3.33

2

98.07%

51.88

18.80

80.0%

3.37

3

96.97%

74.96

25.75

86.3%

3.48

4

95.77%

90.23

33.14

97.7%

3.44

5

94.90%

108.59

43.61

99.7%

3.45

6

93.90%

136.07

49.68

100%

3.44

7

93.08%

138.05

57.29

100%

3.45

8

92.57%

167.35

65.47

100%

3.45

9

92.05%

198.42

71.10

100%

3.45

10

91.12%

215.11

53.79

100%

3.52

11

93.30%

270.13

59.16

100%

3.57

    

Original Model

3.3

Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.

Training SAEs that we were happy with used to take much longer than it is taking us now. Last week, it took me 20 hours to train a 50k feature SAE on 1 billion tokens and over the weekend it took 3 hours for us to train 25k SAE on 300M tokens with similar variance explained, L0 and CE loss recovered. 

We attribute the improvement to having implemented various pieces of advice that have made our lives a lot easier:

  • Ghost Gradients / Avoiding Resampling:  Prior to ghost gradients (which we were made aware of last week in the Anthropic January Update), we were training SAEs with approximately 50k features on 1 billion tokens with 3 resampling events to reduce the number of dead features. This took around 20 hours and might cost about $10 with an A6000 GPU. With ghost gradients, we don’t need to resample (or wait for loss curves to plateau after resampling). Now we can train on only 300M tokens instead. Simultaneously, since we now have very few dead features, we can make the SAE smaller which speeds up the training time significantly. Edit: It turns out that ghost gradients do reduce dead neurons but also dropping the number of features can have this effect. When I dropped the number of features after I thought ghost grads were working, I obscured the buggy implementation. 
  • Using a learning rate warmup: Before implementing ghost gradients, Arthur found that using a learning rate warmup at the beginning of training also kept features alive which are otherwise killed off early in training. Using this warm-up leads to less dead features enabling us to train smaller SAEs.  Since ghost gradients don’t kick in until we’re sure a feature is dead (5000 steps where they didn’t fire), it seems like a learning rate warm-up is likely worth keeping around. 
  • Initializing the decoder bias at the Geometric Median. Another improvement that made a significant impact on our results was initializing the decoder bias at the Geometric Median (as recommended by Anthropic). This seemed to help avoid dense/uninterpretable features which had caused deceptively good looking statistics in previous training runs. Anecdotally, we did notice that resampling helped eliminate these features and that initializing at the mean was possibly just as good as initializing with the geometric median. 

While we haven’t tested our code extensively since implementing these improvements, we suspect that hyperparameter tuning may be easier in the future since these method improvements make the process generally less sensitive. 

To demonstrate the interpretability of these SAEs, we share screenshots of feature dashboards we produced using a reproduction of Anthropic’s dashboard developed by Callum McDougall. 

Finally, we end by discussing how readers can access these SAE’s, some experiments which they could perform to upskill with SAEs and possible research directions to pursue. 

Introduction

What are Sparse AutoEncoders and why should we care about them?

Sparse autoencoders (SAEs) are an unsupervised technique to take a model's activations and decompose it into interpretable feature vectors. We highly recommend this tutorial on SAEs for those interested. Recent papers on the topic can be found here and here.

I’m particularly excited about Sparse Autoencoders for two reasons:

  1. SAEs are an unsupervised that might help us understand how model internals work and may be robust to us being wrong about how models think about the world. Understanding model internals in detail could help with lots of alignment proposals such as Eliciting Latent KnowledgeMechanistic Anomaly Detection and Retargeting the Search.
  2. SAEs represent a major scientific breakthrough. As a former computational biologist, they seem to be analogous to DNA or mRNA sequencing. A huge proportion of modern biology is informatic, and operates based on information we have about genes, gene products (mRNA and proteins) and databases which relate these internal components to variables we care about like cancer and other pathologies. Knowing an organism's genome doesn’t immediately give you the ability to arbitrarily intervene on internals or cure all disease, though it’s an important start! Moreover, there are classes of therapies which you would just never arrive at without this level of insight into cells (such as immuno-cancer therapies). Bringing it back to neural networks, it’s possible that there are some forms of AI misalignment which are tractable with insight into model internals, but are intractable otherwise!

General Advice for Training SAEs

Why can training Sparse AutoEncoders be difficult? 

Sparse autoencoders are an unsupervised method which attempts to trade off reconstruction accuracy against interpretability, which we achieve by inducing activation sparsity. Since we don’t have good metrics for interpretability / reconstruction quality, it’s hard to know when we are actually optimizing what we care about. On top of this, we’re trying to pick a good point on the pareto frontier between interpretability and reconstruction quality which is a hard thing to assess well. 

The main objective is to have your Sparse Autoencoder learn a population of sparse features (which are likely to be interpretable) without having some dense features (features which activate all the time and are likely uninterpretable) or too many dead features (features which never fire). As discussed in the 5 minute summary, we went from training GPT2 small residual streams in  12+ hours to ~ 3 hours (so 4x faster / cheaper). Though the L0 and CE loss were somewhat similar, the feature density histograms also suggested we avoided dead / dense features way more effectively after using ghost gradients. 

Let’s dig into the challenges associated with dead / dense features a bit more a bit more:

  • Too many dead features. Dead features don’t receive gradients and so represent permanently lost capacity in your SAE. They make training slower and more expensive and are present in almost everyone's SAEs. 
    • One solution is to use a resampling strategy but this results in much longer training times. Resampling (if it's enough features) causes your loss curves to go a bit nuts and, though it's useful, means you gotta wait ages between resampling and between the last resample and ending training.
    • Another solution is just use a larger dictionary (but then you are wasting compute during training). 
  • Dense features. One issue you can get when hyperparameter tuning SAEs is that you get dense features (which fire on > 1/100 or 1/10 tokens). These features seem generally uninterpretable, but help with your reconstruction immensely. 
    • Previously it felt like a bit of an art form to avoid both dense features and dead features simultaneously since reducing the sparsity penalty kills less features but encourages dense features. 
  • Reading Feature Density Histograms: Feature density histograms are a good measure of SAE quality. We plot the log10 feature sparsity (how often it fires) for all features. In order to make this easier to operationalize, I’ve drawn a diagram that captures my sense of the issues these histograms help you diagnose. Feature density histograms can be broken down into:
    • Too Dense:  dense features will occur at a frequency > 1 / 100. Some dense-ish features are likely fine (such as a feature representing that a token begins with a space) but too many is likely an issue.
    • Too Sparse: Dead features won’t be sampled so will turn up at log10(epsilon), for epsilon added to avoid logging 0 numbers. Too many of these mean you’re over penalizing with L1. 
    • Just-Right: Without too many dead or dense features, we see a distribution that has most mass between -5 or -4 and -3 log10 feature sparsity. The exact range can vary depending on the model / SAE size but the dense or dead features tend to stick out. 

Which tricks help the most?

In terms of solutions, Anthropic published useful advice (especially ghost gradients) and the research community is building consensus on how to train SAEs well (what your loss curves and final statistics should look like for example). I found Arthur Conmy’s post very useful.

The top 3 changes that I made which led to my ability to train these SAEs cheaply/quickly were:

  • Ghost Gradients / Avoiding Resampling: Ghost gradients work by adding a loss term to your overall loss which causes gradient updates to dead features in the direction of fitting the error term (the part of the activation vector your SAE is not fitting). In other words, we take our spare capacity that isn’t being used and point it at our error. This creates a kind of “buoyancy” around feature density which continuously resuscitates features.
    • Importantly, this means you can train a smaller sparse autoencoder, which has the same number of alive features, increasing the speed of your training.
    • Simultaneously, the resampling strategies we used previously were a little like restarting training from scratch if you had lots of dead features (L1 and MSE loss spike after resampling and it takes time to “re-equilibrate”). Total training tokens are therefore much shorter with ghost grads and no resampling. 
  • Using a learning rate warmup: This helps with avoiding dead features early in training. I haven’t done an ablation study with this / ghost grads but before I was using ghost grads this made a fairly large difference to the feature density histograms. 
  • Initializing the decoder bias with an estimate of the geometric median of the activations. This seems to help avoid dense/dead features.
    • Anecdotally, I calculated the distance between the geometric median of ~250k randomly sampled activations and the final decoder bias in some of my autoencoders and they were fairly close (closer than the mean of those activations!) and so this seems like a straightforward case of “if you can guess a parameter before training, initialize close to it for better results”. 
    • Note, Anthropic had suggested initializing the geometric median this way from the start and we simply hadn’t implemented it. In retrospect it seems obvious, but I think there were many other hypotheses we’d considered at the time for what we should do to improve. 

Sparse AutoEncoders for the GPT2 Residual Stream

Why GPT2 small? Why the residual stream?

GPT2 small has been extensively studied by the mechanistic interpretability community and whilst not an incredibly performant model, it certainly has some kind of “prototypical object of study” property. We chose the residual stream because this enables us to analyze “the total sum of previous output” in a manner not dissimilar to the logit lens approach. This may be useful for understanding how features are constructed from earlier features as well as studying how the distribution of features over time changes in a model. 

Architecture and Hyperparameters

We trained 12 Sparse Autoencoders on the Residual Stream of GPT2-small. 

  • Each of these contains ~ 25k features as we used an expansion factor of 32 and the residual stream dimension of GPT2 has 768 dimensions. 
  • We trained with an L1 coefficient of 8e-5, a learning rate of 4e-4 for 300 Million tokens  from OpenWebText.
  • We store activations in a buffer of ~500k tokens which is refilled and shuffled whenever 50% of the tokens are used (ie: Neel’s approach). 
  • To avoid dead features, we use ghost gradients. 
  • Our encoder/decoder weights are untied but we do use a tied decoder bias initialized at the geometric median per Bricken et al. 

Were I to be training SAEs on another model or part of the same model, I wouldn’t change any of the architectural choices (except maybe expansion factor). Other parameters like learning rate, l1 coefficient, number of tokens to train on all likely need to be tuned in practice. It also seems plausible we’ll continue to see methodological advances in the future which I’m excited about!

What do we think about when choosing hyper-parameters / evaluating SAEs?

We train against:

  • A reconstruction objective (MSE)
  • A sparsity inducing term (L1 norm on the activations).
  • A ghost gradient term (used to resuscitate dead features). 

However, what we actually care about is whether we reconstructed information required for model performance and how useful the features are for interpretability. Better proxies for these desiderata are:

  • The cross entropy loss (how well the model performs if we replace its activations with the output of the autoencoder).
  • The L0 (number of features which fire when we reconstruct any given activation).

Though L0 is a pretty good proxy for interpretability, in practice the feature density histogram (the distribution of how frequent features are) turns out to be one of the most important things we need to get right when tuning hyperparameters. 

How good are these Sparse AutoEncoders? 

At a glance, the summary metrics seem fairly good. I’ll make a number of comments:

  • It's unclear how many features we should expect to see per token (what the L0 should be), but clearly higher L0’s enable us to achieve better reconstruction. Furthermore, it makes sense that larger models and later layers have more features firing and our spot checks indicated that higher L0 SAEs still seem interpretable and anecdotally have been useful for circuit analysis. Therefore, we tolerate much higher L0’s in practice than we had previously considered ideal. 
  • Our reconstruction scores were pretty good. We found GPT2 small achieves a cross entropy loss of about 3.3, and with reconstructed activations in place of the original activation, the CE Log Loss stays below 3.6. 
  • Most obviously, the proportion of dead features is just way lower than you’d be able to get without ghost gradients/
  • I was very happy with the feature density histograms (see below) which showed few dead features for most layers and very few dense features.

Georg Lang pointed out to me that the L2 loss grows quadratically with the norm which increases with layer whilst the L1 coefficient grows linearly. This means that since I didn’t vary the L1 coefficient when training these, we’re effectively pushing less hard for sparsity in later layers (which would explain the trend in L0 / L1 and the feature density histograms). Interestingly, the variance explained still gets worse with layers. 

 Layer 

 Variance Explained 

L1 loss 

L0* 

% Alive Features 

Reconstruction 

CE Log Loss 

0

99.15%

4.58

12.24

80.0%

3.32

1

98.37%

41.04

14.68

83.4%

3.33

2

98.07%

51.88

18.80

80.0%

3.37

3

96.97%

74.96

25.75

86.3%

3.48

4

95.77%

90.23

33.14

97.7%

3.44

5

94.90%

108.59

43.61

99.7%

3.45

6

93.90%

136.07

49.68

100%

3.44

7

93.08%

138.05

57.29

100%

3.45

8

92.57%

167.35

65.47

100%

3.45

9

92.05%

198.42

71.10

100%

3.45

10

91.12%

215.11

53.79

100%

3.52

11

93.30%

270.13

59.16

100%

3.57

    

Original Model

3.3

Summary Statistics for GPT2 Small Residual Stream SAEs. *L0 = Average number of features firing per token.

 

Log Feature Sparsity Histograms for each the residual stream SAEs of GPT2-small

How interpretable are the features in each layer?

Feature interpretability is far from a settled science, but feature dashboards sure do automate a huge chunk of the work. We use a reproduction of Anthropic’s dashboard developed by Callum McDougall. 

We’re still working on making our dashboard generating code more efficient (to keep up with the improvements in our ability to train sparse autoencoders!). In the meantime, we’ve collected some anecdotal examples of features at layers 2, 5 and 10 to give examples of the kinds of features we can detect in GPT2 small.
Though we share some features below, you can look through more features at the bottom of the dashboard here.

Layer 2: The President Feature

For example, below we show a “President” feature which promotes the first names of presidents. 

 

Layer 2: The “c” subword token feature

Another example of a fairly typical layer 2 feature is this feature which fires on “c” due to tokenization which splits a word that starts with c. 

Layer 5: The what you are saying thanks OR sorry for feature

For example, this feature appears to fire for short stretches of text involving thanks or apologies

Layer 5: The Force is strong with this Feature. 

Though there are plenty of features that seem interesting about layer 5 SAE, some are just way stronger in the force than others. 

Layer 10: Violence / Conflict Feature.

How to get involved

I want to look at more dashboards!

You can download all 25k for layer 10 and 2 and the first 5k for layer 5 here

How can I download and analyze these SAEs?

For those who would like to play around with these sparse autoencoders, my codebase is pretty crazy right now but you can mostly ignore it once you have the SAE. The codebase has:

  •  A README explaining:
    • The interface for training SAE’s
    • The interface for generating dashboards. 
  • A tutorial showing how to use the SAE with TransformerLens 

In order to speed up my own analysis, I creates a “SessionLoader” class which takes a path to the saved SAE and then instantiates the model it was trained on, the sparse autoencoder and the activations_loader (which gets your tokens/activations). Between these three artifacts, you start analyzing an SAE very quickly post training. 

After cloning the repo and installing the requirements.txt, users can simply run the following commands:

from sae_training.utils import LMSparseAutoencoderSessionloader

path ="path/to/sparse_autoencoder.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

What kinds of analysis can we do with Residual Stream Sparse Autoencoders?

Without getting into entire research directions, it’s worth discussing briefly the kinds of experiments that can be run with Sparse Autoencoders. These are projects that might enable you to get a taste of working with SAEs and decide if you’re excited to work with them more seriously. 

Some example upskilling projects could be:

  • Studying Co-Occurrence: Collect feature activations on a bunch of data and then look for co-occurring features. Co-occurrence of features is disincentive with the L1 penalty but might be retained if it’s useful for the reconstruction objective. Understanding which features co-occur and why within the same SAE / layer of GPT2 or between layers could be quite an interesting mini-project. Plotting the joint distribution of features for highly co-occurring features can be interesting too. 
  • Identifying Redundancy between SAEs trained on nearby layers. If we have 25k residual stream features in 12 residual stream SAE’s, then that’s a lot of features! Having automatic ways to identify redundant features which track similar or the same variables between layers could be an interesting task to solve. 
  • Red-Teaming Feature Dashboards: Feature dashboards are super useful, but rely on max-activating examples which can be very misleading. Finding examples of where feature dashboards are misleading and finding ways to improve them or automatic ways to detect these issues could be quite valuable. For some inspiration, you could think about how Simpson’s Paradox or Berkeson’s paradox might be relevant. 

For more ideas, posts published by researchers currently working on SAEs can be a great source of inspiration:

What research directions could you pursue with SAEs? 

It’s likely that there are a bunch of low hanging fruit with SAEs right now. Logan Riggs posted a bunch of ideas here and Anthropic list some directions for future work which are worth reading here

Two direction, I’m excited by are: 

  • Red-teaming / Improving SAEs: Understanding why models don’t perform better when we replace activations with SAE output seems like an important target. If you were sequencing DNA, but your sequencer wasn’t reliable enough for you to put the DNA back in the bacteria, then you wouldn’t be sure your sequencer was working. There may be issues with the way we’re currently training SAEs or more fundamental issues with our conceptualisations. Maybe some fraction of features really are dense or aren’t well described by a sparse overcomplete basis. How would we know? 
  • Understanding Higher Level Structures composed of Features: Previously, we’ve been talking about features as being connected by circuits, but there might be other lenses through which to understand the ways features relate to each other. For example, are there sets of mutually exclusive features which represent similar concepts (like colors).

Appendix

Thanks

I’d like to thank Neel Nanda and Arthur Conmy for their support and feedback while I’ve been working on this and other SAE related work. I also appreciate feedback and support from members of the Mechanistic Interpretability Stream in the MATS 5.0 cohort, especially Ben Wu and Andy Arditi. 

I’d also like to thank the interpretability team at Anthropic for continually sharing their advice on how to train sparse autoencoders, and to Callum McDougall for his awesome SAE visualizer (replication of Anthropic’s dashboard).

Funding Note

This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, with support from Neel Nanda and Arthur Conmy. Funding for this work was provided by the Manifund Regranting Program and donors as well as LightSpeed Grants. 

How to Cite

@misc{bloom2024gpt2residualsaes,
   title = {Open Source Sparse Autoencoders for all Residual Stream Layers of GPT2 Small},
   author = {Joseph Bloom},
   year = {2024},
   howpublished = {\url{https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream}},
}
New Comment
4 comments, sorted by Click to highlight new comments since: Today at 10:26 AM

Hey Joseph (and coauthors),

Your directions are really fantastic. I hope you don't mind, but I generated the activation data for the first 3000+ directions for each of the 12 layers and uploaded your directions to Neuronpedia:

https://www.neuronpedia.org/gpt2-small/res-jb 

Your directions are also linked on the home page and the model page.

They're also accessible by layer (sorted by top activation), eg layer 6: https://neuronpedia.org/gpt2-small/6-res-jb

I added the "Anthropic dashboard" to Neuronpedia for your dataset.

Explanations, comments, and autointerp scoring are also working - anyone can do this:

  • Click a direction and submit explanation on the top-left. Here's another Star Wars direction (5-RES-JB:1681) where GPT4 gave me a score of 96:
    • Click the score for the scoring details:

I plan to do some autointerp explaining on a batch of these directions too.

Btw - your directions are so good that it's easy to find super interesting stuff. 5-RES-JB:5 is about astronomy:

I'm aware that you're going to do some library updates to get even better directions, and I'm excited for that - will re-generate/upload all layers after the new changes come in.

Things that I'm still working on and hope to get working in the next few days:

  • Making activation testing work for each neuron
  • "Search / test" the same way that we have search/test for OpenAI's directions

Again, your directions look fantastic - congrats. I hope this is useful/interesting for you and anyone trying to browse/explain them. Also, I didn't know how to provide a citation/reference to you (and your team?) so I just used RES-JB = Residuals by Joseph Bloom and included links to all relevant sources on your directions page.

If there's anything you'd like me to modify about this, or any feature you'd like me to add to make it better, please do not hesitate to let me know.

I tried replicating your statistics using my own evaluation code (in evaluation.py here). I pseudo-randomly chose layer 1 and layer 7. Sadly, my results look rather different from yours:

Layer MSE Loss % Variance Explained L1 L0 % Alive CE Reconstructed
1 0.11 92 44 17.5 54 5.95
7 1.1 82 137 65.4 95 4.29

Places where our metrics agree: L1 and L0.

Places where our metrics disagree, but probably for a relatively benign reason:

  • Percent variance explained: my numbers are slightly lower than yours, and from a brief skim of your code I think it's because you're calculating variance slightly incorrectly: you're not subtracting off the activation's mean before doing .pow(2).sum(-1). This will slightly overestimate the variance of the original activations, so probably also overestimate percent variance explained.
  • Percent alive: my numbers are slightly lower than yours, and this is probably because I determined whether neurons are alive on a (somewhat small) batch of 8192 tokens. So my number is probably an underestimate and yours is correct.

Our metrics disagree strongly on CE reconstructed, and this is a bit alarming. It means that either you have a bug which significantly underestimates reconstructed CE loss, or I have a bug which significantly overestimates it. I think I'm 50/50 on which it is. Note that according to my stats, your MSE loss is kinda bad, which would suggest that you should also have high CE reconstructed (especially when working with residual stream dictionaries! (in contrast to e.g. MLP dictionaries which are much more forgiving)).

Spitballing a possible cause: when computing CE loss, did you exclude padding tokens? If not, then it's possible that many of the tokens on which you're computing CE are padding tokens, which is artificially making your CE look extremely good.

Here is my code. You'll need to pip install nnsight before running it. Many thanks to Caden Juang for implementing the UnifiedTransformer functionality in nnsight, which is a crazy Frankenstein marriage of nnsight and transformer_lens; it would have been very hard for me to attempt this replication without this feature.

Why do you scale your MSE by 1/(x_centred**2).sum(dim=-1, keepdim=True).sqrt() ? In particular, I'm confused about why you have the square root. Shouldn't it just be 1/(x_centred**2).sum(dim=-1, keepdim=True)?

For your dashboards, how many tokens are you retrieving the top examples from?