The authors apply an AlphaZero-like algorithm to discover new matrix multiplication algorithms. They do this by turning matrix multiplication into a one-player game, where the state represents how far from correct the current output is, moves are algorithmic instructions, and the reward is -1 per step (plus a terminal reward of -rank(final state), if the final state is not a zero tensor). On small matrices, they find that AlphaTensor can discover algorithms that use fewer scalar multiplications than the best known human-designed matrix multiplication algorithms. They apply this to find hardware-specific matmuls (by adding an additional reward equal to -time to the terminal state) that have a 10-20% larger speedup than Strassen's algorithm on NVIDIA V100s and TPU V2s (saving 4%/7.5% wall clock time).

Paper abstract:

Improving the efficiency of algorithms for fundamental computations can have a widespread impact, as it can affect the overall speed of a large amount of computations. Matrix multiplication is one such primitive task, occurring in many systems—from neural networks to scientific computing routines. The automatic discovery of algorithms using machine learning offers the prospect of reaching beyond human intuition and outperforming the current best human-designed algorithms. However, automating the algorithm discovery procedure is intricate, as the space of possible algorithms is enormous. Here we report a deep reinforcement learning approach based on AlphaZero for discovering efficient and provably correct algorithms for the multiplication of arbitrary matrices. Our agent, AlphaTensor, is trained to play a single-player game where the objective is finding tensor decompositions within a finite factor space. AlphaTensor discovered algorithms that outperform the state-of-the-art complexity for many matrix sizes. Particularly relevant is the case of 4 × 4 matrices in a finite field, where AlphaTensor’s algorithm improves on Strassen’s two-level algorithm for the first time, to our knowledge, since its discovery 50 years ago. We further showcase the flexibility of AlphaTensor through different use-cases: algorithms with state-of-the-art complexity for structured matrix multiplication and improved practical efficiency by optimizing matrix multiplication for runtime on specific hardware. Our results highlight AlphaTensor’s ability to accelerate the process of algorithmic discovery on a range of problems, and to optimize for different criteria.

New Comment
11 comments, sorted by Click to highlight new comments since:

I'm surprised they got a paper out of this. The optimization problem they're solving isn't actually that hard at small sizes (like the example in Deepmind's post) and does not require deep learning; I played around with it just using a vanilla solver from scipy a few years ago, and found similar results. I assume the reason nobody bothered to publish results like Deepmind found is that they don't yield a big-O speedup on recursive decompositions compared to just using Strassen's algorithm; that was why I never bothered writing up the results from my own playing around.

[ETA: actually they did find a big-O speedup over Strassen, see Paul below.]

Computationally brute-forcing the optimization problem for Strassen's algorithm certainly isn't a new idea, and it doesn't look like the deep learning part actually made any difference. Which isn't surprising; IIUC researchers in the area generally expect that practically-useful further big-O improvements on matmul will need a different functional form from Strassen (and therefore wouldn't be in the search space of the optimization problem for Strassen-like algorithms). The Strassen-like optimization problem has been pretty heavily searched for decades now.

Their improved 4x4 matrix multiplication algorithm does yield improved asymptotics compared to just using Strassen's algorithm. They do 47 multiplies for 4x4 matrix multiplication, so after log(n)/log(4) rounds of decomposition you get a runtime of 47^(log(n) / log(4)), which is less than 7^(log(n) / log(2)) from Strassen.

(ETA: this improved 4x4 algorithm is only over Z/2Z, it's not relevant for real matrix multiplies, this was a huge misreading. They also show an improvement for some rectangular matrices and 11 x 11 matrix multiplication, but those don't represent asymptotic improvements they just deal with awkwardness from e.g. 11 being prime. They do show an improvement in measured running time for 8k x 8k real multiplies on V100, but that seems like it's just be weird XLA stuff rather than anything that has been aggressively optimized by the CS community. And I'd expect there to be crazier algorithms over Z/2Z that work "by coincidence." So overall I think your comment was roughly as right as this one.)

Of course this is not state of the art asymptotics because we know other bigger improvements over Strassen for sufficiently large matrices. I'm not sure what you mean by "different functional form from Strassen" but it is known that you can approximate the matrix multiplication exponent arbitrarily well by recursively applying an efficient matrix multiplication algorithm for constant-sized matrices.

People do use computer-assisted search to find matrix multiplication algorithms, and as you say the optimization problem has been studied extensively. As far as I can tell the results in this paper are better than anything that is known for 4x4 or 5x5 matrices, and I think they give the best asymptotic performance of any explicitly known multiplication algorithm on small matrices. I might be missing something, but if not then I'm quite skeptical that you got anything similar.

As I mentioned, we know better algorithms for sufficiently large matrices. But for 8k x 8k matrices in practice I believe that 1-2 rounds of Strassen is state of the art. It looks like the 47-multiply algorithm they find is not better than Strassen in practice on gpus at that scale because of the cost of additions and other practical considerations. But they also do an automated search based on measured running time rather than tensor rank alone, and they claim to find an algorithm that is ~4% faster than their reference implementation of Strassen for 8k matrices on a v100 (which is itself ~4% faster than the naive matmul).

This search also probably used more compute than existing results, and that may be a more legitimate basis for a complaint. I don't know if they report comparisons between RL and more-traditional solvers at massive scale (I didn't see one on a skim, and I do imagine such a comparison would be hard to run). But I would guess that RL adds value here.

