Itachi-1824 commited on
Commit
9c88cc5
·
1 Parent(s): e834bf5

fix: typed Action model, OPENAI_API_KEY support, proper spec compliance

Browse files
Files changed (3) hide show
  1. inference.py +1 -1
  2. models.py +15 -3
  3. server/app.py +3 -3
inference.py CHANGED
@@ -28,7 +28,7 @@ from openai import OpenAI
28
 
29
  API_BASE_URL = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1")
30
  MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31b-it")
31
- HF_TOKEN = os.getenv("HF_TOKEN")
32
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
33
 
34
  MAX_STEPS = 50
 
28
 
29
  API_BASE_URL = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1")
30
  MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31b-it")
31
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
32
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
33
 
34
  MAX_STEPS = 50
models.py CHANGED
@@ -1,16 +1,28 @@
1
  """
2
  Data models for the EU AI Act Compliance Auditor Environment.
3
 
4
- MCP pattern: no custom Action class needed the framework provides CallToolAction.
5
- We define Observation and State for typed environment responses.
6
  """
7
 
8
  from typing import Any, Dict, List, Optional
9
 
10
  from pydantic import BaseModel, Field
11
 
 
 
 
 
 
12
 
13
- class ComplianceObservation(BaseModel):
 
 
 
 
 
 
 
 
14
  """Observation returned after each environment interaction."""
15
 
16
  done: bool = False
 
1
  """
2
  Data models for the EU AI Act Compliance Auditor Environment.
3
 
4
+ Typed Action, Observation, and State models for OpenEnv spec compliance.
 
5
  """
6
 
7
  from typing import Any, Dict, List, Optional
8
 
9
  from pydantic import BaseModel, Field
10
 
11
+ try:
12
+ from openenv.core.env_server.types import Action, Observation
13
+ except ImportError:
14
+ from pydantic import BaseModel as Action
15
+ from pydantic import BaseModel as Observation
16
 
17
+
18
+ class ComplianceAction(Action):
19
+ """Action for the Compliance Auditor — an MCP tool call."""
20
+
21
+ tool_name: str = Field(default="", description="Name of the audit tool to call")
22
+ arguments: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments as JSON")
23
+
24
+
25
+ class ComplianceObservation(Observation):
26
  """Observation returned after each environment interaction."""
27
 
28
  done: bool = False
server/app.py CHANGED
@@ -17,8 +17,8 @@ from typing import Any, Dict, Optional
17
  from fastapi import Body, HTTPException
18
  from pydantic import BaseModel
19
  from openenv.core.env_server import create_app
20
- from openenv.core.env_server.types import Action, Observation
21
 
 
22
  from server.environment import ComplianceAuditorEnvironment, QUERY_BUDGET
23
  from scenarios.registry import SCENARIO_LIST, DIFFICULTY_TIERS
24
 
@@ -26,8 +26,8 @@ from scenarios.registry import SCENARIO_LIST, DIFFICULTY_TIERS
26
 
27
  app = create_app(
28
  ComplianceAuditorEnvironment,
29
- Action,
30
- Observation,
31
  env_name="compliance_auditor_env",
32
  max_concurrent_envs=5,
33
  )
 
17
  from fastapi import Body, HTTPException
18
  from pydantic import BaseModel
19
  from openenv.core.env_server import create_app
 
20
 
21
+ from models import ComplianceAction, ComplianceObservation
22
  from server.environment import ComplianceAuditorEnvironment, QUERY_BUDGET
23
  from scenarios.registry import SCENARIO_LIST, DIFFICULTY_TIERS
24
 
 
26
 
27
  app = create_app(
28
  ComplianceAuditorEnvironment,
29
+ ComplianceAction,
30
+ ComplianceObservation,
31
  env_name="compliance_auditor_env",
32
  max_concurrent_envs=5,
33
  )