""" dqn_baseline.py — Dueling DQN agent for SecureAI-Guard. Usage: python dqn_baseline.py --episodes 500 --task basic_security """ import argparse import random from collections import deque, namedtuple from typing import Any, Dict, List, Tuple import numpy as np import requests import torch import torch.nn as nn import torch.optim as optim Experience = namedtuple("Experience", ["state", "action", "reward", "next_state", "done"]) class DuelingDQN(nn.Module): def __init__(self, state_size: int, action_size: int, hidden: int = 256): super().__init__() self.shared = nn.Sequential( nn.Linear(state_size, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, hidden // 2), nn.ReLU(), ) self.value = nn.Linear(hidden // 2, 1) self.advantage = nn.Linear(hidden // 2, action_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.shared(x) v = self.value(x) a = self.advantage(x) return v + (a - a.mean(dim=1, keepdim=True)) class ReplayBuffer: def __init__(self, capacity: int = 10_000): self.buf: deque = deque(maxlen=capacity) def push(self, *args): self.buf.append(Experience(*args)) def sample(self, n: int) -> List[Experience]: return random.sample(self.buf, n) def __len__(self) -> int: return len(self.buf) class DQNAgent: DECISIONS = ["allow", "block", "warn", "investigate"] def __init__(self, state_size: int = 9, action_size: int = 4, lr: float = 1e-3): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.q_net = DuelingDQN(state_size, action_size).to(self.device) self.target_net = DuelingDQN(state_size, action_size).to(self.device) self.target_net.load_state_dict(self.q_net.state_dict()) self.opt = optim.Adam(self.q_net.parameters(), lr=lr) self.buf = ReplayBuffer() self.gamma = 0.99 self.eps = 1.0 self.eps_min = 0.01 self.eps_decay = 0.995 self.batch = 64 self.update_every = 100 self._step = 0 def encode(self, obs: Dict[str, Any]) -> np.ndarray: return np.array([ obs["hf_risk_score"], obs["user_trust"] / 100.0, obs["system_fatigue"] / 100.0, float(obs["channel"] == "sms"), float(obs["channel"] == "email"), float(obs["channel"] == "web"), min(len(obs["content"]) / 500.0, 1.0), len(obs.get("threat_history", [])) / 10.0, obs["timestamp"] % 86400 / 86400, ], dtype=np.float32) def act(self, obs: Dict[str, Any], train: bool = True) -> Tuple[str, float, str]: if train and random.random() < self.eps: idx = random.randint(0, 3) else: s = torch.FloatTensor(self.encode(obs)).unsqueeze(0).to(self.device) with torch.no_grad(): idx = int(self.q_net(s).argmax()) decision = self.DECISIONS[idx] return decision, 0.85, self._reason(obs, decision) def _reason(self, obs: Dict[str, Any], decision: str) -> str: r = obs["hf_risk_score"] t = obs["user_trust"] msgs = { "block": f"High risk score ({r:.2f}) warrants blocking.", "warn": f"Moderate risk ({r:.2f}); warning user to maintain trust ({t:.0f}).", "investigate": f"Ambiguous risk ({r:.2f}); flagging for investigation.", "allow": f"Low risk ({r:.2f}) and trusted context; allowing message.", } return msgs[decision] def remember(self, obs, action_str, reward_val, next_obs, done): s = self.encode(obs) ns = self.encode(next_obs) a_idx = self.DECISIONS.index(action_str) self.buf.push(s, a_idx, reward_val, ns, done) def learn(self) -> float: if len(self.buf) < self.batch: return 0.0 batch = Experience(*zip(*self.buf.sample(self.batch))) s = torch.FloatTensor(np.array(batch.state)).to(self.device) a = torch.LongTensor(np.array(batch.action)).to(self.device) r = torch.FloatTensor(np.array(batch.reward)).to(self.device) ns = torch.FloatTensor(np.array(batch.next_state)).to(self.device) d = torch.BoolTensor(np.array(batch.done)).to(self.device) curr_q = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze() with torch.no_grad(): next_q = self.target_net(ns).max(1)[0] target = r + self.gamma * next_q * ~d loss = nn.HuberLoss()(curr_q, target) self.opt.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0) self.opt.step() self.eps = max(self.eps_min, self.eps * self.eps_decay) self._step += 1 if self._step % self.update_every == 0: self.target_net.load_state_dict(self.q_net.state_dict()) return float(loss.item()) def train(api: str, task_id: str, episodes: int): agent = DQNAgent() rewards = [] for ep in range(episodes): resp = requests.post(f"{api}/reset", json={"task_id": task_id, "seed": ep}) obs = resp.json()["observation"] total = 0.0 done = False while not done: decision, conf, reason = agent.act(obs) step_resp = requests.post( f"{api}/step", json={"action": {"decision": decision, "confidence": conf, "reasoning": reason}}, ).json() r_val = step_resp["reward"]["value"] agent.remember(obs, decision, r_val, step_resp["observation"], step_resp["done"]) agent.learn() obs = step_resp["observation"] total += r_val done = step_resp["done"] rewards.append(total) if ep % 10 == 0: avg = float(np.mean(rewards[-10:])) print(f"Episode {ep:4d} | avg_reward={avg:.4f} | eps={agent.eps:.3f}") torch.save(agent.q_net.state_dict(), "dqn_checkpoint.pth") print(f"Training complete. Checkpoint saved to dqn_checkpoint.pth") return agent, rewards if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--api", default="http://localhost:7860") parser.add_argument("--task", default="basic_security") parser.add_argument("--episodes", type=int, default=500) args = parser.parse_args() train(args.api, args.task, args.episodes)