A thing I really like about the approach in this paper is that it makes use of a lot more of the model's knowledge of human values than traditional RLHF approaches. Pretrained LLM's already know a ton of what humans say about human values, and this seems like a much more direct way to point models at that knowledge than binary feedback on samples.
I might be totally wrong here, but could this approach be used to train models that are more likely to be myopic (than e.g. existing RL reward functions)? I'm thinking specifically of the form of myopia that says "only care about the current epoch", which you could train for by (1) indexing epochs, (2) giving the model access to its epoch index, (3) having the reward function go negative past a certain epoch, (4) giving the model the ability to shutdown. Then you could maybe make a model that only wants to run for a few epochs and then shuts off, and maybe that helps avoid cross-epoch optimization?
Yeah. Or maybe not even to zero but it isn’t increasing.
Could it be that Chris's diagram gets recovered if the vertical scale is "total interpretable capabilities"? Like maybe tiny transformers are more interpretable in that we can understand ~all of what they're doing, but they're not doing much, so maybe it's still the case that the amount of capability we can understand has a valley and then a peak at higher capability.
So indeed with cross-entropy loss I see two plateaus! Here's rank 2:
(note that I've offset the loss to so that equality of Z and C is zero loss)
I have trouble getting rank 10 to find the zero-loss solution:
But the phenomenology at full rank is unchanged:
Woah, nice! Note that I didn't check rank 1 with Adam, just rank >= 2.
Erm do C and Z have to be valid normalized probabilities for this to work?
(with the caveat that this is still "I tried a few times" and not any quantitative study)
It's a good caution, but I do see more bumps with Adam than with SGD across a number of random initializations.
Something like this?
def loss(learned, target): p_target = torch.exp(target) p_target = p_target / torch.sum(p_target) p_learned = torch.exp(learned) p_learned = p_learned / torch.sum(p_learned) return -torch.sum(p_target * torch.log(p_learned))