Building a model for which you're confident your interpretability is correct, by reverse-engineering each part of the model to work how your interpretability says it should work. (Based on discussion in alignment reading group, ideas from William, Adam, dmz, Evan, Leo, maybe others)

The basic idea

  1. Train an untrusted network
  2. Do interpretability on it, getting some human-understandable schema that is intended to capture what (some of) the model is doing
  3. Identify anything that should be changed about what the model is (supposedly) doing, and make a new human-understandable schema of the desired behavior
  4. Train a model to do what this new representation says it should be doing, by having a human implement each part of the schema then doing imitation + distillation

I think this can be thought of in a few different ways:

  1. One way you might implement Microscope AI in practice
  2. A way to do imitative amplification where the decomposition is generated by gradient descent instead of by humans, combined with some aspects of imitative generalization

The main advantages here are:

  • Compared to just using interpretability to audit your trained model, this is safer, because you can be more confident your interpretation matches what the final model is doing (because you trained the final model to conform to your interpretation). If you have mistakes in your interpretability, you should get performance issues rather than safety issues.
  • Compared to normal IDA or ‘trying to build a schema from scratch’, this is easier because you use the power of the initial model to decide how to decompose tasks, or how to solve individual tasks, and you can use patterns the model learnt from large volumes of data without needing to have been able to generate that knowledge yourself

The main disadvantages are:

  • To be competitive this requires really good interpretability - basically understanding everything about what the model is doing. This scheme might actually require more out of your interpretability techniques than just auditing your model would.
  • The ‘ontology mismatch’ issue for ELK is still a problem.
  • A naive implementation of this training procedure is really expensive (in terms of both compute and human labels)

The scheme in more detail:

First, train some big, untrusted model on untrusted, unsupervised data

Next, use some (automated) interpretability tools to get natural-language (or otherwise) descriptions of functions of components of the model (e.g. neurons, circuits, attention heads, etc)


Now we’re going to build a new model that is constructed based on the description of this model. Each component in the new model is going to be a small model trained to imitate a human computing the function that the description of the component specifies.

For the low-level components, we train a small model to e.g. recognise edges, based on human labels of whether a region contains edges.

For the components above this, we train a small model to go from the human-understandable descriptions of the features that go into it, to the human-understandable description of behavior. For example, we train a model to go from the description of the locations of windows, wheels and car body in the image to a description of the location of likely cars in the image.

Now we have a model that is definitely well-described by the interpretability we did, because we constructed each component to have the function we understood it to. (There’s another assumption needed here - that we can ‘trust’ IID supervised learning. That is, if we have a dataset X, and we can get human labels h(x) for any particular example x \in X, we can train a model to perfectly imitate h on X.)

The hope is that we can now do things like removing parts of the model that reason about human supervisors, correcting parts of the model’s reasoning that will generalize badly, or make other modifications to make the model safe and aligned based on our (now hopefully good) understanding of its functionality.

Reducing costs

To reduce human data costs, we might want to train a general ‘human component imitator’ model that imitates humans implementing a component based on instructions. Ie, we collect a dataset of ([description of component], [human imitating that component on some examples]) pairs and train a model on this. This model would be relatively large.

To reduce compute costs, as a final step we distill our assembly of models (one model for each component) into a single large model. This should yield a model with similar inference cost as the original model. (This distillation step also requires that we can trust IID supervised learning)

Somewhat Silly Example

Run automated interpretability on a vision model as described here:

Now pick a layer near the start, and replace every neuron in this layer with an image-to-text model, e.g. CLIP that looks at the image and answers whether the image contains the feature that the neuron is supposed to detect.

In each layer after this, replace each neuron with a language model, e.g. GPT-3, that takes the questions and answers about features from the previous layer, and answers whether the feature that this neuron is supposed to predict is present.

There are many reasons why this wouldn’t actually work as stated. Our interpretability isn’t good enough to actually describe the behavior of all the neurons. We’d need to retrain or at least finetune CLIP and GPT-3 on this particular weird task. It’s extraordinarily inefficient to run this scheme. You’d want to preserve spatial information about where in an image the feature was, not just whether it was present.

Various hacky improvements/compromises.

For each part of the model you can choose whether to leave it as-is or to re-implement based on your interpretation - in some cases you might leave part of the model as a black box if you really can’t figure out what it is doing but think it’s probably safe. So you can make tradeoffs of performance vs safety on a finer-grained level than the whole training procedure.

