Lecture 22: A2C Demo Code
Advantage Actor-Critic (A2C) Demo
This code demonstrates a minimal implementation of the Advantage Actor-Critic (A2C) algorithm in PyTorch on the CartPole-v1 environment. A2C extends vanilla policy gradient (REINFORCE) by using a learned value function (the critic) as a baseline, which dramatically reduces the variance of the gradient estimates used to update the policy (the actor).
Key ideas illustrated:
- Separate actor (policy) and critic (value) networks
- Advantage estimation using the Monte Carlo return minus the critic’s value estimate
- Simultaneous policy and value function optimization
- Variance reduction compared to REINFORCE
# Partially Adapted from https://towardsdatascience.com/learning-reinforcement-learning-reinforce-with-pytorch-5e8ad7fc7da0
# %%
import numpy as np
import gymnasium as gym
import torch
from torch import nn
import torch.nn.functional as F
import random
from collections import deque
import copy
from tqdm.std import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# %%
env = gym.make('CartPole-v1')
render_env = gym.make('CartPole-v1', render_mode='human')
n_state = int(np.prod(env.observation_space.shape))
n_action = env.action_space.n
print("# of state", n_state)
print("# of action", n_action)
# %%
def run_episode(env, policy):
obs_list = []
act_list = []
reward_list = []
next_obs_list = []
done_list = []
obs = env.reset()[0]
while True:
action = policy(obs)
next_obs, reward, terminated, truncated, _ = env.step(action)
reward_list.append(reward), obs_list.append(obs), \
done_list.append(terminated), act_list.append(action), \
next_obs_list.append(next_obs)
if terminated or truncated:
break
obs = next_obs
return obs_list, act_list, reward_list, next_obs_list, done_list
# %%
class Policy():
def __init__(self, n_state, n_action):
# Define network
self.act_net = nn.Sequential(
nn.Linear(n_state, 16),
nn.ReLU(),
nn.Linear(16, n_action),
nn.Softmax(dim=-1)
)
self.act_net.to(device)
self.v_net = nn.Sequential(
nn.Linear(n_state, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
self.v_net.to(device)
self.gamma = 0.99
self.act_optimizer = torch.optim.Adam(
self.act_net.parameters(), lr=1e-3)
self.v_optimizer = torch.optim.Adam(self.v_net.parameters(), lr=1e-3)
def __call__(self, state):
with torch.no_grad():
state = torch.FloatTensor(state).to(device)
action_probs = self.act_net(state).detach().cpu().numpy()
action = np.random.choice(n_action, p=action_probs)
return action
def update(self, data):
obs, act, reward, next_obs, done = data
# Calculate culmulative return
returns = np.zeros_like(reward)
s = 0
for i in reversed(range(len(returns))):
s = s * self.gamma + reward[i]
returns[i] = s
obs = torch.FloatTensor(obs).to(device)
reward = torch.FloatTensor(reward).to(device)
next_obs = torch.FloatTensor(next_obs).to(device)
returns = torch.FloatTensor(returns).to(device)
# Actions are used as indices, must be
# LongTensor
act = torch.LongTensor(act).to(device)
done = torch.FloatTensor(done).to(device)
# Calculate loss
batch_size = 32
list = [j for j in range(len(obs))]
# Calculate advantages (No need to use (1 - done) on returns, as it involves no neural network prediction)
with torch.no_grad():
adv = returns - self.v_net(obs).squeeze(-1)
for i in range(0, len(list), batch_size):
index = list[i:i+batch_size]
for _ in range(1):
logprob = torch.log(self.act_net(obs[index, :]) + 1e-8)
adv_logprob = adv[index] * torch.gather(logprob, 1,
act[index, None]).squeeze()
act_loss = -adv_logprob.mean()
self.act_optimizer.zero_grad()
act_loss.backward()
self.act_optimizer.step()
for _ in range(2):
with torch.no_grad():
td_target = reward[index, None] + self.gamma * \
(1 - done[index, None]) * self.v_net(next_obs[index, :])
v_loss = F.mse_loss(
self.v_net(obs[index, :]), td_target)
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
return act_loss.item(), v_loss.item()
# %%
loss_act_list, loss_v_list, reward_list = [], [], []
policy = Policy(n_state, n_action)
loss_act, loss_v = 0, 0
n_step = 0
for i in tqdm(range(2000)):
data = run_episode(env, policy)
for _ in range(5):
loss_act, loss_v = policy.update(data)
rew = sum(data[2])
if i > 0 and i % 50 == 0:
print("itr:({:>5d}) loss_act:{:>6.4f} loss_v:{:>6.4f} reward:{:>3.1f}".format(
i, np.mean(loss_act_list[-50:]), np.mean(loss_v_list[-50:]), np.mean(reward_list[-50:])))
if i > 0 and i % 500 == 0:
run_episode(render_env, policy)
loss_act_list.append(loss_act), loss_v_list.append(
loss_v), reward_list.append(rew)
# %%
render_env.close()
scores = [sum(run_episode(env, policy)[2]) for _ in range(100)]
print("Final score:", np.mean(scores))
import pandas as pd
df = pd.DataFrame({'loss_act': loss_act_list, 'reward': reward_list})
df.to_csv("./ClassMaterials/Lecture_19_Actor_Critic/data/A2C.csv",
index=False, header=True)
Walk-through of the Code
The Actor-Critic networks
self.act_net = nn.Sequential(
nn.Linear(n_state, 16), nn.ReLU(),
nn.Linear(16, n_action), nn.Softmax(dim=-1),
)
self.v_net = nn.Sequential(
nn.Linear(n_state, 64), nn.ReLU(),
nn.Linear(64, 1),
)
- Actor (
act_net) maps a state to a probability distribution over the two actions via a softmax output. - Critic (
v_net) maps a state to a scalar estimate of its value \(V(s)\).
Each network has its own Adam optimizer with learning rate 1e-3, so the two losses do not compete for a shared optimizer step.
Action selection — Policy.__call__
with torch.no_grad():
action_probs = self.act_net(state).detach().cpu().numpy()
action = np.random.choice(n_action, p=action_probs)
Actions are sampled from the policy’s distribution, not chosen greedily. This stochasticity is the only source of exploration in the algorithm — a fact that will become important when we discuss its limitations.
Advantage estimation
The update first computes discounted Monte Carlo returns for the trajectory:
\[G_t = \sum_{k=0}^{T-t} \gamma^k r_{t+k}\]s = 0
for i in reversed(range(len(returns))):
s = s * self.gamma + reward[i]
returns[i] = s
and then forms the advantage by subtracting the critic’s value estimate:
\[A(s_t, a_t) \approx G_t - V_\phi(s_t)\]with torch.no_grad():
adv = returns - self.v_net(obs).squeeze(-1)
Subtracting \(V_\phi(s_t)\) is a baseline that does not bias the policy gradient but dramatically reduces its variance — the central insight of actor-critic methods. The torch.no_grad() context ensures the advantage is treated as a fixed signal when updating the actor; it is not something we differentiate through.
Actor update — the policy gradient
logprob = torch.log(self.act_net(obs[index, :]) + 1e-8)
adv_logprob = adv[index] * torch.gather(logprob, 1, act[index, None]).squeeze()
act_loss = -adv_logprob.mean()
This is a direct implementation of the policy gradient theorem:
\[\nabla_\theta J(\theta) = E\!\left[ A(s, a)\, \nabla_\theta \log \pi_\theta(a \mid s) \right]\]We gather the log-probability of the action actually taken, multiply by its advantage, take the mean, and negate it (so that minimizing act_loss maximizes the expected return). The small 1e-8 added inside the log guards against log(0) when one action’s probability has saturated.
Critic update — bootstrapped TD(0) target
with torch.no_grad():
td_target = reward[index, None] + self.gamma * \
(1 - done[index, None]) * self.v_net(next_obs[index, :])
v_loss = F.mse_loss(self.v_net(obs[index, :]), td_target)
The critic is trained by regressing \(V_\phi(s)\) toward the one-step TD target:
\[y_t = r_t + \gamma \,(1 - d_t)\, V_\phi(s_{t+1})\]Two details worth emphasizing:
(1 - done)masks bootstrapping at terminal states. When the pole falls, there is no “future value” to add — the target is simplyr_t.- The target is wrapped in
torch.no_grad(). Otherwise, gradients would flow through \(V_\phi(s_{t+1})\) as well as \(V_\phi(s_t)\), and the optimizer could cheat by pushing both toward a shared constant — a classic and subtle instability.
The critic is updated 2 times per mini-batch while the actor is updated once. This is a common pattern: the critic learns a regression problem and tolerates more steps per batch, while the actor — being on-policy — should not be over-fit to stale data.
Training loop
for i in tqdm(range(2000)):
data = run_episode(env, policy)
for _ in range(5):
loss_act, loss_v = policy.update(data)
For each of 2000 episodes: roll out a trajectory, then perform several passes of updates over that trajectory. Every 50 iterations the mean reward is printed; every 500 iterations the current policy is rendered so you can watch it play.
Problems and Limitations of This Method
While this demo illustrates the mechanics of A2C clearly, it also exhibits several well-known shortcomings that motivate the more sophisticated algorithms covered later in the course.
1. Policy collapse from lack of exploration
The only exploration mechanism is the stochasticity of the softmax policy. Once the actor becomes confident on one action — pushing its probability close to 1 — the probability of ever sampling the other action approaches zero. From that point on, the agent has effectively lost its means to recover: it cannot generate the data needed to learn that the abandoned action might have been better in some state. This is especially damaging when the policy collapses to a locally good but globally suboptimal behavior early in training. The training curve typically shows long stretches of good performance followed by a sudden collapse that the algorithm cannot climb back out of.
Common remedies include an entropy bonus in the actor loss (\(L_{\text{actor}} = -E[A \log \pi] - \beta \, H[\pi]\)), periodic policy resets, or methods with structured exploration.
2. High variance despite the baseline
Subtracting \(V_\phi(s)\) reduces variance relative to REINFORCE, but the remaining variance is still substantial because we are using the full Monte Carlo return \(G_t\) as the advantage estimator. Individual episode returns in CartPole can differ by an order of magnitude purely due to random action sampling, and each return is propagated through the policy gradient with no smoothing. Generalized Advantage Estimation (GAE) — which blends \(n\)-step TD returns with a decay parameter \(\lambda\) — is the standard next step and is covered in a subsequent lab.
3. On-policy data inefficiency
Every trajectory is used for a handful of gradient steps and then thrown away. Because the policy gradient is only valid in expectation under the current policy, we cannot safely reuse old data. This makes A2C dramatically less sample-efficient than off-policy methods like DQN, DDPG, or SAC, which learn from a replay buffer. On harder environments, the wall-clock cost of A2C’s on-policy constraint becomes prohibitive.
4. No gradient clipping, no learning-rate scheduling
The demo uses a fixed learning rate and no gradient clipping. In more complex environments, policy-gradient methods routinely need both to stay stable. The next evolution of this algorithm — Proximal Policy Optimization (PPO) — addresses these issues directly by constraining the size of each policy update, and is the de facto workhorse of modern on-policy RL.