Interpretability in machine learning, especially in language models, is an area with a large number of contributions. While this can be quite useful for improving our understanding of models, one issue is that there is the lack of robust benchmarks to evaluate the efficacy of different interpretability techniques. Drawing comparisons and determining their true effectiveness in real-world scenarios becomes a difficult task.
Interestingly, there exists a parallel in the realm of non-language models under the research umbrella of Machine Unlearning. In this field, the objective is twofold: firstly, to deliberately diminish the model's performance on specified "unlearned" tasks, and secondly, to ensure that the model's proficiency is maintained or even enhanced on certain "retained" tasks. The inherent challenge here is achieving a balance between these seemingly opposing goals, and thus comes with a range of metrics for measuring the effectiveness of the techniques.
Drawing inspiration from Machine Unlearning, I believe that the metrics developed in this space could potentially serve as a litmus test for interpretability techniques in language models. By applying interpretability techniques as unlearning strategies, we can better test the effectiveness of interpretability methods, essentially setting benchmarks for how well these techniques can steer language models in desired directions.
If we aspire to have truly interpretable models, we must not only develop sophisticated techniques, but also robust benchmarks against which these techniques can be validated. Machine Unlearning might just offer the rigorous testing ground we need.
The rest of this post will: 1) Give a brief overview of Machine Unlearning, 2) Give a brief list of Machine Unlearning metrics, and how they may be applicable, 3) Give a deeper dive on each of the metrics, 4) Discuss how these fit in with existing metrics in Interpretability.
Many papers in the subfield of Machine Unlearning are motivated by privacy preservation, and pose the question: "If we trained on someone's information that is now retracted, how can we remove that information without needing to retrain the whole model?"
There are multiple ways you might achieve unlearning. The "ideal/standard" is often to train the model again, but without the data you don't want it to learn. Two of the main ideals for an unlearned model are:
Typically for Machine Unlearning, people want the first ideal. It may seem non-obvious that we should care about this distinction, but people do care, as you don't want to "Goodhart" the unlearning process. If the model behaves in the second way, and this differs to the first, you may instead be adding a circuit that identifies your unlearned training set and just adds randomness.
For interpretability, it might be less concerning to differentiate between these ideals unless gradient-based techniques that explicitly optimize for machine unlearning are employed. One main thing to keep in mind, is that if you train on A and not B, then the model might still learn some things that are useful for making predictions about B.
It may be the case that in some neural network architectures, unlearning may be more or less difficult and knowledge may be more or less entangled. Unlearning one piece of information might inadvertently affect the retention of other unrelated information. It would be ideal if we could measure the degree to which this is the case, and avoid making systems where one could disentangle various pieces of knowledge.
Here is some terminology often used in the machine unlearning literature. (note that there can be some minor differences in use):
Some of the main metrics used for evaluation are described in this Survey of Machine Unlearning. In brackets I have added a comment on my evaluation for how useful this is in practice for interpretability/related techniques on language models.
We note that many of the techniques here involve re-training a model exclusively on the retained tasks. This, in most cases, will likely be too expensive to compute for most people when it comes to large language models.
How good is the model at making predictions? It should stay equal on the "retained" dataset, but get worse at the "unlearned" and "test" datasets. Note that this section could likely be expaned on much further.
There are a lot of other "accuracy" metrics one could use, or more task-specific metrics. For example, one could use
One can look at this paper I have written to get an example of some of the metrics I have tried for assessing drops in accuracy. These are somewhat dependent on the specific metric, but In particular we use the metrics:
There are, however, many metrics one could use, which makes it difficult to coordinate on which metrics to evaluate your technique on. In addition, some accuracy benchmarks are more qualitative than direct next-token prediction (eg: "write an answer").
One should also consider, there are other ways one could measure behaviour that may not be accurately described by the word "accuracy". This could include things such as "toxicity" and "bias", or "refusing harmful requests" and "conforming to instructions". While some papers do try to look at these, there is a wide variety of ways of modelling model behaviour and performance that is not particularly well described in most Machine Unlearning literature, that would likely be useful to understand for a broader search into interpretability metrics.
Evaluation: You should probably be including this anyway
How long does your technique take? How does this compare to training the original model? This seems like you should be collecting this information anyway, so you should probably include it in your report.
How well do you remove the unlearned task from the model? Does the model still possess most of the machinery required to do the task, and you just removed a tiny piece that is inconsequential in the grand scheme of things? Here are a couple of metrics that try to measure this:
Evaluation: Seems OK. Can be expensive
How long does it take to relearn the unlearned skill? Depending on what you are doing (eg: removing a very small amount of knowledge for a specific fact, or removing a large variety of general capabilities), this may or may not be Feasible.
If you are making relatively small changes to your language model, I suspect it should be relatively inexpensive by doing a Quantilised + Low-Rank Adapter (QLoRA) finetuning of your model. If so, it would be valuable to see how long it would take to do this. Otherwise, If this is not possible, or you cannot afford to do such experiments, then that seems OK.
Ideally, you would be able to compare this to a model that has been retrained, though retraining a model without the unlearned task is usually prohibitively expensive.
Evaluation: too expensive (requires retraining)
Compare the "relearn time" (rt) on the forgotten task, for the unlearned model (Mu), and the retrained model (Ms), to be within α performance of the original model (Morig).
Ideally AIN should be close to 1. If relearning takes longer on the unlearned model, then you likely have Goodhart-ed the unlearning task.
This metric doesn't seem particularly useful for interpretability, and is also quite expensive to run.
Evaluation: too expensive (involves retrained model)
Check if the model fully forgets removed data. Is the model after unlearning is like a new model trained without the forgotten data?
Calculate the overlap (using Jaccard distance) between the outputs of the unlearned and retrained models. Ensures no traces of forgotten data impact the model's predictions.
How much does the unlearning affect parts of the model? How affected is the model on retained tasks? on the unlearned tasks? Here are some metrics that people try to use sometimes:
Evaluation: seems not super useful, but cheap, so maybe worth including?
This is a relatively simple metric: How different are the weights of the original model compared to the unlearned model? the retrained model? a randomly initialised model?
I somewhat doubt the practical value of this for interpretability, and don't really understand the point of this metric. I guess if the difference between the original model and the unlearned model is larger than the difference between the original model and the retrained model, I would be somewhat suspicious of the unlearning method.
Evaluation: Seems possibly good.
Originally for this metric, you would get the average L2-distance between the unlearned model and retrained model’s predicted probabilities on the forget set to try to evaluate "indistinguishability". In this case, using a retrained model is too expensive.
However, I think one could build a variation of this metric that compares:
Then one could try to see how much difference there is between these different activations. See also section on ZRF score.
Evaluation: seems good? unsure
Similar to Activation distance, but instead of L2-Distance, you get the Jensen-Shannon Divergence. Same arguments as above.
Evaluation: seems too expensive? unsure
Measures how much information about a dataset the model has learned. Expensive to compute. My understanding of the method for computation:
Step 1: Compute Fisher Information Matrix (FIM):
Step 2: Compute Influence Function: i(w;D)=tr(I(w;D))
Step 3: Compute Efficacy:
My understanding is that the efficacy measures how much the model has already learned about the data. If you were to measure it for base model vs unlearned model on retained vs unlearned tasks, then you could have a baseline for comparison.
If one has to follow the above method, it seems prohibitively expensive for large models, though there may be ways to get approximately the same information with a less expensive method.
Evaluation: seems good?
If we use a gradient-based machine unlearning method, we don't want to explicitly train the model to give the opposite answer, or to give a strangely uniform output prediction. This metric kinda checks for this. We get outputs for the unlearned model, and a randomly initialised model, and calculate the Jensen-Shannon divergence between the two, and calculate:
Then we can evaluate:
If the ZRF score is close to 1, that is good. One caveat is that in some cases (i.e: when you explicitly train to mimic a randomly initialised model), being too close to 1 could be a sign of Goodhart-ing the unlearning criteria (since models trained on task A, but not on task B, might still have better-than-random performance on task B). Overall, it seems like a useful metric for understanding how much information loss compared to original activations there is.
Note that these metrics seem use-case dependent and not super useful in general, as they are particularly interested in the question of data privacy.
Evaluation: unsure, seems use-case dependent.
In general, Membership Inference Attacks ask: “Was this data point part of the training data?” There are too many methods to list here, and they often work under different assumptions. This might be useful for trying to understand tampering in a model, and may be useful for interpretability, but I am unsure how easily this could be converted into a benchmark.
One example given in the context of Machine Unlearning and privacy preservation is: “Given the Original Model and the Unlearned Model, can you infer what was unlearned?”. While interesting, I am unsure how applicable this specific example is for machine unlearning.
Possible use in interpretability: if one was ablating a part responsible for a task, then membership inference techniques could be useful to understand how completely the ablation removes that capability on that task.
Some things to keep in mind:
Evaluation: not really a benchmark, but can be useful
I think the main idea here is to try to reconstruct the input given the output, using the unlearned model. The approach is basically the same as “Feature Visualisation”, and is already often used to better understand models. This could be useful for trying to get qualitative feedback on the approach. The main drawbacks are that it doesn’t apply as well to text-only language models, and is also not really a quantitative benchmark
There are many ways of trying to do interpretability, and many ways of assessing how good your interpretation is. I have listed a couple of the main ones here. While each of these can be a good initial metric, I think there is a lot of potential for better evaluating interpretability techniques. Often the metrics can be quite task-specific.
While I think the Machine Unlearning metrics can provide a rich source of information, how applicable they are is highly dependent on the exact technique you are looking at. I would expect more of these metrics to be much applicable to something like Sparse AutoEncoder research, and less applicable to something like ActAdd. However, I think having a better explicit list of metrics/benchmarks for Interpretability and implementations for running these benchmarks would be quite valuable.
One method used in various cases is to directly try to have features that look interpretable, and seeing how strongly they activate on some input. Some examples include earlier work in “Feature Visualisation”, and later in “Taking Representations out of Superposition using Sparse Auto-Encoders” (Original, Scaled-Up) and linear-probe based techniques such as “Language Models Represent Space and Time” or “Discovering Latent Knowledge”.
However, it is unclear in some of these cases to what extent the component is solely responsible for the behaviour, as it may also be responsible for other tasks, or there may be other component that fulfil the same function. Here is where Machine Unlearning evaluations seem to be the most useful. By intervening on these components, and using a variety of the metrics above, one could better understand the effect of ablation of these components.
One of the most common metrics is directly looking at logits for a specific immediate next token prediction. This can be directly by running the model to the end and looking at the logits, or by inferring the direct effect on logits based on changes in a mid-layer (i.e: Logit Lens, or more recently, Tuned Lens). This can be useful, and provide tight feedback loops, but I think that having a larger range of metrics on the effect on accuracy and activations would be useful.
Another method that is not-quite-interpretability-related is looking at text generations. This can be seen in, for example, the ActAdd paper, where they make generations, and measure word frequencies. I think having more text generation metrics would be quite interesting, and is something I am actively looking into more.
I think there is a lot of room for better metrics in interpretability and model control. Some of these Machine Unlearning metrics seem like potentially useful (while some remain too expensive or not particularly relevant).
One metric that I think is somewhat lacking, is how changes might affect what longer-term generations look like. I am working on a possible metric relevant to this here: [Post Coming Soon™], but I think there is potential for other work to be done as well.
Machine unlearning seems to be a possible direct way of evaluating interpretability methods. I am am interested in working on making an implementation to make it easier to run all of these different metrics, and would be excited for more work to be done in the direction of evaluating interpretability methods
Note: If you think there are important metrics I left out, please comment below. I may update update the post to include it.
"Survey of Machine Unlearning" / "Awesome Machine Unlearning"
"Dissecting Language Models: Machine Unlearning via Selective Pruning"
"Can Bad Teaching Induce Forgetting? Unlearning in Deep Networks using an Incompetent Teacher"
"Sparse Autoencoders Find Highly Interpretable Directions in Language Models"
"Towards Monosemanticity: Decomposing Language Models With Dictionary Learning"
“Language Models Represent Space and Time”
"Discovering Latent Knowledge in Language Models Without Supervision"
"Interpreting GPT: the logit lens"
"Eliciting Latent Predictions from Transformers with the Tuned Lens"
"ActAdd: Steering Language Models without Optimization"