From their results it also looks like this is probably practical in the near future. The largest win is probably not the 4% speedup but the ability to automatically capture Strassen-like gains while being able to handle annoying practical performance considerations involved in optimizing matmuls in the context of a particular large model. Those gains have been historically small, but language models are now large enough that I suspect it would be easily worth doing if not for the engineering hassle of writing the new kernels. It looks to me like it would be worth applying automated methods to write new special-purpose kernels for each sufficiently large LM training run or deployment (though I suspect there are a bunch of difficulties to get from a tech demo to something that's worth using). (ETA: and I didn't think about precision at all here, and even a 1 bit precision loss could be a dealbreaker when we are talking about ~8% performance improvements.)

Note that their improvement over Strassen on 4x4 matrices is for finite fields only, i.e. modular arithmetic, not what most neural networks use.

That's a very important correction. For real arithmetic they only improve for rectangular matrices (e.g. 3x4 multiplied by 4x5) which is less important and less well studied.

In fact, the 47 multiplication result is on , so it's not even general modular arithmetic. 

That being said, there are still speedups on standard floating point arithmetic both in terms of number of multiplications, but also wall clock time. 

Ah, guess I should have looked at the paper. I foolishly assumed that if they had an actual big-O improvement over Strassen, they'd mention that important fact in the abstract and blog post.

Any improvement over Strassen on 4x4 matrices represents an asymptotic improvement over Strassen.

ETA: this claim is still right, but they only asymptotically improve Strassen over F2, not for normal matrix multiplies. So your intuition about what they would definitely mention in the abstract / blog post may well be right.

Indeed. Unfortunately, I didn't catch that when skimming.

(Most of my comment was ninja'ed by Paul)

I'll add that I'm pretty sure that RL is doing something. The authors claim that no one has applied search methods for 4x4 matrix multiplication or larger, and the branching factor on brute force search without a big heuristic grows something like the 6th power of n? So it seems doubtful that they will scale. 

That being said, I agree that it's a bit odd to not do a head-to-head comparison at equal compute, though. The authors just cite related work (which uses much less compute) and claims superiority over them.

Beyond that it seems tensorflow and pytorch don't even bother to use Strassen's algorithm over N^3 matrix multiplication (or perhaps something Strassen-like is used in the low-level GPU circuits?).

See also Scott Aaronson on experimental computational complexity theory (haha its a joke wait no maybe he's not joking wait what?)

The meeting ended with a “Wild & Crazy Ideas Session,” at which I (naturally) spoke. I briefly considered talking about quantum gravity computing, closed timelike curves, or quantum anthropic postselection, but ultimately decided on something a little less mainstream. My topic was “Experimental Computational Complexity Theory,” or “why do theoretical physicists get $8-billion machines for the sole purpose of confirming or refuting their speculative ideas, whereas theoretical computer scientists get diddlysquat?” More concretely, my proposal is to devote some of the world’s computing power to an all-out attempt to answer questions like the following: does computing the permanent of a 4-by-4 matrix require more arithmetic operations than computing its determinant?