TLDR; This is the fourth main post of Distilling Singular Learning Theory which is introduced in DSLT0. I explain how to relate SLT to thermodynamics, and therefore how to think about phases and phase transitions in the posterior in statistical learning. I then provide intuitive examples of first and second order phase transitions in a simple loss function. Finally, I experimentally demonstrate phase transitions in two layer ReLU neural networks associated to the node-degeneracy and orientation-reversing phases established in DSLT3, which we can understand precisely through the lens of SLT.
In deep learning, the terms "phase" and "phase transition" are often used in an informal manner to refer to a steep change in a metric we care about, like the training or test loss, as a function of SGD steps, or alternatively some hyperparameter like the number of samples from the truth .
But what exactly are the phases? And why do phase transitions even occur? SLT provides us a solid theoretical framework for understanding phases and phase transitions in deep learning. In this post, we will argue that in the Bayesian setting,
A phase of the learning process corresponds to a singularity of , and a phase transition corresponds to a drastic change in the posterior as a function of a hyperparameter .
The hyperparameter could be the number of samples from the truth , some way of varying the model function or something about the true distribution , amongst other things. At some critical value , we recognise a phase transition as being a discontinuous change in the free energy or one of its derivatives, for example the generalisation error .
In this post, we will present experiments that observe precise phase transitions in the toy neural network models we studied in DSLT3, for which we understand the set of true parameters and therefore the phases. By the end of this post, you will have a framework for thinking about phase transitions in singular models and an intuition for why SLT predicts them to occur in learning.
This subsection is modelled on [Callen, Ch9], but it is only intended to be a high level discussion of the concepts grounded in some basic physics - don't get too bogged down in the details of the thermodynamics.
Fundamentally, a phase describes an aggregate state of a complex system of many interacting components, where the state retains particular qualities with variations in some hyperparameter. To explain the concept in detail, it is natural to start in physics (thermodynamics in particular), where these ideas originally arose. But there is a deeper reason to build from here: every human has an intuitive understanding of the phases of water and how they change with temperature [1], which serves as the base mental model for what a phase is.
One of the main goals of thermodynamics is to study how the equilibrium state of a system changes as a function of macroscopic parameters. In the case of a vessel of water at 1atm of pressure in constant contact with a thermal and pressure reservoir, the equilibrium state of the system corresponds to a state that is minimised by the Gibbs free energy [2]. The phases, then, are the equilibrium states, which describe qualitative physical properties of the system. The states of matter - solid, liquid, and gas - are all phases of water, which are characterised by variables like their volume and crystal structure. As anybody that has boiled water before knows, these phases undergo transitions as a function of temperature. Let's make this more precise.
Consider a system of water molecules moving in a 2D container, each with equal mass . To each particle we can associate a set of microstates describing its physical properties at a point in time, for example its position and its velocity . In our discussion we will simply focus on the position, which we will relabel (for reasons that will become clear), so our configuration space of possible microstates is
Since it is physically infeasible to know or model the positions of all molecules, we instead reason about the dynamics of the system by calculating macroscopic variables associated to a microstate, for example the temperature or total volume of the molecules. We will focus on the volume of a microstate . Importantly, a macroscopic state is an aggregate over the system (for example, temperature being related to average squared velocity), meaning there are many possible configurations of microstates that result in the same macrostate. To this end, we can define regions of our configuration space according to their volume ,
In our toy example, we want to study how the system changes as a function of temperature, which we will denote with . In a Gibbs ensemble, we can associate an energy functional, the Hamiltonian , to any given microstate at temperature . The fundamental postulate of such a Gibbs ensemble is that probability of the system being in a particular micro state is determined by a Gibbs distribution [3]
This should look pretty familiar from our statistical learning setup! Indeed, we can then calculate the free energy of the ensemble for different volumes at temperature ,
For a Gibbs ensemble, the equilibrium state of a given system is that state which minimises the free energy. In the context of bringing water to a boiling point, there are two minima of the free energy characterised by the liquid and gaseous states, which for ease we will characterise by their volumes and . Then the equilibrium state changes at the critical temperature ,
Importantly, while small variations in the temperature away from will change the free energy of each state, it will not change the configuration of these minima with respect to the free energy. In other words, the system will still be a liquid for any - its qualitative properties are stable. This is the content of a phase.
A phase of a system is a region of configuration space that minimises the free energy, and is invariant to small perturbations in a relevant hyperparameter . Typically, phases are distinguished by some macroscopic variable, in our case the volume distinguishing subsets . More generally though, a phase describes some qualitative aggregate state of a system - like, as we've discussed in our example, the states of matter.
In some sense, you can define a phase to be any region that induces an equilibrium state with qualities you care about. But what makes phases a powerful concept is their relation to phase transitions - when there is a sudden jump in which state is preferred by the system.
Phase transitions are changes in the structure of the global minima of the free energy, and often arise as non-analyticities of . This is a fancy way of saying they correspond to discontinuities in the free energy or one of its derivatives [4].
A first order phase transition at a critical temperature corresponds to a reconfiguration of which phase is the global minima of the free energy.
As we discussed above, heating water to boiling point is a classic example of a first order phase transition.
Two examples of second order phase transitions are where:
(Note that we have not given a full classification of phase transitions here, because to do so one needs to study the possible types of catastrophes that can occur, as presented in [Gilmore]).
The notation and concepts in the previous section were not presented without reason. For starters, the Gibbs ensemble view of statistical learning is actually quite a rich analogy because, when the prior is uniform, the (random) Hamiltonian is equal to the empirical KL divergence [5],
The configuration space of microstates of the physical system then corresponds to parameter space with microstates given by different parameters . This means the posterior is equivalent to the Gibbs probability distribution of the system being in a certain microstate, meaning the definition of free energy is identical. So, what exactly are the phases then?
In statistical learning then,
A phase corresponds to a local neighbourhood containing a singularity of interest.
To say that minimises the free energy is equivalent to saying that it has non-negligible posterior mass. The reason for this, as we explored in DSLT2, is that the singularity structure of a most singular optimal point dominates the behaviour of the free energy, because it minimises the loss and has the smallest RLCT .
You can, in principal, define a phase to be any region of . But the analysis of phases in the posterior only gets interesting when you have a set of phases that have fundamentally different geometric properties. The free energy formula tells us that these geometric properties correspond to different accuracy-complexity tradeoffs.
Consequently, in statistical learning, Watanabe states in [Wat18, ] that
A phase transition is a drastic change in the geometry of the posterior as a function of a hyperparameter .
Our definitions of first and second order phase transitions carry over perfectly from the physics discussion above.
It's important to clarify here that phase transitions in deep learning have many flavours. If one believes that SGD is effectively just "sampling from the posterior", then the conception that phase transitions are related to changes in the geometry of the posterior carries over. There is, however, one fundamentally different kind of "phase transition" that we cannot explain easily with SLT: a phase transition of SGD in time, i.e. the number gradient descent steps. The Bayesian framework of SLT does not really allow one to speak of time - the closest quantity is the number of datapoints , but these are not equivalent. We leave this gap as one of the fundamental open questions of relating SLT to current deep learning practice. [6]
The hyperparameter can affect any number of objects involved in the posterior. Remembering that the posterior is
we could include hyperparameter dependence in any of:
In DSLT2 we studied an example of a very simple one-dimensional curve and got a feel for how the accuracy and complexity of a singularity affect the free energy of different neighbourhoods. Having now learned about phase transitions, we can cast new light on this example.
Example 4.1: Consider again a KL divergence given by
where and are the singularities, but the accuracy of is worse, . Then we can identify two phases corresponding to the two singularities,
for some radius such that the accuracy of is better, but the complexity of was smaller,
As the hyperparameter [7] varies, we see a first order phase transition at the critical value of where the two free energy curves intersect, causing an exchange which phase is the global minima of the free energy. As we argued in that post, this is largely due to the accuracy-complexity tradeoff of the free energy. Notice also how the free energy of the global minima is non-differentiable at , showing an example of the "non-analyticity" of that we mentioned above.
Example 4.2: We can modify our example slightly to observe a second order phase transition. Let's consider
where is a hyperparmeter that shifts the two singularities and towards the origin. We will continue to label these phases and , noting their dependence. [8]
Thus, at the two phases will merge and the KL divergence will be
Therefore, at the singularity will have an RLCT of
There is a new most singular point caused by the merging of two phases! Again, we can visually depict this phase transition:
Now that we have the basic intuitions of SLT and phase transitions down pat, let's apply these concepts to the case of two layer feedforward ReLU neural networks.
The main claim of this sequence is that Singular Learning Theory is a solid theoretical framework for understanding phases and phase transitions in neural networks. It's now time to make good on that promise and bring all of the pieces together to understand an actual example of phase transitions in neural networks. The full details of these experiments are explained in my thesis, [Carroll, ], but I will briefly outline some points here for the interested reader. All notation and terminology is explained in detail in DSLT3, so use that section as a reference.
If you are uninterested, just skip to the next subsection to see the results.
We will consider a (model, truth) pair defined by the simple two layer feedforward ReLU neural network models we studied in DSLT3. Phase transitions will be induced by varying true distribution by a hyperparameter , meaning . Since we have a full classification of from DSLT3, we understand the phases of the system, and therefore we want to study how their differing geometries affect the posterior. As we explained in that post, the scaling and permutation symmetries are generic (they occur for all parameters ), but the node-degeneracy and orientation-reversing symmetries only occur under precise configurations of the truth. Thus, we are interested in studying the how the posterior changes as we vary the truth to induce these alternative true parameters - the phases of our setup.
The posterior sampling procedure uses an MCMC variant called HMC NUTS, which is brilliantly explained and interpreted here. Estimating precise nominal free energy values, and particularly those of the RLCT , using sampling methods is currently very challenging (as explained in [Wei22]). So, for these experiments, our inference about phases and phase transitions will be based on visualising the posterior and observing the posterior concentrations of different phases. With this in mind, the posteriors below are averaged over four trials, 20,000 samples each, for each fixed true distribution defined by . (Bayesian sampling is very computationally expensive, even in simple settings).
To isolate the phases we care about, we can use the fact that the scaling symmetry and permutation symmetries of our networks are generic. To this end we will normalise the weights by defining the effective weight [9], which preserves functional equivalence [10]. We will say a node is degenerate if . We also project different node indices on to the same axes as follows:
The prior on inputs is uniform on the square , and the prior on parameters is the standard multidimensional normal .
In this experiment we will see a first order phase transition induced by deforming a true network from having no degenerate nodes to having one (possibility of a) degenerate node, as discussed in DSLT3 - Node Degeneracy. This example will reinforce the key messages of Watanabe's free energy formula: true parameters are preferred according to their RLCT, and at finite non-true parameters can be preferred due to the accuracy-complexity tradeoff.
We are going to consider a model network with nodes,
and a realisable true network with nodes, which we will denote by to signify its hyperparameter dependence (and distinguish it from the next experiment),
The true weights rotate towards one another by a hyperparameter , so [11]
As we explained in DSLT3, we can depict the function and its activation boundaries pictorially:
At , the truth could be expressed by a network with only one node, ,
This degeneracy is what we are interested in studying. The WBIC tells us to expect the posterior to prefer the one-degenerate-node configuration since it has less effective parameters. [12]
To identify our phases, at there are two possible configurations of the effective model weights that are true parameters:
To study these configurations we thus define phases based on annuli in the plane centred on the circle of radius with annuli radius of ,