compliance-auditor-env / inference.py
Itachi-1824
feat: eu ai act compliance auditor — mcp-based openenv environment
5d5e37e
raw
history blame
15 kB
"""
Baseline inference for EU AI Act Compliance Auditor.
Uses OpenAI function calling through NVIDIA NIM to audit AI systems.
Connects to the live HF Space via HTTP (no WebSocket timeout issues).
Required env vars:
API_BASE_URL LLM endpoint (default: https://integrate.api.nvidia.com/v1)
MODEL_NAME Model identifier (default: google/gemma-4-31b-it)
HF_TOKEN API key for the LLM
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import sys
import time
from typing import Any, Dict, List, Optional
from openai import OpenAI
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31b-it")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
MAX_STEPS = 50
CONTEXT_CHAR_LIMIT = 100000
SYSTEM_PROMPT = """You are an expert EU AI Act compliance auditor. You must investigate AI systems and determine their compliance status.
# MISSION
Audit the AI system, classify its risk level, identify all compliance violations, recommend remediation, and submit your final determination.
# TOOLS (call them in this order)
## Investigation (gather evidence)
- get_system_overview: ALWAYS call this first — understand what you're auditing
- classify_system: Classify risk level (prohibited/high_risk/limited_risk/minimal_risk)
- check_documentation: Review Annex IV technical documentation
- audit_training_data: Check for bias, data governance (Article 10)
- verify_human_oversight: Verify Article 14 human-in-the-loop
- check_transparency: Check Article 50 transparency obligations
- assess_risk_management: Review risk management system (Article 9)
- check_logging: Verify automatic logging (Article 12)
## Resolution (after investigation)
- submit_finding: Report each violation found (call multiple times if needed)
- recommend_fix: Propose remediation for each finding
- verify_compliance: FINAL — submit your overall compliance determination
# CRITICAL RULES
- ALWAYS call get_system_overview FIRST
- INVESTIGATE before CLASSIFYING — gather evidence before judging
- For PROHIBITED systems: classify as prohibited, submit finding, recommend immediate shutdown
- For HIGH-RISK: check ALL articles (documentation, data, oversight, transparency, risk, logging)
- Call submit_finding for EACH violation separately
- Call verify_compliance LAST with your final risk_classification
"""
# ---------------------------------------------------------------------------
# Tool conversion for OpenAI function calling
# ---------------------------------------------------------------------------
def mcp_tools_to_openai(tools: List[Dict]) -> List[Dict]:
"""Convert MCP tool schemas to OpenAI function-calling format."""
openai_tools = []
for tool in tools:
name = tool.get("name", "")
description = tool.get("description", "")
schema = tool.get("inputSchema", {})
properties = {}
required = []
if schema and "properties" in schema:
for pname, pschema in schema["properties"].items():
prop = {"type": pschema.get("type", "string")}
if "description" in pschema:
prop["description"] = pschema["description"]
if "enum" in pschema:
prop["enum"] = pschema["enum"]
properties[pname] = prop
required = schema.get("required", [])
openai_tools.append({
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
})
return openai_tools
# ---------------------------------------------------------------------------
# Context management
# ---------------------------------------------------------------------------
def _summarize_tool_result(content: str, max_chars: int = 200) -> str:
if not content or len(content) <= max_chars:
return content or "(empty)"
try:
data = json.loads(content)
if "error" in data:
return f"error: {data['error'][:100]}"
return json.dumps(data)[:max_chars] + "..."
except (json.JSONDecodeError, TypeError):
return content[:max_chars] + "..."
def summarize_old_messages(messages: List[Dict]) -> List[Dict]:
"""Compress old tool calls to stay within context limits."""
total = sum(len(str(m.get("content", ""))) for m in messages)
if total <= CONTEXT_CHAR_LIMIT:
return messages
system_msg = messages[0]
user_msg = messages[1]
keep_recent = 12
split_idx = max(2, len(messages) - keep_recent)
old = messages[2:split_idx]
recent = messages[split_idx:]
lines = ["Previous audit steps:"]
i = 0
while i < len(old):
msg = old[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tc = msg["tool_calls"][0]
name = tc["function"]["name"]
args = tc["function"]["arguments"][:60]
result = "(no response)"
if i + 1 < len(old) and old[i + 1].get("role") == "tool":
result = _summarize_tool_result(old[i + 1].get("content", ""))
i += 1
lines.append(f"- {name}({args}) -> {result}")
i += 1
return [system_msg, user_msg, {"role": "user", "content": "\n".join(lines)}] + recent
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
async def run_episode(
env,
llm_client: OpenAI,
model: str,
tools: List[Dict],
difficulty: str = "medium",
scenario_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run a single compliance audit episode using OpenAI function calling."""
reset_kwargs = {"difficulty": difficulty}
if scenario_id:
reset_kwargs["scenario_id"] = scenario_id
reset_result = await env.reset(**reset_kwargs)
task_name = scenario_id or f"{difficulty}_episode"
print(f"[START] task={task_name} env=compliance_auditor_env model={model}", flush=True)
alert_msg = reset_result.get("message", "Compliance audit assigned. Call get_system_overview to begin.")
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": alert_msg},
]
step_count = 0
done = False
consecutive_text = 0
while not done and step_count < MAX_STEPS:
step_count += 1
# LLM call with retry
response = None
for attempt in range(4):
try:
response = llm_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.1,
max_tokens=500,
)
break
except Exception as e:
if "429" in str(e) or "rate" in str(e).lower():
wait = 2 ** attempt + 1
time.sleep(wait)
continue
print(f"[DEBUG] LLM error: {str(e)[:100]}", flush=True)
break
if response is None:
print(f"[END] task={task_name} score=0.01 steps={step_count}", flush=True)
return {"reward": 0.01, "error": "LLM failed", "steps": step_count}
message = response.choices[0].message
# Handle function call
if message.tool_calls:
consecutive_text = 0
tc = message.tool_calls[0]
tool_name = tc.function.name
tool_call_id = tc.id
try:
tool_args = json.loads(tc.function.arguments)
except (json.JSONDecodeError, TypeError):
messages.append({"role": "assistant", "content": None, "tool_calls": [
{"id": tool_call_id, "type": "function", "function": {"name": tool_name, "arguments": tc.function.arguments}}
]})
messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": "Error: malformed JSON. Retry."})
continue
# Add to history
messages.append({"role": "assistant", "content": None, "tool_calls": [
{"id": tool_call_id, "type": "function", "function": {"name": tool_name, "arguments": tc.function.arguments}}
]})
# Execute tool via env
try:
result_text = await env.call_tool(tool_name, **tool_args)
except Exception as e:
result_text = json.dumps({"error": str(e)})
if not isinstance(result_text, str):
result_text = json.dumps(result_text) if result_text else ""
# Check done/reward
reward = 0.0
if result_text:
try:
parsed = json.loads(result_text)
if parsed.get("done"):
done = True
if "reward" in parsed:
reward = float(parsed["reward"])
except (json.JSONDecodeError, TypeError):
pass
if hasattr(env, "_last_done") and env._last_done:
done = True
if hasattr(env, "_last_reward") and env._last_reward:
reward = max(reward, env._last_reward)
safe_reward = max(0.01, min(0.99, reward))
print(f"[STEP] step={step_count} action={tool_name} reward={safe_reward:.2f} done={'true' if done else 'false'} error=null", flush=True)
if done:
final_score = max(0.01, min(0.99, reward))
print(f"[END] task={task_name} score={final_score:.2f} steps={step_count}", flush=True)
return {"reward": reward, "steps": step_count}
# Add result to history
if len(result_text) > 3000:
result_text = result_text[:3000] + "\n...(truncated)"
messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result_text or "No result"})
messages = summarize_old_messages(messages)
elif message.content:
consecutive_text += 1
messages.append({"role": "assistant", "content": message.content})
if consecutive_text >= 3:
messages.append({"role": "user", "content": "You MUST call verify_compliance NOW with your best assessment."})
else:
messages.append({"role": "user", "content": "Please use one of the available tools."})
else:
continue
print(f"[END] task={task_name} score=0.01 steps={MAX_STEPS}", flush=True)
return {"reward": 0.01, "error": "max_steps", "steps": MAX_STEPS}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
BASELINE_SCENARIOS = {
"easy": ["easy_chatbot_transparency_001", "easy_recommendation_minimal_001"],
"medium": ["medium_hiring_bias_001", "medium_credit_scoring_001", "medium_medical_triage_001"],
"hard": ["hard_social_scoring_prohibited_001", "hard_deepfake_generation_001", "hard_multi_system_corporate_001"],
}
async def async_main() -> None:
parser = argparse.ArgumentParser(description="EU AI Act Compliance Auditor Inference")
parser.add_argument("--difficulty", default=None, choices=["easy", "medium", "hard"])
parser.add_argument("--episodes", type=int, default=1)
parser.add_argument("--model", default=None)
parser.add_argument("--space", default=None, help="HF Space URL")
args = parser.parse_args()
api_key = HF_TOKEN
if not api_key:
print("[DEBUG] No HF_TOKEN set. Using dummy key.", flush=True)
api_key = "dummy"
model = args.model or MODEL_NAME
llm_client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
# Determine base URL
if args.space:
base_url = args.space
else:
base_url = "http://localhost:7860"
from client import ComplianceAuditorHTTP
difficulties = [args.difficulty] if args.difficulty else ["easy", "medium", "hard"]
# Start local server if not using Space
server_proc = None
if not args.space:
import subprocess
server_proc = subprocess.Popen(
[sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", "7860"],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
time.sleep(4)
try:
# Discover tools
async with ComplianceAuditorHTTP(base_url=base_url) as discover_env:
await discover_env.reset(difficulty="easy")
tools_raw = await discover_env.list_tools()
tools = mcp_tools_to_openai(tools_raw)
print(f"[DEBUG] Mode: {'remote' if args.space else 'local'} | Model: {model}", flush=True)
print(f"[DEBUG] Tools: {[t['function']['name'] for t in tools]}", flush=True)
print(f"[DEBUG] Difficulties: {difficulties}", flush=True)
all_results = {}
for difficulty in difficulties:
scenario_ids = BASELINE_SCENARIOS.get(difficulty, [])
for sid in scenario_ids:
for run in range(args.episodes):
try:
async with ComplianceAuditorHTTP(base_url=base_url) as ep_env:
result = await run_episode(ep_env, llm_client, model, tools, difficulty, sid)
except Exception as e:
print(f"[START] task={sid} env=compliance_auditor_env model={model}", flush=True)
print(f"[END] task={sid} score=0.01 steps=0", flush=True)
result = {"reward": 0.01, "error": str(e)[:100], "steps": 0}
all_results[sid] = result
# Summary
print(f"\n{'='*60}", flush=True)
print(f"BASELINE RESULTS — {model}", flush=True)
for sid, r in all_results.items():
score = max(0.01, min(0.99, r.get("reward", 0)))
print(f" {sid}: {score:.4f} ({r.get('steps', 0)} steps)", flush=True)
if all_results:
avg = sum(max(0.01, min(0.99, r.get("reward", 0))) for r in all_results.values()) / len(all_results)
print(f" OVERALL: {avg:.4f}", flush=True)
finally:
if server_proc:
server_proc.terminate()
def main() -> None:
asyncio.run(async_main())
if __name__ == "__main__":
main()