Advantage Actor-Critic (A2C) Demo

This code demonstrates the implementation of the Advantage Actor-Critic (A2C) algorithm using PyTorch. A2C improves upon vanilla policy gradient methods by using a value function (critic) to reduce variance in gradient estimates.

Key concepts illustrated:

  • Separate actor (policy) and critic (value) networks
  • Advantage estimation using temporal difference learning
  • Simultaneous policy and value function optimization
  • Reduced variance compared to REINFORCE
A2C Implementation
Download
# Partially Adapted from https://towardsdatascience.com/learning-reinforcement-learning-reinforce-with-pytorch-5e8ad7fc7da0

# %%
import numpy as np
import gymnasium as gym
import torch
from torch import nn
import torch.nn.functional as F
import random
from collections import deque
import copy

from tqdm.std import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# %%
env = gym.make('CartPole-v1')
n_state = int(np.prod(env.observation_space.shape))
n_action = env.action_space.n
print("# of state", n_state)
print("# of action", n_action)

# %%


def run_episode(env, policy, render=False):
    obs_list = []
    act_list = []
    reward_list = []
    next_obs_list = []
    done_list = []
    obs = env.reset()[0]
    while True:
        if render:
            env.render()

        action = policy(obs)
        next_obs, reward, done, _, _ = env.step(action)
        reward_list.append(reward), obs_list.append(obs), \
            done_list.append(done), act_list.append(action), \
            next_obs_list.append(next_obs)
        if done:
            break
        obs = next_obs

    return obs_list, act_list, reward_list, next_obs_list, done_list
# %%


class Policy():
    def __init__(self, n_state, n_action):
        # Define network
        self.act_net = nn.Sequential(
            nn.Linear(n_state, 16),
            nn.ReLU(),
            nn.Linear(16, n_action),
            nn.Softmax()
        )
        self.act_net.to(device)
        self.v_net = nn.Sequential(
            nn.Linear(n_state, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )
        self.v_net.to(device)
        self.gamma = 0.99
        self.act_optimizer = torch.optim.Adam(
            self.act_net.parameters(), lr=1e-3)
        self.v_optimizer = torch.optim.Adam(self.v_net.parameters(), lr=1e-3)

    def __call__(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action_probs = self.act_net(state).detach().cpu().numpy()
            action = np.random.choice(n_action, p=action_probs)
        return action

    def update(self, data):
        obs, act, reward, next_obs, done = data
        # Calculate culmulative return
        returns = np.zeros_like(reward)
        s = 0
        for i in reversed(range(len(returns))):
            s = s * self.gamma + reward[i]
            returns[i] = s

        obs = torch.FloatTensor(obs).to(device)
        reward = torch.FloatTensor(reward).to(device)
        next_obs = torch.FloatTensor(next_obs).to(device)
        returns = torch.FloatTensor(returns).to(device)
        # Actions are used as indices, must be
        # LongTensor
        act = torch.LongTensor(act).to(device)
        done = torch.FloatTensor(done).to(device)
        # Calculate loss
        batch_size = 32
        list = [j for j in range(len(obs))]
        # Calculate advantages A inf
        with torch.no_grad():
            adv = (1 - done) * returns - \
                self.v_net(obs).squeeze_()
        for i in range(0, len(list), batch_size):
            index = list[i:i+batch_size]
            for _ in range(1):
                logprob = torch.log(self.act_net(obs[index, :]))
                adv_logprob = adv[index] * torch.gather(logprob, 1,
                                                        act[index, None]).squeeze()
                act_loss = -adv_logprob.mean()
                self.act_optimizer.zero_grad()
                act_loss.backward()
                self.act_optimizer.step()

            for _ in range(5):

                td_target = reward[index, None] + self.gamma * \
                    self.v_net(next_obs[index, :])
                v_loss = F.mse_loss(
                    td_target, self.v_net(obs[index, :]))
                self.v_optimizer.zero_grad()
                v_loss.backward()
                self.v_optimizer.step()

        return act_loss.item(), v_loss.item()


# %%
loss_act_list, loss_v_list, reward_list = [], [], []
policy = Policy(n_state, n_action)
loss_act, loss_v = 0, 0
n_step = 0
for i in tqdm(range(2000)):
    data = run_episode(env, policy)
    for _ in range(5):
        loss_act, loss_v = policy.update(data)
    rew = sum(data[2])
    if i > 0 and i % 50 == 0:
        print("itr:({:>5d}) loss_act:{:>6.4f} loss_v:{:>6.4f} reward:{:>3.1f}".format(
            i, np.mean(loss_act_list[-50:]), np.mean(loss_v_list[-50:]), np.mean(reward_list[-50:])))

    loss_act_list.append(loss_act), loss_v_list.append(
        loss_v), reward_list.append(rew)

# %%
scores = [sum(run_episode(env, policy, False)[2]) for _ in range(100)]
print("Final score:", np.mean(scores))

import pandas as pd
df = pd.DataFrame({'loss_act': loss_act_list, 'reward': reward_list})
df.to_csv("./ClassMaterials/Lecture_19_Actor_Critic/data/A2C.csv",
          index=False, header=True)