Lecture 18: Deep Q-Networks (DQN)
Deep Q-Network (DQN) Demo
This lecture demonstrates how to implement Deep Q-Networks on the CartPole-v1 environment. We build up DQN incrementally, starting from a vanilla version and adding two critical improvements: replay buffers and target networks.
Background: From Q-Learning to DQN
In tabular Q-learning, we maintain a table \(Q(s, a)\) and update it with:
\[Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]\]DQN replaces the Q-table with a neural network \(Q_\theta(s, a)\) that takes a state as input and outputs Q-values for all actions. This allows generalization across similar states, making it feasible for environments with continuous or high-dimensional state spaces.
However, naively combining Q-learning with neural networks is unstable. Two key problems arise:
- Correlated samples – Consecutive transitions \((s_t, a_t, r_t, s_{t+1})\) are highly correlated, violating the i.i.d. assumption of stochastic gradient descent.
- Moving targets – The TD target \(r + \gamma \max_{a'} Q_\theta(s', a')\) changes with every update to \(\theta\), creating a “chasing your own tail” effect.
Replay Buffer
A replay buffer stores past transitions and samples random mini-batches for training, which breaks the temporal correlation between consecutive samples.
How It Works
- Store each transition \((s, a, r, s', \text{done})\) into a fixed-size buffer (we use a
dequewithmaxlen=50000) - When updating, sample a random mini-batch (size 32) from the buffer
- This provides diverse, decorrelated training data at each update step
Implementation Details
class ReplayBuffer:
def __init__(self, size):
self.buff = deque(maxlen=size)
def add(self, obs, act, reward, next_obs, done):
self.buff.append([obs, act, reward, next_obs, done])
def sample(self, sample_size):
sample = random.sample(self.buff, sample_size)
# Convert to tensors and return
...
Key design choices:
- Buffer size (50,000): Large enough to hold diverse experiences, small enough to fit in memory. Older transitions are automatically discarded when the buffer is full.
- Minimum buffer size (2,000): We wait until the buffer has at least 2,000 transitions before starting updates, ensuring sufficient diversity in the initial samples.
- Batch size (32): Each update samples 32 random transitions from the buffer.
Target Network
The target network is a separate copy of the Q-network that is updated less frequently. It provides stable TD targets during training.
The Problem It Solves
Without a target network, the TD target is:
\[y = r + \gamma \max_{a'} Q_\theta(s', a')\]Since \(Q_\theta\) is the same network being updated, every gradient step changes both the prediction and the target simultaneously. This feedback loop causes oscillations and divergence.
How It Works
- Create a copy of the Q-network:
target_net = copy.deepcopy(q_net) - Use
target_net(notq_net) to compute TD targets - Periodically sync:
target_net.load_state_dict(q_net.state_dict()) - The target network is kept in eval mode and never receives gradients
The TD target becomes:
\[y = r + \gamma \max_{a'} Q_{\theta^-}(s', a')\]where \(\theta^-\) are the frozen parameters of the target network.
Implementation Details
# In __init__:
self.target_net = copy.deepcopy(self.q_net)
self.target_net.eval() # No dropout/batchnorm training behavior
self.network_sync_freq = 10 # Sync every 10 updates
# In update():
if self.network_sync_counter == self.network_sync_freq:
self.target_net.load_state_dict(self.q_net.state_dict())
self.network_sync_counter = 0
self.network_sync_counter += 1
# Target computation uses target_net, NOT q_net:
with torch.no_grad():
y = reward + self.gamma * (1 - done) * torch.max(self.target_net(next_obs), axis=1)[0]
Key design choices:
- Sync frequency (10 updates): This is a “hard” update strategy – the target network is completely replaced every 10 updates. An alternative is “soft” updates (\(\theta^- \leftarrow \tau \theta + (1-\tau)\theta^-\)), used in DDPG and SAC.
torch.no_grad(): The target computation is wrapped inno_grad()because we do not want gradients flowing through the target network.(1 - done)masking: When the episode terminates, there is no future reward, so the target is simply \(r\).
The Q-Network
The Q-network is a simple feedforward network that maps states to Q-values for each action:
State (4) -> Linear(64) -> ReLU -> Linear(32) -> ReLU -> Linear(n_action)
The output layer has one neuron per action. Since the network outputs Q-values for all actions, we need to extract only the Q-value for the action that was actually taken. This is done with advanced indexing:
self.q_net(obs)[range(len(obs)), act]
q_net(obs) returns a tensor of shape (sample_size, n_action). The indexing [range(len(obs)), act] selects element act[i] from row i. For example:
q_net(obs) = [[1.2, 0.8], # sample 0: Q(s0, left)=1.2, Q(s0, right)=0.8
[0.5, 1.1], # sample 1: Q(s1, left)=0.5, Q(s1, right)=1.1
[0.9, 0.3]] # sample 2: Q(s2, left)=0.9, Q(s2, right)=0.3
act = [1, 0, 1] # actions actually taken
q_net(obs)[[0,1,2], [1,0,1]] = [0.8, 0.5, 0.3] # Q-values for taken actions
This is equivalent to q_net(obs).gather(1, act.unsqueeze(1)).squeeze(1) but more concise.
Epsilon-Greedy Exploration
The policy uses epsilon-greedy action selection:
- With probability \(\varepsilon\): choose a random action (exploration)
- With probability \(1 - \varepsilon\): choose \(\arg\max_a Q_\theta(s, a)\) (exploitation)
Epsilon is linearly decayed from 0.5 to 0.005 over 5,000 episodes, gradually shifting from exploration to exploitation as the Q-network improves.
Training Loop
The training loop ties everything together:
- Collect experience: At each step, select an action via epsilon-greedy, execute it, and store \((s, a, r, s', \text{done})\) in the replay buffer.
- Update policy: Every 4 steps (controlled by
update_index), sample a mini-batch from the replay buffer and perform one gradient step. - Decay epsilon: After each episode, reduce \(\varepsilon\) linearly.
- Visualize progress: Every 500 episodes, run a greedy episode on a render environment to observe the agent’s current behavior.
# %% [markdown]
# In this lab, we will implement Q learning with deep neural nets.
# %%
import numpy as np
import gymnasium as gym
import torch
from torch import nn
import torch.nn.functional as F
import random
from collections import deque
import copy
from tqdm.std import tqdm
env = gym.make('CartPole-v1')
render_env = gym.make('CartPole-v1', render_mode="human")
n_state = int(np.prod(env.observation_space.shape))
n_action = env.action_space.n
print("# of state", n_state)
print("# of action", n_action)
# SEED = 1234
# torch.manual_seed(SEED)
# np.random.seed(SEED)
# random.seed(SEED)
# env.seed(SEED)
# %% [markdown]
# Given certain policy, how can we compute the value function for each state.
# %%
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def run_episode(env, policy, render=False):
""" Runs an episode and return the total reward """
obs = env.reset()[0]
states = []
rewards = []
actions = []
while True:
if render:
env.render()
states.append(obs)
action = int(policy(obs))
actions.append(action)
obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
rewards.append(reward)
if done:
break
return states, actions, rewards
# %%
class Policy():
def __init__(self, n_state, n_action, eps):
self.q_net = nn.Sequential(
nn.Linear(n_state, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, n_action)
)
self.eps = eps
self.gamma = 0.95
self.target_net = copy.deepcopy(self.q_net)
self.target_net.to(device)
self.target_net.eval()
self.network_sync_counter = 0
self.network_sync_freq = 10
self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-3)
self.q_net.to(device)
self.replaybuff = ReplayBuffer(50000)
def update(self, data=None):
obs, act, reward, next_obs, done = self.replaybuff.sample(32)
if (self.network_sync_counter == self.network_sync_freq):
self.target_net.load_state_dict(self.q_net.state_dict())
self.network_sync_counter = 0
self.network_sync_counter += 1
self.optimizer.zero_grad()
with torch.no_grad():
y = reward + self.gamma * (1 - done) * \
torch.max(self.target_net(next_obs), axis=1)[0]
loss = F.mse_loss(y, self.q_net(obs)[range(len(obs)), act])
loss.backward()
self.optimizer.step()
return loss.item()
def __call__(self, state):
if np.random.rand() < self.eps:
return np.random.choice(n_action)
if not torch.is_tensor(state):
state = torch.FloatTensor(state).to(device)
with torch.no_grad():
Q = self.q_net(state).cpu().numpy()
act = np.argmax(Q)
return act
# %%
class ReplayBuffer:
def __init__(self, size):
self.buff = deque(maxlen=size)
def add(self, obs, act, reward, next_obs, done):
self.buff.append([obs, act, reward, next_obs, done])
def sample(self, sample_size):
if(len(self.buff) < sample_size):
sample_size = len(self.buff)
sample = random.sample(self.buff, sample_size)
obs = torch.FloatTensor([exp[0] for exp in sample]).to(device)
act = torch.LongTensor([exp[1] for exp in sample]).to(device)
reward = torch.FloatTensor([exp[2] for exp in sample]).to(device)
next_obs = torch.FloatTensor([exp[3] for exp in sample]).to(device)
done = torch.FloatTensor([exp[4] for exp in sample]).to(device)
return obs, act, reward, next_obs, done
def __len__(self):
return len(self.buff)
# %%
losses_list, reward_list = [], []
policy = Policy(n_state, n_action, 0.5)
update_index = 0
loss = 0
for i in tqdm(range(10000)):
obs, rew = env.reset()[0], 0
while True:
act = policy(obs)
next_obs, reward, terminated, truncated, _ = env.step(act)
done = terminated or truncated
rew += reward
update_index += 1
if len(policy.replaybuff) > 2e3 and update_index > 4:
update_index = 0
loss = policy.update()
policy.replaybuff.add(obs, act, reward, next_obs, done)
obs = next_obs
if done:
break
if i > 0 and i % 500 == 0:
print("itr:({:>5d}) loss:{:>3.4f} reward:{:>3.1f}".format(
i, np.mean(losses_list[-500:]), np.mean(reward_list[-500:])))
old_eps = policy.eps
policy.eps = 0.0
run_episode(render_env, policy, render=True)
policy.eps = old_eps
policy.eps = max(0.005, policy.eps - 1.0/5000)
losses_list.append(loss), reward_list.append(rew)
# %%
policy.eps = 0.0
scores = [sum(run_episode(env, policy, False)[2]) for _ in range(100)]
print("Final score:", np.mean(scores))
import pandas as pd
df = pd.DataFrame({'loss': losses_list, 'reward': reward_list})
df.to_csv("./ClassMaterials/Lecture_14_DQN/data/dqn-target-replay.csv",
index=False, header=True)