When I first heard about EfficientZero, I was amazed that it could learn at a sample efficiency comparable to humans. What's more, it was doing it without the gigantic amount of pre-training the humans have, which I'd always felt made comparing sample efficiencies with humans rather unfair. I also wanted to practice my ML programming, so I thought I'd make my own version.
This article uses what I've learned to give you an idea, not just of how the EfficientZero algorithm works, but also of what it looks like to implement in practice. The algorithm itself has already been well covered in a LessWrong post here. That article inspired me to write this and if it's completely new to you it might be a good place to start - the focus here will be more on what the algorithm looks like as a piece of code.
The code below is all written by me and comes from a cleaned and extra-commented version of EfficientZero which draws from the papers (MuZero, Efficient Zero), the open implementation pf MuZero by Werner Duvaud, the pseudocode provided by the MuZero paper, and the original implementation of EfficientZero.
You can have a look at the full code and run it at on github. It's currently functional and works on trivial games like cartpole but struggles to learn much on Atari games within a reasonable timeframe, not certain if this reflects an error or just insufficient time. Testing on my laptop or Colab for Atari games is slow - if anyone could give access to some compute to do proper testing that would be amazing!
Grateful to Misha Wagner for feedback on both code and post.
EfficientZero is based on MuZero, which itself is based on AlphaZero, a refinement of the architecture which was the first beat the Go world champion. With AlphaZero, you play a deterministic game, like chess, by developing a neural network that evaluates game states, associating each possible state of the board with a value, the discounted expected return (in zero-sum games like chess, discount rate is 0 and this is just win%). Since the algorithm can have access to a game 'simulator', it can test out different moves, and responses to those moves before actually playing them. More specifically, from an initial game state it can traverse the tree of potential games, making different moves, playing against itself, and evaluating these derived game states. After traversing this tree, and seeing the quality of the states reached, we can average the values of the derived states to get a better estimate of how good that initial game state actually was, and make our final move based on these estimates.
When playing out these hypothetical games, we are playing roughly according to our policy, but if we start finding that a move that looked promising leads to bad situations we can start avoiding that, thereby improving on our original policy. In the limit, this constrains our position evaluation function to be consistent with itself, meaning that if position A is rated highly, and our response in that situation would be to move to position B, then B should also be rated highly, etc. This allows us the maximize the value of our training data, because if we learn that state C is bad, we will also learn to avoid states which would lead to C and vice versa.
Note that this constraint is what similar to that enforced by the Minimax algorithm, but AZ and descendants propagate the average value of the found states, rather than the minimum, up the tree to avoid compounding NN error.
While AlphaZero was very impressive, from a research direction, it seemed (to me) fundamentally limited by the fact that it requires a fully deterministic space in which to play - to search the tree of potential moves you need the existence of a 'board state', to which you can apply an action (e.g. 'knight to e5', Ne5), and get a new board state. This simply doesn't exist in most RL domains so how could this algorithm be used in any of these domains?
MuZero's solution is to incorporate ideas from model-based learning. They learn a mapping from observations (which could be game boards, but also could be images of Atari games, or Starcraft etc) to a latent vector, which is called the representation function. The dynamics function, instead of being e.g. a chess program, is again just a learned mapping from one vector to a new one, which is called the dynamics function.
But how should these dynamics and representation functions be learned? What is the 'correct' mapping to vector space? The answer MuZero gives is that all you need to do is train the system end-to-end and correct functions for both of these will be learned! The surprising fact is that this works so well that it can actually outperform AlphaZero, even though AZ has access to a perfect game simulator!
EfficientZero takes MuZero as a base and then makes a number of changes to it, designed to make it much more sample efficient. We'll cover these in detail below, once we've established what a MuZero system looks like in practice.
Now onto how we would actually make such a system!
There are three neural networks, the Representation network, the Dynamics network, and the Prediction network. These are three separate networks, each part of a larger class, and called with separate functions (though they can be neatly combined into two functions, initial_inference(observation), and recurrent_inference(latent_vector)).
Their type signatures are below, and will make more sense once the algorithm is described in detail.
The exact shapes of these can vary heavily between different implementations of this general pattern. For example the dynamics network in MuZero consisted of 16 ResNet layers, while in EfficientZero there is just one. Also, the latent space can be just a single vector, or in the case of Atari games, a two dimensional tensor mirroring the structure of the input image.
The core of using the MuZero algorithm to play a game is building the tree structure by which the algorithm explores the tree of possible moves, and therefore decides what to do.
Looking at Figure 1 of the original MuZero paper (below), we see the creation of a tree structure using the represent function (h) to take the board state into a latent representation, the predict function (f) to predict the policy and value that the algorithm will reach, and the dynamics function (g), to simulate taking an action within the latent space.
The open interpretation has this section written in C++ to minimize the time taken to create this tree but I've not found that this is a bottleneck. Generating 100k steps of game play on Atari using EfficientZero and tens of rollouts per move is manageable on a CPU, taking perhaps 10h while sufficient training time requires a long time and good optimization even with a (single) GPU.
Translating this abstract structure into code, the basic idea is that you have a tree of nodes which represents the state of your exploration. The algorithm is built for the case where the action space is finite, and so each node has an array of slots to hold each potential child.
TreeNode is an individual node of a search tree.
It has one potential child for each potential action which,
if it exists, is another TreeNode.
Its function is to hold the relevant statistics for
deciding which action to take.
self, action_space_size, latent, policy_pred=None, value_pred=None, ...
# These will be filled with other TreeNodes
self.children = [None] * action_space_size
# Holding the latent vector and the predicted policy and value
self.latent = latent
self.value_pred = value_pred
self.policy_pred = policy_pred
The initial node is created by first getting a latent representation of the observation, and using the prediction network to estimate the value and predict the eventual action distribution:
# tensor.unsqueeze(0) adds an extra dimension of size 1 at the 0th dimension
# which is needed as the network is designed to take batched inputs.
frame_t = torch.tensor(current_frame, device=device).unsqueeze(0)
# These can be brought together as 'initial_inference'.
initial_latent = mu_net.represent(frame_t)
initial_policy_logits, init_value = mu_net.predict(init_latent)
initial_policy_logits = initial_policy_logits
initial_value = initial_value[
From the logits of the predicted action distribution (init_policy_logits) from the prediction network, we get our final probabilities for how we will begin to explore the tree by:
init_policy_probs = torch.softmax(init_policy_logits, 0)
init_policy_probs = add_dirichlet(
# This adds some noise to the probabilities of taking an action
# to encourage exploration
root_node = TreeNode(init_latent, init_policy_probs, ...)
With this root node we have the basis for our exploration tree and can begin to populate it. One of the key hyperparameters is the number of simulations (config["n_simulations"]) which is the number of times we explore, starting from the root node. In the MuZero paper, this number was 800 and 50 during training for Go and Atari respectively, but can be pushed much higher during evaluation to boost performance - indeed it's one of the key results that the learned dynamics function is sufficiently good that you can raise the number of simulations order of magnitude above that used in training and still get performance boosts.
The config object is a dictionary of hyperparameters and anything else that could plausibly be changed between runs. For ease it's passed to most functions so they have access to whatever they might need.
# Traversing the tree of possible game decision
for i in range(config["n_simulations"]):
# It's vital to have with(torch.no_grad()): or else the size
# of the computation graph quickly becomes gigantic and we're
# not training here but evaluating what to do
current_node = root_node
new_node = False # Tracks whether we have reached a new node yet
# tracks the route of the simulation through the tree
search_list = 
# We traverse the graph by picking actions, until we reach a new node
# at which point we revert back to the initial node.
while not new_node:
# Decide which action we will 'take' in this tree of potential decisions
action = current_node.pick_action()
if current_node.children[action] is None:
# If this action hasn't been taken then we'll need
# to do a forward pass of the dynamics and prediction function
# to evaluate the resulting state
# Getting the action as a one-hot vector
action_t = nn.functional.one_hot(
# Simulate the state transition when taking the chosen action
# and get the predicted policy and value at that node
# This can be brought together as 'recurrent_inference'
latent, reward = [
x for x in mu_net.dynamics(latent.unsqueeze(0), action_t)
new_policy, new_value = [
x for x inmu_net.predict(latent.unsqueeze(0))
# Now that we've evaluated this new position,
# we call the insertion function,
# which will put a new TreeNode into current_node.children[action].
current_node.insert(action, latent, new_policy, new_value, ...)
# If we have already explored this node then we take the
# child as our new current node and repeat
current_node = current_node.children[action]
Here the pick_action function is doing a lot of work in deciding the form of our exploration.
We pick the action with the following function:
Gets the score each of the potential actions and picks the one with the highest
total_visit_count = sum(a.num_visits for a in self.children if a)
scores = [
for a in range(self.action_size)
max_score = max(scores)
# Need to be careful not to always pick the first action as it common
# that two are scored identically
action = np.random.choice(
[a for a in range(self.action_size) if scores[a] == max_score]
The action score function has the following formula, which is designed to calculate the upper confidence bound of an action:
This formula is more complicated than the actual work it does because the constants used in the paper are c1=1.25;c2=19652 which means that with on the order of 100 simulations, the final log never differs much from one and can be ignored. The rest is a balance between the score that has been found so far, Q(s,a) and the product of the prior P(s,a) and explore term favouring new actions. The impact of the explore term is to dilute the strength of the prior as the number of simulations grows, so that a strong prior does not overcome poor empirical results after multiple tries, and actions that score poorly on priors can still be tried. I discuss this formula and its mathematical source in more detail here.
def action_score(self, action_n, total_visit_count):
Scoring function for the different potential actions,
following the formula in Appendix B of MuZero
child = self.children[action_n]
n = child.num_visits if child else 0
# minmax.normalize interpolates the value between the highest
# and lowest values seen in the run so far.
q = self.minmax.normalize(child.average_value) if child else 0
prior = self.policy_pred[action_n]
# This term increases the prior on those actions which have been taken
# only a small fraction of the current number of visits to this node
explore_term = math.sqrt(total_visit_count) / (1 + n)
# This is intended to more heavily weight the prior
# as we take more and more actions.
# Its utility is questionable, because with on the order of 100
# simulations, this term will always be very close to 1.
balance_term = c1 + math.log((total_visit_count + c2 + 1) / c2)
score = q + (prior * explore_term * balance_term)
Now that we know how to create and traverse this tree of simulations, we can roll this into a search() function, return the root node, and play the game as follows:
while not over and frames < config["max_frames"]:
# Makes a single tensor from past frames and actions, varies by game type.
frame_input = get_frame_input()
# tree is the root node of the exploration tree defined earlier
tree = search(
config, mu_net, frame_input, minmax, log_dir, device=device
# pick_game_action is a function which looks at the tree we've
# generated and decides which action to take in the actual game,
# based on the number of visits we've made to each node.
# temperature defines how noisy we are in picking the most chosen action
# in the tree
action = tree.pick_game_action(temperature=temperature)
# Taking the next action in the environment
frame, reward, over, _ = env.step(action)
# Adding the details of this step to the object which saves the trajectory
game_record.add_step(frame, action, reward, tree)
frames += 1
We can now play the game using MuZero, and just need to save the results in order to train and learn. We don't need to save the whole tree, just the visit counts (which is our policy), action, rewards, values and observations. The observations for Atari games can quickly get large so these are converted into np.uint8 arrays before saving to minimize their footprint.
Training a network of this type is quite ordinary in many ways but the structure of the system, in which we learn a recurrent dynamics network, requires a bit of extra work.
The first question is what are our training targets?
Just as important though, we also want to train our representation and dynamics functions to be able to simulate a trajectory of the game. To do this, if we include in our batch step i of game j, the batch will contain the observation at i, but the rewards, values, and policies at steps i : i + config[rollout_depth]. We can then turn the observation into the latent with represent(obs) , use the actions taken in game to apply dynamics(latent, action) to this latent multiple times, and then predict the rewards, value and policies for each of these multiple steps with predict(latent). The resulting loss, when backpropagated, will train not just the predict, but also the represent and dynamics functions, all in one step!
i : i + config[rollout_depth]
weights are included because we want to train more often on the cases where our value guesses have been incorrect, but these need to be down-weighted a corresponding amount so as not to bias the value network.
depths are included as there will not always be enough time left in the game to do a full rollout, and so the depths tensor.
The overall batch therefore looks as follows, with the first two dimensions of each target tensor being batch_size and rollout_depth:
) = ray.get(next_batch)
We need to do this training within a for loop, rather than as a single forward pass, because the dynamics function requires the output of the previous dynamics function. The dynamics function is therefore a single iteration of a recurrent neural network (and getting a recurrent reinforcement learning setup to train correctly can be as fiddly as it sounds). The need for so many different forward passes makes training each batch quite slow, and could probably be significantly optimized.
for i in range(config["rollout_depth"]):
# The screen_t tensor allows us to remove all cases where
# there are fewer than i steps of data
screen_t = torch.tensor(depths) > i
if torch.sum(screen_t) < 1:
target_value_step_i = target_values[:, i]
target_reward_step_i = target_rewards[:, i]
target_policy_step_i = target_policies[:, i]
pred_policy_logits, pred_value_logits = mu_net.predict(latents)
new_latents, pred_reward_logits = mu_net.dynamics(latents, one_hot_actions)
# We scale down the gradient, I believe so that the gradient
# at the base of the unrolled network converges to a maximum
# rather than increasing linearly with depth
new_latents.register_hook(lambda grad: grad * 0.5)
pred_values = support_to_scalar(
pred_rewards = support_to_scalar(
value_loss = torch.nn.MSELoss()
reward_loss = torch.nn.MSELoss()
value_loss = value_loss(pred_values, target_value_step_i[screen_t])
reward_loss = reward_loss(pred_rewards, target_reward_step_i[screen_t])
policy_loss = mu_net.policy_loss(
batch_policy_loss += (policy_loss * weights[screen_t]).mean()
batch_value_loss += (value_loss * weights[screen_t]).mean()
batch_reward_loss += (reward_loss * weights[screen_t]).mean()
latents = new_latents
This is a bit of a wall of code but basically what we're doing is to build up the losses by unrolling, screening at each step to remove games that have finished, and scaling down the gradient at each step so that the gradient converges to a finite value rather than scaling linearly with depth.
The network is unrolled to a particular depth, here called config[rollout_depth] which is always set to 5, but each individual example in a batch may not be this deep, because the game may end in fewer than 5 steps.
When we finally backpropagate, we train the entire system with a single call to optimizer.step().
# Zero the gradients in the computation graph and then
# propagate the loss back through it
# I've found clipping the gradient is very important for training stability.
if config["grad_clip"] != 0:
One notable detail is the use of support_to_scalar functions (and their inverse, scalar_to_support). These are a slightly peculiar piece of MuZero, by which the value and reward functions, although they are ultimately predicting a scalar, actually predict logits of a distribution over numbers. The numbers represented by each position in the predicted 'support' vector are roughly proportional to the square of their centered position, so a support of width 5 would correspond to values roughly [−4,−1,0,1,4], and logits which softmax to [0.5,0.5,0,0,0] would correspond to a final value of -2.5 (although the details are slightly more complex).
This is the addition mentioned in MuZero reanalyse, and basically reassesses the values and policies in past games.
More specifically, the target 'value' is the discounted sum of the next config[value_depth]=5 steps of actual reward, plus the estimated future reward after these 5 steps. While clearly not a perfect picture of value, this is enough to bootstrap the value estimating function. This target value will be worse if the value estimation function is worse, which means that the older value estimates will provide a worse signal, and so the reanalyser goes through old games, and updates the value estimates using the new, updated value function.
Updating these values basically consists of constructing trees exploring the game at each node, just as if we were playing the game
p = buffer.get_reanalyse_probabilities()
ndxs = buffer.get_buffer_ndxs()
ndx = np.random.choice(ndxs, p=p)
game_rec = buffer.get_buffer_ndx(ndx) # Gets the game record at ndx in the buffer
values = 
search_stats = 
for i in range(len(game_rec.observations) - 1):
obs = game_rec.get_last_n(pos=i)
new_root = search(current_frame=obs, ...)
values.append = new_root.average_value
[c.num_visits if c else 0 for c in new_root.children]
buffer.update_game_info.remote(ndx=ndx, values=values, search_stats=search_stats)
To speed up training and playing, we parallelize by converting the main classes into 'actors', as defined by the ray framework. This means wrapping classes with the ray.remote() decorator, and then calling their functions with ray.get(actor.func.remote(*func_args)) instead of actor.func(*func_args).
The basic classes are the Player, Trainer, and Reanalyser, and each of these have access to a Memory class and a Buffer class from which to pull data.
EfficientZero builds upon MuZero. There are three changes to the underlying algorithm, well summarized in this post. They also massively shrink the size of the networks, going from 16 residual blocks in the dynamics function from MuZero, to only 1. I'll go these three changes in turn, and what they look like as changes to the code.
In MuZero, the network tries to predict the reward at each time point This apparently causes difficulty due to the 'state aliasing' problem, by which the model needs to predict exactly which frame or state will give a reward, but this gets tricky with exponentially compounding error.
In EfficientZero, the 'reward' prediction target changes from being the reward in the current step to the sum of reward from the first step being analysed to the rollout_depth. The reward being predicted is the cumulative reward from the current step, to the point where we just take the estimated value at that step. This is why it's called the value prefix.
Making this change requires small changes to the way batches are put together:
target_rewards.append(sum(self.rewards[ndx : ndx + i + 1]))
target_rewards.append(self.rewards[ndx + i])
and to the dynamics net, which initially looks like this:
def forward(self, old_latent):
out = new_latent.reshape(batch_size, -1)
reward_logits = self.fc2(torch.relu(self.fc1(out)))
return new_latent, reward_logits
which then becomes the following:
def forward(self, old_latent, reward_hiddens):
out = new_latent.reshape(batch_size, -1)
# We collect the lstm section into a function which is largely
# a series of fully connected layers, but with a single LSTM
# layer in the middle.
value_prefix, new_reward_hiddens = dyna_lstm(
return new_latent, value_prefix, new_reward_hiddens
When training we initialize the hidden state as a matrix of zeros when we begin training a batch which gets fed into the first iteration of the dynamics network, and then this hidden state is passed back into the dynamics function alongside the latent vector at each time.
I find the 'state-aliasing problem' explanation of why this is a useful change not totally convincing/sufficient as it seems that rollouts are able to go much deeper than trained and still provide value and policy estimation. I guess it makes the training signal less noisy, and therefore improves the learning? I'm also not sure why an LSTM is needed since the dynamics net is already a form of RNN (maybe just add more latent dimension to help track what reward is already expected?)
The idea here is that in these deterministic games, the latent vector representing the state of the game as the network expects it to be, after a series of actions (i.e. applying the represent network to the initial observation, and then applications of the dynamics network), should be the same as the latent found after that series of actions is actually taken in game, and then the represent network is applied to the final observation.
# The target latent is the representation of the observation
# at time (t + i), from the initial observation
target_latents = mu_net.represent(images[:, i]).detach()
# The latent here is the latent that found by
# applying the dynamics network with the chosen
# actions to the initial latent i times.
consistency_loss = mu_net.consistency_loss(
The consistency loss used here is a cosine loss, meaning the cosine of the angle between the latent and target_latent, interpreted as vectors in Rn.
This is a simple change that improves the value target.
The idea is simple. The value target is the sum of the next n steps of observed reward, plus the discounted expected value at the nth step. The actions taken can't be changed, so as our policy improves, the actions, and therefore the rewards will become more and more out of date, but thanks to the reanalyser, the expected value function stays up to date. It therefore improves the quality of the value target, as a proxy for what the value would be under the current policy, if we shrink n as the trajectory ages.
def get_reward_depth(self, value, tau=0.3, total_steps=100_000, max_depth=5):
# Varying reward depth depending on the length of time
# since the trajectory was generated.
# Follows the formula in A.4 of EfficientZero paper
steps_ago = self.total_values - value
depth = max_depth - np.floor((steps_ago / (tau * total_steps)))
depth = int(np.clip(depth, 1, max_depth))
depth = max_depth
The most difficult parts of the process were various pieces of debugging once the code was split into multiple actors. This made stepping into the code more onerous and introduced a new set of potential problems very unlike what I’d been used to.
When running on Colab using Ray actors, the traceback shows the original error class, but gives a traceback in terms of Ray libraries, rather than the original location of the code, and I also can't get into the ray debugger. With multiple actors, even the order of print statements making it to the console can be a bit disordered, making reconstructing the cause of an error tough.
The worst part, though, was when I'd got the code to a point where it was working consistently over long runs, and then set it to perform a test of various hyperparameters, and would find that after several hours, at some points it would just.. die. No error message, no hint of what caused it, the process would just end. Because I was using ray, I guessed that there was some kind of problem that broke the system in such a way that didn't allow it to exit gracefully, some kind of memory error..
After a lot of frustration and confusion, and self-inflicted damage like updating all packages, I started just ignoring it and working on something else, at which point I realized that even trivial errors weren't showing up.
Once I knew that I could replicate the 'no traceback' issue just by introducing a trivial error, I could then easily go back through the commits and find the point at which the traceback disappeared, which made finding the cause super easy.
I'd used ray.wait() instead of ray.get() to get the final results of the actors, and when one of those actors crashed, ray.wait() continued, and immediately hit the end of the script, at which point all the actors were cancelled, before even the error message could be printed! Unfortunately, I'd made this change just after flushing out all the small bugs, so was getting this blank shutdown only after hours of running. I thought it was the result of a out-of-memory error, so instead of being a simple error to find, it was found only after days of confused work.
The main takeaway was not to prematurely assume one possible cause of error. The worst case scenario, that I had some deep bug that caused an error is such a way that the process immediately died was possible, but I'd far too easily focused on this, instead of the case where I'd caused the lack of traceback myself by a silly error.
I found myself naturally converging on similar architectures. When starting off I looked at the open implementation and the pseudocode provided by the MuZero paper to look for ideas when things weren't working, but I also made a conscious decision not to follow to the way they'd organized their code, and after a while, the differences compounded to the point where I could take much directly from their code, even if I wanted to.
Nonetheless, I often found that I was forced into becoming more similar.
For example, I'd followed the open implementation in converting my classes into Ray actors, which would then run concurrently. At first this was just the Player and the Trainer, but then having a separate Memory class, quickly became useful to hold state for the others to grab.
Within the Memory actor, I at first had the replay buffer in the Memory actor, alongside simple statistics like the elapsed number of steps and batches. .However, the buffer needs to do a lot of work to retrieve batches of data and format them into batches for training, and these long operations leave the memory actor blocked, which delays lots of things, not least a while loop which checks if the max steps has not been reached. It's therefore helpful to split the memory into one which stores and returns basic shared statistics, and a buffer actor which creates batches from the store of saved games.
Even though we're using the same algorithm for different games, there are differences in the operations - for example doing some basic normalization on the pixel values For just one or two games, it's quite easy to add if/else statements to process these differently, but this gets ugly quickly, and so it becomes a natural pattern to wrap these different functions into a game class, from which the algorithm can call these different functions without the need for switch statements, something that the open implementation also does.
I'm not going to list potential improvements because I think this kind of architecture is a major stepping stone to intelligent in-the-world actors and I've no desire to speed up their arrival, on the off chance that the ideas are any good.
The huge shrinking of the architecture between MuZero and EfficientZero alone suggests that the parameters of this kind of algorithm aren't particularly optimized at all and there's lots of room for architectural tweaks.
Some are probably already being worked on while other wouldn't work, but I expect to see improved variants on this theme coming out quite soon - or maybe are already out there.
I'm doing this work to learn the skills needed for technical AI Safety research.
If you might be interested in hiring me for applied AI Safety work please reach out either here on LW or at firstname.lastname@example.org.