In a similar vein, you could imagine doing interpretability on some parts of the components, reverse-engineering those, then retraining the model with the reverse-engineered components substituted in. You can then repeat this. This might help you regain some performance while maintaining safety properties, if your interpretability is mostly good but not totally reliable.

Additionally, you could train a model to use the outputs of the human-understandable layer below + outputs from a smaller neural network -> predict human label.

This makes it easier to get comparable performance but at the cost of introducing some uninterpretable features. But maybe you could either use a regularizer to tune how much uninterpretable stuff you're okay with, or you could then try to analyze the uninterpretable stuff and see if you can find more patterns in it that can be turned into interpretable features.


Scale of interpretation and understandability

We have to pick what level of detail to interpret/build the model at. At one extreme, if we make the scale too small (e.g. ‘this part of the model computes the AND of these two inputs), we haven’t achieved any better understanding of what the model’s actually doing overall, so we won’t be able to do things like detecting power-seeking behavior. However, in this case we’ll at least be able to preserve performance, because we can see and describe exactly what the model is doing.

If we only break the model down into large components, then it should be easy for humans to understand whether any of the components combine to lead to e.g. power-seeking behavior. However, in this case it’s going to be hard to preserve performance, because it’s hard for humans to do the required interpretability to understand and describe exactly what these larger components do. In addition, if we get too large we may find that the humans are not actually able to compute the function that we think the component computes.


By training a small NN on each part of the model we’re massively increasing the number of parameters. I think the hope here is that you can then distill the new model, while retaining the property that you can always run the big model (or parts of it) on any particular example you want to check, so you still have your IID SFT guarantees

Relationship to ELK

If we put this approach in the ELK setting, what does that look like?

The idea is that we’ve trained some powerful prediction model, then we want to use that to reverse-engineer an honest reporter. We have the same ‘ontology mismatch’ problem: maybe the predictor is understanding the world using all kinds of crazy representations that are completely opaque to the human, such that they can only replicate the performance of the predictor by replicating the low-level functionality of the components, and don’t understand what the components are doing enough to identify whether the answer being given is generated by the desired reporter procedure or a bad reporter.

Concrete Experiments

The first experiment I’d want to try here would be something like taking the most well-understood vision models, and seeing how much of the components can be replaced with either code, or human computation, or an im-text or text-text model, and how much this hurts performance.

E.g., replace the output of the curve detector neuron with a human-chosen number to represent whether there are curves at different locations in the image, or replace the dog neuron with a human who looks at the activations of labeled neurons in the previous layer. Compute the accuracy across a moderate number of images - e.g. 100 - and see how much this hurts performance. Then try this using CLIP instead of a human.



New Comment
2 comments, sorted by Click to highlight new comments since:

Now we’re going to build a new model that is constructed based on the description of this model. Each component in the new model is going to be a small model trained to imitate a human computing the function that the description of the component specifies.


Some of the recent advances in symbolic regression and equation learning might be useful during this step to help generate functions describing component behavior, if what the component in the model is doing is moderately complicated. (E.g. A Mechanistic Interpretability Analysis of Grokking found that a model trained to do modular arithmetic ended up implementing it using discrete Fourier transforms and trig identities, which sounds like the sort of thing that might be a lot easier to figure our from a learned equation describing the component's behavior). Being able to reduce a neural circuit to an equation or a Bayesnet or whatever would help a lot with interpretablity, and at that point you might even not need to train an implementation model — we could maybe just use the symbolic form directly, as a more compact and more efficiently computable representation.

At the end of this process, you might even end up with something symbolic that looked a lot like a "Good Old Fashioned AI" (GOFAI) model — but a "Bitter Lesson compatible" one first learnt by a neural net and then reverse engineered using interpretability. Obviously doing this would put high demands on our interpretation tools. 

If I had such a function describing a neural net component, one of my first questions would be: what portions of the domain of this function are well covered by the training set that the initial neural net model was trained on, or at least sufficiently near items in that training set interpolating the function to them seems likely to be safe (given its local first, second, third... partial derivatives) vs. what portions are untested extrapolations? Did the symbolic regression/function learning process give us multiple candidate functions, and if so how much do they differ outside that well-tested region of the function domain?

This seems like it would give us some useful intuition for when the model might be unsafely extrapolating outside the training distribution and we need to be particularly cautious.

Some portions of the neural net may turn out to be irreducibly complex — I suspect it would be good to be able to identify when something genuinely is complex, and when we're just looking at a big tangled-up blob of memorized instances from the training set (e.g. by somehow localizing sources of loss on the test set).

Image link is broken