Easy A2C

There are a great number of simple, easy-to-understand tutorials around for how to build a DQN agent. However, the modern baseline for reinforcement learning is called Advantage Actor-Critic (sometimes Asynchronous Advantage Actor-Critic, which I’ll come back to in a bit). It is typically a better contestant than plain DQN for a few reasons: it converges on better optima, it runs faster, and it is simpler to implement. Unfortunately A2C doesn’t have many tutorials, and those that exist are hard to follow. This tutorial is to make sure the next wave of reinforcement learning scholars have a simple, intuitive explanation of the details of the algorithm.

Before following along with this post I’d recommend reading this comic, it does a fantastic job at explaining the intuition behind A2C.

I’m going to write this with an assumption that you have a basic knowledge of reinforcement learning (DQN in particular). So if you have the fundamentals under your belt, let’s dive in.

A2C vs A3C

A3C was the first to come along. It has a seperate thread, each with it’s own copy of the network, for each copy of the environment being explored. Each thread explores, learns from the explorations, updates the central network and updates its own network to match the central network. Because of the asynchronicity (multiple agents inferring and learning at the same time) it is very hard to use a GPU and typical implementations use only a CPU (GitHub user dgriff777 did manage to get it working on a GPU though, and it performs very well).

Removing the multiple workers part of the algorithm killed it’s performance. However, it turns out that what was important wasn’t the asynchronicity, but instead the way that having multiple workers exploring environments at the same time decorrelated the data. This means that instead of having a stream of very correlated (not much change for datapoint to datapoint) data, every few timesteps the data would switch to a completely different context, forcing the agent to learn the different contexts if it wanted to perform well. To do this, we gather data from a number of different agents in turn and then do an update on all of their data together. Since only one environment is being processed at a time, we can use a GPU for all our computation (speeding up learning).

There have been other implementations that blend the two, but today we’ll stick with the simplest one to get working. The whole agent is uploaded as a Jupyter notebook here. Feel free to use it to debug your own agent.

Setup

import torch
import gym
import numpy as np
from IPython.display import clear_output
from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt
import common.wrappers
%matplotlib inline

First are some imports. We’re going to be using PyTorch for the implementation, OpenAI Gym for the environment, NumPy for occaisional data processing, and Matplotlib for visualising the learning progress. There are also some wrappers imported and modified from OpenAI’s Baselines repository. Those are available along with the notebook on the GitHub page.

max_frames = 5000000
batch_size = 5
learning_rate = 7e-4
gamma = 0.99
entropy_coef = 0.01
critic_coef = 0.5
env_name = 'PongNoFrameskip-v4'
no_of_workers = 16

max_frames is the maximum number of frames (between every environment) that the training will run for. A2C is generally less data efficient than DQN, but this is a simple A2C implementation and I feel like this could be improved with some more complex code. For me, the whole training run on a P4000 GPU takes a few hours.

batch_size is the number of timesteps each worker will run for before handing over to the next worker. Anywhere from 5-20 is typically pretty good, and it depends vastly on the environment.

learning_rate is fairly self-explanatory.

gamma is how much to discount future rewards per timestep. So a reward 1 timestep in the future is worth \(r\gamma^1\) and one 100 timesteps in the future is worth \(r\gamma^{100}\). This results in closer rewards being valued more than rewards further away. More on this later.

entropy_coef is how much we maximise the entropy of the agent’s output. An output where all values are equal is high entropy and an output with one value much higher than the others is low entropy. We try to keep a little bit of entropy around so that the agent explores more.

env_name is the name of the Gym environment we will be solving.

no_of_workers is how many environments to run simultaneously.

if torch.cuda.is_available():
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor

This code just tells PyTorch to use a GPU if one is available. I don’t recommend running this code without a GPU, training will be very slow.

env = common.wrappers.make_atari(env_name)
env = common.wrappers.wrap_deepmind(env)
env = common.wrappers.wrap_pytorch(env)

Here we set up our environment. We’re actually going to instantiate more environments (one for each worker) later, but we need one as a reference for determining the action space, observation space, etc. In the interest of keeping things simple, we’re using premade wrappers. Some of these are from OpenAI’s Baselines, and the other (wrap_pytorch) from Higgsfield’s RL Adventure project.

make_atari creates the environment.

