Anthropic's recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability.
While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.
1. Directly Analyzing the Critical Points of a Linear Toy Model
Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper "Neural networks and principal component analysis: Learning from examples without local minima." by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).
Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have D neurons and that the middle layer has d<D neurons. This means that the function that the model is implementing is of the form x↦ABx, where x∈RD, B is a d×D matrix, and A is a D×d matrix. That is, the matrix B contains the weights of the connections between the input layer and the hidden layer, and the matrix A is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though - for a given set of weights - the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.
The error on a given input x will be measured by ∥∥x−ABx∥∥2 and on the data set {xt}Tt=1, the total loss is
Define Σ to be the matrix whose (i,j)th entry σij is given by
σij=T∑t=1xtixtj.
Clearly this matrix is symmetric.
Assumption. We will assume that the data is such that a) Σ is invertible and b) Σ has distinct eigenvalues.
Let λ1>⋯>λD be the eigenvalues of Σ.
1.1 The Global Minimum
Proposition 1. (Characterization of Critical Points)Fix the dataset and consider L to be a function of the two matrix variables A and B. For any critical point (A,B) of L, there is a subset I⊂{1,…,D} of size dfor which
AB is an orthogonal projection onto a d-dimensional subspace spanned by orthonormal eigenvectors of Σ corresponding to the eigenvalues {λi}i∈I; and
L(A,B)=trΣ−∑i∈Iλi=∑i∉Iλi.
Corollary 2. (Characterization of the Minimum)The loss has a unique minimum value that is attained when I={1,…,d}, which corresponds to the situation when AB is an orthogonal projection onto the d-dimensional subspace spanned by the eigendirections of Σ that have the largest eigenvalues.
Remarks. We won't try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.
We begin by directly differentiating L with respect to the entries of A and B. Using summation convention on repeated indices, we first take the derivative with respect to bj′k′:
Setting this equation equal to zero for every i′=1,…,D and j′=1,…,d we have that:
ΣBT=ABΣBT.(2)
Thus
∇L(A,B)=0⟺{ATΣ=ATABΣΣBT=ABΣBT.(3)
Since we have assumed that Σ is invertible, the first equation immediately implies that AT=ATAB. If we assume in addition that A has full rank (a reasonable assumption in any case of practical interest), then ATA is invertible and we have that
(ATA)−1AT=B,(4)
which in turn implies that
AB=A(ATA)−1AT=PA,(5)
where we have written PA to denote the orthogonal projection on to the column space of A.
Claim. We next claim that Σ commutes with PA.
Proof of claim. Plugging (5) into (3), we have:
ΣBT=PAΣBT.(6)
Then, right-multiply by AT and use the fact that PTA=PA to get:
ΣPA=PAΣPA.(7)
The right-hand side is manifestly a symmetric matrix, so we deduce that ΣPA is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that Σ commutes with PA and completes the proof of the claim.
Now let U be the orthogonal matrix which diagonalizes Σ, i.e. the matrix for which
Σ=UΛUT,(8)
where Λ is a diagonal matrix with entries λ1>λ2>⋯>λD>0.
Claim. We next claim that PA=UPUTAUT and that PUTA is diagonal.
Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have
PUTA=UTA(ATUUTA)−1ATU=UTA(ATA)−1ATU=UTPAU,
which implies that
PA=UPUTAUT.(9)
To show that PUTA is diagonal, we show that it commutes with the diagonal matrix Λ (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from PUTAΛ, we first insert the identity matrix in the form UTU, and then use (8) and (9) thus:
PUTAΛ=UTUPUTAUTUΛUTU=UTPAΣU
Then recall that we have already established that PA commutes with Σ. So we can swap them and then performing the same trick in reverse:
UTPAΣU=UTΣPAU=UTUΛUTUPUTAUTU=ΛPUTA.
This shows that PUTA commutes with Λ and completes the proof of the claim.
So, given that PUTA is an orthogonal projection of rank d and is diagonal, there exists a set of indices I={i1,…,id} with 1≤i1<i2<⋯<id≤D such that the (i,j)th entry of PUTA is zero if i≠j and 1 if i=jandi∈I. And since PA=UPUTAUT, we see that
PA=UIUTI,(10)
where UI is formed from U by simply setting to zero the jth column if j∉I. This is manifestly an orthogonal projection onto the span of {ui1,…,uid}, where u1,u2,…,uD is an orthonormal basis of eigenvectors of Σ (and indeed the columns of U). Combining these observations with (5), we have that
AB=PA=UIUTI=PUI.(11)
This proves the first claim of the proposition.
To prove the second part, write AB=[pij] and compute thus:
Anthropic's recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability.
While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.
1. Directly Analyzing the Critical Points of a Linear Toy Model
Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper "Neural networks and principal component analysis: Learning from examples without local minima." by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).
Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have D neurons and that the middle layer has d<D neurons. This means that the function that the model is implementing is of the form x↦ABx, where x∈RD, B is a d×D matrix, and A is a D×d matrix. That is, the matrix B contains the weights of the connections between the input layer and the hidden layer, and the matrix A is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though - for a given set of weights - the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.
The error on a given input x will be measured by ∥∥x−ABx∥∥2 and on the data set {xt}Tt=1, the total loss is
L=L(A,B,{xt}Tt=1):=T∑t=1∥∥xt−ABxt∥∥2=T∑t=1D∑i=1(xti−D∑j,k=1aijbjkxtk)2Define Σ to be the matrix whose (i,j)th entry σij is given by
σij=T∑t=1xtixtj.Clearly this matrix is symmetric.
Assumption. We will assume that the data is such that a) Σ is invertible and b) Σ has distinct eigenvalues.
Let λ1>⋯>λD be the eigenvalues of Σ.
1.1 The Global Minimum
Proposition 1. (Characterization of Critical Points) Fix the dataset and consider L to be a function of the two matrix variables A and B. For any critical point (A,B) of L, there is a subset I⊂{1,…,D} of size dfor which
Corollary 2. (Characterization of the Minimum) The loss has a unique minimum value that is attained when I={1,…,d}, which corresponds to the situation when AB is an orthogonal projection onto the d-dimensional subspace spanned by the eigendirections of Σ that have the largest eigenvalues.
Remarks. We won't try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.
We begin by directly differentiating L with respect to the entries of A and B. Using summation convention on repeated indices, we first take the derivative with respect to bj′k′ :
∂L∂bj′k′=T∑t=1D∑i=1−2(xti−aijbjkxtk)ailδlj′δqk′xtq=−2T∑t=1(xtiaij′xtk′−aijbjkxtkaij′xtk′)Setting this equal to zero and interpreting this equation for all j′=1,…,d and k′=1,…,D gives us that
ATΣ=ATABΣ.(1)Then, separately, we differentiate L with respect to ai′j′ :
∂L∂ai′j′=T∑t=1D∑i=1−2(xti−aijbjkxtk)δii′δpj′bpqxtq=−2T∑t=1(xti′bj′qxtq−ai′jbjkxtkbj′qxtq).Setting this equation equal to zero for every i′=1,…,D and j′=1,…,d we have that:
ΣBT=ABΣBT.(2)Thus
∇L(A,B)=0⟺{ATΣ=ATABΣΣBT=ABΣBT.(3)Since we have assumed that Σ is invertible, the first equation immediately implies that AT=ATAB. If we assume in addition that A has full rank (a reasonable assumption in any case of practical interest), then ATA is invertible and we have that
(ATA)−1AT=B,(4)which in turn implies that
AB=A(ATA)−1AT=PA,(5)where we have written PA to denote the orthogonal projection on to the column space of A.
Claim. We next claim that Σ commutes with PA.
Proof of claim. Plugging (5) into (3), we have:
ΣBT=PAΣBT.(6)Then, right-multiply by AT and use the fact that PTA=PA to get:
ΣPA=PAΣPA.(7)The right-hand side is manifestly a symmetric matrix, so we deduce that ΣPA is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that Σ commutes with PA and completes the proof of the claim.
Now let U be the orthogonal matrix which diagonalizes Σ, i.e. the matrix for which
Σ=UΛUT,(8)where Λ is a diagonal matrix with entries λ1>λ2>⋯>λD>0.
Claim. We next claim that PA=UPUTAUT and that PUTA is diagonal.
Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have
PUTA=UTA(ATUUTA)−1ATU=UTA(ATA)−1ATU=UTPAU,which implies that
PA=UPUTAUT.(9)To show that PUTA is diagonal, we show that it commutes with the diagonal matrix Λ (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from PUTAΛ, we first insert the identity matrix in the form UTU, and then use (8) and (9) thus:
PUTAΛ=UTUPUTAUTUΛUTU=UTPAΣUThen recall that we have already established that PA commutes with Σ. So we can swap them and then performing the same trick in reverse:
UTPAΣU=UTΣPAU=UTUΛUTUPUTAUTU=ΛPUTA.This shows that PUTA commutes with Λ and completes the proof of the claim.
So, given that PUTA is an orthogonal projection of rank d and is diagonal, there exists a set of indices I={i1,…,id} with 1≤i1<i2<⋯<id≤D such that the (i,j)th entry of PUTA is zero if i≠j and 1 if i=j and i∈I. And since PA=UPUTAUT, we see that
PA=UIUTI,(10)where UI is formed from U by simply setting to zero the jth column if j∉I. This is manifestly an orthogonal projection onto the span of {ui1,…,uid}, where u1,u2,…,uD is an orthonormal basis of eigenvectors of Σ (and indeed the columns of U). Combining these observations with (5), we have that
AB=PA=UIUTI=PUI.(11)This proves the first claim of the proposition.
To prove the second part, write AB=[pij] and compute thus:
T∑t=1∥∥xt−ABxt∥∥