Thanks to Kshitij Sachan for helpful feedback on the draft of this post.
If you train a neural network with SGD, you can embed within the weights of the network any state machine: the network encodes the state st in its weights, uses it and the current input to computes the new state st+1, and then uses SGD to encode the new state in its weights.
We present an algorithm to embed a state machine into a pytorch module made out of neural network primitives (+, x, sigmoid, ...) but using a non-standard architecture. We use it for several toy experiments, such as this experiment where we translate a simple state machine into a neural network, and make the output of the network vary when we reach an intermediate state:
In this post, we explain how to embed a state machine into a neural network, discuss some of the limitations of our algorithm and how these limitations could be overcome, and present some experimental results.
Our embedding is quite efficient. If you’re allowed to use stopgrad (a function that stops the gradient from passing through), we only need one parameter to store one real number, and storing only requires a small constant number of floating point operations. If you aren’t allowed to use stop grad, we can’t encode as efficiently.
We did this work for two reasons:
Simple notebooks with our experimental results can be found here: https://github.com/redwoodresearch/Gradient-Machine
As a simple example, suppose you’re training a model N on the regression problem of matching the function f, with loss
We want to piggyback on the SGD update step to make an update of our choice to a particular parameter s.
Given some arbitrary function ϕ(x,s) defining a state machine, and access to the function f to match, suppose we define our model as follows:
Then, our stopgraded computation of the input and the actual label will cancel out, so the loss will be:
Thus, SGD will make s closer to ϕ(x,s). And even better, if the learning rate is 1, after one step of SGD:
We have successfully updated s to ϕ(x,s)!
Let’s say that instead of having access to the ground truth function, you have a very good predictor of the labels Mθ:X→Y, and that you approximate the state machine ϕ:(X,S)→S you wish to embed with a neural network Φ. Then, you can build a Δ function which computes which deviation from the optimal value would lead to appropriate change of state, and combine them in the following neural network N:
For example, if the loss function is the mean-squared error L=12(N(x)−f(x))2, then, if M and Φ are good enough, the following function would provide the desired change: Δ(^y,s′,s)=stopgrad(^y)+1√α(s−stopgrad(s′)), where α is the learning rate. (We can show this using the same derivation as the one from the previous section.)
If dims>dimy=1, we can modify the Δ as follows: Δ(^y,s′,s)=stopgrad(^y)+12√αn(∑i[si−stopgrad(s′i)]2), where n(x)=x√stopgrad(x).
Similar Δ functions could be found for other loss functions: the idea is to make s responsible for an increase in loss of the right magnitude.
A key part of most Δ functions is stopping the gradient from modifying either Φ or Mθ: without the stopgrad (which prevents propagating the gradient through s′ and y), either of those neural networks could learn to anticipate the action of Δ to lower the loss. For instance, with the MSE example above, and ignoring the gradient through Φ, M would eventually learn to predict y−1√α(s−stopgrad(s′)) rather than f(x), removing the gradient used to update s.
Stopgrad function can be approximated if sigmoid activations σ(x)=11+exp(−x) are natively available: using the plateaus on either side of the sigmoid function, the gradient from a binary variable can be mostly removed by passing a binary variable b through sigmoid_stopgrad=σ((b−12)∗1000) (which will have tiny gradients, even if the slope p=1000 is a differentiable parameter). This can be extended to variables which can take a finite number of values: using d layers with n sigmoid neurons, you can build a staircase with around nd steps, and because each step is flat, it mostly blocks the gradients from flowing through.
However, a neural network would need a very large amount of control over its own activations to be able to create such a pattern.
Stopgrad can also be implemented by multiplying a big weight by a tiny constant, e.g. replace y(x)=θx by y(x)=ϵθ′x (using the initalization θ′=θ/ϵ). This multiplies the effective learning rate by ϵ2, since after one step of SGD on x0, y(x)=ϵ(θ′−αdLdyx0ϵx0)x. This means the gradient with respect to θ′ can be ignore if the effective learning step is small enough relative to the number of training step used in practice, or if the floating point numbers underflow during the update of θ′ . This might be achieved with regular neural network primitives by using layernorm, which will make even large magnitude inputs small, thus decreasing the effective learning rate - though the denominator in layernorm also depends on its input, which may make it difficult to use it in practice.
Implementing stopgrad with ReLUs is possible as long as the slope is a fixed large parameter P: step(x)=ReLU(Px)−ReLU(P(x−ϵ)) is a sharp step function of height Pϵ and width ϵ, which can be used to build stopgrad. But this construction does not work in usual MLPs, since the slopes of both ReLU functions are individually tunable, which means that the derivate of step(x) with respect to the slope parameter in the first ReLU is not negligible (for x>0) (which is not the case for the single slope parameter in a sigmoid function).
But we think it might be possible for models to implement stopgrad in neural networks in a better way. For example, it might be possible to build stopgrad without the need for a sigmoid activation function. Thus, the complexity of our implementation may overestimate the complexity of the implementation of stopgrad, that could be used in practice by gradient-hacking AIs.
If the batch size is greater than 1, then the network Φ approximating the state machine is still limited to operations on individual inputs, but updates to the state will be done in batches. The algorithm described above can be extended to state machines with aggregatable updates: the state machine still operates on individual inputs, but ϕ(xB,ϕ(xB−1,…,ϕ(x1,s0)))−s0≈C∑i(ϕ(xi,s0)−s0), (where C is a constant known when creating the embedding and B is the batch size), because applying the algorithm naively will lead to the target states s’ obtained for each input in the batch to average out.
Updates are aggregatable for state machines that operate on the number or the proportion of inputs on which a certain property holds. But this prevents the embedding of state machines like the ones computing maxxP(x) for P a scalar property X→R.
The idea described in this post can’t be applied to optimizers which have a state like Adam or SGD with momentum, because Δ would need to know the state of the optimizer corresponding to s in order to create the appropriate output. This means each state tracked by the network would use two states (the state itself and the state of the optimizer corresponding to this state), which seems to be impossible (though further research might find a way to adapt ideas described here to optimizers with a state).
Imperfect predictors limit our ability to change s precisely. If we can afford a very large loss, simply scaling s allows meaningful manipulation of the state: given a regression problem where the label y* is known to be in [-1,1], and using the Δ function described above for the MSE loss applied to the constant predictor M(x)=0, the update becomes s←s′+f(s)√α, which conveys meaningful information as long as s′−s≫√α. This naive strategy leads to a very high loss and huge loss spikes. But if we tolerate noisy writes, and if we have an accurate (but not perfect) predictor, the amount of additional loss required can be relatively small.
Our toy task is the following: a model should output 1 if the data point (x1, x2) is at L1 distance greater than 1 from the origin, and 0 otherwise, and is scored with a MSE loss. The state should track the number of data points seen where x1 > 3 (which is an aggregatable state machine). The training data is normally distributed. We use a depth-1 width-40 staircase as stopgrad, such that the final model is only made out of +, × and σ(as well as some indexing and concatenation).
We compare two scenarios: the case where the model M has its weights hard-coded to the exact problem solution, and the case where the weights are learned (from a random initialization, using SGD on 500k normally distributed inputs). In both cases, the model is an MLP with 4 intermediate neurons. We see that small imperfections in M leads to inaccurate state tracking.
We use the same methodology as for the previous toy model, except that we use two states, we use the variation of the Δ used for multiple variables (which in particular means that we use the divide operator on top of +, ×, and σ), and we use the pytorch default built-in stopgrad function (detach) for simplicity. The two states are s0:“number of points with x1>3” and s1:“number of points with x2>3 when s0>10”.
We then use some logic (implemented with sigmoids) to change the output of the network when s1>10 & s2 ≤ 10. We chose a particular change which didn't affect the gradient (we added a constant inside the n function) and caused a constant loss increase, but in principle any modification is possible as long as it is accounted for by Δ if the state needs to be changed.
Because we have a very good predictor, we can scale down the state by a large factor (e.g. “count=7” is represented by 0.07 instead of 7), which leads to smaller loss spikes when the state is changed.
To validate the general methodology in a more realistic setting, we study the following setup: we use the MNIST dataset, and we flip a proportion p of the images in each batch along their vertical axis. We want to see if the model can have an inner parameter which matches p. To do so, we train two small CNN: M, which classifies the images (trained with the cross-entropy loss, ending at 94% top-1 accuracy), and Φ, which classifies images as flipped or not flipped (trained with the binary cross-entropy loss, and ending at 99% top-1 accuracy).
Here, M and Φ are noisy, so we don’t try to get an exact Δ, and simply use Δ(y,s′,s)=stopgrad(y)(1−s−stopgrad(s′))γ, where y are the logits, which incentivizes s to be close to s′ (the proportion of image flipped) in the where γ is the strength of the incentive.
We find that this technique works well across a wide range of strengths. We report 1-sigma standard deviation over the p_flip value measured after the end of each of the 5 epochs we run the model on.
We have embedded a state machine into a neural network build with neural network primitives, but with a non-standard architecture. Future work could figure out how to embed state machines inside regular MLPs or Transformers.
Future work could make our discussion of "What if you don’t have a perfect predictor of the labels?" more quantitative, and thereby establish a bound on how much computation can happen using imperfect predictors. This bound might be enough to rule out this kind of behavior happening in current models and future models.
The expression is different than the 1-variable case because we need to tie si to s′i. To do so, we sum the (si−s′i)2. But this would leave us with 2(si−s′i)∑i[si−s′i]2, which is why need the n function.