Technology & AI

Implementation of Coding to Train Agents for Offline Security Reinforced Priority Learning Using Conservative Q-Learning with d3rlpy and Embedded Historical Data

In this tutorial, we build a security-enhancing learning pipeline that learns entirely from static, offline data rather than live testing. We design a custom environment, generate a behavioral data set from a constrained policy, and train both a baseline Behavior Cloning and a Conservative Q-Learning agent using d3rlpy. By designing a workflow for offline datasets, careful analysis, and sequential learning objectives, we show how robust decision-making policies can be trained in settings where unsafe testing is not an option. Check it out FULL CODES here.

!pip -q install -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"


import os
import time
import random
import inspect
import numpy as np
import matplotlib.pyplot as plt


import gymnasium as gym
from gymnasium import spaces


import torch
import d3rlpy




SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)




def pick_device():
   if torch.cuda.is_available():
       return "cuda:0"
   return "cpu"




DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| device:", DEVICE)




def make_config(cls, **kwargs):
   sig = inspect.signature(cls.__init__)
   allowed = set(sig.parameters.keys())
   allowed.discard("self")
   filtered = {k: v for k, v in kwargs.items() if k in allowed}
   return cls(**filtered)

We set up the environment by installing dependencies, importing libraries, and preparing random seeds for reproducibility. We find and configure the computing device to ensure uniform performance across all systems. We also describe a utility for creating configuration objects safely for all versions of d3rlpy. Check it out FULL CODES here.

class SafetyCriticalGridWorld(gym.Env):
   metadata = {"render_modes": []}


   def __init__(
       self,
       size=15,
       max_steps=80,
       hazard_coords=None,
       start=(0, 0),
       goal=None,
       slip_prob=0.05,
       seed=0,
   ):
       super().__init__()
       self.size = int(size)
       self.max_steps = int(max_steps)
       self.start = tuple(start)
       self.goal = tuple(goal) if goal is not None else (self.size - 1, self.size - 1)
       self.slip_prob = float(slip_prob)


       if hazard_coords is None:
           hz = set()
           rng = np.random.default_rng(seed)
           for _ in range(max(1, self.size // 2)):
               x = rng.integers(2, self.size - 2)
               y = rng.integers(2, self.size - 2)
               hz.add((int(x), int(y)))
           self.hazards = hz
       else:
           self.hazards = set(tuple(x) for x in hazard_coords)


       self.action_space = spaces.Discrete(4)
       self.observation_space = spaces.Box(low=0.0, high=float(self.size - 1), shape=(2,), dtype=np.float32)


       self._rng = np.random.default_rng(seed)
       self._pos = None
       self._t = 0


   def reset(self, *, seed=None, options=None):
       if seed is not None:
           self._rng = np.random.default_rng(seed)
       self._pos = [int(self.start[0]), int(self.start[1])]
       self._t = 0
       obs = np.array(self._pos, dtype=np.float32)
       return obs, {}


   def _clip(self):
       self._pos[0] = int(np.clip(self._pos[0], 0, self.size - 1))
       self._pos[1] = int(np.clip(self._pos[1], 0, self.size - 1))


   def step(self, action):
       self._t += 1


       a = int(action)
       if self._rng.random() < self.slip_prob:
           a = int(self._rng.integers(0, 4))


       if a == 0:
           self._pos[1] += 1
       elif a == 1:
           self._pos[0] += 1
       elif a == 2:
           self._pos[1] -= 1
       elif a == 3:
           self._pos[0] -= 1


       self._clip()


       x, y = int(self._pos[0]), int(self._pos[1])
       terminated = False
       truncated = self._t >= self.max_steps


       reward = -1.0


       if (x, y) in self.hazards:
           reward = -100.0
           terminated = True


       if (x, y) == self.goal:
           reward = +50.0
           terminated = True


       obs = np.array([x, y], dtype=np.float32)
       return obs, float(reward), terminated, truncated, {}

We describe a safety-critical GridWorld environment with hazards, terminal states, and stochastic transitions. We list penalties for unsafe regions and rewards for successful completion of the task. We ensure that the environment tightly controls flexibility to reflect real-world security constraints. Check it out FULL CODES here.

def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
   x, y = int(obs[0]), int(obs[1])
   gx, gy = env.goal


   preferred = []
   if gx > x:
       preferred.append(1)
   elif gx < x:
       preferred.append(3)
   if gy > y:
       preferred.append(0)
   elif gy < y:
       preferred.append(2)


   if len(preferred) == 0:
       preferred = [int(env._rng.integers(0, 4))]


   if env._rng.random() < epsilon:
       return int(env._rng.integers(0, 4))


   candidates = []
   for a in preferred:
       nx, ny = x, y
       if a == 0:
           ny += 1
       elif a == 1:
           nx += 1
       elif a == 2:
           ny -= 1
       elif a == 3:
           nx -= 1
       nx = int(np.clip(nx, 0, env.size - 1))
       ny = int(np.clip(ny, 0, env.size - 1))
       if (nx, ny) not in env.hazards:
           candidates.append(a)


   if len(candidates) == 0:
       return preferred[0]
   return int(random.choice(candidates))




def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
   episodes = []
   for i in range(n_episodes):
       obs, _ = env.reset(seed=int(seed + i))
       obs_list = []
       act_list = []
       rew_list = []
       done_list = []


       done = False
       while not done:
           a = safe_behavior_policy(obs, env, epsilon=epsilon)
           nxt, r, terminated, truncated, _ = env.step(a)
           done = bool(terminated or truncated)


           obs_list.append(np.array(obs, dtype=np.float32))
           act_list.append(np.array([a], dtype=np.int64))
           rew_list.append(np.array([r], dtype=np.float32))
           done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))


           obs = nxt


       episodes.append(
           {
               "observations": np.stack(obs_list, axis=0),
               "actions": np.stack(act_list, axis=0),
               "rewards": np.stack(rew_list, axis=0),
               "terminals": np.stack(done_list, axis=0),
           }
       )
   return episodes




