Itachi-1824
fix: brutal audit β reset tool_call_counts, date dedup, unused vars, playground overhaul with scenario picker + status dashboard
b4d7ce3 | """ | |
| State-graph engine for the EU AI Act Compliance Auditor. | |
| Each scenario is a directed graph where: | |
| - Nodes represent audit states (e.g., INITIAL, CLASSIFYING, AUDITING_DATA) | |
| - Edges represent tool calls with outcomes: progress / no_effect / worsened | |
| - BFS depth from current node to RESOLVED gives partial credit | |
| - Wrong actions can push the audit backward (worsened transitions) | |
| - Parameter randomization prevents memorization | |
| Reward is computed from 6 components: | |
| 1. Classification accuracy (20%) β correct risk category | |
| 2. Finding completeness (25%) β found X of Y violations | |
| 3. Finding precision (15%) β penalty for false positives | |
| 4. Remediation quality (15%) β correct priority ordering | |
| 5. Process methodology (15%) β followed correct audit sequence | |
| 6. Efficiency (10%) β steps vs optimal path | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Set | |
| # --------------------------------------------------------------------------- | |
| # State graph primitives | |
| # --------------------------------------------------------------------------- | |
| class StateNode: | |
| """A node in the audit state graph.""" | |
| id: str | |
| label: str | |
| is_terminal: bool = False | |
| is_start: bool = False | |
| class Transition: | |
| """An edge in the audit state graph.""" | |
| from_state: str | |
| to_state: str | |
| tool_name: str | |
| outcome: str # "progress" | "no_effect" | "worsened" | |
| required_args: Dict[str, Any] = field(default_factory=dict) | |
| description: str = "" | |
| class StateGraph: | |
| """Directed graph of audit states with BFS-based partial credit.""" | |
| def __init__(self): | |
| self.nodes: Dict[str, StateNode] = {} | |
| self.transitions: List[Transition] = [] | |
| self._adjacency: Dict[str, List[Transition]] = {} | |
| self._start_node: Optional[str] = None | |
| self._terminal_nodes: Set[str] = set() | |
| def add_node(self, node: StateNode) -> None: | |
| self.nodes[node.id] = node | |
| if node.id not in self._adjacency: | |
| self._adjacency[node.id] = [] | |
| if node.is_start: | |
| self._start_node = node.id | |
| if node.is_terminal: | |
| self._terminal_nodes.add(node.id) | |
| def add_transition(self, t: Transition) -> None: | |
| self.transitions.append(t) | |
| if t.from_state not in self._adjacency: | |
| self._adjacency[t.from_state] = [] | |
| self._adjacency[t.from_state].append(t) | |
| def start_node(self) -> str: | |
| if self._start_node is None: | |
| raise ValueError("No start node defined") | |
| return self._start_node | |
| def get_transitions(self, state_id: str) -> List[Transition]: | |
| return self._adjacency.get(state_id, []) | |
| def get_progress_transitions(self, state_id: str) -> List[Transition]: | |
| return [t for t in self.get_transitions(state_id) if t.outcome == "progress"] | |
| def bfs_distance_to_terminal(self, state_id: str) -> int: | |
| """BFS shortest path from state_id to any terminal node.""" | |
| if state_id in self._terminal_nodes: | |
| return 0 | |
| visited = {state_id} | |
| queue = deque([(state_id, 0)]) | |
| while queue: | |
| current, dist = queue.popleft() | |
| for t in self.get_transitions(current): | |
| if t.outcome == "progress" and t.to_state not in visited: | |
| if t.to_state in self._terminal_nodes: | |
| return dist + 1 | |
| visited.add(t.to_state) | |
| queue.append((t.to_state, dist + 1)) | |
| return 999 # unreachable | |
| def optimal_path_length(self) -> int: | |
| """Minimum steps from start to any terminal.""" | |
| return self.bfs_distance_to_terminal(self.start_node) | |
| def total_progress_states(self) -> int: | |
| """Total number of non-terminal states reachable via progress transitions.""" | |
| visited = set() | |
| queue = deque([self.start_node]) | |
| while queue: | |
| current = queue.popleft() | |
| if current in visited: | |
| continue | |
| visited.add(current) | |
| for t in self.get_transitions(current): | |
| if t.outcome == "progress": | |
| queue.append(t.to_state) | |
| return len(visited) | |
| # --------------------------------------------------------------------------- | |
| # Scenario definition | |
| # --------------------------------------------------------------------------- | |
| class AuditScenario: | |
| """A complete compliance audit scenario with state graph and ground truth.""" | |
| scenario_id: str | |
| title: str | |
| difficulty: str # easy | medium | hard | |
| description: str # initial alert/assignment text | |
| # The AI system being audited | |
| system_name: str | |
| system_description: str | |
| system_category: str # prohibited | high_risk | limited_risk | minimal_risk | |
| deployer_info: str | |
| # State graph | |
| graph: StateGraph = field(default_factory=StateGraph) | |
| # Ground truth for grading | |
| correct_classification: str = "" # prohibited | high_risk | limited_risk | minimal_risk | |
| ground_truth_findings: List[str] = field(default_factory=list) | |
| required_remediation: List[str] = field(default_factory=list) | |
| red_herrings: List[str] = field(default_factory=list) | |
| # Investigation documents (rich text requiring analysis β no pre-digested verdicts) | |
| documentation_data: str = "" | |
| training_data_info: str = "" | |
| oversight_info: str = "" | |
| transparency_info: str = "" | |
| risk_assessment_info: str = "" | |
| logging_info: str = "" | |
| # Deep-dive documents (revealed on repeat tool calls β adaptive depth) | |
| deep_documentation: str = "" | |
| deep_training_data: str = "" | |
| deep_oversight: str = "" | |
| deep_transparency: str = "" | |
| deep_risk_assessment: str = "" | |
| deep_logging: str = "" | |
| # Randomization parameters (re-rolled on each reset) | |
| _rand_params: Dict[str, str] = field(default_factory=dict) | |
| def randomize(self, seed: Optional[int] = None) -> None: | |
| """Re-roll randomizable parameters to prevent memorization.""" | |
| rng = random.Random(seed) | |
| company_names = [ | |
| "TechNova Solutions", "QuantumLeap AI", "NeuralPath Inc", | |
| "DataForge Systems", "CogniTech Labs", "AlphaWave AI", | |
| "SynthMind Corp", "PrismAI Technologies", "Vertex Analytics", | |
| "OmniSense AI", "DeepCurrent Inc", "StrataLogic Systems", | |
| ] | |
| regions = ["EU-West", "EU-Central", "EU-North", "EU-South", "EU-East"] | |
| versions = ["v2.1", "v3.0", "v4.2", "v5.1", "v1.8", "v6.0"] | |
| deploy_date = f"2026-{rng.randint(1, 3):02d}-{rng.randint(1, 28):02d}" | |
| user_ct = rng.randint(10000, 5000000) | |
| self._rand_params = { | |
| "company": rng.choice(company_names), | |
| "region": rng.choice(regions), | |
| "version": rng.choice(versions), | |
| "date": deploy_date, | |
| "usercount": f"{user_ct:,}", | |
| "deployment_date": deploy_date, | |
| "user_count": str(user_ct), | |
| } | |
| def get_param(self, key: str) -> str: | |
| return self._rand_params.get(key, "Unknown") | |
| # --------------------------------------------------------------------------- | |
| # Reward computation (6 components) | |
| # --------------------------------------------------------------------------- | |
| def safe_reward(score: float) -> float: | |
| """Clamp reward to (0, 1) exclusive β required by OpenEnv validator.""" | |
| return max(0.001, min(0.999, score)) | |
| class RewardBreakdown: | |
| classification: float = 0.0 # 20% | |
| finding_completeness: float = 0.0 # 25% | |
| finding_precision: float = 0.0 # 15% | |
| remediation: float = 0.0 # 15% | |
| methodology: float = 0.0 # 15% | |
| efficiency: float = 0.0 # 10% | |
| def total(self) -> float: | |
| raw = ( | |
| self.classification * 0.20 | |
| + self.finding_completeness * 0.25 | |
| + self.finding_precision * 0.15 | |
| + self.remediation * 0.15 | |
| + self.methodology * 0.15 | |
| + self.efficiency * 0.10 | |
| ) | |
| return safe_reward(raw) | |
| def to_dict(self) -> Dict[str, float]: | |
| return { | |
| "classification": round(self.classification, 3), | |
| "finding_completeness": round(self.finding_completeness, 3), | |
| "finding_precision": round(self.finding_precision, 3), | |
| "remediation": round(self.remediation, 3), | |
| "methodology": round(self.methodology, 3), | |
| "efficiency": round(self.efficiency, 3), | |
| "total": round(self.total(), 4), | |
| } | |
| def compute_reward( | |
| scenario: AuditScenario, | |
| classification_submitted: str, | |
| findings_submitted: List[str], | |
| remediation_submitted: List[str], | |
| tool_sequence: List[str], | |
| steps_taken: int, | |
| ) -> RewardBreakdown: | |
| """Compute the 6-component reward for a completed audit.""" | |
| breakdown = RewardBreakdown() | |
| # 1. Classification accuracy (20%) | |
| if classification_submitted.lower().strip() == scenario.correct_classification.lower(): | |
| breakdown.classification = 1.0 | |
| elif _partial_classification_match(classification_submitted, scenario.correct_classification): | |
| breakdown.classification = 0.4 | |
| else: | |
| breakdown.classification = 0.0 | |
| # 2. Finding completeness (25%) β recall of ground truth findings | |
| if scenario.ground_truth_findings: | |
| found = set(f.lower().strip() for f in findings_submitted if len(f.strip()) >= 3) | |
| truth = set(f.lower() for f in scenario.ground_truth_findings) | |
| def _token_match(submitted: str, ground_truth: str) -> bool: | |
| s_tok = set(submitted.replace("-", "_").split("_")) | |
| t_tok = set(ground_truth.replace("-", "_").split("_")) | |
| s_tok.discard("") | |
| t_tok.discard("") | |
| overlap = len(s_tok & t_tok) | |
| return overlap >= 2 or (t_tok and overlap / len(t_tok) >= 0.4) or submitted == ground_truth | |
| matches = sum(1 for t in truth if any(_token_match(f, t) for f in found)) | |
| breakdown.finding_completeness = matches / len(truth) | |
| else: | |
| breakdown.finding_completeness = 1.0 # no findings expected | |
| # 3. Finding precision (15%) β penalize false positives | |
| if findings_submitted: | |
| found = set(f.lower().strip() for f in findings_submitted) | |
| truth = set(f.lower() for f in scenario.ground_truth_findings) | |
| red = set(r.lower() for r in scenario.red_herrings) | |
| true_positives = sum(1 for f in found if any(t in f or f in t for t in truth)) | |
| false_positives = sum(1 for f in found if any(r in f or f in r for r in red)) | |
| total = len(found) | |
| if total > 0: | |
| precision = true_positives / total | |
| red_herring_penalty = false_positives * 0.15 | |
| breakdown.finding_precision = max(0.0, precision - red_herring_penalty) | |
| else: | |
| breakdown.finding_precision = 0.0 | |
| else: | |
| breakdown.finding_precision = 0.0 | |
| # 4. Remediation quality (15%) β correct fixes in priority order | |
| if scenario.required_remediation: | |
| rem_lower = [r.lower().strip() for r in remediation_submitted] | |
| req_lower = [r.lower() for r in scenario.required_remediation] | |
| # Check presence | |
| matches = sum(1 for req in req_lower if any(req in r or r in req for r in rem_lower)) | |
| presence_score = matches / len(req_lower) | |
| # Check ordering (bonus if in correct priority) | |
| order_score = _check_ordering(rem_lower, req_lower) | |
| breakdown.remediation = presence_score * 0.7 + order_score * 0.3 | |
| else: | |
| breakdown.remediation = 1.0 | |
| # 5. Process methodology (15%) β correct audit sequence | |
| expected_sequence = [ | |
| "get_system_overview", "classify_system", "check_documentation", "audit_training_data", | |
| "verify_human_oversight", "check_transparency", "assess_risk_management", "check_logging", | |
| ] | |
| actual_tools = [t for t in tool_sequence if t in expected_sequence] | |
| if actual_tools: | |
| # Score based on how many tools were used in the expected order | |
| order_violations = 0 | |
| for i in range(len(actual_tools) - 1): | |
| if actual_tools[i] in expected_sequence and actual_tools[i + 1] in expected_sequence: | |
| idx_a = expected_sequence.index(actual_tools[i]) | |
| idx_b = expected_sequence.index(actual_tools[i + 1]) | |
| if idx_b < idx_a: | |
| order_violations += 1 | |
| coverage = len(set(actual_tools)) / len(expected_sequence) | |
| order_penalty = min(order_violations * 0.15, 0.5) | |
| breakdown.methodology = max(0.0, coverage - order_penalty) | |
| else: | |
| breakdown.methodology = 0.0 | |
| # 6. Efficiency (10%) β steps vs optimal | |
| # Anti-gaming: agent must take at least as many steps as optimal to get full efficiency | |
| # Taking FEWER steps than optimal means skipping investigation β penalized | |
| optimal = scenario.graph.optimal_path_length() | |
| if optimal > 0 and steps_taken > 0: | |
| if steps_taken < optimal: | |
| # Took fewer steps than optimal = skipped investigation | |
| breakdown.efficiency = steps_taken / optimal * 0.5 # penalty | |
| else: | |
| # Normal: efficiency decreases as steps increase beyond optimal | |
| breakdown.efficiency = min(optimal / steps_taken, 1.0) | |
| else: | |
| breakdown.efficiency = 0.3 | |
| return breakdown | |
| def _partial_classification_match(submitted: str, correct: str) -> bool: | |
| """Check if classification is partially correct (e.g., high_risk vs limited_risk).""" | |
| risk_levels = ["prohibited", "high_risk", "limited_risk", "minimal_risk"] | |
| sub = submitted.lower().strip().replace("-", "_").replace(" ", "_") | |
| cor = correct.lower().strip() | |
| if sub not in risk_levels or cor not in risk_levels: | |
| return False | |
| return abs(risk_levels.index(sub) - risk_levels.index(cor)) == 1 | |
| def _check_ordering(submitted: List[str], required: List[str]) -> float: | |
| """Score how well submitted items match the required priority order.""" | |
| if not submitted or not required: | |
| return 0.0 | |
| matched_indices = [] | |
| for req in required: | |
| for i, sub in enumerate(submitted): | |
| if req in sub or sub in req: | |
| matched_indices.append(i) | |
| break | |
| if len(matched_indices) < 2: | |
| return 0.5 | |
| # Check if matched items are in increasing order | |
| in_order = sum(1 for i in range(len(matched_indices) - 1) if matched_indices[i] < matched_indices[i + 1]) | |
| return in_order / (len(matched_indices) - 1) | |