This is a linkpost for https://neelnanda.io/mechanistic-interpretability/prereqs

A Barebones Guide to Mechanistic Interpretability Prerequisites

4Zac Hatfield-Dodds

2Neel Nanda

1mic

0Yulu Pi

1Neel Nanda

New Comment

For Python basics, I have to *anti*-recommend Shaw's 'learn the hard way'; it's generally outdated and in some places actively misleading. And why would you want to learn the hard way instead of the best way in any case?

Instead, my standard recommendation is Al Sweigart's *Automate the Boring Stuff* and then *Beyond the Basic Stuff* (both readable for free on inventwithpython.com, or purchasable in books); he's also written some books of exercises. If you prefer a more traditional textbook, *Think Python 2e* is excellent and also available freely online.

Thanks! I learned Python ~10 years ago and have no idea what sources are any good lol. I've edited the post with your recs :)

Thanks for writing this! Here is a quick explanation of all the math concepts – mostly written by ChatGPT with some manual edits.

A **basis** for a vector space is a set of linearly independent vectors that can be used to represent any vector in the space as a linear combination of those basis vectors. For example, in two-dimensional Euclidean space, the standard basis is the set of vectors (1, 0) and (0, 1), which are called the "basis vectors."

A **change of basis** is the process of expressing a vector in one basis in terms of another basis. For example, if we have a vector v in two-dimensional Euclidean space and we want to express it in terms of the standard basis, we can write v as a linear combination of (1, 0) and (0, 1). Alternatively, we could choose a different basis for the space, such as the basis formed by the vectors (4, 2) and (3, 5). In this case, we would express v in terms of this new basis by writing it as a linear combination of (4, 2) and (3, 5).

A **vector space** is a set of vectors that can be added together and multiplied ("scaled") by numbers, called scalars. Scalars are often taken to be real numbers, but there are also vector spaces with scalar multiplication by complex numbers, rational numbers, or generally any field. The operations of vector addition and scalar multiplication must satisfy certain requirements, called axioms. Examples of vector spaces include the set of all two-dimensional vectors (i.e., the set of all points in two-dimensional Euclidean space), the set of all polynomials with real coefficients, and the set of all continuous functions from a given set to the real numbers. A vector space can be thought of as a geometric object, but it does not necessarily have a canonical basis, meaning that there is not a preferred set of basis vectors that can be used to represent all the vectors in the space.

A **matrix** is a rectangular array of numbers, symbols, or expressions, arranged in rows and columns. A matrix is a linear map between two vector spaces, or from a vector space to itself, because it can take any vector in the original vector space and transform it into a new vector in the target vector space using a set of linear equations. Each column of the matrix represents one of the new basis vectors, which are used to define the transformation. In the expression , we take each element of the original vector and multiply it by the corresponding element in the appropriate column of the matrix, and then add these products together to create the new vector.

The **singular value decomposition** (SVD) is a factorization of a matrix M into the product of three matrices: , where U and V are orthogonal matrices and S is a diagonal matrix with non-negative real numbers on the diagonal, called the "singular values" of M. The SVD is a useful tool for understanding the properties of a matrix and for solving certain types of linear systems. It can also be used for data compression, image processing, and other applications.

An **orthogonal matrix** (or orthonormal matrix) is a square matrix whose columns and rows are mutually orthonormal (i.e., they are orthogonal and have unit length). Orthogonal matrices have the property that their inverse is equal to their transpose.

Changing to an orthonormal basis can be importantly different from just any change of basis because it has certain computational advantages. For example, when working with an orthonormal basis, the inner product of two vectors can be computed simply as the sum of the products of their corresponding components, without the need to use any weights or scaling factors. This can make certain calculations, such as finding the length of a vector or the angle between two vectors, simpler and more efficient.

**Eigenvalues** and **eigenvectors** are special types of scalars and vectors that are associated with a linear map or a matrix. If M is a linear map or matrix and v is a non-zero vector, then v is an **eigenvector** of M if there exists a scalar λ, called an **eigenvalue**, such that . In other words, when a vector is multiplied by the matrix M, the resulting vector is a scalar multiple of the original vector. Eigenvalues and eigenvectors are important because they provide insight into the properties of the linear map or matrix. For example, the eigenvalues of a matrix can tell us whether it is singular (i.e., not invertible) or whether it is diagonalizable (i.e., can be expressed in the form , where P is a matrix and D is a diagonal matrix). The eigenvectors of a matrix can also be used to determine its rank, nullity, and other characteristics.

