Mode collapse in RL may be fueled by the update equation

9Jacob Hilton

3Oliver Sourbut

2Alex Turner

1Oliver Sourbut

2Alex Turner

2Charlie Steiner

2Alex Turner

New Comment

7 comments, sorted by Click to highlight new comments since: Today at 4:37 PM

I think KL/entropy regularization is usually used to prevent mode collapse partly because it has nice theoretical properties. In particular, it is easy to reason about the optimal policy for the regularized objective - see for example the analysis in the paper Equivalence Between Policy Gradients and Soft Q-Learning.

Nevertheless, action-dependent baselines do appear in the literature, although the story is a bit confusing. This is my understanding of it from some old notes:

- The idea was explored in Q-Prop. But unlike you, their intention was not to change the optimal policy, but rather to reduce the variance of the policy gradient. Therefore they also incorporated an additional term to cancel out the bias introduced by the action-dependent baseline. (Incidentally, perhaps this analysis is also relevant to understanding ACTDE.)
- Later, The Mirage of Action-Dependent Baselines showed that in fact the variance reduction due the action-dependent baseline was negligible, and the entire benefit of Q-Prop was essentially due to a bug! The implementation normalized advantage estimates, but failed to apply the same adjustment to the bias-correction term, which turned out to be independently helpful because it's essentially the DDPG training objective.

I like the philosophical and strategic take here: let's avoid wireheading, arbitrary reinforcement strength is risky^{[1]}, hopefully we can get some values-caring-about-human-stuff.

