There was a recent Twitter thread about this. See here and here.
Optimizing for the outcome metric alone on some training distribution, without any insight into the process producing that outcome, runs the risk that the system won’t behave as desired when out-of-distribution. This is probably a serious concern to the system maintainers, even ignoring (largely externalized) X-risks.
Note that their improvement over Strassen on 4x4 matrices is for finite fields only, i.e. modular arithmetic, not what most neural networks use.
Here's a straightforward argument that phase changes are an artifact of AdamW and won't be seen with SGD (or SGD with momentum).
Suppose we have 101 weights all initialized to 0 in a linear model, and two possible ways to fit the training data:
The first is M=[1,1,…,1,1,0]. (It sets the first 100 weights to 1, and the last one to 0.)
The second is G=[0,0,…,0,0,1]. (It sets the first 100 weights to 0, and the last one to 1.)
(Any combination aM+bG with a+b=1 will also fit the training data.)
Intuitively, the first solution M memorizes the training data: we can imagine that each of the first 100 weights corresponds to storing the value of one of the 100 samples in our training set. The second solution G is a simple, generalized algorithm for solving all instances from the underlying data distribution, whether in the training set or not.
M has an L2 norm which is ten times as large as G. SGD, since it follows the gradient directly, will mostly move directly toward G as it's the direction of steeper descent. It will ultimately converge on the minimum norm solution 1101M+100101G. (Momentum won't change this picture much, since it's just smearing out each SGD step over multiple updates, and each individual SGD step goes in the direction of steepest descent.)
AdamW, on the other hand, is basically the same as Adam at first, since L2 weight decay doesn't do much when the weights are small. Since Adam is a scale-invariant, coordinate-wise adaptive learning rate algorithm, it will move at the same speed for each of the 101 coordinates in the direction which reduces loss, moving towards the solution 12M+12G, i.e. with heavy weight on the memorization solution. Weight decay will start to kick in a bit before this point, and over time AdamW will converge to (close to) the same minimum-norm solution 1101M+100101G as SGD. This is the phase transition from memorization to generalization.