File size: 6,457 Bytes
b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e 9c88cc5 b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 9c88cc5 b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e 2dc4e1e b4d7ce3 2dc4e1e b4d7ce3 2dc4e1e b4d7ce3 2dc4e1e b4d7ce3 2dc4e1e b4d7ce3 2dc4e1e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 5d5e37e b2cf922 51a14c0 5d5e37e b2cf922 5d5e37e b2cf922 a0d8e70 a2bdc87 b2cf922 a2bdc87 b2cf922 5d5e37e b2cf922 5d5e37e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """FastAPI application for EU AI Act Compliance Auditor.
Architecture (modeled on Maverick98's winning pattern):
- create_app() for standard OpenEnv endpoints (/reset, /step, /state, /health, /ws)
- Custom HTTP session API (/api/reset, /api/call_tool, /api/close)
- Custom Gradio landing mounted at '/' replacing default inspector
"""
import inspect
import json
import uuid
import asyncio
import uvicorn
from typing import Any, Dict, Optional
from fastapi import Body, HTTPException
from pydantic import BaseModel
from openenv.core.env_server import create_app
from models import ComplianceAction, ComplianceObservation
from server.environment import ComplianceAuditorEnvironment, QUERY_BUDGET
from scenarios.registry import SCENARIO_LIST, DIFFICULTY_TIERS
# ββ Create base OpenEnv app βββββββββββββββββββββββββββββββββββββ
app = create_app(
ComplianceAuditorEnvironment,
ComplianceAction,
ComplianceObservation,
env_name="compliance_auditor_env",
max_concurrent_envs=5,
)
# ββ /tasks endpoint (hackathon validator) βββββββββββββββββββββββ
@app.get("/tasks")
def list_tasks():
return {"tasks": SCENARIO_LIST}
# ββ HTTP Session API ββββββββββββββββββββββββββββββββββββββββββββ
_sessions: Dict[str, ComplianceAuditorEnvironment] = {}
_session_lock = asyncio.Lock()
class GraderBody(BaseModel):
task_id: str = "easy"
episode_id: Optional[str] = None
seed: Optional[int] = None
classification: str = ""
findings: list = []
remediation: list = []
tool_sequence: list = []
steps_taken: int = 10
@app.post("/grader")
async def grader_endpoint(body: GraderBody):
"""Grade a completed episode. Returns score in [0.001, 0.999]."""
from server.engine import compute_reward
from scenarios.registry import get_scenario
# Resolve scenario β supports fixed IDs, difficulty tiers, and procedural
try:
sc = get_scenario(body.task_id, body.seed)
except ValueError:
# Fallback: treat task_id as difficulty tier, pick first scenario
tier_map = {"easy": "easy_chatbot_transparency_001", "medium": "medium_hiring_bias_001", "hard": "hard_social_scoring_prohibited_001"}
sc = get_scenario(tier_map.get(body.task_id, "easy_chatbot_transparency_001"), body.seed)
breakdown = compute_reward(
scenario=sc,
classification_submitted=body.classification,
findings_submitted=body.findings,
remediation_submitted=body.remediation,
tool_sequence=body.tool_sequence,
steps_taken=body.steps_taken,
)
return {"score": breakdown.total(), "breakdown": breakdown.to_dict()}
class ResetBody(BaseModel):
difficulty: str = "medium"
scenario_id: Optional[str] = None
seed: Optional[int] = None
class CallToolBody(BaseModel):
session_id: str
tool_name: str
arguments: Dict[str, Any] = {}
class CloseBody(BaseModel):
session_id: str
@app.post("/api/reset")
async def api_reset(body: ResetBody = Body(default_factory=ResetBody)):
"""Create session, reset env, return session_id + tools + observation."""
env = ComplianceAuditorEnvironment()
obs = env.reset(
seed=body.seed,
difficulty=body.difficulty,
scenario_id=body.scenario_id,
)
session_id = str(uuid.uuid4())
async with _session_lock:
_sessions[session_id] = env
# Build tool schemas from _tool_fns
tools = []
for name, fn in env._tool_fns.items():
sig = inspect.signature(fn)
props = {}
required = []
for pname, param in sig.parameters.items():
ptype = "string"
if param.annotation == int:
ptype = "integer"
props[pname] = {"type": ptype}
if param.default is inspect.Parameter.empty:
required.append(pname)
tools.append({
"name": name,
"description": (fn.__doc__ or "").strip().split("\n")[0],
"inputSchema": {"type": "object", "properties": props, "required": required},
})
return {
"session_id": session_id,
"observation": obs.metadata if hasattr(obs, "metadata") else {},
"done": obs.done,
"reward": obs.reward,
"tools": tools,
}
@app.post("/api/call_tool")
async def api_call_tool(body: CallToolBody):
"""Call a tool on an existing session."""
async with _session_lock:
env = _sessions.get(body.session_id)
if env is None:
raise HTTPException(404, f"Session not found: {body.session_id}")
fn = env._tool_fns.get(body.tool_name)
if fn is None:
raise HTTPException(400, f"Tool not found: {body.tool_name}. Available: {list(env._tool_fns.keys())}")
try:
result = fn(**body.arguments)
except Exception as e:
return {"result": json.dumps({"error": str(e)}), "done": env._done, "reward": env._reward}
return {"result": result, "done": env._done, "reward": env._reward}
@app.post("/api/close")
async def api_close(body: CloseBody):
"""Close and clean up a session."""
async with _session_lock:
env = _sessions.pop(body.session_id, None)
if env:
env.close()
return {"closed": True, "session_id": body.session_id}
# ββ Mount Gradio landing at '/' βββββββββββββββββββββββββββββββββ
try:
import gradio as gr
from server.gradio_landing import create_landing_app
_landing = create_landing_app()
# Mount at / β exactly like Maverick98's working pattern
app = gr.mount_gradio_app(app, _landing, path="/")
print(f"[gradio_landing] mounted at / β gradio {gr.__version__}", flush=True)
except Exception as e:
import sys
import traceback
print(f"[gradio_landing] MOUNT FAILED: {e}", file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
# ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββ
def main(host: str = "0.0.0.0", port: int = 7860):
uvicorn.run(app, host=host, port=port, ws_ping_interval=None, ws_ping_timeout=None)
if __name__ == "__main__":
main()
|