Advantage Actor-Critic (A2C) Demo

This code demonstrates a minimal implementation of the Advantage Actor-Critic (A2C) algorithm in PyTorch on the CartPole-v1 environment. A2C extends vanilla policy gradient (REINFORCE) by using a learned value function (the critic) as a baseline, which dramatically reduces the variance of the gradient estimates used to update the policy (the actor).

Key ideas illustrated:

  • Separate actor (policy) and critic (value) networks
  • Advantage estimation using the Monte Carlo return minus the critic’s value estimate
  • Simultaneous policy and value function optimization
  • Variance reduction compared to REINFORCE
A2C Implementation
Download
# Partially Adapted from https://towardsdatascience.com/learning-reinforcement-learning-reinforce-with-pytorch-5e8ad7fc7da0

# %%
import numpy as np
import gymnasium as gym
import torch
from torch import nn
import torch.nn.functional as F
import random
from collections import deque
import copy

from tqdm.std import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# %%
env = gym.make('CartPole-v1')
render_env = gym.make('CartPole-v1', render_mode='human')
n_state = int(np.prod(env.observation_space.shape))
n_action = env.action_space.n
print("# of state", n_state)
print("# of action", n_action)

# %%


def run_episode(env, policy):
    obs_list = []
    act_list = []
    reward_list = []
    next_obs_list = []
    done_list = []
    obs = env.reset()[0]
    while True:
        action = policy(obs)
        next_obs, reward, terminated, truncated, _ = env.step(action)
        reward_list.append(reward), obs_list.append(obs), \
            done_list.append(terminated), act_list.append(action), \
            next_obs_list.append(next_obs)
        if terminated or truncated:
            break
        obs = next_obs

    return obs_list, act_list, reward_list, next_obs_list, done_list
# %%


