NTK/GP Models of Neural Nets Can't Learn Features

by interstice4 min read22nd Apr 20217 comments

13

AI
Frontpage

Since people are talking about the NTK/GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/GP models of neural networks can't learn features. By 'feature learning' I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then 'fine-tuned' with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/GP models can't do it at all.

The reason for this is pretty simple. During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all. Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only 'transfer' that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn't provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well -- the distribution of functions computed by intermediate neurons doesn't change after conditioning on the outputs, so networks sampled from the GP posterior wouldn't be useful for transfer learning either.

This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. 'GP/NTK performs similarly to SGD on simple tasks' has been found before, but it tends to break down as the tasks become more complex.[3]

So are there any theoretical models of neural nets which are able to incorporate feature learning? Yes. In fact, there are a few candidate theories, of which I think Greg Yang's Tensor Programs is the best. I got all the above anti-NTK/GP talking points from him, specifically his paper Feature Learning in Infinite Width Neural Networks. The basic idea of this paper is pretty neat -- he derives a general framework for taking the 'infinite-width-limit' of 'tensor programs', general computation graphs containing tensors with a width parameter. He then applies this framework to SGD itself -- the successive iterates of SGD can be represented as just another type of computation graph, so the limit can be taken straightforwardly, leading to a infinite-width limit distinct from the NTK/GP one, and one in which the features computed by intermediate neurons can change. He also shows that this limit outperforms both finite-width nets and NTK/GP., and learns non-trivial feature embeddings. Two caveats: this 'tensor program limit' is much more difficult to compute than NTK/GP, so he's only actually able to run experiments on networks with very few layers and/or linear activations. And the scaling used to take the limit is actually different from that used in practice. Nevertheless, I think this represents the best theoretical attempt yet to capture the non-kernel learning that seems to be going on in neural nets.

To be clear, I think that the NTK/GP models have been a great advance in our understanding of neural networks, and it's good to see people on LW discussing them. However, there are some important phenomena they fail to explain. They're a good first step, but a comprehensive theoretical account of neural nets has yet to be written.[4]


  1. You might be wondering how it's possible for the output function to change if none of the individual neurons' functions change. Basically, since the output is the sum of N things, each of them only needs to change by O(1/N) to change the output by O(1), so they don't change at all in the wide-width limit(See also my discussion with johnswentworth in the comments) ↩︎

  2. Sort of. A more exact statement might be that the NTK can technically do transfer learning, but only trivially so, i.e. it can only 'transfer' to tasks to the extent that they are exactly the same as its original task. See this comment. ↩︎

  3. In fairness to the NTK/GP, they also haven't been tried as much on more difficult problems because they scale worse than neural nets in terms of data(D^2*(kernel eval cost) in number of data points, since you need to compute the kernel between all points). So it's possible that they could do better if people had the chance to try them out more, iterate improved versions, and so on. ↩︎

  4. I'll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting -- perhaps even implementing something like a simplicity prior over a large class of functions, which I'm pretty sure NTK/GP can't be ↩︎

AI2
Frontpage

13

12 comments, sorted by Highlighting new comments since Today at 2:31 AM
New Comment
I'll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting -- perhaps even implementing something like a simplicity prior over a large class of functions, which I'm pretty sure NTK/GP can't be

Wait, why can't NTK/GP be implementing a simplicity prior over a large class of functions? They totally are, it's just that the prior comes from the measure in random initialization space, rather than from the gradient update process. As explained here. Right? No?

There's an important distinction[1] to be made between these two claims:

A) Every function with large volume in parameter-space is simple

B) Every simple function has a large volume in parameter space

For a method of inference to qualify as a 'simplicity prior', you want both claims to hold. This is what lets us derive bounds like 'Solomonoff induction matches the performance of any computable predictor', since all of the simple, computable predictors have relatively large volume in the Solomonoff measure, so they'll be picked out after boundedly many mistakes. In particular, you want there to be an implication like, if a function has complexity , it will have parameter-volume at least .

Now, the Mingard results, at least the ones that have mathematical proof, rely on the Levin bound. This only shows (A), which is the direction that is much easier to prove -- it automatically holds for any mapping from parameter-space to functions with bounded complexity. They also have some empirical results that show there is substantial 'clustering', that is, there are some simple functions that have large volumes. But this still doesn't show that all of them do, and indeed is compatible with the learnable function class being extremely limited. For instance, this could easily be the case even if NTK/GP was only able to learn linear functions. In reality the NTK/GP is capable of approximating arbitrary functions on finite-dimensional inputs but, as I argued in another comment, this is not the right notion of 'universality' for classification problems. I strongly suspect[2] that the NTK/GP can be shown to not be 'universally data-efficient' as I outlined there, but as far as I'm aware no one's looked into the issue formally yet. Empirically, I think the results we have so far suggest that the NTK/GP is a decent first-order approximation for simple tasks that tends to perform worse on the more difficult problems that require non-trivial feature learning/efficiency.


  1. I actually posted basically the same thing underneath another one of your comments a few weeks ago, but maybe you didn't see it because it was only posted on LW, not the alignment forum ↩︎

  2. Basically, because in the NTK/GP limit the functions for all the neurons in a given layer are sampled from a single computable distribution, so I think you can show that the embedding is 'effectively finite' in some sense(although note it is a universal approximator for fixed input dimension) ↩︎

