I’ve recently been given funding from the Long Term Future Fund to develop work on an agenda I'll tentatively call Distilled Representations, and I'll be working on this full-time over the next 6 months with Misha Wagner (part time).
We're working on a way of training autoencoders so that they can only represent information in certain ways - ways that we can define in a flexible manner.
It works by training multiple autoencoders to encode a set of objects, while for some objects defining a preferred representation that the autoencoders are encouraged to encode the objects as. We then distill these multiple autoencoders into single autoencoder which encodes only that information which is encoded in the same way across the different autoencoders. If we are correct, this new autoencoder should only encode information using the preferred strategy. Vitally, this can be not just the original information in the preferred representations, but also information represented by generalizations of that encoding strategy.
It is similar to work such as Concept Bottleneck Models but we hope the distillation from multiple models should allow interpretable spaces in a much broader range of cases.
The rest of this post gives more detail of the intuition that we hope to build into a useful tool, some toy experiments we’ve performed to validate the basic concepts, the experiments that we hope to build in the future, and the reasons we hope it can be a useful tool for alignment.
We'd like to make sure we understand what similar work has been done and where this work could be useful. If you're familiar with disentangled representations, or interpretability tools more generally, we're interested in having a chat. You can reach me here on LessWrong or at firstname.lastname@example.org.
Previous versions of similar ideas can be found in my ELK submission and especially Note-taking Without Hidden Messages.
The intuition that this work builds on is the following:
There are quite a few leaps in this reasoning, and we view the key assumptions / hypotheses to be tested as the following:
We've been able to validate all of these assumptions in toy scenarios using compression of random binary vectors. The question now is whether we'll be able to make this a useful tool which scales to more complex environments.
We train an autoencoder to reconstruct 2n-digit binary strings using a latent space of Rn. On top of the reconstruction loss, we introduce a representation loss, which is the gap between the encoded digits and our preferred representation. This incentivizes the model to use the nth latent variable to represent the nth binary digit. The n+1th to 2nth binary digits have no preferred representation.
vector = random_binary_vector() # shape=2N
latent_vector = encoder(vector) # shape=N
reconstructed_vector = decoder(latent_vector) # shape=2N
preferred_representation = vector[:N]
reconstruction_loss = torch.nn.MSELoss()(vector, reconstructed_vector)
representation_loss = torch.nn.MSELoss()(latent_vector, preferred_representation)
loss = reconstruction_loss + representation_loss
We find that, perhaps unsurprisingly, the autoencoder learns to “hide” the second half of the vector in the representation of first half. As one would expect, the encoder learns to map the first n digits closely to themselves, and then hide the second n digits in small variations around the preferred representations.
To remove this additional, hidden information, we train multiple encoder/decoder pairs, and then freeze the weights of the decoders while training a new encoder from scratch. When training this new encoder, the encoded representations are decoded by decoded that is randomly selected each batch.
What we find is that when training 8 separate autoencoders (averaged into blue line), the information which has a preferred representation is encoded very quickly (p1_representation_loss) while the other information is learned more slowly by the separate models (p2_representation_loss). The representation loss is the loss from the gap between representation and preferred representation.
Then, when training a new set of encoders (orange line), which are responding to vectors from randomly chosen decoders - the weights of which are frozen - the information that had a preferred representation is again very quickly communicated. Meanwhile the rest of the information, being encoded in different ways, is mostly not successfully encoded - in fact it consistently does worse than having no information at all, though this is somewhat strange and is not true when retraining decoders (green line). We can tell it's close or worse than communicating no useful information because it's not far below (or ever above) the red line, which is the error with zero information (always guessing 0.5 for every binary value).
It's trivial to force a correct representation if we always know exactly what our preferred representation should be. The test is whether we can do this in cases where the preferred representation is sparse and incomplete. We can make this signal sparse and restrict the signal only to a small fraction of the orthants (n-dimensional quadrant) and find that this also works. Below it the same setup, but where the preferred representation is only added for fewer than 20% of orthants, and even then only added 10% of the time - so the preferred representation is only given for under 1 in 50 cases.
sum(vector) < 4 and random.random() < 0.1.
These results with sparse inputs show that the method can work well even when preferred representations are not always present, and don't ever cover the full range of inputs. We also see the fact that it learns to encode info in the preferred representation much faster than it adds additional info, so we would get the same property simply by stopping early, though we don't expect this effect to be as stark in more complex environments.
We have also got similar results for cases where the preferred representation is more complex than just the identity function of a subset the input, though the learning process is slower.
Still, this requires only a very basic kind of generalization compared to the kind of generalization we would hope for in an ELK scenario - ideally, using language as humans do to represent situations, thereby being a direct translator rather than a human simulator, so we'll test the setup on progressively more complex environments.
This idea comes from working on the Eliciting Latent Knowledge (ELK) challenge and is basically the combination of two posts in the ELK prizes - 'Train a reporter that is useful for another AI' (my own, amongst others) and 'Train different reporters and require them to agree'. If you are the person who suggested the latter then thank you!
The picture that I think is hopeful for this kind of system starts with the background of Eliciting Latent Knowledge, in which we have a powerful system which we can ask questions of in certain situations where it knows the answer. These may be detailed Q&As or it may simply be a case where we know the state of the system and so we know the world-variables that it should be representing.
This information represents a preferred representation, and we restrict this representation to cases where we we're confident that it doesn't favour the human-simulation over the honest translator. We then train these systems to compress as much of their internal state into the answers to these questions, while ensuring that known information uses the preferred representation. The need to compress the information should then ensure we get something that is a direct translator, but much of this 'translation' will not be in an interpretable form, and so we use the multi-encoder strategy to get only the parts of the direct translator that we are able to understand.
This work can be thought of as trying to unify question-answering and autoencoders. Question-answering gives an easily comprehensible response but it's hard to ensure that it actually contains the right information. Autoencoders condense the information, but not in a way that is easily interpretable. The question that this work hopes to answer is 'Can we use question-answering as a form of autoencoder, without hiding a pile of incomprehensible information through the answers?'. This technique of distilled representations is step towards a working unification of these two ways of understanding a system.
While this is the most ambitious form of hope for this approach, we think that there's potential for it to be useful as an interpretability tool.
It can be seen as a contribution to the related field of disentangled representations, which creates latent representations, usually of images, but where the dimensions correspond to human-identifiable concepts.
This procedure could potentially be used to automatically make these, by labeling images according to certain dimensions, encoding images while deriving preferred representations from these labels, and then distilling out non-shared variability in encoding to get a latent space where the variables strictly only vary the dimensions of interest.
It also may be possible to use it within networks to create neurons which only respond in particular situations without having any additional meanings, making mechanistic interpretability easier.
There have been proposals for using chain-of-thought reasoning to make superhuman reasoning interpretable by default, and concerns that this could lead to steganography in the 'thoughts'.
I hope that this kind of system could be used to prevent a chain-of-thought reasoner from hiding addition information in what appeared to be trivially understandable written statements, and concern that (In this situation the preferred representation would be implicitly created by the large pre-training dataset and refined during fine tuning phases).
All work done in collaboration with Misha Wagner. Graphs and code on GitHub. Work supported by the Long Term Future Fund.
Hey - reccommend looking at this paper: https://arxiv.org/abs/1807.07306It shows a more elegant way than KL regularization for bounding the bit-rate of an auto-encoder bottleneck. This can be used to find the representations which are most important at a given level of information.