Using Deep Q-Learning (DQN) from Scratch Using RLax JAX Haiku and Optax to Train a CartPole Reinforcement Learning Agent

In this tutorial, we use a reinforcement learning agent RLaxa research-oriented library developed by Google DeepMind to build reinforcement learning algorithms with JAX. We combine RLax with JAX, Haiku, and Optax to build a Deep Q-Learning (DQN) agent that learns to solve the CartPole environment. Instead of using a fully packaged RL framework, we assemble our own training pipeline to clearly understand how the main components of reinforcement learning work together. We describe a neural network, build a replay buffer, calculate temporal difference errors with RLax, and train an agent using light-based optimization. Also, we focus on understanding how RLax provides RL primitives that can be combined with custom reinforcement learning pipelines. We use JAX for efficient computing, Haiku for neural network modeling, and Optax for optimization.
!pip -q install "jax[cpu]" dm-haiku optax rlax gymnasium matplotlib numpy
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import random
import time
from dataclasses import dataclass
from collections import deque
import gymnasium as gym
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import rlax
seed = 42
random.seed(seed)
np.random.seed(seed)
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
num_actions = env.action_space.n
def q_network(x):
mlp = hk.Sequential([
hk.Linear(128), jax.nn.relu,
hk.Linear(128), jax.nn.relu,
hk.Linear(num_actions),
])
return mlp(x)
q_net = hk.without_apply_rng(hk.transform(q_network))
dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32)
rng = jax.random.PRNGKey(seed)
params = q_net.init(rng, dummy_obs)
target_params = params
optimizer = optax.chain(
optax.clip_by_global_norm(10.0),
optax.adam(3e-4),
)
opt_state = optimizer.init(params)We install the necessary libraries and import all the modules needed to implement the reinforcement learning pipeline. We initialize the environment, describe the structure of a neural network using Haiku, and configure a Q network that predicts action values. We also introduce the network and target network parameters, as well as the optimizer that will be used during training.
@dataclass
class Transition:
obs: np.ndarray
action: int
reward: float
discount: float
next_obs: np.ndarray
done: float
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def add(self, *args):
self.buffer.append(Transition(*args))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
obs = np.stack([t.obs for t in batch]).astype(np.float32)
action = np.array([t.action for t in batch], dtype=np.int32)
reward = np.array([t.reward for t in batch], dtype=np.float32)
discount = np.array([t.discount for t in batch], dtype=np.float32)
next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32)
done = np.array([t.done for t in batch], dtype=np.float32)
return {
"obs": obs,
"action": action,
"reward": reward,
"discount": discount,
"next_obs": next_obs,
"done": done,
}
def __len__(self):
return len(self.buffer)
replay = ReplayBuffer(capacity=50000)
def epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000):
mix = min(frame_idx / decay_frames, 1.0)
return eps_start + mix * (eps_end - eps_start)
def select_action(params, obs, epsilon):
if random.random() < epsilon:
return env.action_space.sample()
q_values = q_net.apply(params, obs[None, :])
return int(jnp.argmax(q_values[0]))We define a transition structure and use a replay buffer to store past experiences from the environment. We create tasks to add changes and sample mini-batches to be used later to train the agent. We also use the epsilon-greedy test strategy.
@jax.jit
def soft_update(target_params, online_params, tau):
return jax.tree_util.tree_map(lambda t, s: (1.0 - tau) * t + tau * s, target_params, online_params)
def batch_td_errors(params, target_params, batch):
q_tm1 = q_net.apply(params, batch["obs"])
q_t = q_net.apply(target_params, batch["next_obs"])
td_errors = jax.vmap(
lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2)
)(q_tm1, batch["action"], batch["reward"], batch["discount"], q_t)
return td_errors
@jax.jit
def train_step(params, target_params, opt_state, batch):
def loss_fn(p):
td_errors = batch_td_errors(p, target_params, batch)
loss = jnp.mean(rlax.huber_loss(td_errors, delta=1.0))
metrics = {
"loss": loss,
"td_abs_mean": jnp.mean(jnp.abs(td_errors)),
"q_mean": jnp.mean(q_net.apply(p, batch["obs"])),
}
return loss, metrics
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, metricsWe describe the core learning activities used during training. We calculate the temporal difference errors using RLax’s Q-learning primitive and calculate the loss using the Huber loss function. We then run a training step that combines the gradients, uses the optimizer’s updates, and returns the training metrics.
def evaluate_agent(params, episodes=5):
returns = []
for ep in range(episodes):
obs, _ = eval_env.reset(seed=seed + 1000 + ep)
done = False
truncated = False
total_reward = 0.0
while not (done or truncated):
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
next_obs, reward, done, truncated, _ = eval_env.step(action)
total_reward += reward
obs = next_obs
returns.append(total_reward)
return float(np.mean(returns))
num_frames = 40000
batch_size = 128
warmup_steps = 1000
train_every = 4
eval_every = 2000
gamma = 0.99
tau = 0.01
max_grad_updates_per_step = 1
obs, _ = env.reset(seed=seed)
episode_return = 0.0
episode_returns = []
eval_returns = []
losses = []
td_means = []
q_means = []
eval_steps = []
start_time = time.time()We define a test function that measures the agent’s performance. We optimize the training hyperparameters, including the number of frames, the batch size, the discount factor, and the target network update rate. We also implement a different tracking statistics for training, including episode recovery, loss, and analysis metrics.
for frame_idx in range(1, num_frames + 1):
epsilon = epsilon_by_frame(frame_idx)
action = select_action(params, obs.astype(np.float32), epsilon)
next_obs, reward, done, truncated, _ = env.step(action)
terminal = done or truncated
discount = 0.0 if terminal else gamma
replay.add(
obs.astype(np.float32),
action,
float(reward),
float(discount),
next_obs.astype(np.float32),
float(terminal),
)
obs = next_obs
episode_return += reward
if terminal:
episode_returns.append(episode_return)
obs, _ = env.reset()
episode_return = 0.0
if len(replay) >= warmup_steps and frame_idx % train_every == 0:
for _ in range(max_grad_updates_per_step):
batch_np = replay.sample(batch_size)
batch = {k: jnp.asarray(v) for k, v in batch_np.items()}
params, opt_state, metrics = train_step(params, target_params, opt_state, batch)
target_params = soft_update(target_params, params, tau)
losses.append(float(metrics["loss"]))
td_means.append(float(metrics["td_abs_mean"]))
q_means.append(float(metrics["q_mean"]))
if frame_idx % eval_every == 0:
avg_eval_return = evaluate_agent(params, episodes=5)
eval_returns.append(avg_eval_return)
eval_steps.append(frame_idx)
recent_train = np.mean(episode_returns[-10:]) if episode_returns else 0.0
recent_loss = np.mean(losses[-100:]) if losses else 0.0
print(
f"step={frame_idx:6d} | epsilon={epsilon:.3f} | "
f"recent_train_return={recent_train:7.2f} | "
f"eval_return={avg_eval_return:7.2f} | "
f"recent_loss={recent_loss:.5f} | buffer={len(replay)}"
)
elapsed = time.time() - start_time
final_eval = evaluate_agent(params, episodes=10)
print("nTraining complete")
print(f"Elapsed time: {elapsed:.1f} seconds")
print(f"Final 10-episode evaluation return: {final_eval:.2f}")
plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
plt.plot(episode_returns)
plt.title("Training Episode Returns")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.subplot(1, 3, 2)
plt.plot(eval_steps, eval_returns)
plt.title("Evaluation Returns")
plt.xlabel("Environment Steps")
plt.ylabel("Avg Return")
plt.subplot(1, 3, 3)
plt.plot(losses, label="Loss")
plt.plot(td_means, label="|TD Error| Mean")
plt.title("Optimization Metrics")
plt.xlabel("Gradient Updates")
plt.legend()
plt.tight_layout()
plt.show()
obs, _ = eval_env.reset(seed=999)
frames = []
done = False
truncated = False
total_reward = 0.0
render_env = gym.make("CartPole-v1", render_mode="rgb_array")
obs, _ = render_env.reset(seed=999)
while not (done or truncated):
frame = render_env.render()
frames.append(frame)
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
obs, reward, done, truncated, _ = render_env.step(action)
total_reward += reward
render_env.close()
print(f"Demo episode return: {total_reward:.2f}")
try:
import matplotlib.animation as animation
from IPython.display import HTML, display
fig = plt.figure(figsize=(6, 4))
patch = plt.imshow(frames[0])
plt.axis("off")
def animate(i):
patch.set_data(frames[i])
return (patch,)
anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=30, blit=True)
display(HTML(anim.to_jshtml()))
plt.close(fig)
except Exception as e:
print("Animation display skipped:", e)We use a full reinforcement learning training loop. We periodically update network parameters, test agent performance, and record metrics for visibility. Also, we plot the results of the training and provide a demonstration episode to observe how the trained agent behaves.
In conclusion, we have built a complete Deep Q-Learning agent by integrating RLax with a state-of-the-art JAX-based machine learning ecosystem. We built a neural network to estimate action values, used experience replay to stabilize learning, and calculated TD errors using RLax’s Q-learning primitive. During training, we updated the network parameters using gradient-based optimization and periodically checked the agent to track performance improvements. Also, we saw how RLax enables a modular approach to reinforcement learning by providing reusable algorithmic components instead of full algorithms. This flexibility allows us to easily explore different architectures, learning rules, and optimization strategies. By expanding this base, we can build more advanced agents, such as Double DQN, distributed reinforcement learning models, and actor criticism methods, using the same RLax primitives.
Check out The complete Notebook is here. Also, feel free to follow us Twitter and don’t forget to join our 120k+ ML SubReddit and Subscribe to Our newspaper. Wait! are you on telegram? now you can join us on telegram too.