Ah, OK. Interesting, thanks. Would you agree with the following view:

"The NTK/GP stuff has neural nets implementing a "psuedosimplicity prior" which is maybe also a simplicity prior but might not be, the evidence is unclear. A psuedosimplicity prior is like a simplicity prior except that there are some important classes of kolmogorov-simple functions that don't get high prior / high measure."

Which would you say is more likely: The NTK/GP stuff is indeed not universally data efficient, and thus modern neural nets aren't either, or (b) NTK/GP stuff is indeed not universally data efficient, and thus modern neural nets aren't well-characterized by the NTK/GP stuff.

During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all, only the output function does.

Eh? Why does this follow? Derivatives make sense; the derivatives staying approximately-constant is one of the assumptions underlying NTK to begin with. But the functions computed by individual neurons should be able to change for exactly the same reason the output function changes, assuming the network has more than one layer. What am I missing here?

The asymmetry between the output function and the intermediate neuron functions comes from backprop -- from the fact that the gradients are backprop-ed through weight matrices with entries of magnitude O(). So the gradient of the output w.r.t itself is obviously 1, then the gradient of the output w.r.t each neuron in the preceding layer is O(), since you're just multiplying by a vector with those entries. Then by induction all other preceding layers' gradients are the sum of N random things of size O(1/N), and so are of size O() again. So taking a step of backprop will change the output function by O(1) but the intermediate functions by O(), vanishing in the large-width limit.

