(Thanks to Adam Scherlis, Kshitij Sachan, Buck Shlegeris, Chris Olah, and Nicholas Schiefer for conversations that informed this post. Thanks to Aryan Bhatt for catching an error in the loss minimization.)
Anthropic recently published a paper on toy models of superposition [Elhage+22]. One of the most striking results is that, when features are sparse, feature embeddings can become divided into disjoint subspaces with just a few vectors per subspace. This type of decomposition is known as a tegum product.
This post aims to give some intuition for why tegum products are a natural feature of the minima of certain loss functions.
Suppose we’ve got embedding dimensions and features. We want to embed features into dimensions in a way that minimizes overlap between their embedding vectors. In the simplest case this could be because we’re building an autoencoder and want to compress the features into a low-dimensional space.
One approach is to encode all of the features in all of the dimensions. With this approach there is some interference between every pair of features (i.e. no pair is embedded in a fully orthogonal way), but we have a lot of degrees of freedom that we can use to minimize this interference.
Another approach is to split the dimensions into orthogonal subspaces of dimensions. This has the advantage of making most pairs of vectors exactly orthogonal, but at the cost that some vectors are packed more closely together. In the limit where this reduces to the first approach.
Our aim is to figure out the that minimizes the loss on this task.
Suppose our loss has the following properties:
- . That is, the loss decomposes into a sum of terms involving the cosine similarities of feature vectors, and all features are equally important.
- . The loss vanishes for orthogonal vectors.
- . The loss is greater the more the vectors overlap.
Using these properties, we find that the loss is roughly
where is the typical cosine similarity between vectors in a subspace.
Edit: The formula for below is a simplification in the limit of small . This simplification turns out to matter, and affects the subsequent loss minimization. I've struck through the affected sections below, explained the correct optimization in this comment below and reproduced the relevant results below the original here. None of the subsequent interpretation is affected.
The Johnson-Lindenstrauss lemma says that we can pack nearly-orthogonal vectors into dimensions, with mutual angles satisfying where and is a constant. Setting and gives Assuming we pick our vectors optimally to saturate the Johnson-Lindenstrauss bound, we can substitute this for in the loss and differentiate with respect to to find There are three possible cases: either the minimum occurs at (the greatest value it can take), or at (the smallest value it can take) or at some point in between where vanishes. The derivative vanishes if which gives where When there is no place where the derivative vanishes, and the optimum is . Otherwise there is an optimum at so long as this is less than . If it reaches , the optimum sticks to .
The Johnson-Lindenstrauss lemma says that we can pack nearly-orthogonal vectors into dimensions, with mutual angles satisfying
where (per Scikit and references therein). The cubic term matters because it makes the interference grow faster than the quadratic alone would imply (especially in the vicinity of ).
With this correction it's not feasible to do the optimization analytically, but we can still do things numerically. Setting , , , and gives:
The top panel shows the normalized loss for a few different , and the lower shows the loss derivative with respect to . Note that the range of is set by the real roots of : for larger there are no real roots, which corresponds to the interference crossing unity. In practice this bound applies well before . Intuitively, if there are more vectors than dimensions then the interference becomes order-unity (so there is no information left!) well before the subspace dimension falls to unity.
Anyway, all of these curves have global minima in the interior of the domain (if just barely for ), and the minima move to the left as rises. That is, for we care increasingly about higher moments as we increase and so we want fewer subspaces.
What happens for ?
The global minima disappear! Now the optimum is always . In fact though the transition is no longer at but a little higher:
We can think of as the sensitivity of the loss to interference. Specifically, which moment of the interference distribution do we care about?
When is large, we care more about decreasing higher moments, and in the limit of infinite what matters is just the maximum interference between vectors. Hence when is large we want to have fewer subspaces, each with more vectors but smaller cosine similarities.
By contrast, when is small, we care more about decreasing smaller moments, and in the limit as what matters is the fraction of vectors that interfere at all. Hence when is small we want to have more subspaces, each with fewer vectors but larger cosine similarities.
So tegum products are preferred when we can tolerate larger “peak” interference and want fewer instances of interference, whereas a single large subspace is preferred when we can tolerate lots of instances of interference and want to minimize the worst cases.
Relation to Anthropic’s Results
In Anthropic’s Toy Model 2, the dimension of the subspaces increases the sparser the features get, meaning that falls. We can make sense of this by expanding the loss as they do in powers of the sparsity :
where is the loss associated with -sparse vectors. In the sparse limit so
The term is a penalty on positive biases and the term is the loss on 1-sparse vectors. In this limit, the biases are negative (to minimize ), and this has the effect of wiping out the contribution of small interference terms in . So the model is pushed to decrease the worst case interference (which might overcome the bias) rather than minimize the average, corresponding to our large- limit.
On the other hand, in the dense limit so
The term is the loss on dense vectors, which means there is interference between all pairs of vectors. This makes minimizing the average interference more important than minimizing the really bad cases (i.e. higher moments), so the model prefers lots of small subspaces, corresponding to our small- limit.
Just as the relevant limit varies with sparsity for a given toy model, we can also get different phenomenology for different models. This makes sense given that different setups can have different loss functions.
There are many ways for models to embed features. A surprising observation from Elhage+22 is that sometimes the optimal choice is one which divides the embedding space into many orthogonal subspaces (i.e. a tegum product). We can understand this roughly as coming from a tradeoff between minimizing higher moments of the feature interference (e.g. worst-case) and minimizing lower moments (e.g. average-case interference).
Smaller subspaces minimize the lower moments by making most pairs of vectors exactly orthogonal. The cost of this is that there is less freedom to choose vector pairs in each subspace, so there is worse interference between the pairs that do interfere.
Larger subspaces have the reverse tradeoff: they experience interference between more pairs of vectors, but it tends to be milder because larger-dimensional spaces support packing more nearly-orthogonal vectors, even at a fixed ratio of vectors-to-dimension.