Can't say much about transformers, but the tensor product definition seems off. There can be many elements in V⊗W that aren't expressible as v⊗w, only as a linear combination of multiple such. That can be seen from dimensionality: if v and w have dimensions n and m, then all possible pairs can only span n+m dimensions (Cartesian product), but the full tensor product has nm dimensions.
Here's an explanation of tensor products that I came up with sometime ago in an attempt to make it "click". Imagine you have a linear function that takes in two vectors and spits out a number. But wait, there are two natural but incompatible ways to imagine it:
f(a,b) + f(c,d) = f(a+c,b+d), linear in both arguments combined. The space of such functions has dimension n+m, and corresponds to Cartesian product.
f(a,b) + f(a,c) = f(a,b+c) and also f(a,c) + f(b,c) = f(a+b,c), in other words, linear in each argument separately. The space of such functions has dimension nm, and corresponds to tensor product.
It's especially simple to work through the case n=m=1. In that case all functions satisfying (1) have the form f(x,y)=ax+by, so their space is 2-dimensional, while all functions satisfying (2) have the form f(x,y)=axy, so their space is 1-dimensional. Admittedly this case is a bit funny because nm<n+m, but you can see how in higher dimensions the space of functions of type (2) becomes much bigger, because it will have terms for x1y1, x1y2, etc.
I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn't understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
Tensor product
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces V,W we can construct the tensor product space V⊗W as the span[1] of all matrices v⊗w, where v∈V,w∈W, with the property (v⊗w)ij=viwj [2]. We can equivalently define it as a vector space with basis elements eVi⊗eWj, where we used the basis elements of V and W respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) A:V→X,B:W→Y we can define A⊗B:V⊗W→X⊗Y, where each map simply operates on its own vector space, not interacting with the other:
(A⊗B)(v⊗w)=A(v)⊗B(w)
For more information on the tensor product, I recommend this intuitive explanation and the Wikipedia entry.
How does this connect to the attention-only transformer?
In the "attention-only" formulation of the transformer we can write the "residual" of a fixed head as AXWVWO, with the values weight matrix WV, the attention matrix A, the output weight matrix WO, and the current embeddings at each position X
Let E be the embedding dimension, L the total context length and D the dimension of the values, then we have that
Let's identify the participating vector spaces:
A maps from the "position" space back to the "position" space, which we will call P (and which is isomorphic to RL). Similarly, we have the "embedding" space E≅RE and the "value" space V≅RD.
It might become clear now that we can identify X with an element from P⊗E, i.e. that we can write X=Xij(ePi⊗eEj).
From that lense, we can see that right-multiplying X with WV is equivalent to multiplying with Id⊗WV, which maps an element from P⊗E to an element from P⊗V, by applying WV to the E-part of the tensor [3]:
(Id⊗WV)(X)=(Id⊗WV)∑ijXijePi⊗eEj=∑ijXijePi⊗WV(eEj)=∑ijXijePi⊗∑kWjkeVk=∑ik∑j(XijWjk)ePi⊗eVk=∑ik(XWV)ikePi⊗eVk=XWV
Identical arguments hold for WO and A, so that we get the formulation from the paper:
AXWOWV=(A⊗WOWV)⋅X
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form ABC you can re-write it as (A⊗C)⋅B (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).[4]
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set. ↩︎
I'm limiting myself to finite-dim spaces because that's what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn't add much to understanding the application in this paper ↩︎
Note that the 'linear map' that we use here is basically right multiplying with WV, so that it maps eEk↦WTVeEk ↩︎
I should note that this is also what is mentioned in the paper's introduction on tensor products, but it didn't click with me, whereas going through the above steps did. ↩︎