* Authors sorted alphabetically.
An appendix to this post.
As mentioned above, our method allows us to explain quantitatively measured model behavior operationalized as the expectation of a function f on a distribution D.
Note that no part of our method distinguishes between the part of the input or computational graph that belongs to the “model” vs the “metric.”
It turns out that you can phrase a lot of mechanistic interpretability in this way. For example, here are some results obtained from attempting to explain how a model has low loss:
That being said, you can set up experiments using other metrics besides loss as well:
If you’re trying to explain the expectation of f, we always consider it a valid move to suggest an alternative function f′ if f(x)=f′(x) on every input (“extensional equality”), and then explain f′ instead. In particular, we’ll often start with our model’s computational graph and a simple interpretation, and then perform “algebraic rewrites” on both graphs to naturally specify the correspondence.
Common rewrites include:
Note that there are many trivial or unenlightening algebraic rewrites. For example, you could always replace f’ with a lookup table of f, and in cases where the model performs perfectly, you can also replace f with the constant zero function. Causal scrubbing is not intended to generate mechanistic interpretations or ensure that only mechanistic interpretations are allowed, but instead to check that a given interpretation is faithful. We discuss this more in the limitations section of the main post.
We allow hypotheses at a wide variety of levels of specificity. For example, here are two potential interpretations of the same f:
These interpretations correspond to the same input-output mappings, but the hypothesis on the right is more specific, because it's saying that there are three separate nodes in the graph expressing this computation instead of one. So when we construct G2 to correspond to I2 we would need three different activations that we claim are important in different ways, instead of just one for G1 mapping to I1. In interpretability, we all-else-equal prefer more specific explanations, but defining that is out of scope here–we’re just trying to provide a way of looking at the predictions made by hypotheses, rather than expressing any a priori preference over them.
In both of these results posts, in order to measure the similarity between the scrubbed and unscrubbed models, we use % loss recovered.
As a baseline we use Erandomized, the ‘randomized loss’, defined as the loss when we shuffle the connection between the correct labels and the model’s output. Note this randomized loss will be higher than the loss for a calibrated guess with no information. We use randomized loss as the baseline since we are interested in explaining why the model makes the guesses it makes. If we had no idea, we could propose the trivial correspondence that the model’s inputs and outputs are unrelated, for which Escrubbed=Erandomized.
Thus we define:
% loss recovered(Emodel,Escrubbed)=Escrubbed−ErandomizedEmodel−Erandomized⋅100%.
This percentage can exceed 100% or be negative. It is not very meaningful as a fraction, and is rather an arithmetic aid for comparing the magnitude of expected losses under various distributions. However, it is the case that hypotheses with a “% loss recovered” closer to 100% result in predictions that are more consistent with the model.
Above, we rate our hypotheses using the distance between the expectation under the dataset and the scrubbed distribution, |E[f(x)]−Escrubbed(h,D)|.
You could instead rate hypotheses by comparing the full distribution of input-output behavior. That is, the difference between the distribution of the random variable f(x) under the data set D, and f(x) under Dscrubbed.
In this work, we prefer the expected loss. Suppose that one of the drivers of the model’s behavior is noise: trying to capture the full distribution would require us to explain what causes the noise. For example, you’d have to explain the behavior of a randomly initialized model despite the model doing ‘nothing interesting’.
Earlier, we noted our preference for “resampling ablation” of a component of a model (patch an activation of that component from a randomly selected input in the dataset) over zero or mean ablation of that component (set that component’s activation to 0 or its mean over the entire dataset, respectively) in order to test the claim “this component doesn’t matter for our explanation of the model”. We also mentioned three specific problems we see with using zero or mean ablation to test this claim. Here, we’ll discuss these problems in greater detail.
1) Zero and mean ablations take your model off distribution in an unprincipled manner.
The first problem we see with these ablations is that they destroy various properties of the distribution of activations in a way that seems unprincipled and could lead to the ablated model performing either worse or better than it should.
As an informal argument, imagine we have a module whose activations are in a two dimensional space. In the picture below we’ve drawn some of its activations as gray crosses, the mean as a green cross, and the zero as a red cross:
It seems to us that zero ablating takes your model out of distribution in an unprincipled way. (If the model was trained with dropout, it’s slightly more reasonable, but it’s rarely clear how a model actually handles dropout internally.) Mean ablating also takes the model out of distribution because the mean is not necessarily on the manifold of plausible activations.
2) Zero and mean ablations can have unpredictable effects on measured performance.
Another problem is that these ablations can have unpredictable effects on measured performance. For example, suppose that you’re looking at a regression model that happens to output larger answers when the activation from this module is at its mean activation (which, let’s suppose, is off-distribution and therefore unconstrained by SGD). Also, suppose you’re looking at it on a data distribution where this module is in fact unimportant. If you’re analyzing model performance on a data subdistribution where the model generally guesses too high, then mean ablation will make it look like ablating this module harms performance. If the model generally guesses too low on the subdistribution, mean ablation will improve performance. Both of these failure modes are avoided by using random patches, as resampling ablation does, instead of mean ablation.
3) Zero and mean ablations remove variation that your model might depend on for performance.
The final problem we see with these ablations is that they neglect the variation in the outputs of the module. Removing this variation doesn’t seem reasonable when claiming that the module doesn't matter.
For an illustrative toy example, suppose we’re trying to explain the performance of a model with three modules M1, M2, and M3. This model has been trained with dropout and usually only depends on components M1 and M2 to compute its output, but if dropout is active and knocks out M2, the model uses M3 instead and can perform almost as well as if it were able to use M1 and M2.
If we zero/mean ablate M2 (assume mean 0), it will look like M2 wasn't doing anything at all and our hypothesis that it wasn't relevant will be seemingly vindicated. If instead we resample ablate M2, the model will perform significantly worse (exactly how much worse is dependent on exactly how the output of M2 is relevant to the final output).
This example, while somewhat unrealistic, hopefully conveys our concern here: sometimes the variation in the outputs of a component is important to your model and performing mean or zero ablation forces this component to only act as a fixed bias term, which is unlikely to be representative of its true contribution to the model’s outputs.
We think these examples provide sufficient reasons to be skeptical about the validity of zero or mean ablation and demonstrate our rationale for preferring resampling ablation.
Suppose we have the following hypothesis where I maps to the nodes of G in blue:
There are four activations in G that we claim are unimportant.
Causal scrubbing requires performing a resampling ablation on these activations. When doing so, should we pick one data point to get all four activations on? Two different data points, one for R and S (which both feed into V) and a different one for X and Y? Or four different data points?
In our opinion, all are reasonable experiments that correspond to subtly different hypotheses. This may not be something you considered when proposing your informal hypothesis, but following the causal scrubbing algorithm forces you to resolve this ambiguity. In particular, the more we sample unimportant activations independently, the more specific the hypothesis becomes, because it allows you to make strictly more swaps. It also sometimes makes it easier for the experimenter to reason about the correlations between different inputs. For a concrete example where this matters, see the paren balance checker experiment.
And so, in the pseudocode above we sample the pairs (R, S) and (X, Y) separately, although we allow hypotheses that require all unimportant inputs throughout the model to be sampled together.
Why not go more extreme, and sample every single unimportant node separately? One reason is that it is not well-defined: we can always rewrite our model to an equivalent one consisting of a different set of nodes, and this would lead to completely different sampling! Another is that we don’t actually intend this: we do believe it’s important that the inputs to our treeified model be “somewhat reasonable”, i.e. have some of the correlations that they usually do in the training distribution, though we’re not sure exactly which ones matter. So if we started from saying that all nodes are sampled separately, we’d immediately want to hypothesize something about them needing to be sampled together in order for our scrubbed model to not get very high loss. Thus this default makes it simpler to specify hypotheses.
In general we don’t require hypotheses to be surjective, meaning not all nodes of G need to be mapped onto by c, nor do we require that G contains all edges of I. This is convenient for expressing claims that some nodes (or edges) of G are unimportant for the behavior. It leaves a degree of freedom, however, in how to treat these unimportant nodes, as discussed in the preceding section.
It is possible to remove this ambiguity by requiring that the correspondence be an isomorphism between G and I. In this section we’ll demonstrate how to do this in a way that is consistent with the pseudocode presented, by combining all the unimportant parents of each important node.
In the example below, both R and S are unimportant inputs to the node V, and both X and Y are unimportant inputs to the node Z. We make the following rewrites in the example below:
If you want to take a different approach to sampling the unimportant inputs, you can rewrite the graphs in a different way (for instance, keeping X and Y as separate nodes).
One general lesson from this is that rewriting the computational graphs G and I is extremely expressive. In practice, we have found that with some care it allows us to run the experiments we intuitively wanted to.
Suppose we have a function f to which we want to apply the causal scrubbing algorithm. Consider an isomorphic (see above) treeified hypothesis h=(GT,IT,cT) for f. In this appendix we will show that causal scrubbing preserves the joint distribution of inputs to each node of IT (Lemma 1). Then we show that the distribution of inputs induced by causal scrubbing is the maximum entropy distribution satisfying this constraint (Theorem 2).
Let X be the domain of f and D be the input distribution for f (a distribution on X). Let ~D be the distribution given by the causal scrubbing algorithm (so the domain of ~D is Xn, where n is the number of times that the input is repeated in IT).
We find it useful to define two sets of random variables: one set for the values of wires (i.e. edges) in IT when IT is run on a consistent input drawn from D (i.e. on (x,…,x) for some x); and one set for the values of wires in IT induced by the causal scrubbing algorithm:
Definition (f-consistent random variables): For all the edges of IT, we call the “f-consistent random variable” the result of evaluating the interpretation IT on (x,…,x), for a random input x∼D. For each node u∈IT, we will speak of the joint distribution of its input wires, and call the resulting random variable the “f-consistent inputs (to u)”. We also refer to the value of the wire going out of u∈IT as the “f-consistent output (to u)”.
Definition (scrubbed random variables): Suppose that we run IT on (x1,x2,…,xn)∼~D. In the same way, this defines a set of random variables, which we call the scrubbed random variables (and use the terms "scrubbed inputs" and "scrubbed output" accordingly).
Lemma 1: For every node u∈IT, the joint distribution of scrubbed inputs to u is equal to the product distribution of f-consistent inputs to u.
Proof: Recall that the causal scrubbing algorithm assigns a datum in X to every node of IT, starting from the root and moving up. The key observation is that for every node u of IT, the distribution of the datum of u is exactly D. We can see this by induction. Clearly this is true for the root. Now, consider an arbitrary non-root node u and assume that this claim is true for the parent v of u. Consider the equivalence classes on X defined as follows: x1 and x2 are equivalent if (x1,…,x1) has the same value at u as (x2,…,x2) when IT is run on each input. Then the datum of u is chosen by sampling from D subject to being in the same equivalence class as the datum of v. Since (by assumption) the datum of v is distributed according to D, so is the datum of u.
Now, by the definition of the causal scrubbing algorithm, for every node u, the scrubbed inputs to u are equal to the inputs to u when IT is run on the datum of u. Since the datum of u is distributed according to D, it follows that the joint distribution of scrubbed inputs to u is equal to the joint distribution of f-consistent inputs to u.
Theorem 2: The joint distribution of (top-level) scrubbed inputs is the maximum-entropy distribution on Xn, subject to the constraints imposed by Lemma 1.
Proof: We proceed by induction on a stronger statement: consider any way to "cut" through IT in a way that separates all of the inputs to IT from the root (and does so minimally, i.e. if any edge is un-cut then there is a path from some leaf to the root). (See below for an example.) Then the joint scrubbed distribution of the cut wires has maximal entropy subject to the constraints imposed by Lemma 1 on the joint distribution of scrubbed inputs to all nodes lying on the root's side of the cut.
Our base case is the cut through the input wires to the root (in which case Theorem 2 is vacuously true). Our inductive step will take any cut and move it up through some node u, so that if previously the cut passed through the output of u, it will now pass through the inputs of u. We will show that if the original cut satisfies our claim, then so will the new one.
Consider any cut and let u be the node through which we will move the cut up. Let x denote the vector of inputs to u, y be the output of u (so y=u(x)), and z denote the values along all cut wires besides y. Note that x and z are independent conditional on y; this follows by conditional independence rules on Bayesian networks (x and z are d-separated by y).
Next, we show that this distribution is the maximum-entropy distribution. The following equality holds for *any* random variables X,Y,Z such that Y is a function of X:
Where I(⋅) is mutual information. The first step follows from the fact that Y is a function of X. The second step follows from the identity I(A;B)=H(A)+H(B)−H(A,B). The third step follows from the identity that I(A;B∣C)=I(A;B,C)−I(A;C). The last step follows from the fact that I(X;Y)=H(Y)−H(Y∣X)=H(Y), again because Y is a function of X.
Now, consider all possible distributions of (x,z) subject to the constraints imposed by Lemma 1 on the joint distribution of scrubbed inputs to all nodes lying on the root's side of the updated cut. The lemma specifies the distribution of x and (therefore) y. Thus, subject to these constraints, H(x,z) is equal to H(y,z)−I(x;z∣y) plus H(x)−H(y), which is a constant. By the inductive hypothesis, H(y,z) is as large as possible subject to the lemma's constraints. Mutual information is non-negative, so it follows that if I(x;z∣y)=0, then H(x,z) is as large as possible subject to the aforementioned constraints. Since x and z are independent conditional on y, this is indeed the case.
This concludes the induction. So far we have only proven that the joint distribution of scrubbed inputs is *some* maximum-entropy distribution subject to the lemma's constraints. Is this distribution unique? Assuming that the space of possible inputs is finite (which it is if we're doing things on computers), the answer is yes: entropy is a strictly concave function and the constraints imposed by the lemma on the distribution of scrubbed inputs are convex (linear, in particular). A strictly concave function has a unique maximum on a convex set. This concludes the proof.
Fun Fact 3: The entropy of the joint distribution of scrubbed inputs is equal to the entropy of the output of IT, plus the sum over all nodes u∈IT of the information lost by u (i.e. the entropy of the joint input to u minus the entropy of the output). (By Lemma 1, this number does not depend on whether we imagine IT being fed f-consistent inputs or scrubbed inputs.) By direct consequence of the proof of Theorem 2, we have H(x,z)−H(y,z)=H(x)−H(y) (with x,y,z as in the proof of Theorem 2). Proceeding by the same induction as in Theorem 2 yields this fact.
In our polysemanticity toy model paper, we introduced an analytically tractable setting where the optimal model represents features in superposition. In this section, we’ll analyze this model using causal scrubbing, as an example of what it looks like to handle polysemantic activations.
The simplest form of this model is the two-variable, one-neuron case, where we have independent variables x1 and x2 which both have zero expectation and unit variance, and we are choosing the parameters c and d to minimize loss in the following setting:
Where ~y is our model, c and d are the parameters we’re optimizing, and a and b are part of the task definition. As discussed in our toy model paper, in some cases (when you have some combination of a and b having similar values and x1 and x2 having high kurtosis (e.g. because they are usually equal to zero)), c and d will both be set to nonzero values, and so (cx1+dx2) can be thought of as a superposed representation of both x1 and x2.
To explain the performance of this model with causal scrubbing, we take advantage of function extensionality and expand y_tilde:
And then we explain it with the following hypothesis:
When we sample outputs using our algorithm here, we’re going to sample the interference term from random other examples. And so the scrubbed model will have roughly the same estimated loss as the original model–the errors due to interference will no longer appear on the examples that actually suffer from interference, but the average effect of interference will be approximately reproduced.
In general, this is our strategy for explaining polysemantic models: we do an algebraic rewrite on the model so that the model now has monosemantic components and an error term, and then we say that the monosemantic components explain why the model is able to do the computation that it does, and we say that we don’t have any explanation for the error term.
This works as long as the error is actually unstructured–if the model was actively compensating for the interference errors (as in, doing something in a way that correlates with the interference errors to reduce their cost), we’d need to describe that in the explanation in order to capture the true loss.
This strategy also works if you have more neurons and more variables–we’ll again write our model as a sum of many monosemantic components and a residual. And it’s also what we’d do with real models–we take our MLP or other nonlinear components and make many copies of the set of neurons that are required for computing a particular feature.
This strategy means that we generally have to consider an explanation that’s as large as the model would be if we expanded it to be monosemantic. But it’s hard to see how we could have possibly avoided this.
Note that this isn’t a solution to finding a monosemantic basis - we’re just claiming that if you had a hypothesized monosemantic reformulation of the model you could test it with causal scrubbing.
This might feel vacuous–what did we achieve by rewriting our model as if it was monosemantic and then adding an error term? We claim that this is actually what we wanted. The hypothesis explained the loss because the model actually was representing the two input variables in a superposed fashion and resigning itself to the random error due to interference. The success of this hypothesis reassures us that the model isn’t doing anything more complicated than that. For example, if the model was taking advantage of some relationship between these features that we don’t understand, then this hypothesis would not replicate the loss of the model.
Now, suppose we rewrite the model from the form we used above:
To the following form:
Where we’ve split the noise term into two pieces. If we sample these two parts of the noise term independently, we will have effectively reduced the magnitude of the noise, for the usual reason that averages of two samples from a random variable have lower variance than single samples. And so if we ignore this correlation, we’ll estimate the cost of the noise to be lower than it is for the real model. This is another mechanism by which ignoring a correlation can cause the model to seem to perform better than the real model does; as before, this error gives us the opportunity to neglect some positive contribution to performance elsewhere in the model.
We can construct cases where the explanation can make the model look better by sneaking in information. For example, consider the following setting:
The model’s input is a tuple of a natural number and the current game setting, which is either EASY or HARD (with equal frequency). The model outputs the answer either “0”, “1”, or “I don’t know”. The task is to guess the last bit of the hash of the number.
Here’s the reward function for this task:
If the model has no idea how to hash numbers, its optimal strategy is to guess when in EASY mode and say “I don’t know” in HARD mode.
Now, suppose we propose the hypothesis that claims that the model outputs:
To apply causal scrubbing, we consider the computational graph of both the model and the hypothesis to consist of the input nodes and a single output node. In this limited setting, the projected model runs the following algorithm:
Now consider running the projected model on a HARD case. According to the hypothesis, we output the correct answer, so we replace the input
So, when you do causal scrubbing on HARD cases, the projected model will now guess correctly half the time, because half its “I don’t know” answers will be transformed into the correct answer. The projected model’s performance will be worse on the EASY cases, but the HARD cases mattered much more, so the projected model’s performance will be much better than the original model’s performance, even though the explanation is wrong!
In examples like this one, hypotheses can cheat and get great scores while being very false.
(Credit for the ideas in this section is largely due to ARC.)
We might have hoped that we’d be able to use causal scrubbing as a check on our hypotheses analogous to using a proof checker like Lean or Coq to check our mathematical proofs, but this doesn’t work. Our guess is that it’s probably impossible to have an efficient algorithm for checking interpretability explanations which always rejects false explanations. This is mostly because we suspect that interpretability explanations should be regarded as an example of defeasible reasoning. Checking interpretations in a way that rejects all false explanations is probably NP-hard, and so we want to choose a notion of checking which is weaker.
We aren’t going to be able to check hypotheses by treating as uncorrelated everything that the hypotheses claimed wasn’t relevantly correlated. This would have worked if ignoring correlations could only harm the model. But as shown above, we have several cases where ignoring correlations helps the model.
So we can’t produce true explanations by finding hypotheses subject to the constraint that they predict the observed metrics. As an alternative proposal, we can check if hypotheses are comprehensive by seeing if any adversarial additions to the hypothesis would cause the predicted metric to change considerably. In all of the counterexamples above, the problem is that the metric was being overestimated because there were important correlations that were being neglected and which would reduce the estimated metric if they were included. If we explicitly check for additional details to add to our hypotheses which cause the estimated metric to change, all the counterexamples listed above are solved.
To set up this adversarial validation scheme, we need some mechanism for hypotheses to be constructed adversarially. That is, we need to handle cases where the adversary wants to rewrite f to an extensionally-equal function. One way of thinking about this is that we want a function `join` which is a binary operation on hypotheses, taking the two hypotheses to the hypothesis which preserves all structure in the model that either of the two hypotheses preserved.
Here are two ways of defining this operation:
The main failure of the algorithm listed above is that we don’t know how to handle cases where the adversary wants to rewrite f to an extensionally-equal function in a way which is mutually incompatible with the original hypothesis (for example, because their computational graphs have different shapes and there’s no way to splice the two computational graphs together). This is a pretty bad problem because the function extensionality move seems very important in practice. ARC has worked on basically this problem for a while and hasn’t yet solved it, regrettably.
Some other questions that we haven’t answered:
Overall, it seems plausible that these problems can be overcome, but they are definitely not currently solved. We hold out hope for an interpretability process which has validity properties which allow us to use powerful optimization inside it and still trust the conclusions, and hope to see future work in this direction.
This is also true when you’re training models with an autodiff library–you construct a computational graph that computes loss, and run backprop on the whole thing, which quickly recurses into the model but doesn’t inherently treat it differently.
This allows for testing out human interpretable approximations to neural network components: ‘Artificial Artificial Neural networks’. We think it’s more informative to see how the model performs with the residual of this approximation resampling ablated as opposed to zero ablated.
In general, you could have the output be non-scalar with any distance metric δ to evaluate the deviation of the scrubbed expectation, but we’ll keep things simple here.
Another way of thinking about this is: when we consider the adversarial game setting, we would like each side to be able to request that terms are sampled together. By default therefore we would like terms (even random ones!) to be sampled separately.
Great stuff! Excited to see this extended and applied. I hope to dive deeper into this series and your followup work.
Came to the appendix for 2.2 on metrics, still feel curious about the metric choice.
I’m trying to figure out why this is wrong: “loss is not a good basis for a primary metric even though its worth looking at and intuitive, because it hides potentially large+important changes to the X-> Y mapping learned by the network that have equivalent loss. Instead, we should just measure how yscrubbed_i has changed from yhat_i (original model) at each xi we care about.” I think I might have heard people call this a “function space” view (been a while since I read that stuff) but that is confusing wording with your notation of f.
Dumb regression example. Suppose my training dataset is scalar (x,y) pairs that almost all fall along y=sin(x). I fit a humungo network N and when i plot N(x) for all my xs I see a great approximation of sin(x). I pick a weird subset of my data where instead of y=sin(x), this data is all y=0 (as far as I can tell this is allowed? I don’t recall restrictions on training distribution having to match) and use it to compute my mse loss during scrubbing. I find a hypothesis that recovers 100% of performance! But I plot and it looks like cos(x), which unless I’m tired has the same MSE from the origin in expectation.
I probably want to know if I my subnetwork is actually computing a very different y for the same exact x, right? Even if it happens to have a low or even equal or better loss?
(I see several other benefits of comparing model output against scrubbed model output directly, for instance allowing application to data which is drawn from your target distribution but not labelled)
Even if this is correct, I doubt this matters much right now compared to the other immediate priorities for this work, but I’d hope someone was thinking about it and/ or I can become less confused about why the loss is justified