Itachi-1824
fix: brutal audit β€” reset tool_call_counts, date dedup, unused vars, playground overhaul with scenario picker + status dashboard
b4d7ce3
"""
State-graph engine for the EU AI Act Compliance Auditor.
Each scenario is a directed graph where:
- Nodes represent audit states (e.g., INITIAL, CLASSIFYING, AUDITING_DATA)
- Edges represent tool calls with outcomes: progress / no_effect / worsened
- BFS depth from current node to RESOLVED gives partial credit
- Wrong actions can push the audit backward (worsened transitions)
- Parameter randomization prevents memorization
Reward is computed from 6 components:
1. Classification accuracy (20%) β€” correct risk category
2. Finding completeness (25%) β€” found X of Y violations
3. Finding precision (15%) β€” penalty for false positives
4. Remediation quality (15%) β€” correct priority ordering
5. Process methodology (15%) β€” followed correct audit sequence
6. Efficiency (10%) β€” steps vs optimal path
"""
from __future__ import annotations
import random
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
# ---------------------------------------------------------------------------
# State graph primitives
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class StateNode:
"""A node in the audit state graph."""
id: str
label: str
is_terminal: bool = False
is_start: bool = False
@dataclass(frozen=True)
class Transition:
"""An edge in the audit state graph."""
from_state: str
to_state: str
tool_name: str
outcome: str # "progress" | "no_effect" | "worsened"
required_args: Dict[str, Any] = field(default_factory=dict)
description: str = ""
class StateGraph:
"""Directed graph of audit states with BFS-based partial credit."""
def __init__(self):
self.nodes: Dict[str, StateNode] = {}
self.transitions: List[Transition] = []
self._adjacency: Dict[str, List[Transition]] = {}
self._start_node: Optional[str] = None
self._terminal_nodes: Set[str] = set()
def add_node(self, node: StateNode) -> None:
self.nodes[node.id] = node
if node.id not in self._adjacency:
self._adjacency[node.id] = []
if node.is_start:
self._start_node = node.id
if node.is_terminal:
self._terminal_nodes.add(node.id)
def add_transition(self, t: Transition) -> None:
self.transitions.append(t)
if t.from_state not in self._adjacency:
self._adjacency[t.from_state] = []
self._adjacency[t.from_state].append(t)
@property
def start_node(self) -> str:
if self._start_node is None:
raise ValueError("No start node defined")
return self._start_node
def get_transitions(self, state_id: str) -> List[Transition]:
return self._adjacency.get(state_id, [])
def get_progress_transitions(self, state_id: str) -> List[Transition]:
return [t for t in self.get_transitions(state_id) if t.outcome == "progress"]
def bfs_distance_to_terminal(self, state_id: str) -> int:
"""BFS shortest path from state_id to any terminal node."""
if state_id in self._terminal_nodes:
return 0
visited = {state_id}
queue = deque([(state_id, 0)])
while queue:
current, dist = queue.popleft()
for t in self.get_transitions(current):
if t.outcome == "progress" and t.to_state not in visited:
if t.to_state in self._terminal_nodes:
return dist + 1
visited.add(t.to_state)
queue.append((t.to_state, dist + 1))
return 999 # unreachable
def optimal_path_length(self) -> int:
"""Minimum steps from start to any terminal."""
return self.bfs_distance_to_terminal(self.start_node)
def total_progress_states(self) -> int:
"""Total number of non-terminal states reachable via progress transitions."""
visited = set()
queue = deque([self.start_node])
while queue:
current = queue.popleft()
if current in visited:
continue
visited.add(current)
for t in self.get_transitions(current):
if t.outcome == "progress":
queue.append(t.to_state)
return len(visited)
# ---------------------------------------------------------------------------
# Scenario definition
# ---------------------------------------------------------------------------
@dataclass
class AuditScenario:
"""A complete compliance audit scenario with state graph and ground truth."""
scenario_id: str
title: str
difficulty: str # easy | medium | hard
description: str # initial alert/assignment text
# The AI system being audited
system_name: str
system_description: str
system_category: str # prohibited | high_risk | limited_risk | minimal_risk
deployer_info: str
# State graph
graph: StateGraph = field(default_factory=StateGraph)
# Ground truth for grading
correct_classification: str = "" # prohibited | high_risk | limited_risk | minimal_risk
ground_truth_findings: List[str] = field(default_factory=list)
required_remediation: List[str] = field(default_factory=list)
red_herrings: List[str] = field(default_factory=list)
# Investigation documents (rich text requiring analysis β€” no pre-digested verdicts)
documentation_data: str = ""
training_data_info: str = ""
oversight_info: str = ""
transparency_info: str = ""
risk_assessment_info: str = ""
logging_info: str = ""
# Deep-dive documents (revealed on repeat tool calls β€” adaptive depth)
deep_documentation: str = ""
deep_training_data: str = ""
deep_oversight: str = ""
deep_transparency: str = ""
deep_risk_assessment: str = ""
deep_logging: str = ""
# Randomization parameters (re-rolled on each reset)
_rand_params: Dict[str, str] = field(default_factory=dict)
def randomize(self, seed: Optional[int] = None) -> None:
"""Re-roll randomizable parameters to prevent memorization."""
rng = random.Random(seed)
company_names = [
"TechNova Solutions", "QuantumLeap AI", "NeuralPath Inc",
"DataForge Systems", "CogniTech Labs", "AlphaWave AI",
"SynthMind Corp", "PrismAI Technologies", "Vertex Analytics",
"OmniSense AI", "DeepCurrent Inc", "StrataLogic Systems",
]
regions = ["EU-West", "EU-Central", "EU-North", "EU-South", "EU-East"]
versions = ["v2.1", "v3.0", "v4.2", "v5.1", "v1.8", "v6.0"]
deploy_date = f"2026-{rng.randint(1, 3):02d}-{rng.randint(1, 28):02d}"
user_ct = rng.randint(10000, 5000000)
self._rand_params = {
"company": rng.choice(company_names),
"region": rng.choice(regions),
"version": rng.choice(versions),
"date": deploy_date,
"usercount": f"{user_ct:,}",
"deployment_date": deploy_date,
"user_count": str(user_ct),
}
def get_param(self, key: str) -> str:
return self._rand_params.get(key, "Unknown")
# ---------------------------------------------------------------------------
# Reward computation (6 components)
# ---------------------------------------------------------------------------
def safe_reward(score: float) -> float:
"""Clamp reward to (0, 1) exclusive β€” required by OpenEnv validator."""
return max(0.001, min(0.999, score))
@dataclass
class RewardBreakdown:
classification: float = 0.0 # 20%
finding_completeness: float = 0.0 # 25%
finding_precision: float = 0.0 # 15%
remediation: float = 0.0 # 15%
methodology: float = 0.0 # 15%
efficiency: float = 0.0 # 10%
def total(self) -> float:
raw = (
self.classification * 0.20
+ self.finding_completeness * 0.25
+ self.finding_precision * 0.15
+ self.remediation * 0.15
+ self.methodology * 0.15
+ self.efficiency * 0.10
)
return safe_reward(raw)
def to_dict(self) -> Dict[str, float]:
return {
"classification": round(self.classification, 3),
"finding_completeness": round(self.finding_completeness, 3),
"finding_precision": round(self.finding_precision, 3),
"remediation": round(self.remediation, 3),
"methodology": round(self.methodology, 3),
"efficiency": round(self.efficiency, 3),
"total": round(self.total(), 4),
}
def compute_reward(
scenario: AuditScenario,
classification_submitted: str,
findings_submitted: List[str],
remediation_submitted: List[str],
tool_sequence: List[str],
steps_taken: int,
) -> RewardBreakdown:
"""Compute the 6-component reward for a completed audit."""
breakdown = RewardBreakdown()
# 1. Classification accuracy (20%)
if classification_submitted.lower().strip() == scenario.correct_classification.lower():
breakdown.classification = 1.0
elif _partial_classification_match(classification_submitted, scenario.correct_classification):
breakdown.classification = 0.4
else:
breakdown.classification = 0.0
# 2. Finding completeness (25%) β€” recall of ground truth findings
if scenario.ground_truth_findings:
found = set(f.lower().strip() for f in findings_submitted if len(f.strip()) >= 3)
truth = set(f.lower() for f in scenario.ground_truth_findings)
def _token_match(submitted: str, ground_truth: str) -> bool:
s_tok = set(submitted.replace("-", "_").split("_"))
t_tok = set(ground_truth.replace("-", "_").split("_"))
s_tok.discard("")
t_tok.discard("")
overlap = len(s_tok & t_tok)
return overlap >= 2 or (t_tok and overlap / len(t_tok) >= 0.4) or submitted == ground_truth
matches = sum(1 for t in truth if any(_token_match(f, t) for f in found))
breakdown.finding_completeness = matches / len(truth)
else:
breakdown.finding_completeness = 1.0 # no findings expected
# 3. Finding precision (15%) β€” penalize false positives
if findings_submitted:
found = set(f.lower().strip() for f in findings_submitted)
truth = set(f.lower() for f in scenario.ground_truth_findings)
red = set(r.lower() for r in scenario.red_herrings)
true_positives = sum(1 for f in found if any(t in f or f in t for t in truth))
false_positives = sum(1 for f in found if any(r in f or f in r for r in red))
total = len(found)
if total > 0:
precision = true_positives / total
red_herring_penalty = false_positives * 0.15
breakdown.finding_precision = max(0.0, precision - red_herring_penalty)
else:
breakdown.finding_precision = 0.0
else:
breakdown.finding_precision = 0.0
# 4. Remediation quality (15%) β€” correct fixes in priority order
if scenario.required_remediation:
rem_lower = [r.lower().strip() for r in remediation_submitted]
req_lower = [r.lower() for r in scenario.required_remediation]
# Check presence
matches = sum(1 for req in req_lower if any(req in r or r in req for r in rem_lower))
presence_score = matches / len(req_lower)
# Check ordering (bonus if in correct priority)
order_score = _check_ordering(rem_lower, req_lower)
breakdown.remediation = presence_score * 0.7 + order_score * 0.3
else:
breakdown.remediation = 1.0
# 5. Process methodology (15%) β€” correct audit sequence
expected_sequence = [
"get_system_overview", "classify_system", "check_documentation", "audit_training_data",
"verify_human_oversight", "check_transparency", "assess_risk_management", "check_logging",
]
actual_tools = [t for t in tool_sequence if t in expected_sequence]
if actual_tools:
# Score based on how many tools were used in the expected order
order_violations = 0
for i in range(len(actual_tools) - 1):
if actual_tools[i] in expected_sequence and actual_tools[i + 1] in expected_sequence:
idx_a = expected_sequence.index(actual_tools[i])
idx_b = expected_sequence.index(actual_tools[i + 1])
if idx_b < idx_a:
order_violations += 1
coverage = len(set(actual_tools)) / len(expected_sequence)
order_penalty = min(order_violations * 0.15, 0.5)
breakdown.methodology = max(0.0, coverage - order_penalty)
else:
breakdown.methodology = 0.0
# 6. Efficiency (10%) β€” steps vs optimal
# Anti-gaming: agent must take at least as many steps as optimal to get full efficiency
# Taking FEWER steps than optimal means skipping investigation β†’ penalized
optimal = scenario.graph.optimal_path_length()
if optimal > 0 and steps_taken > 0:
if steps_taken < optimal:
# Took fewer steps than optimal = skipped investigation
breakdown.efficiency = steps_taken / optimal * 0.5 # penalty
else:
# Normal: efficiency decreases as steps increase beyond optimal
breakdown.efficiency = min(optimal / steps_taken, 1.0)
else:
breakdown.efficiency = 0.3
return breakdown
def _partial_classification_match(submitted: str, correct: str) -> bool:
"""Check if classification is partially correct (e.g., high_risk vs limited_risk)."""
risk_levels = ["prohibited", "high_risk", "limited_risk", "minimal_risk"]
sub = submitted.lower().strip().replace("-", "_").replace(" ", "_")
cor = correct.lower().strip()
if sub not in risk_levels or cor not in risk_levels:
return False
return abs(risk_levels.index(sub) - risk_levels.index(cor)) == 1
def _check_ordering(submitted: List[str], required: List[str]) -> float:
"""Score how well submitted items match the required priority order."""
if not submitted or not required:
return 0.0
matched_indices = []
for req in required:
for i, sub in enumerate(submitted):
if req in sub or sub in req:
matched_indices.append(i)
break
if len(matched_indices) < 2:
return 0.5
# Check if matched items are in increasing order
in_order = sum(1 for i in range(len(matched_indices) - 1) if matched_indices[i] < matched_indices[i + 1])
return in_order / (len(matched_indices) - 1)