**Probability basics**: Probability is a measure of the likelihood of an event occurring. It is typically represented as a number between 0 and 1, where 0 indicates that the event is impossible and 1 indicates that the event is certain to occur. The probability of an event occurring can be calculated by counting the number of ways in which the event can occur, divided by the total number of possible outcomes.

**Basics of distributions**: A **distribution** is a function that describes the probability of a random variable taking on different values. The **expected value** of a distribution is a measure of the center of the distribution, and it is calculated as the weighted average of the possible values of the random variable, where the weights are the probabilities of each value occurring. The **standard deviation** is a measure of the dispersion of the distribution, and it is calculated as the square root of the variance, which is the expected value of the squared deviation of a random variable from its mean. A **normal distribution** (or **Gaussian distribution**) is a continuous probability distribution with a bell-shaped curve, which is defined by its mean and standard deviation.

**Log likelihood**: The **log likelihood** of a statistical model is a measure of how well the model fits a given set of data. It is calculated as the logarithm of the probability of the data given the model, and it is often used to compare the relative fit of different models.

**Maximum value estimators**: A **maximum value estimator** is a statistical method that is used to estimate the value of a parameter that maximizes a given objective function. Examples of maximum value estimators include the maximum likelihood estimator and the maximum a posteriori estimator.

- The
**maximum likelihood estimator**is a method for estimating the parameters of a statistical model based on the principle that the parameters that maximize the likelihood of the data are the most likely to have generated the data. - The
**maximum a posteriori (MAP) estimator**is a method for estimating the parameters of a statistical model based on the principle that the parameters that maximize the posterior probability of the data are the most likely to have generated the data. The posterior probability is the probability of the data given the model and the prior knowledge about the parameters. The MAP estimator is often used in Bayesian inference, and it is a popular method for estimating the parameters of a model in the presence of prior knowledge.

**Random variables**: A **random variable** is a variable whose value is determined by the outcome of a random event. For example, the toss of a coin is a random event, and the number of heads that result from a series of coin tosses is a random variable.

**Central limit theorem**: The **central limit theorem** is a statistical theorem that states that, as the sample size of a random variable increases, the distribution of the sample means approaches a normal distribution, regardless of the distribution of the underlying random variable.

**Calculus basics**: **Calculus** is a branch of mathematics that deals with the study of rates of change and the accumulation of quantities. It is a fundamental tool in the study of functions and is used to model and solve problems in a variety of fields, including physics, engineering, and economics.

**Gradients**: In calculus, the **gradient** of a (scalar-valued multivariate differentiable) function is a vector that describes the direction in which the function is increasing most quickly. It is calculated as the partial derivative of the function with respect to each variable.

**The chain rule**: The **chain rule** is a fundamental rule of calculus that allows us to calculate the derivative of a composite function. It states that if f is a function of g, and g is a function of x, then the derivative of f with respect to x is equal to the derivative of f with respect to g times the derivative of g with respect to x. In tohers words, (df / dx) = (df / dg) * (dg / dx).

On backpropagation:

Backpropagation is an algorithm for training artificial neural networks, which are machine learning models inspired by the structure and function of the brain. It is used to adjust the weights and biases of the network in order to minimize the error between the predicted output and the desired output of the network.

The idea behind backpropagation is that, given a multivariate function that describes the relationships between the input variables and the output variables of a neural network, we can use the chain rule to calculate the gradient of the function with respect to the weights and biases of the network. The gradient tells us how the error changes as we adjust the weights and biases, and we can use this information to update the weights and biases in a way that reduces the error.

To understand why backpropagation is just the chain rule on multivariate functions, it's helpful to consider the structure of a neural network. A neural network consists of layers of interconnected nodes, each of which performs a calculation based on the inputs it receives from the previous layer. The output of the network is a function of the inputs, and the weights and biases of the network determine how the inputs are transformed as they pass through the layers of the network.

