These are just some notes I wrote while reading about transformers which I thought might be a useful reference to others. Thanks to Aryan Bhatt for a correction to the attention normalization. Further corrections welcome.

We take tokens as inputs and pass them through an embedding layer. The embedding layer outputs its result into the residual stream (x0). This has dimension (C,E), where C is the number of tokens in the context window and E is the embedding dimension.

The residual stream is processed by the attention mechanism (H) and the result is added back into the residual stream (i.e. x1 = H(x0) + x0).

The residual stream is processed by an MLP layer (MLP) and the result is added back into the residual stream (i.e. x2 = MLP(x1) + x1).

Steps (2) and (3) together define a “residual block”. The body of the transformer is formed of a stack of these blocks in series.

After the final residual block, we apply an unembedding transformation to produce logits, which represent the relative probabilities of different output tokens.

Attention Mechanism

The attention mechanism (H) is divided into multiple attention heads hj, which act in parallel. That is,

H(x)=∑jhj(x)

Note that this decomposition is only useful if attention heads are non-linear. Fortunately they are! Each attention head is of the form

hij(x)=∑klAik(x)Sjlxkl

That is, Aik(x) mixes across tokens (which is the first index of x) and Sjl transforms each token in parallel. Another way we could have written this is

h(x)=A(x)⋅x⋅S

The matrix S is also written in more common notation as WOWV, which are sometimes called the output and value weights. In general though S is just some low-rank matrix that we learn. S has shape (E,E) because it transforms in the embedding space.

The matrix A is where the nonlinearity of attention comes in. This is given by

A=softmax(xYxT)

where Y is written in more common notation as WTQWK/√dk, which are sometimes called the query and key weights. The dimension dk is the dimension of the output of WK, and so is the rank of Y. As with S, Y is just some low-rank matrix that we learn. The softmax acts on the whole matrix.

MLP Layer

The MLP (multilayer perceptron) layer processes the residual stream using the same MLP for each token index. That is, there is no communication between tokens in the MLP layer. All this layer does is transform in the embedding space.

Positional Encodings

A quirk of the attention mechanism is that it is covariant with respect to shuffling the token index. That is, if P is a permutation matrix then

h(Px)=Ph(x)

To see this, we expand the left-hand side:

h(Px)=A(Px)⋅Px⋅S=softmax((Px)Y(Px)T)⋅(Px)⋅S

The permutations don’t change any of the values inside the softmax, so they can be pulled outside:

h(Px)=Psoftmax(xYxT)PTPxS

The transpose of a permutation matrix is its inverse, so PTP=I and

h(Px)=Ph(x)

Similarly, the MLP layer acts on each token individually and so doesn’t know anything about their orderings.

What this means is that there is no information about token ordering in the transformer unless we put it there in the embedding space. This is what positional encodings do.

A typical positional encoding is given by adding a position-dependent vector to the embedding of each token. A common choice is

ek,j<E/2=sin(k/N2j/E)ek,j≥E/2=cos(k/N2(j−d/2)/E)

where k is the token index in the context window and j indexes the embedding space. Here N>C so that this is not a periodic function of k. The reason this choice is common is that there is a linear transformation for shifting k→k+1, identical across all k, which makes it easy for models to learn to compare adjacent tokens. If this is not apparent note that pairs of j offset by d/2 give a representation of a complex number eik/N2j/E, and we can increment k by multiplying by a diagonal operator ei/N2j/E which is the same for all k.

These are just some notes I wrote while reading about transformers which I thought might be a useful reference to others. Thanks to Aryan Bhatt for a correction to the attention normalization. Further corrections welcome.## Overview of Transformers

Many transformer models have the following architecture:

Data flows as follows:

## Attention Mechanism

The attention mechanism (H) is divided into multiple attention heads hj, which act in parallel. That is,

H(x)=∑jhj(x)Note that this decomposition is only useful if attention heads are non-linear. Fortunately they are! Each attention head is of the form

hij(x)=∑klAik(x)SjlxklThat is, Aik(x) mixes across tokens (which is the first index of x) and Sjl transforms each token in parallel. Another way we could have written this is

h(x)=A(x)⋅x⋅SThe matrix S is also written in more common notation as WOWV, which are sometimes called the output and value weights. In general though S is just some low-rank matrix that we learn. S has shape (E,E) because it transforms in the embedding space.

The matrix A is where the nonlinearity of attention comes in. This is given by

A=softmax(xYxT)where Y is written in more common notation as WTQWK/√dk, which are sometimes called the query and key weights. The dimension dk is the dimension of the output of WK, and so is the rank of Y. As with S, Y is just some low-rank matrix that we learn. The softmax acts on the whole matrix.

## MLP Layer

The MLP (multilayer perceptron) layer processes the residual stream using the same MLP for each token index. That is, there is no communication between tokens in the MLP layer. All this layer does is transform in the embedding space.

## Positional Encodings

A quirk of the attention mechanism is that it is covariant with respect to shuffling the token index. That is, if P is a permutation matrix then

h(Px)=Ph(x)To see this, we expand the left-hand side:

h(Px)=A(Px)⋅Px⋅S=softmax((Px)Y(Px)T)⋅(Px)⋅SThe permutations don’t change any of the values inside the softmax, so they can be pulled outside:

h(Px)=Psoftmax(xYxT)PTPxSThe transpose of a permutation matrix is its inverse, so PTP=I and

h(Px)=Ph(x)Similarly, the MLP layer acts on each token individually and so doesn’t know anything about their orderings.

What this means is that there is no information about token ordering in the transformer unless we put it there in the embedding space. This is what positional encodings do.

A typical positional encoding is given by adding a position-dependent vector to the embedding of each token. A common choice is

ek,j<E/2=sin(k/N2j/E)ek,j≥E/2=cos(k/N2(j−d/2)/E)where k is the token index in the context window and j indexes the embedding space. Here N>C so that this is not a periodic function of k. The reason this choice is common is that there is a linear transformation for shifting k→k+1, identical across all k, which makes it easy for models to learn to compare adjacent tokens. If this is not apparent note that pairs of j offset by d/2 give a representation of a complex number eik/N2j/E, and we can increment k by multiplying by a diagonal operator ei/N2j/E which is the same for all k.