I actually do have some publicly hosted, only on residual stream and some simple training code.
I'm wanting to integrate some basic visualizations (and include Antrhopic's tricks) before making a public post on it, but currently:
Dict on pythia-70m-deduped
Dict on Pythia-410m-deduped
Which can be downloaded & interpreted with this notebook
With easy training code for bespoke models here.
I've had trouble figuring out a weight-based approach due to the non-linearity and would appreciate your thoughts actually.
We can learn a dictionary of features at the residual stream (R_d) & another mid-MLP (MLP_d), but you can't straightfowardly multiply the features from R_d with W_in, and find the matching features in MLP_d due to the nonlinearity, AFAIK.
I do think you could find Residual features that are sufficient to activate the MLP features, but not all linear combinations from just the weights.
Using a dataset-based method, you could find causal features in practice (the ACDC portion of the paper was a first attempt at that), and would be interested in an activation*gradient method here (though I'm largely ignorant).
Specifically, I think you should scale the residual stream activations by their in-distribution max-activating examples.
In ITI paper, they track performance on TruthfulQA w/ human labelers, but mention that other works use an LLM as a noisy signal of truthfulness & informativeness. You might be able to use this as a quick, noisy signal of different layers/magnitude of direction to add in.
Preferably, a human annotator labels model answers as true or false given the gold standard answer. Since human annotation is expensive, Lin et al. (2021) propose to use two finetuned GPT-3-13B models (GPT-judge) to classify each answer as true or false and informative or not. Evaluation using GPT-judge is standard practice on TruthfulQA (Nakano et al. (2021); Rae et al. (2021); Askell et al. (2021)). Without knowing which model generates the answers, we do human evaluation on answers from LLaMA-7B both with and without ITI and find that truthfulness is slightly overestimated by GPT-judge and opposite for informativeness. We do not observe GPT-judge favoring any methods, because ITI does not change the style of the generated texts drastically
[word] and [word]can be thought of as "the previous token is ' and'."
I think it's mostly this, but looking at the ablated text, removing the previous word before and does have a significant effect some of the time. I'm less confident on the specifics of why the previous word matter or in what contexts.
Maybe the reason you found ' and' first is because ' and' is an especially frequent word. If you train on the normal document distribution, you'll find the most frequent features first.
This is a database method, so I do believe we'd find the features most frequently present in that dataset, plus the most important for reconstruction. An example of the latter: the highest MCS feature across many layers & model sizes is the "beginning & end of first sentence" feature which appears to line up w/ the emergent outlier dimensions from Tim Dettmer's post here, but I do need to do more work to actually show that.
Setup:Model: Pythia-70m (actually named 160M!)Transformer lens: "blocks.2.hook_resid_post" (so layer 2)Data: Neel Nanda's Pile-10k (slice of pile, restricted to have only 25 tokens, same as last post)Dictionary_feature sizes: 4x residual stream ie 2k (though I have 1x, 2x, 4x, & 8x, which learned progressively more features according to the MCS metric)
Uniform Examples: separate feature activations into bins & sample from each bin (eg one from [0,1], another from [1,2])
Logit Lens: The decoder here had 2k feature directions. Each direction is size d_model, so you can directly unembed the feature direction (e.g. the German Feature) you're looking at. Additionally I subtract out several high norm tokens from the unembed, which may be an artifact of the pythia tokenizer never using those tokens (thanks Wes for mentioning this!)
Ablated Text: Say the default feature (or neuron in your words) activation of Token_pos 10 is 5, so you can remove all tokens from 0 to 10 one at a time and see the effect on the feature activation. I select the token pos by finding the max feature activating position or the uniform one described above. This at least shows some attention head dependencies, but not more complicated ones like (A or B... C) where removing A or B doesn't effect C, but removing both would.
[Note: in the examples, I switch between showing the full text for context & showing the partial text that ends on the uniformly-selected token]
Actually any that are significantly effected in "Ablated Text" means that it's not just the embedding. Ablated Text here means I remove each token in the context & see the effect on the feature activation for the last token. This is True in the StackExchange & Last Name one (though only ~50% of activation for last-name, will still recognize last names by themselves but not activate as much).
The Beginning & End of First Sentence actually doesn't have this effect (but I think that's because removing the first word just makes the 2nd word the new first word?), but I haven't rigorously studied this.
We have our replication here for anyone interested!
How likely do you think bilinear layers & dictionary learning will lead to comprehensive interpretability?
Are there other specific areas you're excited about?
Why is loss stickiness deprecated? Were you just not able to see the an overlap in basins for L1 & reconstruction loss when you 4x the feature/neuron ratio (ie from 2x->8x)?
As (maybe) mentioned in the slides, this method may not be computationally feasible for SOTA models, but I'm interested in the ordering of features turned monosemantic; if the most important features are turned monosemantic first, then you might not need full monosemanticity.
I initially expect the "most important & frequent" features to become monosemantic first based off the superposition paper. AFAIK, this method only captures the most frequent because "importance" would be w/ respect to CE-loss in the model output, not captured in reconstruction/L1 loss.