The process of backpropagation involves starting at the output layer of the network and working backwards through the layers, using the chain rule to calculate the gradients of the weights and biases at each layer. This is done by calculating the derivative of the error with respect to the output of each layer, and then using the chain rule to propagate these derivatives back through the layers of the network. This allows us to calculate the gradients of the weights and biases at each layer, which we can use to update the weights and biases in a way that minimizes the error.

Overall, backpropagation is an efficient and effective way to train neural networks because it allows us to calculate the gradients of the weights and biases efficiently, using the chain rule to propagate the derivatives through the layers of the network. This enables us to adjust the weights and biases in a way that minimizes the error, which is essential for the effective operation of the network.

hey Neel,

Great post!

I am trying to look into the code here

- Good (but hard) exercise: Code your own tiny GPT-2 and train it. If you can do this, I’d say that you basically fully understand the transformer architecture.
- Example of
__basic training boilerplate__and__train script__ - The
__EasyTransformer codebase__is probably good to riff off of here

- Example of

But the links dont work anymore! It would be nice if you could help update them!

I dont know if this link works for the original content: https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/Clean_Transformer_Demo_Template.ipynb

Thanks a lot!

Ah, thanks! Haven't looked at this point in a while, updated it a bit. I've since made my own transformer tutorial which (in my extremely biased opinion) is better esp for interpretability. It comes with a template notebook to fill out alongside part 2, (with tests!) and by the end you'll have implemented your own GPT-2.

More generally, my getting started in mech interp guide is a better place to start than this guide, and has more on transformers!

Co-authored by Neel Nanda and Jess SmithCheck outConcrete Steps for Getting Started in Mechanistic Interpretabilityfor a better starting pointWhy does this exist?People often get intimidated when trying to get into AI or AI Alignment research. People often think that the gulf between where they are and where they need to be is huge. This presents practical concerns for people trying to change fields: we all have limited time and energy. And for the most part, people wildly overestimate the actual core skills required.

This guide is our take on the essential skills required to understand, write code and ideally contribute useful research to mechanistic interpretability. We hope that it’s useful and unintimidating. :)

Core Skills:3Blue1BrownorLinear Algebra Done RightAutomate the Boring Stuffand thenBeyond the Basic Stuff(both readable for free on inventwithpython.com, or purchasable in books); he's also written some books of exercises. If you prefer a more traditional textbook,Think Python 2eis excellent and also available freely online.https://github.com/rougier/numpy-100. Bonus points for doing them in pytorch on tensors :)fast.aiis a good intro, but a fair bit more effort than is necessary. For an 80/20, focus on Andrej Karpathy’s new video explaining neural nets:https://www.youtube.com/watch?v=VMj-3S1tku0PyTorchbasicsrealskill in programming.I highly, highly recommend learning how to useeinopsreallyimportant to deeply understand the architectures of the models you use, all of the moving parts inside of them, and how they fit together. In this case, the main architecture that matters is a transformer! (This is useful in normal ML too, but you can often get away with treating the model as a black box)what is a transformerandimplementing GPT-2 From Scratchvideo tutorialsMy transformer glossary/explainerA worthwhile exercise is to fill out the, accompanying the tutorial (no copying and pasting!)template notebookTransformers for Software Engineers(also useful to non software engineers!)the illustrated transformerJacob Hilton’s Deep learning for Alignment syllabus- this is a lot more content than you strictly need, but is well put together and likely a good use of time to go through at least some of!Once you have the pre-reqs, my Getting Started in Mechanistic Interpretability guide goes into how to get further into mechanistic interpretability!

Note that there are a lot more skills in the “nice-to-haves”, but I think that generally the best way to improve at something is by getting your hard dirty and engaging with the research ideas directly, rather than making sure you learn every nice-to-have skill first - if you have the above, I think you should just jump in and start learning about the topic! Especially for the coding related skills, your focus should not be on getting your head around concepts, it should be about

doing, and actually writing code and playing around with the things - the challenge of making something that actually works, and dealing with all of the unexpected practical problems that arise is the best way of really getting this.