Anthropic's recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability. 

While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.

1. Directly Analyzing the Critical Points of a Linear Toy Model

Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper "Neural networks and principal component analysis: Learning from examples without local minima." by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).

Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have  neurons and that the middle layer has  neurons.  This means that the function that the model is implementing is of the form , where  is a  matrix, and  is a  matrix. That is, the matrix  contains the weights of the connections between the input layer and the hidden layer, and the matrix  is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though - for a given set of weights - the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.

The error on a given input  will be measured by  and on the data set , the total loss is 

Define  to be the matrix whose  entry  is given by 

Clearly this matrix is symmetric.  

Assumption. We will assume that the data is such that a)  is invertible and b)  has distinct eigenvalues. 

Let  be the eigenvalues of .

 

1.1 The Global Minimum


Proposition 1. (Characterization of Critical Points) Fix the dataset and consider  to be a function of the two matrix variables  and . For any critical point  of , there is a subset  of size for which 

  1.  is an orthogonal projection onto a -dimensional subspace spanned by orthonormal eigenvectors of  corresponding to the eigenvalues ; and
  2. .

Corollary 2. (Characterization of the Minimum) The loss has a unique minimum value that is attained when , which corresponds to the situation when   is an orthogonal projection onto the -dimensional subspace spanned by the eigendirections of  that have the largest eigenvalues. 

Remarks. We won't try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.

We begin by directly differentiating  with respect to the entries of  and . Using summation convention on repeated indices, we first take the derivative with respect to 

 

Setting this equal to zero and interpreting this equation for all  and  gives us that 

Then, separately, we differentiate  with respect to  : 

Setting this equation equal to zero for every  and  we have that: 

Thus 

Since we have assumed that  is invertible, the first equation immediately implies that . If we assume in addition that  has full rank (a reasonable assumption in any case of practical interest), then  is invertible and we have that 

which in turn implies that 

where we have written  to denote the orthogonal projection on to the column space of .  

Claim. We next claim that  commutes with .

Proof of claim. Plugging (5) into (3), we have:

Then, right-multiply by  and use the fact that  to get:

The right-hand side is manifestly a symmetric matrix, so we deduce that  is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that  commutes with  and completes the proof of the claim.

 

Now let  be the orthogonal matrix which diagonalizes ,  i.e. the matrix for which 

where  is a diagonal matrix with entries 

 

Claim. We next claim that  and that  is diagonal.  

Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have 

 which implies that

To show that  is diagonal, we show that it commutes with the diagonal matrix  (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from , we first insert the identity matrix in the form , and then use (8) and (9) thus: 

Then recall that we have already established that  commutes with . So we can swap them and then performing the same trick in reverse: 

This shows that  commutes with  and completes the proof of the claim.

 

So, given that  is an orthogonal projection of rank  and is diagonal, there exists a set of indices  with  such that the  entry of  is zero if  and 1 if . And since , we see that 

where  is formed from  by simply setting to zero the  column if . This is manifestly an orthogonal projection onto the span of , where  is an orthonormal basis of eigenvectors of  (and indeed the columns of ).  Combining these observations with (5), we have that

This proves the first claim of the proposition.

 

To prove the second part, write  and compute thus: