Deep Q-Network (DQN) Demo

This lecture demonstrates how to implement Deep Q-Networks on the CartPole-v1 environment. We build up DQN incrementally, starting from a vanilla version and adding two critical improvements: replay buffers and target networks.

Background: From Q-Learning to DQN

In tabular Q-learning, we maintain a table \(Q(s, a)\) and update it with:

\[Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]\]

DQN replaces the Q-table with a neural network \(Q_\theta(s, a)\) that takes a state as input and outputs Q-values for all actions. This allows generalization across similar states, making it feasible for environments with continuous or high-dimensional state spaces.

However, naively combining Q-learning with neural networks is unstable. Two key problems arise:

  1. Correlated samples – Consecutive transitions \((s_t, a_t, r_t, s_{t+1})\) are highly correlated, violating the i.i.d. assumption of stochastic gradient descent.
  2. Moving targets – The TD target \(r + \gamma \max_{a'} Q_\theta(s', a')\) changes with every update to \(\theta\), creating a “chasing your own tail” effect.

Replay Buffer

A replay buffer stores past transitions and samples random mini-batches for training, which breaks the temporal correlation between consecutive samples.

How It Works

  • Store each transition \((s, a, r, s', \text{done})\) into a fixed-size buffer (we use a deque with maxlen=50000)
  • When updating, sample a random mini-batch (size 32) from the buffer
  • This provides diverse, decorrelated training data at each update step

Implementation Details

class ReplayBuffer:
    def __init__(self, size):
        self.buff = deque(maxlen=size)

    def add(self, obs, act, reward, next_obs, done):
        self.buff.append([obs, act, reward, next_obs, done])

    def sample(self, sample_size):
        sample = random.sample(self.buff, sample_size)
        # Convert to tensors and return
        ...

Key design choices:

  • Buffer size (50,000): Large enough to hold diverse experiences, small enough to fit in memory. Older transitions are automatically discarded when the buffer is full.
  • Minimum buffer size (2,000): We wait until the buffer has at least 2,000 transitions before starting updates, ensuring sufficient diversity in the initial samples.
  • Batch size (32): Each update samples 32 random transitions from the buffer.

Target Network

The target network is a separate copy of the Q-network that is updated less frequently. It provides stable TD targets during training.

The Problem It Solves

Without a target network, the TD target is:

\[y = r + \gamma \max_{a'} Q_\theta(s', a')\]

Since \(Q_\theta\) is the same network being updated, every gradient step changes both the prediction and the target simultaneously. This feedback loop causes oscillations and divergence.

How It Works

  • Create a copy of the Q-network: target_net = copy.deepcopy(q_net)
  • Use target_net (not q_net) to compute TD targets
  • Periodically sync: target_net.load_state_dict(q_net.state_dict())
  • The target network is kept in eval mode and never receives gradients

The TD target becomes:

\[y = r + \gamma \max_{a'} Q_{\theta^-}(s', a')\]

where \(\theta^-\) are the frozen parameters of the target network.

Implementation Details

# In __init__:
self.target_net = copy.deepcopy(self.q_net)
self.target_net.eval()            # No dropout/batchnorm training behavior
self.network_sync_freq = 10       # Sync every 10 updates

# In update():
if self.network_sync_counter == self.network_sync_freq:
    self.target_net.load_state_dict(self.q_net.state_dict())
    self.network_sync_counter = 0
self.network_sync_counter += 1

# Target computation uses target_net, NOT q_net:
with torch.no_grad():
    y = reward + self.gamma * (1 - done) * torch.max(self.target_net(next_obs), axis=1)[0]

Key design choices:

  • Sync frequency (10 updates): This is a “hard” update strategy – the target network is completely replaced every 10 updates. An alternative is “soft” updates (\(\theta^- \leftarrow \tau \theta + (1-\tau)\theta^-\)), used in DDPG and SAC.
  • torch.no_grad(): The target computation is wrapped in no_grad() because we do not want gradients flowing through the target network.
  • (1 - done) masking: When the episode terminates, there is no future reward, so the target is simply \(r\).

The Q-Network

The Q-network is a simple feedforward network that maps states to Q-values for each action:

State (4) -> Linear(64) -> ReLU -> Linear(32) -> ReLU -> Linear(n_action)

The output layer has one neuron per action. Since the network outputs Q-values for all actions, we need to extract only the Q-value for the action that was actually taken. This is done with advanced indexing:

self.q_net(obs)[range(len(obs)), act]

q_net(obs) returns a tensor of shape (sample_size, n_action). The indexing [range(len(obs)), act] selects element act[i] from row i. For example:

q_net(obs) = [[1.2, 0.8],    # sample 0: Q(s0, left)=1.2, Q(s0, right)=0.8
              [0.5, 1.1],    # sample 1: Q(s1, left)=0.5, Q(s1, right)=1.1
              [0.9, 0.3]]    # sample 2: Q(s2, left)=0.9, Q(s2, right)=0.3

act = [1, 0, 1]              # actions actually taken

q_net(obs)[[0,1,2], [1,0,1]] = [0.8, 0.5, 0.3]  # Q-values for taken actions

This is equivalent to q_net(obs).gather(1, act.unsqueeze(1)).squeeze(1) but more concise.

Epsilon-Greedy Exploration

The policy uses epsilon-greedy action selection:

  • With probability \(\varepsilon\): choose a random action (exploration)
  • With probability \(1 - \varepsilon\): choose \(\arg\max_a Q_\theta(s, a)\) (exploitation)

Epsilon is linearly decayed from 0.5 to 0.005 over 5,000 episodes, gradually shifting from exploration to exploitation as the Q-network improves.

Training Loop

The training loop ties everything together:

  1. Collect experience: At each step, select an action via epsilon-greedy, execute it, and store \((s, a, r, s', \text{done})\) in the replay buffer.
  2. Update policy: Every 4 steps (controlled by update_index), sample a mini-batch from the replay buffer and perform one gradient step.
  3. Decay epsilon: After each episode, reduce \(\varepsilon\) linearly.
  4. Visualize progress: Every 500 episodes, run a greedy episode on a render environment to observe the agent’s current behavior.
DQN with Target Network and Replay Buffer
Download
# %% [markdown]
# In this lab, we will implement Q learning with deep neural nets.

# %%
import numpy as np
import gymnasium as gym
import torch
from torch import nn
import torch.nn.functional as F
import random
from collections import deque
import copy

from tqdm.std import tqdm


env = gym.make('CartPole-v1')
render_env = gym.make('CartPole-v1', render_mode="human")
n_state = int(np.prod(env.observation_space.shape))
n_action = env.action_space.n
print("# of state", n_state)
print("# of action", n_action)

# SEED = 1234
# torch.manual_seed(SEED)
# np.random.seed(SEED)
# random.seed(SEED)
# env.seed(SEED)
# %% [markdown]
# Given certain policy, how can we compute the value function for each state.

# %%

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def run_episode(env, policy, render=False):
    """ Runs an episode and return the total reward """
    obs = env.reset()[0]
    states = []
    rewards = []
    actions = []
    while True:
        if render:
            env.render()

        states.append(obs)
        action = int(policy(obs))
        actions.append(action)
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        rewards.append(reward)
        if done:
            break

    return states, actions, rewards


# %%
class Policy():
    def __init__(self, n_state, n_action, eps):
        self.q_net = nn.Sequential(
            nn.Linear(n_state, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, n_action)
        )
        self.eps = eps
        self.gamma = 0.95

        self.target_net = copy.deepcopy(self.q_net)
        self.target_net.to(device)
        self.target_net.eval()
        self.network_sync_counter = 0
        self.network_sync_freq = 10

        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-3)
        self.q_net.to(device)
        self.replaybuff = ReplayBuffer(50000)

    def update(self, data=None):
        obs, act, reward, next_obs, done = self.replaybuff.sample(32)

        if (self.network_sync_counter == self.network_sync_freq):
            self.target_net.load_state_dict(self.q_net.state_dict())
            self.network_sync_counter = 0
        self.network_sync_counter += 1
        self.optimizer.zero_grad()
        with torch.no_grad():
            y = reward + self.gamma * (1 - done) * \
                torch.max(self.target_net(next_obs), axis=1)[0]

        loss = F.mse_loss(y, self.q_net(obs)[range(len(obs)), act])
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def __call__(self, state):
        if np.random.rand() < self.eps:
            return np.random.choice(n_action)

        if not torch.is_tensor(state):
            state = torch.FloatTensor(state).to(device)
        with torch.no_grad():
            Q = self.q_net(state).cpu().numpy()
            act = np.argmax(Q)

        return act

# %%


class ReplayBuffer:
    def __init__(self, size):
        self.buff = deque(maxlen=size)

    def add(self, obs, act, reward, next_obs, done):
        self.buff.append([obs, act, reward, next_obs, done])

    def sample(self, sample_size):
        if(len(self.buff) < sample_size):
            sample_size = len(self.buff)

        sample = random.sample(self.buff, sample_size)
        obs = torch.FloatTensor([exp[0] for exp in sample]).to(device)
        act = torch.LongTensor([exp[1] for exp in sample]).to(device)
        reward = torch.FloatTensor([exp[2] for exp in sample]).to(device)
        next_obs = torch.FloatTensor([exp[3] for exp in sample]).to(device)
        done = torch.FloatTensor([exp[4] for exp in sample]).to(device)
        return obs, act, reward, next_obs, done

    def __len__(self):
        return len(self.buff)


# %%
losses_list, reward_list = [], []
policy = Policy(n_state, n_action, 0.5)
update_index = 0
loss = 0
for i in tqdm(range(10000)):
    obs, rew = env.reset()[0], 0
    while True:
        act = policy(obs)
        next_obs, reward, terminated, truncated, _ = env.step(act)
        done = terminated or truncated
        rew += reward

        update_index += 1
        if len(policy.replaybuff) > 2e3 and update_index > 4:
            update_index = 0
            loss = policy.update()

        policy.replaybuff.add(obs, act, reward, next_obs, done)
        obs = next_obs
        if done:
            break
    if i > 0 and i % 500 == 0:
        print("itr:({:>5d}) loss:{:>3.4f} reward:{:>3.1f}".format(
            i, np.mean(losses_list[-500:]), np.mean(reward_list[-500:])))
        old_eps = policy.eps
        policy.eps = 0.0
        run_episode(render_env, policy, render=True)
        policy.eps = old_eps
    policy.eps = max(0.005, policy.eps - 1.0/5000)

    losses_list.append(loss), reward_list.append(rew)

# %%
policy.eps = 0.0
scores = [sum(run_episode(env, policy, False)[2]) for _ in range(100)]
print("Final score:", np.mean(scores))

import pandas as pd
df = pd.DataFrame({'loss': losses_list, 'reward': reward_list})
df.to_csv("./ClassMaterials/Lecture_14_DQN/data/dqn-target-replay.csv",
          index=False, header=True)