Spaces:
Sleeping
Sleeping
| """ | |
| Baseline inference for EU AI Act Compliance Auditor. | |
| Uses OpenAI function calling through NVIDIA NIM to audit AI systems. | |
| Connects to the live HF Space via HTTP (no WebSocket timeout issues). | |
| Required env vars: | |
| API_BASE_URL LLM endpoint (default: https://integrate.api.nvidia.com/v1) | |
| MODEL_NAME Model identifier (default: google/gemma-4-31b-it) | |
| HF_TOKEN API key for the LLM | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Any, Dict, List, Optional | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31b-it") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| MAX_STEPS = 50 | |
| CONTEXT_CHAR_LIMIT = 100000 | |
| SYSTEM_PROMPT = """You are an expert EU AI Act compliance auditor. You must investigate AI systems and determine their compliance status. | |
| # MISSION | |
| Audit the AI system, classify its risk level, identify all compliance violations, recommend remediation, and submit your final determination. | |
| # TOOLS (call them in this order) | |
| ## Investigation (gather evidence) | |
| - get_system_overview: ALWAYS call this first — understand what you're auditing | |
| - classify_system: Classify risk level (prohibited/high_risk/limited_risk/minimal_risk) | |
| - check_documentation: Review Annex IV technical documentation | |
| - audit_training_data: Check for bias, data governance (Article 10) | |
| - verify_human_oversight: Verify Article 14 human-in-the-loop | |
| - check_transparency: Check Article 50 transparency obligations | |
| - assess_risk_management: Review risk management system (Article 9) | |
| - check_logging: Verify automatic logging (Article 12) | |
| ## Resolution (after investigation) | |
| - submit_finding: Report each violation found (call multiple times if needed) | |
| - recommend_fix: Propose remediation for each finding | |
| - verify_compliance: FINAL — submit your overall compliance determination | |
| # CRITICAL RULES | |
| - ALWAYS call get_system_overview FIRST | |
| - INVESTIGATE before CLASSIFYING — gather evidence before judging | |
| - For PROHIBITED systems: classify as prohibited, submit finding, recommend immediate shutdown | |
| - For HIGH-RISK: check ALL articles (documentation, data, oversight, transparency, risk, logging) | |
| - Call submit_finding for EACH violation separately | |
| - Call verify_compliance LAST with your final risk_classification | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Tool conversion for OpenAI function calling | |
| # --------------------------------------------------------------------------- | |
| def mcp_tools_to_openai(tools: List[Dict]) -> List[Dict]: | |
| """Convert MCP tool schemas to OpenAI function-calling format.""" | |
| openai_tools = [] | |
| for tool in tools: | |
| name = tool.get("name", "") | |
| description = tool.get("description", "") | |
| schema = tool.get("inputSchema", {}) | |
| properties = {} | |
| required = [] | |
| if schema and "properties" in schema: | |
| for pname, pschema in schema["properties"].items(): | |
| prop = {"type": pschema.get("type", "string")} | |
| if "description" in pschema: | |
| prop["description"] = pschema["description"] | |
| if "enum" in pschema: | |
| prop["enum"] = pschema["enum"] | |
| properties[pname] = prop | |
| required = schema.get("required", []) | |
| openai_tools.append({ | |
| "type": "function", | |
| "function": { | |
| "name": name, | |
| "description": description, | |
| "parameters": { | |
| "type": "object", | |
| "properties": properties, | |
| "required": required, | |
| }, | |
| }, | |
| }) | |
| return openai_tools | |
| # --------------------------------------------------------------------------- | |
| # Context management | |
| # --------------------------------------------------------------------------- | |
| def _summarize_tool_result(content: str, max_chars: int = 200) -> str: | |
| if not content or len(content) <= max_chars: | |
| return content or "(empty)" | |
| try: | |
| data = json.loads(content) | |
| if "error" in data: | |
| return f"error: {data['error'][:100]}" | |
| return json.dumps(data)[:max_chars] + "..." | |
| except (json.JSONDecodeError, TypeError): | |
| return content[:max_chars] + "..." | |
| def summarize_old_messages(messages: List[Dict]) -> List[Dict]: | |
| """Compress old tool calls to stay within context limits.""" | |
| total = sum(len(str(m.get("content", ""))) for m in messages) | |
| if total <= CONTEXT_CHAR_LIMIT: | |
| return messages | |
| system_msg = messages[0] | |
| user_msg = messages[1] | |
| keep_recent = 12 | |
| split_idx = max(2, len(messages) - keep_recent) | |
| old = messages[2:split_idx] | |
| recent = messages[split_idx:] | |
| lines = ["Previous audit steps:"] | |
| i = 0 | |
| while i < len(old): | |
| msg = old[i] | |
| if msg.get("role") == "assistant" and msg.get("tool_calls"): | |
| tc = msg["tool_calls"][0] | |
| name = tc["function"]["name"] | |
| args = tc["function"]["arguments"][:60] | |
| result = "(no response)" | |
| if i + 1 < len(old) and old[i + 1].get("role") == "tool": | |
| result = _summarize_tool_result(old[i + 1].get("content", "")) | |
| i += 1 | |
| lines.append(f"- {name}({args}) -> {result}") | |
| i += 1 | |
| return [system_msg, user_msg, {"role": "user", "content": "\n".join(lines)}] + recent | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| async def run_episode( | |
| env, | |
| llm_client: OpenAI, | |
| model: str, | |
| tools: List[Dict], | |
| difficulty: str = "medium", | |
| scenario_id: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """Run a single compliance audit episode using OpenAI function calling.""" | |
| reset_kwargs = {"difficulty": difficulty} | |
| if scenario_id: | |
| reset_kwargs["scenario_id"] = scenario_id | |
| reset_result = await env.reset(**reset_kwargs) | |
| task_name = scenario_id or f"{difficulty}_episode" | |
| print(f"[START] task={task_name} env=compliance_auditor_env model={model}", flush=True) | |
| alert_msg = reset_result.get("message", "Compliance audit assigned. Call get_system_overview to begin.") | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": alert_msg}, | |
| ] | |
| step_count = 0 | |
| done = False | |
| consecutive_text = 0 | |
| while not done and step_count < MAX_STEPS: | |
| step_count += 1 | |
| # LLM call with retry | |
| response = None | |
| for attempt in range(4): | |
| try: | |
| response = llm_client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| tools=tools, | |
| tool_choice="auto", | |
| temperature=0.1, | |
| max_tokens=500, | |
| ) | |
| break | |
| except Exception as e: | |
| if "429" in str(e) or "rate" in str(e).lower(): | |
| wait = 2 ** attempt + 1 | |
| time.sleep(wait) | |
| continue | |
| print(f"[DEBUG] LLM error: {str(e)[:100]}", flush=True) | |
| break | |
| if response is None: | |
| print(f"[END] task={task_name} score=0.01 steps={step_count}", flush=True) | |
| return {"reward": 0.01, "error": "LLM failed", "steps": step_count} | |
| message = response.choices[0].message | |
| # Handle function call | |
| if message.tool_calls: | |
| consecutive_text = 0 | |
| tc = message.tool_calls[0] | |
| tool_name = tc.function.name | |
| tool_call_id = tc.id | |
| try: | |
| tool_args = json.loads(tc.function.arguments) | |
| except (json.JSONDecodeError, TypeError): | |
| messages.append({"role": "assistant", "content": None, "tool_calls": [ | |
| {"id": tool_call_id, "type": "function", "function": {"name": tool_name, "arguments": tc.function.arguments}} | |
| ]}) | |
| messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": "Error: malformed JSON. Retry."}) | |
| continue | |
| # Add to history | |
| messages.append({"role": "assistant", "content": None, "tool_calls": [ | |
| {"id": tool_call_id, "type": "function", "function": {"name": tool_name, "arguments": tc.function.arguments}} | |
| ]}) | |
| # Execute tool via env | |
| try: | |
| result_text = await env.call_tool(tool_name, **tool_args) | |
| except Exception as e: | |
| result_text = json.dumps({"error": str(e)}) | |
| if not isinstance(result_text, str): | |
| result_text = json.dumps(result_text) if result_text else "" | |
| # Check done/reward | |
| reward = 0.0 | |
| if result_text: | |
| try: | |
| parsed = json.loads(result_text) | |
| if parsed.get("done"): | |
| done = True | |
| if "reward" in parsed: | |
| reward = float(parsed["reward"]) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| if hasattr(env, "_last_done") and env._last_done: | |
| done = True | |
| if hasattr(env, "_last_reward") and env._last_reward: | |
| reward = max(reward, env._last_reward) | |
| safe_reward = max(0.01, min(0.99, reward)) | |
| print(f"[STEP] step={step_count} action={tool_name} reward={safe_reward:.2f} done={'true' if done else 'false'} error=null", flush=True) | |
| if done: | |
| final_score = max(0.01, min(0.99, reward)) | |
| print(f"[END] task={task_name} score={final_score:.2f} steps={step_count}", flush=True) | |
| return {"reward": reward, "steps": step_count} | |
| # Add result to history | |
| if len(result_text) > 3000: | |
| result_text = result_text[:3000] + "\n...(truncated)" | |
| messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result_text or "No result"}) | |
| messages = summarize_old_messages(messages) | |
| elif message.content: | |
| consecutive_text += 1 | |
| messages.append({"role": "assistant", "content": message.content}) | |
| if consecutive_text >= 3: | |
| messages.append({"role": "user", "content": "You MUST call verify_compliance NOW with your best assessment."}) | |
| else: | |
| messages.append({"role": "user", "content": "Please use one of the available tools."}) | |
| else: | |
| continue | |
| print(f"[END] task={task_name} score=0.01 steps={MAX_STEPS}", flush=True) | |
| return {"reward": 0.01, "error": "max_steps", "steps": MAX_STEPS} | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| BASELINE_SCENARIOS = { | |
| "easy": ["easy_chatbot_transparency_001", "easy_recommendation_minimal_001"], | |
| "medium": ["medium_hiring_bias_001", "medium_credit_scoring_001", "medium_medical_triage_001"], | |
| "hard": ["hard_social_scoring_prohibited_001", "hard_deepfake_generation_001", "hard_multi_system_corporate_001"], | |
| } | |
| async def async_main() -> None: | |
| parser = argparse.ArgumentParser(description="EU AI Act Compliance Auditor Inference") | |
| parser.add_argument("--difficulty", default=None, choices=["easy", "medium", "hard"]) | |
| parser.add_argument("--episodes", type=int, default=1) | |
| parser.add_argument("--model", default=None) | |
| parser.add_argument("--space", default=None, help="HF Space URL") | |
| args = parser.parse_args() | |
| api_key = HF_TOKEN | |
| if not api_key: | |
| print("[DEBUG] No HF_TOKEN set. Using dummy key.", flush=True) | |
| api_key = "dummy" | |
| model = args.model or MODEL_NAME | |
| llm_client = OpenAI(base_url=API_BASE_URL, api_key=api_key) | |
| # Determine base URL | |
| if args.space: | |
| base_url = args.space | |
| else: | |
| base_url = "http://localhost:7860" | |
| from client import ComplianceAuditorHTTP | |
| difficulties = [args.difficulty] if args.difficulty else ["easy", "medium", "hard"] | |
| # Start local server if not using Space | |
| server_proc = None | |
| if not args.space: | |
| import subprocess | |
| server_proc = subprocess.Popen( | |
| [sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", "7860"], | |
| stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, | |
| ) | |
| time.sleep(4) | |
| try: | |
| # Discover tools | |
| async with ComplianceAuditorHTTP(base_url=base_url) as discover_env: | |
| await discover_env.reset(difficulty="easy") | |
| tools_raw = await discover_env.list_tools() | |
| tools = mcp_tools_to_openai(tools_raw) | |
| print(f"[DEBUG] Mode: {'remote' if args.space else 'local'} | Model: {model}", flush=True) | |
| print(f"[DEBUG] Tools: {[t['function']['name'] for t in tools]}", flush=True) | |
| print(f"[DEBUG] Difficulties: {difficulties}", flush=True) | |
| all_results = {} | |
| for difficulty in difficulties: | |
| scenario_ids = BASELINE_SCENARIOS.get(difficulty, []) | |
| for sid in scenario_ids: | |
| for run in range(args.episodes): | |
| try: | |
| async with ComplianceAuditorHTTP(base_url=base_url) as ep_env: | |
| result = await run_episode(ep_env, llm_client, model, tools, difficulty, sid) | |
| except Exception as e: | |
| print(f"[START] task={sid} env=compliance_auditor_env model={model}", flush=True) | |
| print(f"[END] task={sid} score=0.01 steps=0", flush=True) | |
| result = {"reward": 0.01, "error": str(e)[:100], "steps": 0} | |
| all_results[sid] = result | |
| # Summary | |
| print(f"\n{'='*60}", flush=True) | |
| print(f"BASELINE RESULTS — {model}", flush=True) | |
| for sid, r in all_results.items(): | |
| score = max(0.01, min(0.99, r.get("reward", 0))) | |
| print(f" {sid}: {score:.4f} ({r.get('steps', 0)} steps)", flush=True) | |
| if all_results: | |
| avg = sum(max(0.01, min(0.99, r.get("reward", 0))) for r in all_results.values()) / len(all_results) | |
| print(f" OVERALL: {avg:.4f}", flush=True) | |
| finally: | |
| if server_proc: | |
| server_proc.terminate() | |
| def main() -> None: | |
| asyncio.run(async_main()) | |
| if __name__ == "__main__": | |
| main() | |