The ACTDE seems potentially a nice complement/alternative to entropy^{[2]} regularisation for avoiding mode collapse (I haven't evaluated deeply). I think you're misdiagnosing a few things though.

Overall I think the section about oscillating advantage/value estimation is irrelevant (interesting, but unrelated), and I think you should point the finger less at PPO and advantage estimation per se and more at exploration at large. And you might want to flag that *too much* exploration/randomness can also be an issue!

Though note that ideally, once we actually know with confidence what is best, we

*should*be near-greedy about it, rather than softmaxing! Say it was 'ice cream' vs 'slap in the face'. I would infinitely (linearly in time) regret softmaxing over that for eternity. As it stands I think humanity is very far from being able to safely aggressively greedily optimise really important things, but this is at least a consideration to keep in mind. ↩︎Incidentally,

*KL divergence*regularisation is not primarily for avoiding mode collapse AFAIK, it's for approximate trust region constraints - which may incidentally help to avoid mode collapse by penalising large jumps away from initially-high-entropy policies. See the TRPO paper. Entropy regularisation directly addresses mode collapse. ↩︎

Though note that ideally, once we actually know with confidence what is best, we

shouldbe near-greedy about it, rather than softmaxing!

I disagree. I don't view reward/reinforcement as indicating what is "best" (from our perspective), but as chiseling decision-making circuitry into the AI (which may then decide what is "best" from its perspective). One way of putting a related point: I think that we don't need to infinitely reinforce a line of reasoning in order to train an AI which reasons correctly.

(I want to check -- does this response make sense to you? Happy to try explaining my intuition in another way.)

The problem is that this advantage can oscillate forever.

This is a pretty standard point in RL textbooks. But the culprit is the learning rate (which you set to be 1 in the example, but you can construct a nonconverging case for any constant )! The advantage definition itself is correct and non-oscillating, it's the *estimation* of the expectation using a moving average which is (sometimes) at fault.

Oscillating or nonconvergent value estimation is not the cause of policy mode collapse.

The advantage definition itself is correct and non-oscillating... Oscillating or nonconvergent value estimation is not the cause of policy mode collapse.

The advantage is (IIUC) defined with respect to a given policy, and so the advantage can oscillate and then cause mode collapse. I agree that a constant learning rate schedule is problematic, but note that ACTDE converges even with a constant learning rate schedule. So, I would indeed say that oscillating value estimation caused mode collapse in the toy example I gave?

Is this identical to training the next-to-last layer to predict the rewards directly, and then just transforming those predictions to get a sample? Have you considered going whole hog on model-based RL here?

I'd be interested in avoiding mode collapse in cases where that's not practical, like diffusion models. Actually, could you choose a reward that makes diffusion models equivalent to MCMC? Probably no good safety reason to do such a thing though.

Is this identical to training the next-to-last layer to predict the rewards directly, and then just transforming those predictions to get a sample?

In the tabular case, that's equivalent given uniform . Maybe it's also true in the function approximator PG regime, but that's a maybe -- depends on inductive biases. But often we want a pretrained (like when doing RLHF on LLMs), which isn't uniform.

TL;DR:We present an advantage variant which, in certain settings, does not train an optimal policy, but instead uses a fixed reward to update a policy a fixed amount from initialization. Non-tabular empirical results seem mixed: The policy doesn't mode-collapse, but has unclear convergence properties.Summary:Many policy gradient methods allow a network to extract arbitrarily many policy updates from a single kind of reinforcement event (e.g. for outputting tokens related to weddings). Alex proposes a slight modification to the advantage equation, called "action-conditioned TD error" (ACTDE). ACTDE ensures that the network doesn't converge to an "optimal" policy (these almost always put infinite logits on a single action). Instead, ACTDE updates the network by a fixed number of logits.For example, suppose R(pizza)=10 and R(cookies)=11. In this case, PPO converges to a policy which puts arbitrarily many logits on cookies, even though the reward difference is small. By contrast, under ACTDE, the network converges to the softmax-over-reward policy {pizza: 27%, cookies: 73%}, which seems more reasonable.

Then, Michael Einhorn shares initial results which support Alex's theoretical predictions. Using a similar architecture and Q-head loss function to ILQL for a small transformer trained in a prisoner's dilemma, Michael Einhorn collected initial data on ACTDE. Unlike PPO, ACTDE-trained policies did not mode collapse onto a single action and instead learned mixed strategies.

We're interested in additional experiments on ACTDE. We hope that, by using ACTDE instead of advantage, we can automatically mitigate "reward specification" issues and maybe even reduce the need for a KL penalty term. That would make it easier to shape policies which do what we want.

## The advantage equation implies arbitrary amounts of update on a single experience

In PPO, the optimization objective is proportional to the advantage given a policy π, reward function R, and on-policy value function vπ:

Aπ(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−vπ(s).^{[1]}Alex thinks this equation is actually pretty messed up, although it looked decent at first. The problem is that this advantage can oscillate forever. To explain, let's consider a simple bandit problem—one state ("We had a") and two actions ("wedding" and "party") with rewards R(“We had a wedding”)=1 and R(“We had a party”)=.5.

The failure which happens is:

This continues to happen, which means that "wedding" gets arbitrarily high logits.

This flaw is easiest to see formally. Initialize the t=0 tabular value function vπ0 to 0, and the policy π0 to be 50/50 for “party”/“wedding”. Let γ=1, and we update the value function v using tabular TD learning (with learning rate α=1). So, for example, if the system takes the “wedding” action, its new value function vπ1(s)=1. If the system then takes the “party” action, the value snaps back to vπ2(s)=.5.

^{[2]}The policy update rule is: If the advantage Aπ(s,a)=n, then action a becomes n bits more probable under π (i.e. we add n to π's logits on a). So, if π0(s,“ wedding”)=.5 and advantage Aπ0(s,“ wedding")=1, then π1(s,“ wedding”)=.73.

Episode-by-episode:

With probability 1 as t→∞, πt(wedding)→1. You might think this is good, since wedding is in fact “optimal” at that state. This does not seem good. Here are a few kinds of explanations for why:

feedbackabout what kinds of completions are good (or, more abstractly, about what kinds of computations are good to run inside the network). We want a single situation to provide afinite amount of updating.exarcerbated by the network getting penalized relative to its on-policy value estimatevt−1(s).This doesn’t seem limited to tabular TD-learning, or PPO in more realistic domains. EG vanilla policy gradient will also allow a system to extract an unbounded amount of reinforcement from a single kind of event (e.g. “wedding”). Unless very specific care is taken, Alex thinks this kind of failure happens by default in policy gradient methods.

## Action-conditioned TD error avoids arbitrarily high logits

Given the original advantage equation:

Aπ(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−vπ(s),replace the last term’s baseline to account for the taken action:

Aπ∗(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−qπ(s,a).We call this “

action-conditioned TD error” (ACTDE).ACTDE allows the system to account for its decision to go off-policy by selecting a new action a which isn’t the usual recommendation a′∼π(s). Philosophically, Alex wanted to mimic reward prediction error. The network taking a different action is not surprising to the network, so the optimization term should account for the action taken (i.e. by using qπ(s,a)).

Re-analyzing the situation:

The policy quickly converges to the softmax logits over the reward for the next completions, where e1e1+e.5≈.63. That is, the learned policy has R(“party”)=.5 logits on “party” and R(“wedding”)=1 logit on “wedding”.

Therefore this process does not converge to the optimal policy, even in the limit of infinite exploration. Correspondingly, there is no mode collapse in this situation.Reward logits are “added to” initialization logitsπ0(the prior over what completions to output). RL, in this setting, provides a finite amount of reinforcement for certain kinds of computation/actions.Furthermore, self-consistent, Bellman-backed-up Q-functions will have zero advantage and zero updates. Networks aren’t penalized for exploring, and there’s a precise and finite amount of reinforcement which can occur given current predictions about future value, as represented by qt. And training should be more stable, with fewer fluctuations in advantage with respect to the policy itself.

^{[3]}## ACTDE doesn't mode-collapse onto wireheading

ACTDE doesn't mode collapse on wireheading, even

giventhat the network tries out wireheading! (Which Alex thinks is not that likely for practical RL algorithms.)Concretely, suppose that reward is 10 if you eat pizza and 100 if you wirehead. You start off with action distribution {pizza: 1%, wirehead: 99%}, and we're doing TD-learning in the tabular setup we just described. If so, then the policy gradients upweight wireheading more and more. This can happen until the network puts arbitrarily many logits on the wireheading action. In this situation, under these exploration assumptions and with probability 1, PPO "selects for" wireheading and the policy ends up {pizza: ϵ, wirehead: 1−ϵ}.

However, ACTDE does not lead to arbitrarily many logits on wireheading. Instead, ACTDE leads to the softmax distribution over actions, with the softmax taken over the reward for each action. Thus, the "optimum"/fixed-point policy of tabular ACTDE is about { pizza: .02%, wirehead: 99.98% }. That's still mostly wireheading, but there are only finitely many logits on that action.

## PPO vs ACTDE on the iterated prisoner's dilemma

In this toy experiment, the model plays prisoner's dilemmas against its past self, similar to the idea by Krueger et. al. The model is mingpt with a vocab size of two: one token for "cooperate", and one for "defect". mingpt has 3 layers and an embedding dimension of 12. The model sees the history of cooperates and defections, and outputs the next action.

We are not training via self play against a copy.Instead the model at time t plays against its action at time t−1. Playing with its past self for a sequence of`ccddc`

has 4 games:`cc`

,`cd`

,`dd`

,`dc`

, with rewards of 0.5 (for`cc`

), 2 (for`cd`

), -0.74 (for`dd`

), and -1.76 (for`dc`

).^{[4]}Alternating cooperation (

`c`

) and defection (`d`

) is the (bolded) optimal strategy for both start states:`cccc...`

`cddd...`

cdcd...1.492`dddd...`

`dccc...`

dcdc...-1.015What we're testing:If ACTDE mode collapses when used on function approximators (like mingpt), then the theoretical predictions above are wrong.## PPO results

PPO immediately learns the alternating strategy:

## ACTDE results

The model does not collapse onto a pure strategy. Instead, the results are inconsistent across trials. However, ACTDE does reliably:

Here's the first 1K epochs of a training run:

Zooming out to all 10K epochs:

We ran 10 trials and plotted the mean and standard deviation of average returns:

There seems to be very slow convergence,

^{[6]}perhaps towards the softmax-over-returns policy (shown by the dotted lines), or towards the uniform policy. We lean towards "convergence to uniform" due to evidence from a trial on a different reward matrix:Overall, ACTDE's results are sensitive to variations in the algorithm such as whitening advantages, detaching the value and Q-heads, and using the loss function from PPO or ILQL for the value head.

## Speculation

This method might not work very well for e.g. RLHF at scale. Deep RL is notoriously finicky. Furthermore, it would be pretty disappointing if ACTDE generally converges on uniform policies, and that seems like a live possibility given the last graph above.

However, Alex has a few intuitions anyways:

Mitigating reward misspecification.If the reward is really high in a situation which the policy can easily explore into during training, then that's bad and probably leads to distorted policies.Reducing mode collapse.In the tabular bandit regime (explored above), ACTDE adds "reward logits" (eR(s,a)) onto the initial policy logits (logπ0(s,a)). Maybe this is true in general (but probably not). If so, then KL penalties might be less important for keeping the trained policy π close to the initial policy π0.## Summary

ACTDE seems to avoid mode collapse in simple tabular setups. We showed that ACTDE doesn't mode collapse on a toy prisoner's dilemma learning task, but instead trains a mixed strategy.

We're excited for someone to RLHF a language model using ACTDE. Alex is willing to contribute 30 minutes weekly to giving feedback on such a project, insofar as that would be helpful. If necessary, Alex can also help acquire funding for a prospective researcher who has experience doing deep RL. Email him at

`alex@turntrout.com`

if interested. Email Michael at`einhorn.michael1@gmail.com`

for any questions about the code.Contributions:^{[7]}`trl_textworld`

^{[8]}and`prisonerUnitTest`

.Thanks to Connor Leahy, Evan Hubinger, Ulisse Mini, Nate Soares, Leo Gao, Garrett Baker, janus, David Krueger and others for thoughts and discussions.

## Appendix: Random notes

^{^}This advantage equation, as given, can also be called the "TD error."

^{^}Alex thinks that using a fixed learning rate 0<α<1 shouldn’t fix PPO's "infinite logit update issue", but a decaying learning rate schedule probably does. This isn't that surprising, and he doesn't think it fixes the deeper potential issue with fluctuating value baselines.

^{^}Although Alex hasn't analyzed the sequential tabular setting — possibly infinite logit updating can still happen there?

^{^}Note that the

`cd`

and`dc`

always come in pairs except for at most 1 extra.^{^}return(s) averages strategy s's return over the first state being cooperate

`c`

and being defect`d`

.^{^}In the tabular bandit example above, the convergence was extremely fast due to the learning rate and triviality of the problem.

^{^}When Alex wrote this in the fall, he thought that RLHF was responsible for mode collapse behaviors in LMs. However, empirical evidence has since made him think that RLHF is less responsible for these failures. He thinks his theoretical analysis is still correct under the assumptions he made, and he still thinks it's important to investigate empirically.

^{^}One of the goals of

`trl-textworld`

was to evaluate PPO vs ACTDE finetunings on pretrained language models, but the models were not able to learn to play the text adventure, so this project did not get to a point where the algorithm's results could be compared. The implementation may still be useful—it has been tested up to GPT-NeoX 20B on 8 GPUs.