wrap_deepmind makes a few changes to the environment. Firstly, it resizes the image of the screen down to 84x84 pixels and converts it to grayscale. Secondly, it skips 4 frames for every timestep, and for each pixel returns the highest pixel value that occured in the 4 frames. This gives our agent some idea of how things are moving during the game (not as good as an RNN, but good enough to play pong). It makes a few other minor changes too, but they aren’t particularly important for the agent to function. It’s actually quite important to note that this doesn’t normalise the output to be a float between 0 and 1, and instead leaves it at 0 to 255. I’m not entirely sure why (someone please educate me on this), but the agent performs much better with values between 0 and 255.

Finally wrap_pytorch converts the observation format \((\text{width}\times\text{height}\times\text{color_channels})\) into PyTorch’s format \((\text{color_channels}\times\text{width}\times\text{height})\).

The Model

Now onto the model itself. In Jupyter this is all one cell, but I’ll split it up to make it easier to digest.

class Model(torch.nn.Module):
    def __init__(self, action_space):
        super(Model, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=8, stride=4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            torch.nn.ReLU()
        )
        feature_size = self.features(
            torch.zeros(1, *env.observation_space.shape)).cuda().view(1, -1).size(1)
        
        self.critic = torch.nn.Sequential(
            torch.nn.Linear(feature_size, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 1)
        )
        self.actor = torch.nn.Sequential(
            torch.nn.Linear(feature_size, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, action_space),
            torch.nn.Softmax(dim=-1)
        )

We define the model as a PyTorch Module. This gives us a number of quality of life features (such as being able to use model.parameters()). Next up is defining the architecture.

First, a very standard CNN is used for feature extraction. There is no batch normalization, pooling or anything other than convolution and ReLU. As a future improvement, it might be worth adding more filters and using max pooling.

The feature_size variable is used to figure out how many features we actually have. It sends some dummy data through the feature extractor and flattens the output. This actually isn’t my preferred method of doing flattening any more. Nowadays I like to make a Flatten() module that contains return x.view(x.size(0), 1) and put it in the feature extractor itself. This way you don’t have to include x = x.view(x.size(0), 1) in your forward() method.

Next is the part of the net that is the reason we call it ‘Actor-Critic’. We have a critic, which is responsible for predicting how much reward the agent will recieve until the end of the episode. Then we have an actor, which is responsible for picking actions to perform. Exactly how we teach these parts their jobs will be covered later, but needless to say they are a core part of A2C. This being a simple A2C implementation, they are both just 2 layer feed-forward nets.

The value of 512 for the number of neurons is pretty aribitrary, I just found it worked better than 256 but 1024 was a bit slow. The same applies for the convolutional layer parameters, they are the ones I found to work best when I was doing a DQN implementation.

def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        value = self.critic(x)
        actions = self.actor(x)
        return value, actions