(This is kind of an oversimplification since it is possible to have changing intermediate functions while doing backprop, as mentioned in the linked paper. But this is the essence of why it's possible in some limits to move around using backprop without changing the intermediate neurons)

Ok, that's at least a plausible argument, although there are some big loopholes. Main problem which jumps out to me: what happens after one step of backprop is not the relevant question. One step of backprop is not enough to solve a set of linear equations (i.e. to achieve perfect prediction on the training set); the relevant question is what happens after one step of Newton's method, or after enough steps of gradient descent to achieve convergence.

What would convince me more is an empirical result - i.e. looking at the internals of an actual NTK model, trying the sort of tricks which work well for interpreting normal NNs, and seeing how well they work. Just relying on proofs makes it way too easy for an inaccurate assumption to sneak in - like the assumption that we're only using one step of backprop. If anyone has tried that sort of empirical work, I'd be interested to hear what it found.

The result that NTK does not learn features in the large N limit is not in dispute at all -- it's right there on page 15 of the original NTK paper, and indeed holds after arbitrarily many steps of backprop. I don't think that there's really much room for loopholes in the math here. See Greg Yang's paper for a lengthy proof that this holds for all architectures. Also worth noting that when people 'take the NTK limit' they often don't initialize an actual net at all, they instead use analytical expressions for what the inner product of the gradients would be at N=infinity to compute the kernel directly.

Alright, I buy the argument on page 15 of the original NTK paper.

I'm still very skeptical of the interpretation of this as "NTK models can't learn features". In general, when someone proves some interesting result which seems to contradict some combination of empirical results, my default assumption is that the proven result is being interpreted incorrectly, so I have a high prior that that's what's happening here. In this case, it could be that e.g. the "features" relevant to things like transfer learning are not individual neuron activations - e.g. IIRC much of the circuit interpretability work involves linear combinations of activations, which would indeed circumvent this theorem.

This whole class of concerns would be ruled out by empirical results - e.g. experimental evidence on transfer learning with NTKs, or someone applying the same circuit interpretability techniques to NTKs which are applied to standard nets.

I don't think taking linear combinations will help, because adding terms to the linear combination will also increase the magnitude of the original activation vector -- e.g. if you add together units, the magnitude of the sum of their original activations will with high probability be , dwarfing the O(1) change due to change in the activations. But regardless, it can't help with transfer learning at all, since the tangent kernel(which determines learning in this regime) doesn't change by definition.

What empirical results do you think are being contradicted? As far as I can tell, the empirical results we have are 'NTK/GP have similar performance to neural nets on some, but not all, tasks'. I don't think transfer/feature learning is addressed at all. You might say these results are suggestive evidence that NTK/GP captures everything important about neural nets, but this is precisely what is being disputed with the transfer learning arguments.

I can imagine doing an experiment where we find the 'empirical tangent kernel' of some finite neural net at initialization, solve the linear system, and then analyze the activations of the resulting network. But it's worth noting that this is not what is usually meant by 'NTK' -- that usually includes taking the infinite-width limit at the same time. And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD. That's what the above mathematical results mean -- the same mathematical analysis that implies that network training is like solving a linear system, also implies that the activations don't change at all.

They wouldn't be random linear combinations, so the central limit theorem estimate wouldn't directly apply. E.g. this circuit transparency work basically ran PCA on activations. It's not immediately obvious to me what the right big-O estimate would be, but intuitively, I'd expect the PCA to pick out exactly those components dominated by change in activations - since those will be the components which involve large correlations in the activation patterns across data points (at least that's my intuition).

I think this claim is basically wrong:

And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD.

There's a very big difference between "no change to first/second order" and "no change". Even in the limit, we do expect most linear combinations of the activations to change. And those are exactly the changes which would potentially be useful for transfer learning. And the tangent kernel not changing does not imply that transfer learning won't work, for two reasons: starting at a better point can accelerate convergence, and (probably more relevant) the starting point can influence the solution chosen when the linear system is underdetermined (which it is, if I understand things correctly).

I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task. If that's true, and NTKs can't be used for transfer learning, then that would imply that transfer learning in normal nets works for completely different reasons from good performance on the original task, and that good performance on the original task has nothing to do with learning features. Those both strike me as less plausible than these proofs about "NTK not learning features" being misinterpreted.

(I also did a quick google search for transfer learning with NTKs. I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.)

BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.

Hmm, so regarding the linear combinations, it's true that there are some linear combinations that will change by in the large-width limit -- just use the vector of partial derivatives of the output at some particular input, this sum will change by the amount that the output function moves during the regression. Indeed, I suspect(but don't have a proof) that these particular combinations will span the space of linear combinations that change non-trivially during training. I would dispute "we expect most linear combinations to change" though -- the CLT argument implies that we should expect almost all combinations to not appreciably change. Not sure what effect this would have on the PCA and still think it's plausible that it doesn't change at all(actually, I think Greg Yang states that it doesn't change in section 9 of his paper, haven't read that part super carefully though)

And the tangent kernel not changing does not imply that transfer learning won’t work

So I think I was a bit careless in saying that the NTK can't do transfer learning at all -- a more exact statement might be "the NTK does the minimal amount of transfer learning possible". What I mean by this is, any learning algorithm can do transfer learning if the task we are 'transferring' to is sufficiently similar to the original task -- for instance, if it's just the exact same task but with a different data sample. I claim that the 'transfer learning' the NTK does is of this sort. As you say, since the tangent kernel doesn't change at all, the net effect is to move where the network starts in the tangent space. Disregarding convergence speed, the impact this has on generalization is determined by the values set by the old function on axes of the NTK outside of the span of the partial derivatives at the new function's data points. This means that, for the NTK to transfer anything from one task to another, it's not enough for the tasks to both feature, for instance, eyes. It's that the eyes have to correlate with the output in the exact same way in both tasks. Indeed, the transfer learning could actually hurt the generalization. Nor is its effect invariant under simple transformations like flipping the sign of the target function(this would change beneficial transfer to harmful). By default, for functions that aren't simple multiples, I expect the linear correlation between values on different axes to be about 0, even if the functions share many meaningful features. So while the NTK can do 'transfer learning' in a sense, it's about as weak as possible, and I strongly doubt that this sort of transfer is sufficient to explain transfer learning's successes in practice(but don't have empirical proof).

I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task.

It's true that NTK/GP perform pretty closely to finite nets on the tasks we've tried them on so far, but those tasks are pretty simple and we already had decent non-NN solutions. Generally the pattern is '"GP matches NNs on really simple tasks, NTK on somewhat harder ones". I think the data we have is consistent with this breaking down as we move to the harder problems that have no good non-NN solutions. I would be very interested in seeing an experiment with NTK on, say, ImageNet for this reason, but as far as I know no one's done so because of the prohibitive computational cost.

I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.

Thanks for the link -- will read this tomorrow.

BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.

And thank you for engaging in detail -- I have also found this very helpful in forcing me to clarify(partially to myself) what my actual beliefs are.

Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all.
Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity.

I used to think it would be extremely unlikely for a randomly initialized neural net to contain a subnetwork that performs just as well as the entire neural net does after training. But the multi-prize lottery ticket results seem to show just that. So now I don't know what to think when it comes to what sorts of things are likely or unlikely when it comes to this stuff. In particular, is it really so unlikely that 'car detector' functions really do exist somewhere in the random jumble of a sufficiently big randomly initialized NN? Or maybe they don't exist right away, but with very slight tweaks they do?