For neural networks trained to perfect loss, the broadness of optima in parameter space is given by the number and norm of independent orthogonal features the neural network has. The inner product that defines this "feature norm" and independence/orthogonality is the product of Hilbert space.
Introduction - why do we care about broadness?
Recently, there's been some discussion of what determines the broadness of optima of neural networks in parameter space.
People care about this because the size of the optimum basin may influence how easy it is for gradient descent (or other local optimisation algorithms which you might use to train a neural network) to find the corresponding solution, since it's probably easier to stumble across a broader optimum's attractor.
People also care about this because there's some hypotheses floating around saying that the broadness of optima is correlated with their generality somehow. The broader the basin of your solution is in parameter space, the more likely it seems to be that the network will successfully generalise to new data in deployment. So understanding what kinds of circuit configurations are broad seems like it might help us grasp better how AIs learn things, and how an AGI might really work internally.
Meaning, understanding what kind of solutions are broad is probably relevant to start developing general predictive theories of AI and AI training dynamics, which we'd like to have so we can answer questions like "What goals is my AGI going to develop, given that it has architecture , and I'm training it with loss function on data set ?"
Warning: calculus and linear algebra incoming!
Say we have a loss function for a neural network , with some data points and labels , and we've trained the network to perfect loss, setting the parameters to some set .
Now, lets try to look at how much the loss changes when we make small changes to the parameters . To do this, we're going to perform a polynomial expansion of up to quadratic order. This isn't going to describe the function behaviour perfectly, but within a small neighbourhood of , it'll be mostly accurate.
The first order term vanished because (since we're at an optimum). is a vector with entries, specifying how much we're perturbing each parameter in the network from its value at the optimum . is going to be a matrix with entries.
So far so good. Now, if we wanted to work out how broad the basin is by looking at this matrix, how would we do it?
The matrix is filled with second derivatives. Or in other words, curvatures.
If you only had one parameter, this matrix would be a single number (the second derivative), telling you the curvature of the parabola you can fit to the optimum. And the curvature of a parabola determines its broadness, the closer the curvature is to zero, the broader the parabola, and with it the optimum, is.
If you have two parameters, you can imagine the optimum as a little valley in the loss function landscape. The valley has two curvatures that characterise it now, but those don't need to line up with our coordinate axes and . You could, for example, have one "principal direction" of curvature lie in the direction of parameter space, and one in the direction.
To find these principal directions and their curvatures, you perform an eigendecomposition of the matrix. The two eigenvalues of the matrix will be the curvatures, and the two eigenvectors the principal directions in parameter space characterised by these curvatures. Multiplying the parabola widths we get from the two, we obtain the basin volume.
We can do this for the -dimensional case as well. If some eigenvalues of the matrix are zero, that means that at least within a small neighbourhood of the optimum, the loss doesn't drop at all if you walk along the corresponding principal direction. So each such eigenvalue effectively increases the dimensionality of the optimum, from a point to a line, a line to an area, etc.
Now that's all fine and good, but it's not very interpretable. What makes the eigenvalues of this matrix be small or large for any given network? What determines how many of them are zero? What does all this have to do with interpretability? Let's find out!
Broadness and feature space
Let's explicitly work out what that matrix, which is also called the Hessian of the network, will look like, by calculating its entries in the -th row and -th column, meaning the derivative of the loss function with respect to the -th and -th parameters of the network:
The second term vanishes again, because the loss function is required to be convex at optima, so .
Look at this expression for a bit. If you've done some quantum mechanics before, this might remind you of something.
Let's go through a small example. If you have a network that takes in a single floating point number from the input , and maps it to an output using features and a constant, as in , the gradient of the output would look like:
So, each entry in the gradient corresponds to one feature of the network.
The Hessian matrix for this network would be
These are the inner products of the features over the training data set. They tell you how close to orthogonal the features are with respect to each other, and how big they are.
The idea of the inner product / norm and Hilbert space is that we treat the space of functions as a vector space, which can have basis vectors. The size and relative angles of these vectors are given by the inner product: you multiply the functions in question, and sum the result over the function domain.
This gives us a well defined notion of what constitutes "orthogonal" features: Two features are orthogonal if their norm is zero. Intuitively, this means they're separate, non-interacting components of the network mapping.
This generalises to arbitrary numbers of features in arbitrary numbers of layers, though features in earlier layers enter multiplied with derivatives of activation functions, since we're characterising their effect on the final layer.
Each constant feature is treated as multiplying a constant of . The Hessian of the network at the optimum contains the norm of the products of all the features in the network with each other.
So, back to our second order approximation of the loss function. When we find the eigendecomposition of our Hessian matrix, that corresponds to shifting into a basis in parameter space where all off-diagonals of the Hessian are zero. So, the eigenvectors of the Hessian correspond to linear combinations of network features that are orthogonal (in terms of norm) to each other.
Remember, the bigger the eigenvalues, which are the norms of these feature combinations, the bigger the curvature of the loss function is in that direction of parameter space, and the narrower the basin will be.
Eigenvalues of zero mean the network actually had less independent, orthogonal features inside it than it had parameters, giving you directions that are perfectly flat.
So: the broadness of the basin is determined by the number and size of independent features in the network. At least within second order polynomial approximation range.
That sure sounds like it's pointing out the precise mathematical notion of network "complexity", "fine-tuning" or "generality" that's relevant for the basin broadness.
It also neatly quantifies and extends the formalism found by Vivek in this post. If a set of features all depend on different independent parts of the input, they are going to have an orthogonal basis of size . If a set of features depends on the same parts of the input, they are a lot more likely to span a set of dimension less than in Hilbert space, meaning more zero eigenvalues. So, information loss about the input tends to lead to basin broadness.
Some potential implications
If we leave second order approximation range, features in different layers of the network no longer interact linearly, so you can't just straightforwardly treat them as one big set of Hilbert space basis vectors.
But features in the same layer can still be treated like this, and orthogonalised by taking the vector of neuron activations in the layer times its transpose, summed over the training dataset, to form a matrix of Hilbert norms like the one we got from the Hessian. Then, finding the eigenbasis of that matrix corresponds to finding a basis of linear combinations of features in the layer that are mutually orthogonal. Once the bases of two adjacent layers are found, the weight matrix connecting the two layers can also be transformed into the new basis.
This might give us a new perspective of the network. A basis in which each element in a layer might stand for one “independent” feature in that layer, instead of a neuron. Cancellations between the signals coming from different nodes could no longer occur. The connection weights might tell you exactly how much the network “cares” about the signals coming from each element.
This stands in contrast to the usual method (where we treat the neurons as the fundamental unit of the network, and use measures from network theory to try and measure clustering among neuron groups). As we've pointed out before, how big the weights between neurons are doesn't really indicate how much signal is flowing through their connection, making this a non-optimal strategy for interpreting your network.
This makes us suspect that the orthogonal feature basis of the layers might be close to the right framing to use, if we want to quantify how many "independent" calculations a neural network performs, and how it uses these calculations to make new calculations. In other words, this this formalisation might provide the basis for the fundamentally motivated modularity measure we're after.
How might we test this hypothesis?
In our previous post, we discussed the problem of polysemanticity: when a single neuron seems to represent multiple different concepts which aren't naturally bucketed for humans. This problem suggests that individual neurons might not be the best unit of analysis, and we could do better to use features as the basic unit with which to formalise general theories of neural networks. A natural question follows - if we try out graphing the "feature flows" in neural networks with this, do polysemantic neurons mostly disappear and every orthogonalised feature means one, understandable thing?
We plan to find out soon, by experimenting with this change of basis and related ones in small neural networks. If you're interested in helping out with these experiments, or you have any other suggestions for useful experiments to run to test out these ideas, please let us know by commenting below, or sending a private LW message.
Note that we've written this with a sum (and have continued this convention for the rest of the post), but the formalisation would work equally well with an integral. Additionally, for the rest of the post we're assuming is a scalar to make the math a little less complicated, but the formalism should hold for other output types too.
How this works out when you're not at a perfect optimum and this term stays around is something I'm unsure of, and its the reason we proposed experiment 8 in the experiment list on LW.
is zero because we required a perfect loss score, and loss functions are required to be convex.
I wrote out the Hessian computation in a comment to one of Vivek's posts. I actually had a few concerns with his version and I could be wrong but I also think that there are some issues here. (My notation is slightly different because for me the sum over x was included in the function I called "l", but it doesn't affect my main point).d2ldf2
I think the most concrete thing is that the function f - i.e. the `input-output' function of a neural network - should in general have a vector output, but you write things like
without any further explanation or indices. In your main computation it seems like it's being treated as a scalar.l(f(x))=l(f1(x),f2(x),…,fp(x)).
Since we never change the labels or the dataset, on can drop the explicit dependence on y(x) from our notation for l. Then if the network has p neurons in the final layer, the codomain of the function f is Rp (unless I've misunderstood what you are doing?). So to my mind we have:
Going through the computation in full using the chain rule (and a local minimum of the loss function l) one would get something like:Hess(L)(θ)=∑xJf(θ)T[Hess(l)(f(θ))]Jf(θ)
Vivek wanted to suppose that Hess(l) were equal to the identity matrix, or a multiple thereof, which is the case for mean squared loss. But without such an assumption, I don't think that the term∑xJf(θ)TJf(θ)
appears (this is the matrix you describe as containing "the L2 inner products of the features over the training data set.")
Another (probably more important but higher-level) issue is basically: What is your definition of 'feature'? I could say: Have you not essentially just defined `feature' to be something like `an entry of Jf(θ)'? Is the example not too contrived in that sense it clearly supposes that f has a very special form (in particular it is linear in the Θ variables so that the derivatives are not functions of Θ.)
(Moderation note: moved to the Alignment Forum from LessWrong.)