SecureAI-Gaurd / dqn_baseline.py
mohdbelal010's picture
Upload folder using huggingface_hub
eccdd94 verified
Raw
History Blame Contribute Delete
6.59 kB
"""
dqn_baseline.py — Dueling DQN agent for SecureAI-Guard.
Usage:
python dqn_baseline.py --episodes 500 --task basic_security
"""
import argparse
import random
from collections import deque, namedtuple
from typing import Any, Dict, List, Tuple
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.optim as optim
Experience = namedtuple("Experience", ["state", "action", "reward", "next_state", "done"])
class DuelingDQN(nn.Module):
def __init__(self, state_size: int, action_size: int, hidden: int = 256):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_size, hidden), nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden, hidden // 2), nn.ReLU(),
)
self.value = nn.Linear(hidden // 2, 1)
self.advantage = nn.Linear(hidden // 2, action_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.shared(x)
v = self.value(x)
a = self.advantage(x)
return v + (a - a.mean(dim=1, keepdim=True))
class ReplayBuffer:
def __init__(self, capacity: int = 10_000):
self.buf: deque = deque(maxlen=capacity)
def push(self, *args):
self.buf.append(Experience(*args))
def sample(self, n: int) -> List[Experience]:
return random.sample(self.buf, n)
def __len__(self) -> int:
return len(self.buf)
class DQNAgent:
DECISIONS = ["allow", "block", "warn", "investigate"]
def __init__(self, state_size: int = 9, action_size: int = 4, lr: float = 1e-3):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.q_net = DuelingDQN(state_size, action_size).to(self.device)
self.target_net = DuelingDQN(state_size, action_size).to(self.device)
self.target_net.load_state_dict(self.q_net.state_dict())
self.opt = optim.Adam(self.q_net.parameters(), lr=lr)
self.buf = ReplayBuffer()
self.gamma = 0.99
self.eps = 1.0
self.eps_min = 0.01
self.eps_decay = 0.995
self.batch = 64
self.update_every = 100
self._step = 0
def encode(self, obs: Dict[str, Any]) -> np.ndarray:
return np.array([
obs["hf_risk_score"],
obs["user_trust"] / 100.0,
obs["system_fatigue"] / 100.0,
float(obs["channel"] == "sms"),
float(obs["channel"] == "email"),
float(obs["channel"] == "web"),
min(len(obs["content"]) / 500.0, 1.0),
len(obs.get("threat_history", [])) / 10.0,
obs["timestamp"] % 86400 / 86400,
], dtype=np.float32)
def act(self, obs: Dict[str, Any], train: bool = True) -> Tuple[str, float, str]:
if train and random.random() < self.eps:
idx = random.randint(0, 3)
else:
s = torch.FloatTensor(self.encode(obs)).unsqueeze(0).to(self.device)
with torch.no_grad():
idx = int(self.q_net(s).argmax())
decision = self.DECISIONS[idx]
return decision, 0.85, self._reason(obs, decision)
def _reason(self, obs: Dict[str, Any], decision: str) -> str:
r = obs["hf_risk_score"]
t = obs["user_trust"]
msgs = {
"block": f"High risk score ({r:.2f}) warrants blocking.",
"warn": f"Moderate risk ({r:.2f}); warning user to maintain trust ({t:.0f}).",
"investigate": f"Ambiguous risk ({r:.2f}); flagging for investigation.",
"allow": f"Low risk ({r:.2f}) and trusted context; allowing message.",
}
return msgs[decision]
def remember(self, obs, action_str, reward_val, next_obs, done):
s = self.encode(obs)
ns = self.encode(next_obs)
a_idx = self.DECISIONS.index(action_str)
self.buf.push(s, a_idx, reward_val, ns, done)
def learn(self) -> float:
if len(self.buf) < self.batch:
return 0.0
batch = Experience(*zip(*self.buf.sample(self.batch)))
s = torch.FloatTensor(np.array(batch.state)).to(self.device)
a = torch.LongTensor(np.array(batch.action)).to(self.device)
r = torch.FloatTensor(np.array(batch.reward)).to(self.device)
ns = torch.FloatTensor(np.array(batch.next_state)).to(self.device)
d = torch.BoolTensor(np.array(batch.done)).to(self.device)
curr_q = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
with torch.no_grad():
next_q = self.target_net(ns).max(1)[0]
target = r + self.gamma * next_q * ~d
loss = nn.HuberLoss()(curr_q, target)
self.opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
self.opt.step()
self.eps = max(self.eps_min, self.eps * self.eps_decay)
self._step += 1
if self._step % self.update_every == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
return float(loss.item())
def train(api: str, task_id: str, episodes: int):
agent = DQNAgent()
rewards = []
for ep in range(episodes):
resp = requests.post(f"{api}/reset", json={"task_id": task_id, "seed": ep})
obs = resp.json()["observation"]
total = 0.0
done = False
while not done:
decision, conf, reason = agent.act(obs)
step_resp = requests.post(
f"{api}/step",
json={"action": {"decision": decision, "confidence": conf, "reasoning": reason}},
).json()
r_val = step_resp["reward"]["value"]
agent.remember(obs, decision, r_val, step_resp["observation"], step_resp["done"])
agent.learn()
obs = step_resp["observation"]
total += r_val
done = step_resp["done"]
rewards.append(total)
if ep % 10 == 0:
avg = float(np.mean(rewards[-10:]))
print(f"Episode {ep:4d} | avg_reward={avg:.4f} | eps={agent.eps:.3f}")
torch.save(agent.q_net.state_dict(), "dqn_checkpoint.pth")
print(f"Training complete. Checkpoint saved to dqn_checkpoint.pth")
return agent, rewards
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api", default="http://localhost:7860")
parser.add_argument("--task", default="basic_security")
parser.add_argument("--episodes", type=int, default=500)
args = parser.parse_args()
train(args.api, args.task, args.episodes)