This post is the first in a sequence that will describe James Crutchfield's Computational Mechanics framework. We feel this is one of the most theoretically sound and promising approaches towards understanding Transformers in particular and interpretability more generally. As a heads up: Crutchfield's framework will take many posts to fully go through, but even if you don't make it all the way through there are still many deep insights we hope you will pick up along the way.
EDIT: since there was some confusion about this in the comments: These initial posts are supposed to be an introductionary and won't get into the actually novel aspects of Crutchfield's framework yet. It's also not a dunk on existing information- theoretic measures - rather an ode!
To better understand the capability and limitations of large language models it is crucial to understand the inherent structure and uncertainty ('entropy') of language data. It is natural to quantify this structure with complexity measures. We can then compare the performance of transformers to the theoretically optimal limits achieved by minimal circuits. This will be key to interpreting transformers.
The two most well-known complexity measures are the Shannon entropy and the Kolmogorov complexity. We will describe why these measures are not sufficient to understand the inherent structure of language. This will serve as a motivation for more sophisticated complexity measures that better probe the intrinsic structure of language data. We will describe these new complexity measures in subsequent posts. Later in this sequence we will discuss some directions for transformer interpretability work.
Imagine you are an agent coming across some natural system. You stick an appendage into the system, effectively measuring its states. You measure for a million timepoints and get mysterious data that looks like this:
You want to gain an understanding of how this system generates this data, so that you can predict its output, so you can take advantage of the system to your own ends, and because gaining understanding is an intrinsic joy. In reality the data was generated in the following way: output 0, then 1, then you flip a fair coin, and then repeat. Is there some kind of framework or algorithm where we can reliably come to this understanding?
As others have noted, understanding is related to abstraction, prediction, and compression. We operationalize understanding by saying an agent has an understanding of a dataset if it possesses a compressed generative model: i.e. a program that is able to generate samples that (approximately) simulate the hidden structure, both deterministic and random, in the data.
Note that pure prediction is not understanding. As a simple example take the case of predicting the outcomes of 100 fair coin tosses. Predicting tails every flip will give you maximum expected predictive accuracy (50%), but it is not the correct generative model for the data. Over the course of this sequence, we will come to formally understand why this is the case.
To start let's consider the Kolmogorov Complexity and Shannon Entropy as measures of compression, and see why they don't quite work for what we want.
Recall that the Kolmogorov(-Chaitin-Solomonoff) complexity K(x) of a bit string x is defined as the length of the shortest programme outputting x [given a blank output on a chosen universal Turing machine]
One often discussed downside of the K complexity is that it is incomputable. But there is another more conceptual downside if we want Kolmogorov complexity to measure structure in natural systems: it assigns maximal 'complexity' to random strings.
Consider again the 0-1-random sequence
The Turing Machine is forced to explicitly represent every randomly generated bit, since there is no compression available for a string generated by a fair coin. For those bits, we will have to use up a full 1 bit (in this case 10E6/3 bits total). For the deterministic 0 and1 bits, we need only remember where in the sequence we are: the deterministic 0 position, the deterministic 1 position, or the random position. This requires log(3) ~= 1.58 bits.
There are two main things to note here. First, the K complexity is separable into two components, one corresponding to randomness, and the other corresponding to deterministic structure. Second, the component corresponding to randomness is a much larger contributor to the K complexity, by 6 orders of magnitude! This arises from the fact that we want the Turing Machine to recreate the string exactly, accounting for every bit.
If we want to have a compressed understanding of this string, why should we memorize every single random bit? The best thing to do would be to recognize that the random bits are random, and simply characterize them by the entropy associated with that token, instead of trying to account for every sample. In other words, the program associated with the standard way of thinking of K complexity is something like:
# use up a lot of computational resources storing every single
# random bit explicitly. This is a list that is 10E6/3 long!!!
random_bits = [0, 1, 1, 1, 1, 0, 0, 1,
0, 0, 1, 1, 0, 1, 0, 0]
# a relatively compact for-loop of 3 lines. the last line fetches the stored
# random data
for i in range(data_length/3):
The first line of code memorizes every random bit we have to explicitly represent. Then we can loop through the deterministic 0-1 and deterministically append the "random" bits one by one. But here's a much more compact understanding:
for i in range(data_length/3):
Note that the last line of this program is not the same as the last line of the previous program. Here we loop through appending the deterministic 0-1 and then randomly generate a bit.
If we were right in our assessment that the random part was really random then the first programme will overfit: if we do the observation again the nondeterministic part will be different - we've just 'wasted' ~106 of bits on overfitting!
The issue with K complexity is that it must account for all randomly generated data exactly. But an agent trying to create a compact understanding of the world only suffers when they try to account for random bits as if they were not random. A Turing Machine has no mechanism to instantiate uncertainty in the string generating process.
Since K complexity algorithmically accounts for every bit, whether generated by random or deterministic means, it overestimates the complexity of data generated by an at least partially random processes. Maybe then we can try Shannon Entropy since that seems to be a measure of the random nature of a system. As a reminder, here is the mystery string:
Recall that the Shannon Entropy of a distribution p(x) is −∑xp(x)log(p(x)), and that entropy is maximized by uniform distributions, where randomness is maximum. What is the Shannon Entropy of our data? We need to look at the distribution of 0s in 1s in our string:
There are equal numbers of 0's and 1's in our data, so the Shannon Entropy is maximized. From the perspective of Shannon Entropy, our data is indistinguishable from IID fair coin flips! This corresponds to something like the following program:
for i in range(data_length):
which is indeed very compact. But it also does not have great predictive power. There is structure in the data which we can't see from the perspective of Shannon Entropy. Whereas the perspective of K complexity led to overfitting the data by treating randomly generated bits as if they were deterministically generated, the perspective of Shannon Entropy is akin to under fitting the data by treating deterministic structure in the data as if it were generated by a random process.
These are the key takeaways:
As a teaser, the first step in attacking this last question will be investigating how the entropy (irreducible probabilistic uncertainty) changes at different scales.
In Crutchfield's framework these are called 'Epsilon Machines'.
Understanding is not just prediction - since prediction is a purely correlational understanding that does in general lift to the full causal picture. Recall that for most large data sets there are many causal models compatible with the empirical joint distribution. To distinguish these causal models we need interventional data. To have 'deep' understanding we need full causal understanding.
A simple way to see that All Tails is not the correct generative model is to consider sampling it many times: the observed empirical distribution of sampling from Tailx100 is very different from CoinFlipx100
Sorry if this is a spoiler for your next post, but I take issue with the heading "Standard measures of information theory do not work" and the implication that this post contains the pre-Crutchfield state of the art.
The standard approach to this in information theory (which underlies the loss function of autoregressive LMs) isn't to try to match the Shannon entropy of the marginal distribution of bits (a 50-50 distribution in your post), it's to treat the generative model as a distribution for each bit conditional on the previous bits and use the cross-entropy of that distribution under the data distribution as the loss function or measure of goodness of the generative model.
So in this example, "look at the previous bits, identify the current position relative to the 01x01x pattern, and predict 0, 1, or [50-50 distribution] as appropriate" is the best you can do (given sufficient data for the 50-50 proportion to be reasonably accurate) and is indeed an accurate model of the process that generated the data.
We can see the pattern and take the current position into account because the distribution is conditional on previous bits.
Predicting 011011011... doesn't do as well because cross-entropy penalizes unwarranted overconfidence.
Predicting 50-50 for each bit doesn't do as well because cross-entropy still cares about successful predictions.
(Formally, cross-entropy is an expectation over the data distribution instead of an empirical average over a bunch of sampled data, but the term is used in both cases in practice. "Log[-likelihood] loss" and "the log scoring rule" are other common terms for the empirical version.)
As I said above, this isn't just a standard information theory approach to this, it's actually how GPT-3 and other LLMs were trained.
I'm curious about Crutchfield's thing, but so far not convinced that standard information theory isn't adequate in this context.
(I think Kolmogorov complexity is also relevant to LLM interpretability, philosophically if not practically, but that's beyond the scope of this comment.)
A couple of differences between Kolmogorov complexity/Shannon entropy and the loss function of autoregressive LMs (just to highlight them, not trying to say anything you don't already know):
So it seems like there's plenty of room for a measure which is "more sensible" than the former and "more principled" than the latter.
Yeah follow-up posts will definitely get into that!
To be clear: (1) the initial posts won't be about Crutchfield work yet - just introducing some background material and overarching philosophy (2) The claim isn't that standard measures of information theory are bad. To the contrary! If anything we hope these posts will be somewhat of an ode to information theory as a tool for interpretability.
Adam wanted to add a lot of academic caveats - I was adamant that we streamline the presentation to make it short and snappy for a general audience but it appears I might have overshot ! I will make an edit to clarify. Thank you!
I agree with you about the importance of Kolmogorov complexity philosophically and would love to read a follow-up post on your thoughts about Kolmogorov complexity and LLM interpretability:)
interested to see what's next.
One notable absence is the Solomonoff prior, where you weight predictions (of prefix-free TMs) by 2−K to get a probability distribution. Related would be approximations like MML prediction.
Another nitpick would be that Shannon entropy is defined for distributions, not just raw strings of data, so you also have to fix the inference process you're using to extract probabilities from data.