Mechanistic Transparency for Machine Learning



Cross-posted on my blog.

Lately I've been trying to come up with a thread of AI alignment research that (a) I can concretely see how it significantly contributes to actually building aligned AI and (b) seems like something that I could actually make progress on. After some thinking and narrowing down possibilities, I've come up with one -- basically, a particular angle on machine learning transparency research.

The angle that I'm interested in is what I'll call mechanistic transparency. This roughly means developing tools that take a neural network designed to do well on some task, and outputting something like pseudocode for what algorithm the neural network implements that could be read and understood by developers of AI systems, without having to actually run the system. This pseudocode might use high-level primitives like 'sort' or 'argmax' or 'detect cats', that should themselves be able to be reduced to pseudocode of a similar type, until eventually it is ideally reduced to a very small part of the original neural network, small enough that one could understand its functional behaviour with pen and paper within an hour. These tools might also slightly modify the network to make it more amenable to this analysis in such a way that the modified network performs approximately as well as the original network.

There are a few properties that this pseudocode must satisfy. Firstly, it must be faithful to the network that is explained, such that if one substitutes in the pseudocode for each high-level primitive recursively, the result should be the original neural network, or a network close enough to the original that the differences are irrelevant (although just in case, the reconstructed network that is exactly explained should presumably be the one deployed). Secondly, the high-level primitives must be somewhat understandable: the pseudocode for a 256-layer neural network for image classification should not be output = f2(f1(input)) where f1 is the action of the first 128 layers and f2 is the action of the next 128 layers, but rather break down into edge detectors being used to find floppy ears and spheres and textures, and those being combined in reasonable ways to form judgements of what the image depicts. The high-level primitives should be as human-understandable as possible, ideally 'carving the computation at the joints' by representing any independent sub-computations or repeated applications of the same function (so, for instance, if a convolutional network is represented as if it were fully connected, these tools should be able to recover convolutional structure). Finally, the high-level primitives in the pseudocode should ideally be understandable enough to be modularised and used in different places for the same function.

This agenda nicely relates to some existing work in machine learning. For instance, I think that there are strong synergies with research on compression of neural networks. This is partially due to background models about compression being related to understanding (see the ideas in common between Kolmogorov complexity, MDL, Solomonoff induction, and Martin-Löf randomness), and partially due to object-level details about this research. For example, sparsification seems related to increased modularity, which should make it easier to write understandable pseudocode. Another example is the efficacy of weight quantisation, which means that the least significant bits of the weights aren't very important, indicating that the relations between the high-level primitives should be modular in an understandable way and not have crucial details depend on some of the least significant bits of the output.

The Distill post on the building blocks of interpretability includes some other examples of work that I feel is relevant. For instance, work on using matrix factorisation to group neurons seems very related to constructing high-level primitives, and work on neuron visualisation should help with understanding the high-level primitives if their output corresponds to a subset of neurons in the original network.

I'm excited about this agenda because I see it as giving the developers of AI systems tools to detect and correct properties of their AI systems that they see as undesirable, without having to deploy the system in a test environment that they must laboriously ensure is adequately sandboxed. You could imagine developers checking if their systems conform to theories of aligned AI, or detecting any 'deceive the human' subroutine that might exist. I see this as fairly robustly useful, being helpful in most stories of how one would build an aligned AI. The exception is if AGI is built without things which look like modern machine learning algorithms, which I see as unlikely, and at any rate hope that lessons transfer to the methods which are used.

I also believe that this line of research has a shot at working for systems which act in the world. It seems hard for me to describe how I detect laptops given visual informations, but given visual primitives like 'there's a laptop there', it seems much easier for me to describe how I play tetris or even go. As such, I would expect tools developed in this way to illuminate the strategy followed by tetris-playing DQNs by referring to high-level primitives like 'locate T tetronimo', that themselves would have to be understood using neuron visualisation techniques.

Visual primitives are probably not the only things that would be hard to fully understand using the pseudocode technique. In cases where humans evade oversight by other humans, I assert that it is often not due to consequentialist reasoning, but rather due to avoiding things which are frustrating or irritating, where frustration/irritation is hard to introspect on but seems to reliably steer away from oversight in cases where that oversight would be negative. A possible reason that this frustration/irritation is hard to introspect upon is that it is complicated and hard to decompose cleanly, like our object recognition systems are. Similarly, you could imagine that one high-level primitive that guides the AI system's behaviour is hard to decompose and needs techniques like neuron visualisation to understand. However, at least the mechanistic decomposition allowed us to locate this subsystem and determine how it is used in the network, guiding the tests we perform on it. Furthermore, in the case of humans, it's quite possible that our frustration/irritation is hard to introspect upon not because it's hard to understand, but rather because it's strategically better to not be able to introspect upon it (see the ideas in the book The Elephant in the Brain), suggesting that this problem might be less severe than it seems.