def build_mdpdataset(episodes):
   obs = np.concatenate([ep["observations"] for ep in episodes], axis=0).astype(np.float32)
   acts = np.concatenate([ep["actions"] for ep in episodes], axis=0).astype(np.int64)
   rews = np.concatenate([ep["rewards"] for ep in episodes], axis=0).astype(np.float32)
   terms = np.concatenate([ep["terminals"] for ep in episodes], axis=0).astype(np.float32)


   if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
       return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=terms)


   raise RuntimeError("d3rlpy.dataset.MDPDataset not found. Upgrade d3rlpy.")

We develop a restrictive ethics policy that generates offline data without malicious testing. We publish this policy to collect methods and organize them into sections. We then convert these plots into a format compatible with d3rlpy’s offline reading APIs. Check it out FULL CODES here.

def _get_episodes_from_dataset(dataset):
   if hasattr(dataset, "episodes") and dataset.episodes is not None:
       return dataset.episodes
   if hasattr(dataset, "get_episodes"):
       return dataset.get_episodes()
   raise AttributeError("Could not find episodes in dataset (d3rlpy version mismatch).")




def _iter_all_observations(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       if obs is None:
           continue
       for o in obs:
           yield o




def _iter_all_transitions(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       acts = getattr(ep, "actions", None)
       rews = getattr(ep, "rewards", None)
       if obs is None or acts is None:
           continue
       n = min(len(obs), len(acts))
       for i in range(n):
           o = obs[i]
           a = acts[i]
           r = rews[i] if rews is not None and i < len(rews) else None
           yield o, a, r




def visualize_dataset(dataset, env, title="Offline Dataset"):
   state_visits = np.zeros((env.size, env.size), dtype=np.float32)
   for obs in _iter_all_observations(dataset):
       x, y = int(obs[0]), int(obs[1])
       x = int(np.clip(x, 0, env.size - 1))
       y = int(np.clip(y, 0, env.size - 1))
       state_visits[y, x] += 1


   plt.figure(figsize=(6, 5))
   plt.imshow(state_visits, origin="lower")
   plt.colorbar(label="Visits")
   plt.scatter([env.start[0]], [env.start[1]], marker="o", label="start")
   plt.scatter([env.goal[0]], [env.goal[1]], marker="*", label="goal")
   if len(env.hazards) > 0:
       hz = np.array(list(env.hazards), dtype=np.int32)
       plt.scatter(hz[:, 0], hz[:, 1], marker="x", label="hazards")
   plt.title(f"{title} — State visitation")
   plt.xlabel("x")
   plt.ylabel("y")
   plt.legend()
   plt.show()


   rewards = []
   for _, _, r in _iter_all_transitions(dataset):
       if r is not None:
           rewards.append(float(r))
   if len(rewards) > 0:
       plt.figure(figsize=(6, 4))
       plt.hist(rewards, bins=60)
       plt.title(f"{title} — Reward distribution")
       plt.xlabel("reward")
       plt.ylabel("count")
       plt.show()

We use data set resources that iterate well in chunks rather than flat arrays. We visualize state visits to understand coverage and data bias in offline datasets. We also analyze the distribution of rewards to evaluate the learning signal available to the agent. Check it out FULL CODES here.

def rollout_eval(env, algo, n_episodes=25, seed=0):
   returns = []
   lengths = []
   hazard_hits = 0
   goal_hits = 0


   for i in range(n_episodes):
       obs, _ = env.reset(seed=seed + i)
       done = False
       total = 0.0
       steps = 0
       while not done:
           a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, ...])[0])
           obs, r, terminated, truncated, _ = env.step(a)
           total += float(r)
           steps += 1
           done = bool(terminated or truncated)
           if terminated:
               x, y = int(obs[0]), int(obs[1])
               if (x, y) in env.hazards:
                   hazard_hits += 1
               if (x, y) == env.goal:
                   goal_hits += 1


       returns.append(total)
       lengths.append(steps)


   return {
       "return_mean": float(np.mean(returns)),
       "return_std": float(np.std(returns)),
       "len_mean": float(np.mean(lengths)),
       "hazard_rate": float(hazard_hits / max(1, n_episodes)),
       "goal_rate": float(goal_hits / max(1, n_episodes)),
       "returns": np.asarray(returns, dtype=np.float32),
   }