Then we have the main forward pass of the agent’s network. This performs a very much standard pass through the whole network, returning the raw network outputs as PyTorch Variables.

    def get_critic(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.critic(x)

Later, I found myself needing just the critic without the actor portion of the net. To save computation time we have this get_critic function.

    def evaluate_action(self, state, action):
        value, actor_features = self.forward(state)
        dist = torch.distributions.Categorical(actor_features)
        
        log_probs = dist.log_prob(action).view(-1, 1)
        entropy = dist.entropy().mean()
        
        return value, log_probs, entropy

This function takes a state and an action and returns the expected value of the state, the log probability of taking the action in the state and the entropy of each state’s action distribution.

I’d like to state upfront that I’m not 100% on the maths behind why we take the log of the probability rather than the actual probability of taking the action in the state. However, it works nicely later on when we multiply the advantages with the log probabilities to get our loss for the actor. The reason we use dist.log_prob rather than taking the probabilities and calculating the logs after is the latter is numerically unstable in PyTorch. I was getting NaNs after about 1,000,000 frames of training.

The entropy of a distribution of probabilities is how close they are to eachother. Simply put, a distribution of \([0.01, 0.98, 0.01]\) has a very low entropy and a distribution of \([0.33, 0.34, 0.33]\) has a very high entropy. We use this in the loss function later to encourage high-entropy outputs so that the agent explores more.

    def act(self, state):
        value, actor_features = self.forward(state)
        dist = torch.distributions.Categorical(actor_features)
        
        chosen_action = dist.sample()
        return chosen_action.item()

Finally, we have a function that returns an action from the agent. This is pretty straightforward. We don’t use \(\epsilon\)-greedy exploration or anything other than that. We just sample from the probability distribution outputted by the actor.

Memory

class Memory(object):
    def __init__(self):
        self.states, self.actions, self.true_values = [], [], []
    
    def push(self, state, action, true_value):
        self.states.append(state)
        self.actions.append(action)
        self.true_values.append(true_value)
    
    def pop_all(self):
        states = torch.stack(self.states)
        actions = LongTensor(self.actions)
        true_values = FloatTensor(self.true_values).unsqueeze(1)
        
        self.states, self.actions, self.true_values = [], [], []
        
        return states, actions, true_values

We have a little class for storing and retrieving the batches created by the workers in their environments. The code here should be fairly self-explanatory, there’s nothing really fancy going on here. The only part with some potential to be confusing is the torch.stack, unsqueeze, etc. Those are simple data management that has to be performed to get the data into the right shape for the network.

Compute True Values

If we wanted to figure out the actual value of a state (the sum of all rewards recieved after that point in time) would have to wait until the end of the episode and add up all the rewards. On the other hand, in A2C we can use our critic to get a very inaccurate (biased) approximation of the expected value of a state.

However, we can actually combine the two to get something inbetween. Something less biased than our critic but doesn’t have to wait until the end of the episode to get results.

We can do this by taking a number of state-reward pairs at a time, and considering the value of the last one to be whatever our critic predicts. We then calculate the value of the other states assuming the critic is correct (which it wont be, but it will be close enough).

For example, say that your rewards are \([1, 3, -1, 5]\). Your critic says the 5th state (the one after your last reward) is worth \(50\). Your true values become \([58 , 57, 54, 55, 50]\).

We actually multiply each total reward by a \(\gamma\) constant before adding the next reward to it to represent that near rewards are more important than far away ones. You can tune the gamma value so that the agent places more value on long-term planning, but it typically takes a lot longer to train an agent with long-term focus (if it trains at all).

def compute_true_values(states, rewards, dones):
    true_values = []
    rewards = FloatTensor(rewards)
    dones = FloatTensor(dones)
    states = torch.stack(states)
    
    if dones[-1] == True:
        next_value = rewards[-1]
    else:
        next_value = model.get_critic(states[-1].unsqueeze(0))
        
    true_values.append(next_value)
    for i in reversed(range(0, len(rewards) - 1)):
        if not dones[i]:
            next_value = rewards[i] + next_value * gamma
        else:
            next_value = rewards[i]
        true_values.append(next_value)
        
    true_values.reverse()
    
    return FloatTensor(true_values)

Of course, if a state is terminal (end of an episode) we don’t take any future rewards into account.

Reflection

Now comes the important bit, the part where the agent learns.

While this is a short function, it’s very dense, so I’ll break it down as much as I can.

def reflect(memory):
    states, actions, true_values = memory.pop_all()

First, we get the stuff we stored in our memory object. This is a batch of states, actions and the ‘true’ (calculated above) values of each timestep for each worker. So if our batch_size is 5 and we have 10 workers that’s 50 states, actions and true_values. The order isn’t important, since they are all processed together to come up with an average loss for the whole batch.

    values, log_probs, entropy = model.evaluate_action(states, actions)

Next, we get the predicted values (straight from the critic), log probabilities and entropy values for each state from our model. This was discussed in detail earlier, so I wont go over it again.

    advantages =  true_values - values
    critic_loss = advantages.pow(2).mean()

The critic’s loss is very simple. We calculate the advantage of each timestep (the difference between the actual value and our critic’s predicted value). Then we want to minimise that, so we take the square (to make negative advantages positive) and use that as our loss.

    actor_loss = -(log_probs * advantages.detach()).mean()

The actor’s loss is a little bit more complex, so I’ll give some examples for the intuition. The formula is \(-\frac 1 n \sum_{i=1}^{n}\log(p)a\) where \(p\) is the probability of taking the action that it did in the state at the time, and \(a\) is the advantage of that action (how much better or worse it did than expected).

Here are a few examples. First off, if we took a big risk and chose an action with a very low probability, but if payed off with a big reward, we get a very big loss (making those actions more likely in the future):

\[\log(0.1) \times 10 = -1 \times 10 = -10\]

Meanwhile, if the reward was pretty much as expected, then there is very little change (whether we take the action or not doesn’t matter, so don’t risk changing the policy):

\[\log(0.1) \times 0.1 = -1 \times 0.1 = -0.1\]

However, even if the reward is a lot bigger than expected, if the policy already thought that was a good move (high probability) there’s no real need to change the policy.

\[\log(0.9) \times 10 \approx -0.045 \times 10 \approx -0.45\]

Of course, the inverse applies for negative advantages. Note that all these losses are going to have their sign flipped before being fed into the network.

    total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy)

Finally, we combine the critic loss, actor loss and entropy (to incentivise exploration) to get the final loss. They are each multiplied by their respective coefficients to fine-tune the impact they have on the model’s training.

    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
        
    return values.mean().item()    

