Deep Q-Network (DQN) — Atari with Neural Networks
DQN (Mnih et al., DeepMind 2015) replaced the Q-table with a neural network, enabling RL for high-dimensional state spaces like Atari game pixels. Two key innovations make DQN stable: Experience Replay (random sample from memory buffer breaks correlations) and a Target Network (separate frozen Q-network for stable targets).
DQN with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# DQN ARCHITECTURE
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class DQNNetwork(nn.Module):
'''
Maps state to Q-values for all actions simultaneously.
For Atari: 3-layer CNN processes 84x84 grayscale frames.
For CartPole: simple MLP processes 4 sensor readings.
'''
def __init__(self, state_dim: int, n_actions: int, hidden: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions), # one Q-value per action (no activation!)
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.net(state) # [B, n_actions] -- Q(s, a) for all a
class ReplayBuffer:
'''
Experience Replay: store transitions, sample RANDOMLY for training.
WHY: breaks correlation between consecutive samples (stabilizes training).
'''
def __init__(self, capacity: int = 10_000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size: int):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
)
def __len__(self): return len(self.buffer)
class DQNAgent:
def __init__(self, state_dim: int, n_actions: int,
lr: float = 1e-4, gamma: float = 0.99,
epsilon_start: float = 1.0, epsilon_end: float = 0.01,
target_update_freq: int = 1000):
self.n_actions = n_actions
self.gamma = gamma
self.target_update_freq = target_update_freq
self.steps = 0
self.q_net = DQNNetwork(state_dim, n_actions)
self.target_net = DQNNetwork(state_dim, n_actions) # TARGET NETWORK (frozen)
self.target_net.load_state_dict(self.q_net.state_dict())
self.target_net.eval() # never update via backprop
self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
self.buffer = ReplayBuffer(capacity=50_000)
# Linear epsilon decay
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay_steps = 10_000
@property
def epsilon(self) -> float:
progress = min(self.steps / self.epsilon_decay_steps, 1.0)
return self.epsilon_start + progress * (self.epsilon_end - self.epsilon_start)
def act(self, state: np.ndarray) -> int:
if random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
q_vals = self.q_net(torch.FloatTensor(state).unsqueeze(0))
return q_vals.argmax().item()
def train_step(self, batch_size: int = 64) -> float:
if len(self.buffer) < batch_size:
return 0.0
states, actions, rewards, next_states, dones = self.buffer.sample(batch_size)
# Q(s, a) from online network
q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# Target: r + gamma * max_a' Q_target(s', a')
with torch.no_grad():
next_q = self.target_net(next_states).max(1).values
targets = rewards + self.gamma * next_q * (1 - dones)
loss = nn.HuberLoss()(q_values, targets) # Huber loss is more stable than MSE for RL
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0) # gradient clipping
self.optimizer.step()
self.steps += 1
# Periodically copy online net weights to target net
if self.steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
return loss.item()
# Training on CartPole (balance a pole on a moving cart)
import gym
env = gym.make("CartPole-v1") # state: 4 values, actions: 2 (left/right)
agent = DQNAgent(state_dim=4, n_actions=2)
for ep in range(500):
state, _ = env.reset()
total_reward = 0
for _ in range(500): # max 500 steps per episode
action = agent.act(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.buffer.push(state, action, reward, next_state, float(done))
agent.train_step()
state = next_state
total_reward += reward
if done: break
if (ep + 1) % 100 == 0:
print(f"Episode {ep+1} | Reward: {total_reward:.0f} | Epsilon: {agent.epsilon:.3f}")
# Expected: 500/500 (perfect balance) after ~300 episodesTip
Tip
Practice Deep QNetwork DQN Atari with Neural Networks in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of Deep QNetwork DQN Atari with Neural Networks from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Deep QNetwork DQN Atari with Neural Networks is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.