def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
   rng = np.random.default_rng(seed)
   obs_all = []
   act_all = []
   for o, a, _ in _iter_all_transitions(dataset):
       obs_all.append(np.asarray(o, dtype=np.float32))
       act_all.append(int(np.asarray(a).reshape(-1)[0]))
       if len(obs_all) >= 80_000:
           break


   obs_all = np.stack(obs_all, axis=0)
   act_all = np.asarray(act_all, dtype=np.int64)


   idx = rng.choice(len(obs_all), size=min(sample_obs, len(obs_all)), replace=False)
   obs_probe = obs_all[idx]
   act_probe_data = act_all[idx]
   act_probe_pi = algo.predict(obs_probe).astype(np.int64)


   mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
   return float(mismatch.mean())




def create_discrete_bc(device):
   if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
       cls = d3rlpy.algos.DiscreteBCConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           batch_size=256,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteBC"):
       return d3rlpy.algos.DiscreteBC()
   raise RuntimeError("DiscreteBC not available in this d3rlpy version.")




def create_discrete_cql(device, conservative_weight=6.0):
   if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
       cls = d3rlpy.algos.DiscreteCQLConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           actor_learning_rate=3e-4,
           critic_learning_rate=3e-4,
           temp_learning_rate=3e-4,
           alpha_learning_rate=3e-4,
           batch_size=256,
           conservative_weight=float(conservative_weight),
           n_action_samples=10,
           rollout_interval=0,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteCQL"):
       algo = d3rlpy.algos.DiscreteCQL()
       if hasattr(algo, "conservative_weight"):
           try:
               algo.conservative_weight = float(conservative_weight)
           except Exception:
               pass
       return algo
   raise RuntimeError("DiscreteCQL not available in this d3rlpy version.")