After computing the loss, we just have to apply it to the model. We use gradient clipping to prevent catastrophic failure in the event of an unusually large reward.

The Worker

We need multiple copies of the environment to decorrelate the data as described earlier. As such, we have a class to take care of instantiating an environment and getting data from it.

class Worker(object):
    def __init__(self, env_name):
        self.env = common.wrappers.make_atari(env_name)
        self.env = common.wrappers.wrap_deepmind(self.env, scale=True)
        self.env = common.wrappers.wrap_pytorch(self.env)
        self.episode_reward = 0
        self.state = FloatTensor(self.env.reset())
        
    def get_batch(self):
        states, actions, rewards, dones = [], [], [], []
        for _ in range(batch_size):
            action = model.act(self.state.unsqueeze(0))
            next_state, reward, done, _ = self.env.step(action)
            self.episode_reward += reward

            states.append(self.state)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            
            if done:
                self.state = FloatTensor(self.env.reset())
                data['episode_rewards'].append(self.episode_reward)
                self.episode_reward = 0
            else:
                self.state = FloatTensor(next_state)
                
        values = compute_true_values(states, rewards, dones).unsqueeze(1)
        return states, actions, values

There’s really nothing fancy going on here. If you’ve ever used OpenAI’s Gym you’ll recognise everything here. We simply get an action from the model, send it to the environment, get the environment’s next state and reward, then repeat.

Plotting our Data

Being able to see some data on how the training is going is essential when building an agent.

def plot(data, frame_idx):
    clear_output(True)
    plt.figure(figsize=(20, 5))
    if data['episode_rewards']:
        ax = plt.subplot(121)
        ax = plt.gca()
        average_score = np.mean(data['episode_rewards'][-100:])
        plt.title(f"Frame: {frame_idx} - Average Score: {average_score}")
        plt.grid()
        plt.plot(data['episode_rewards'])
    if data['values']:
        ax = plt.subplot(122)
        average_value = np.mean(data['values'][-1000:])
        plt.title(f"Frame: {frame_idx} - Average Values: {average_value}")
        plt.plot(data['values'])
    plt.show()

Here, I’ve used Matplotlib to plot some very simple data about the training. I chose to display two metrics (the total reward for each episode and the average value for each episode returned by the critic) for good reasons. Firstly, the total reward for each episode is absolutely necessary because without it you can’t tell how well your agent is doing (and if it has finished training). However, the episode score has very high variance, making it a not terribly accurate indicator of progress. This is why we also include the value returned by the critic. The value is a much more stable metric (even if it does introduce a lot of bias) that tells you how well the agent thinks it’s doing. It also serves some other purposes, such as telling you if your learning rate is too low (if the critic is taking a very long time to reach an accurate estimate of it’s own skill, you can probably increase the learning rate) and telling you if training has diverged (if the agent massively over or underestimates its skill and the gap continues to grow, you’ve diverged).

Executing the Training

First, we have some initialisation to do. This means setting up the model, the optimiser, workers, etc.

model = Model(env.action_space.n).cuda()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5)
memory = Memory()
workers = []
for _ in range(no_of_workers):
    workers.append(Worker(env_name))
frame_idx = 0
data = {
    'episode_rewards': [],
    'values': []
}

I found RMSProp worked better here than Adam. When it comes to picking hyperparameters, it’s often best to just try a few things and see what works. I’m certain the ones I’ve found aren’t the best, but they’ll do for now. If anyone running this at home finds better hyperparameters, please let me know.

Then, finally, we run the training. This takes a few hours for me, but it’s easy to lose yourself in watching the rewards creep up over time.

%debug
state = FloatTensor(env.reset())
episode_reward = 0
while frame_idx < max_frames:
    for worker in workers:
        states, actions, true_values = worker.get_batch()
        for i, _ in enumerate(states):
            memory.push(
                states[i],
                actions[i],
                true_values[i]
            )
        frame_idx += batch_size
        
    value = reflect(memory)
    if frame_idx % 1000 == 0:
        data['values'].append(value)
        plot(data, frame_idx)

Again, there is almost nothing new going on here. Each worker is polled in order for their batch, then the batch is added to the memory. Once we have data from every worker, we reflect on the data (as described above).

Potential Improvements

  • Max pooling in the feature extraction
  • Each environment in it’s own thread
  • Use an LSTM or GRU for learning temporal information

Conclusion

I hope this has helped someone out there. I know that when I first implemented A2C, I was frustrated at the lack of good writeups on how such a commonplace algorithm works. If anyone has questions, you can contact me via the email on my Github page.