class Policy():
    def __init__(self, n_state, n_action):
        # Define network
        self.act_net = nn.Sequential(
            nn.Linear(n_state, 16),
            nn.ReLU(),
            nn.Linear(16, n_action),
            nn.Softmax(dim=-1)
        )
        self.act_net.to(device)
        self.v_net = nn.Sequential(
            nn.Linear(n_state, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )
        self.v_net.to(device)
        self.gamma = 0.99
        self.act_optimizer = torch.optim.Adam(
            self.act_net.parameters(), lr=1e-3)
        self.v_optimizer = torch.optim.Adam(self.v_net.parameters(), lr=1e-3)

    def __call__(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action_probs = self.act_net(state).detach().cpu().numpy()
            action = np.random.choice(n_action, p=action_probs)
        return action

    def update(self, data):
        obs, act, reward, next_obs, done = data
        # Calculate culmulative return
        returns = np.zeros_like(reward)
        s = 0
        for i in reversed(range(len(returns))):
            s = s * self.gamma + reward[i]
            returns[i] = s

        obs = torch.FloatTensor(obs).to(device)
        reward = torch.FloatTensor(reward).to(device)
        next_obs = torch.FloatTensor(next_obs).to(device)
        returns = torch.FloatTensor(returns).to(device)
        # Actions are used as indices, must be
        # LongTensor
        act = torch.LongTensor(act).to(device)
        done = torch.FloatTensor(done).to(device)
        # Calculate loss
        batch_size = 32
        list = [j for j in range(len(obs))]
        # Calculate advantages (No need to use (1 - done) on returns, as it involves no neural network prediction)
        with torch.no_grad():
            adv = returns - self.v_net(obs).squeeze(-1)
        for i in range(0, len(list), batch_size):
            index = list[i:i+batch_size]
            for _ in range(1):
                logprob = torch.log(self.act_net(obs[index, :]) + 1e-8)
                adv_logprob = adv[index] * torch.gather(logprob, 1,
                                                        act[index, None]).squeeze()
                act_loss = -adv_logprob.mean()
                self.act_optimizer.zero_grad()
                act_loss.backward()
                self.act_optimizer.step()

            for _ in range(2):
                with torch.no_grad():
                    td_target = reward[index, None] + self.gamma * \
                        (1 - done[index, None]) * self.v_net(next_obs[index, :])
                v_loss = F.mse_loss(
                    self.v_net(obs[index, :]), td_target)
                self.v_optimizer.zero_grad()
                v_loss.backward()
                self.v_optimizer.step()

        return act_loss.item(), v_loss.item()


# %%
loss_act_list, loss_v_list, reward_list = [], [], []
policy = Policy(n_state, n_action)
loss_act, loss_v = 0, 0
n_step = 0
for i in tqdm(range(2000)):
    data = run_episode(env, policy)
    for _ in range(5):
        loss_act, loss_v = policy.update(data)
    rew = sum(data[2])
    if i > 0 and i % 50 == 0:
        print("itr:({:>5d}) loss_act:{:>6.4f} loss_v:{:>6.4f} reward:{:>3.1f}".format(
            i, np.mean(loss_act_list[-50:]), np.mean(loss_v_list[-50:]), np.mean(reward_list[-50:])))
    if i > 0 and i % 500 == 0:
        run_episode(render_env, policy)

    loss_act_list.append(loss_act), loss_v_list.append(
        loss_v), reward_list.append(rew)

# %%
render_env.close()
scores = [sum(run_episode(env, policy)[2]) for _ in range(100)]
print("Final score:", np.mean(scores))

import pandas as pd
df = pd.DataFrame({'loss_act': loss_act_list, 'reward': reward_list})
df.to_csv("./ClassMaterials/Lecture_19_Actor_Critic/data/A2C.csv",
          index=False, header=True)

Walk-through of the Code

The Actor-Critic networks

self.act_net = nn.Sequential(
    nn.Linear(n_state, 16), nn.ReLU(),
    nn.Linear(16, n_action), nn.Softmax(dim=-1),
)
self.v_net = nn.Sequential(
    nn.Linear(n_state, 64), nn.ReLU(),
    nn.Linear(64, 1),
)
  • Actor (act_net) maps a state to a probability distribution over the two actions via a softmax output.
  • Critic (v_net) maps a state to a scalar estimate of its value \(V(s)\).

Each network has its own Adam optimizer with learning rate 1e-3, so the two losses do not compete for a shared optimizer step.

Action selection — Policy.__call__

with torch.no_grad():
    action_probs = self.act_net(state).detach().cpu().numpy()
    action = np.random.choice(n_action, p=action_probs)

Actions are sampled from the policy’s distribution, not chosen greedily. This stochasticity is the only source of exploration in the algorithm — a fact that will become important when we discuss its limitations.

Advantage estimation

The update first computes discounted Monte Carlo returns for the trajectory:

\[G_t = \sum_{k=0}^{T-t} \gamma^k r_{t+k}\]
s = 0
for i in reversed(range(len(returns))):
    s = s * self.gamma + reward[i]
    returns[i] = s

and then forms the advantage by subtracting the critic’s value estimate:

\[A(s_t, a_t) \approx G_t - V_\phi(s_t)\]
with torch.no_grad():
    adv = returns - self.v_net(obs).squeeze(-1)

Subtracting \(V_\phi(s_t)\) is a baseline that does not bias the policy gradient but dramatically reduces its variance — the central insight of actor-critic methods. The torch.no_grad() context ensures the advantage is treated as a fixed signal when updating the actor; it is not something we differentiate through.

Actor update — the policy gradient

logprob = torch.log(self.act_net(obs[index, :]) + 1e-8)
adv_logprob = adv[index] * torch.gather(logprob, 1, act[index, None]).squeeze()
act_loss = -adv_logprob.mean()

This is a direct implementation of the policy gradient theorem:

\[\nabla_\theta J(\theta) = E\!\left[ A(s, a)\, \nabla_\theta \log \pi_\theta(a \mid s) \right]\]

We gather the log-probability of the action actually taken, multiply by its advantage, take the mean, and negate it (so that minimizing act_loss maximizes the expected return). The small 1e-8 added inside the log guards against log(0) when one action’s probability has saturated.

Critic update — bootstrapped TD(0) target

with torch.no_grad():
    td_target = reward[index, None] + self.gamma * \
        (1 - done[index, None]) * self.v_net(next_obs[index, :])
v_loss = F.mse_loss(self.v_net(obs[index, :]), td_target)

The critic is trained by regressing \(V_\phi(s)\) toward the one-step TD target:

\[y_t = r_t + \gamma \,(1 - d_t)\, V_\phi(s_{t+1})\]

Two details worth emphasizing:

  1. (1 - done) masks bootstrapping at terminal states. When the pole falls, there is no “future value” to add — the target is simply r_t.
  2. The target is wrapped in torch.no_grad(). Otherwise, gradients would flow through \(V_\phi(s_{t+1})\) as well as \(V_\phi(s_t)\), and the optimizer could cheat by pushing both toward a shared constant — a classic and subtle instability.

The critic is updated 2 times per mini-batch while the actor is updated once. This is a common pattern: the critic learns a regression problem and tolerates more steps per batch, while the actor — being on-policy — should not be over-fit to stale data.

Training loop

for i in tqdm(range(2000)):
    data = run_episode(env, policy)
    for _ in range(5):
        loss_act, loss_v = policy.update(data)

For each of 2000 episodes: roll out a trajectory, then perform several passes of updates over that trajectory. Every 50 iterations the mean reward is printed; every 500 iterations the current policy is rendered so you can watch it play.


Problems and Limitations of This Method

While this demo illustrates the mechanics of A2C clearly, it also exhibits several well-known shortcomings that motivate the more sophisticated algorithms covered later in the course.

1. Policy collapse from lack of exploration

The only exploration mechanism is the stochasticity of the softmax policy. Once the actor becomes confident on one action — pushing its probability close to 1 — the probability of ever sampling the other action approaches zero. From that point on, the agent has effectively lost its means to recover: it cannot generate the data needed to learn that the abandoned action might have been better in some state. This is especially damaging when the policy collapses to a locally good but globally suboptimal behavior early in training. The training curve typically shows long stretches of good performance followed by a sudden collapse that the algorithm cannot climb back out of.

Common remedies include an entropy bonus in the actor loss (\(L_{\text{actor}} = -E[A \log \pi] - \beta \, H[\pi]\)), periodic policy resets, or methods with structured exploration.

2. High variance despite the baseline

Subtracting \(V_\phi(s)\) reduces variance relative to REINFORCE, but the remaining variance is still substantial because we are using the full Monte Carlo return \(G_t\) as the advantage estimator. Individual episode returns in CartPole can differ by an order of magnitude purely due to random action sampling, and each return is propagated through the policy gradient with no smoothing. Generalized Advantage Estimation (GAE) — which blends \(n\)-step TD returns with a decay parameter \(\lambda\) — is the standard next step and is covered in a subsequent lab.

3. On-policy data inefficiency

Every trajectory is used for a handful of gradient steps and then thrown away. Because the policy gradient is only valid in expectation under the current policy, we cannot safely reuse old data. This makes A2C dramatically less sample-efficient than off-policy methods like DQN, DDPG, or SAC, which learn from a replay buffer. On harder environments, the wall-clock cost of A2C’s on-policy constraint becomes prohibitive.

4. No gradient clipping, no learning-rate scheduling

The demo uses a fixed learning rate and no gradient clipping. In more complex environments, policy-gradient methods routinely need both to stay stable. The next evolution of this algorithm — Proximal Policy Optimization (PPO) — addresses these issues directly by constraining the size of each policy update, and is the de facto workhorse of modern on-policy RL.