We describe controlled test methods to measure policy effectiveness without uncontrolled tests. We calculate returns and safety metrics, including accident rates and goals. We also introduce contrast tests to measure how often the learned actions deviate from the behavior of the dataset. Check it out FULL CODES here.

def main():
   env = SafetyCriticalGridWorld(
       size=15,
       max_steps=80,
       slip_prob=0.05,
       seed=SEED,
   )


   raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
   dataset = build_mdpdataset(raw_eps)


   print("dataset built:", type(dataset).__name__)
   visualize_dataset(dataset, env, title="Behavior Dataset (Offline)")


   bc = create_discrete_bc(DEVICE)
   cql = create_discrete_cql(DEVICE, conservative_weight=6.0)


   print("nTraining Discrete BC (offline)...")
   t0 = time.time()
   bc.fit(
       dataset,
       n_steps=25_000,
       n_steps_per_epoch=2_500,
       experiment_name="grid_bc_offline",
   )
   print("BC train sec:", round(time.time() - t0, 2))


   print("nTraining Discrete CQL (offline)...")
   t0 = time.time()
   cql.fit(
       dataset,
       n_steps=80_000,
       n_steps_per_epoch=8_000,
       experiment_name="grid_cql_offline",
   )
   print("CQL train sec:", round(time.time() - t0, 2))


   print("nControlled online evaluation (small number of rollouts):")
   bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
   cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)


   print("BC :", {k: v for k, v in bc_metrics.items() if k != "returns"})
   print("CQL:", {k: v for k, v in cql_metrics.items() if k != "returns"})


   print("nOOD-ish diagnostic (policy action mismatch vs data action at same states):")
   bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
   cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
   print("BC mismatch rate :", bc_mismatch)
   print("CQL mismatch rate:", cql_mismatch)


   plt.figure(figsize=(6, 4))
   labels = ["BC", "CQL"]
   means = [bc_metrics["return_mean"], cql_metrics["return_mean"]]
   stds = [bc_metrics["return_std"], cql_metrics["return_std"]]
   plt.bar(labels, means, yerr=stds)
   plt.ylabel("Return")
   plt.title("Online Rollout Return (Controlled)")
   plt.show()


   plt.figure(figsize=(6, 4))
   plt.plot(np.sort(bc_metrics["returns"]), label="BC")
   plt.plot(np.sort(cql_metrics["returns"]), label="CQL")
   plt.xlabel("Episode (sorted)")
   plt.ylabel("Return")
   plt.title("Return Distribution (Sorted)")
   plt.legend()
   plt.show()


   out_dir = "/content/offline_rl_artifacts"
   os.makedirs(out_dir, exist_ok=True)
   bc_path = os.path.join(out_dir, "grid_bc_policy.pt")
   cql_path = os.path.join(out_dir, "grid_cql_policy.pt")


   if hasattr(bc, "save_policy"):
       bc.save_policy(bc_path)
       print("Saved BC policy:", bc_path)
   if hasattr(cql, "save_policy"):
       cql.save_policy(cql_path)
       print("Saved CQL policy:", cql_path)


   print("nDone.")




if __name__ == "__main__":
   main()

We train both Behavior Cloning and Conservative Q-Learning agents from offline data. We compare their performance using controlled release and diagnostic metrics. We complete the workflow by saving trained policies and summarizing security-aware learning outcomes.

In conclusion, we have shown that conservative Q-Learning produces a more reliable policy than simple simulation when learning from historical data in security-sensitive environments. By comparing the results of offline training, a controlled online experiment, and a distributed action conflict, we show how environmental conservation helps reduce risk-taking, non-distributive behavior. Overall, we have presented a complete, reproducible offline workflow for RL that we can extend to complex domains such as robotics, healthcare, or finance without compromising security.


Check it out FULL CODES here. Also, feel free to follow us Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to Our newspaper. Wait! are you on telegram? now you can join us on telegram too.


Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button