This post was written in the SERI MATS program under John Wentworth.
In our quest to find steering mechanisms in computational systems, we first have to find the right framework to look at the internals of such a system. This framework should naturally express the systems we normally associate with computation, like the brain and computers themselves, as well as more abstract interfaces like Turing Machines. Turing Machines, arithmetic operations, and Neural Nets are in some sense equivalent, so if we use our framework on equivalent computational objects (E.g. a Neural Net and a Turing Machine that simulates it), the representations we get should also be equivalent.
We also want to find meaningful concepts in our chosen framework, like modularity, features, search, and steering mechanisms. But it is hard to evaluate beforehand if we will be able to find these things in a given framework.
My first intuition was that Computational Graphs is what we want. They don’t impose a rigid structure like e.g. Turing Machines would, and we can naturally express many things with them. (I don’t know if they can express the brain’s computation.)
This is essentially the same as causal DAG.
There is some ambiguity in which nodes we choose. We can express certain bit flips in a processor aggregated as addition.
If we choose the algebraic level of detail (left in picture), we implicitly assume that there is no relevant thing going on if we look at the bits (right in picture). Intuitively this seems likely. If we would, for example, look for modularity in a NN, it feels like a module wouldn’t ‘begin’ inside an addition process. Features should not be encoded inside it either.
Computational Graphs work very well for Neural Networks because the graph does not change when we change the inputs. In general, this is not the case. I have not found a crisp way of expressing if-branching in arithmetic graphs. Let's say we want to express the following program:
def program(a, b, c):
return b + c
return b - c
It is clear that we can use a Turing Machine for this. When we change the detail level to arithmetic, however, it is not clear to me how to express that the graph depends on a.
Here the result of a computation does not change which values are propagated through the graph. Rather, it changes the graph itself. To be clear, I am not looking for a way to express an if-statement in arithmetic (possibly with activation functions). I want to extend the notion of graphs to allow for naturally expressing that a computation result changes the nodes of the graph itself.
Nonetheless, it might still be worth looking at computational graphs since we don't face those issues when looking at Neural Nets.
If we take an actual implementation of a program that produced a computational graph and stop it at a time t, then at this time, the computational state of the program corresponds to the value of all nodes for which a successor is still not computed. In this case, the result of their computation is still relevant to the process and can’t be thrown away. I call a subset of nodes that correspond to a computational state an incomplete Cut, and the full set a complete Cut.
A Cut is complete if and only if every path from input to output goes through it. In the causal DAG setting it corresponds to a Markov Blanket.
Layer Activation Space is a generalization of looking at neurons: If we optimize activations for the length of the projection onto ei then this is the same as disregarding all components except the ith neuron and maximizing its activation. It is not intuitively clear to me that 'projection on a vector' is the right notion of 'feature activation'. There might be other notions of 'feature activation' in activation space. E.g. instead of ignoring orthogonal components, as a projection does, I could imagine penalizing them. Cuts can capture neuron activation or feature activation in activation space in other notions as well.
If we use Neurons as the unit of the computational graph then the Activation Space of a Cut is the vector that corresponds to the neurons in the cut. This Cut Activation Space generalizes both neurons and Layer Activation Space. A neuron is a one-element cut, and a layer is just a cut that contains all the neurons in one Layer.
Because in a linear layer everything is connected to everything, not many cuts are possible. But some connections within neural nets are not meaningful. The simplest example is a connection that is being multiplied with a 0-weight. We might want to prune our network to reduce the number of connections that are not meaningful. Maybe we could use counterfactuals to determine how important connections are and leave them out if they are under a certain importance threshold. (similar to the idea discussed here)
In my experience, we talk about neurons or activation space in Interpretability. More specifically, for feature visualization, we maximize activations of neurons or directions in activation space in a CNN. With this Cut idea, we should also look for directions in meaningful cuts (At this point I don’t know when a Cut is meaningful). In the linear layers a meaningful cut could look like this:
This is still an incomplete cut. There are computation results that don't go through nodes in the cut. To make this a complete cut, we would have to take all the nodes of the previous layer into the cut.
In practice, if we pruned the network or determined connections to be not meaningful then it could look like this:
If the idea of Cut Activation Space has merit, we should be able to find meaningful directions in cuts that don’t correspond to layers or neurons. We could test this by e.g. optimizing for different directions in such a cut in a CNN and see if this visualizes a feature. If we only find meaningful features in activation space, then there should be a reason that we can find in this computational graph setting. There is an argument for expecting Layer Activation Space to be the natural cut: Regardless of the order of computation, the whole layer has to be in the computational state at some point. That is because we need the whole layer to calculate the activations of the next one.
Ideally, we would find an algorithm that determines meaningful cuts and meaningful directions in their activation spaces. This seems like a very hard problem.
Going back to our search for steering mechanisms, I feel like we should be able to find a cut that corresponds to the ‘steering information being sent to successive computations’.
Thanks to Justis Millis for proofreading.
Thanks to Stephen Fowler and Adhiraj Nijjer for their feedback on the draft.
A computational Graph is a Graph where nodes correspond to functions. If a node corresponds to a function f (e.g. '+') and has inputs a1,...,an then the node sends f(a1,...,an) through all outgoing edges.
Choosing the functions implicitly assumes a level of detail and determines the expressiveness of the graph. We generally want those functions to be 'low level'
This paragraph seems to lead to a misunderstanding of what I mean.