Andrew Lara commited on
Commit
fb0cc18
·
1 Parent(s): 98c3ce1

Tighten tool routing and executor safety

Browse files
README.md CHANGED
@@ -108,6 +108,10 @@ Execute one tool call. Pass `session_id` (from `/reset`) as a query param to sup
108
  { "tool_id": "commit", "answer": "1889" }
109
  ```
110
 
 
 
 
 
111
  ### `GET /health`
112
 
113
  Returns `{"status": "ok"}`.
@@ -123,6 +127,7 @@ claude_toolOrchestrator/
123
  ├── openenv.yaml # OpenEnv deployment spec
124
  ├── requirements.txt # Python dependencies
125
  ├── .env.example # Key template (copy → .env, never commit .env)
 
126
 
127
  ├── env/ # ── Core RL environment ──────────────────────────
128
  │ ├── environment.py # CostAwareToolEnvironment: reset() + step()
@@ -141,7 +146,7 @@ claude_toolOrchestrator/
141
  │ ├── ceramic_search.py # Web search (Ceramic AI API)
142
  │ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
143
  │ ├── calculator.py # Safe AST-based math evaluator (no exec)
144
- │ ├── code_executor.py # Sandboxed Python exec (blocks os/sys/subprocess)
145
  │ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
146
  │ └── commit.py # Answer pass-through; grading runs in environment
147
 
@@ -164,6 +169,15 @@ claude_toolOrchestrator/
164
 
165
  If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
166
 
 
 
 
 
 
 
 
 
 
167
  ---
168
 
169
  ## Running baselines
 
108
  { "tool_id": "commit", "answer": "1889" }
109
  ```
110
 
111
+ ### `GET /tools`
112
+
113
+ Returns the canonical tool manifest with each tool's label, purpose, input field, cost, and safety notes.
114
+
115
  ### `GET /health`
116
 
117
  Returns `{"status": "ok"}`.
 
127
  ├── openenv.yaml # OpenEnv deployment spec
128
  ├── requirements.txt # Python dependencies
129
  ├── .env.example # Key template (copy → .env, never commit .env)
130
+ ├── tools/runtime.py # Tool catalog, validation, and explicit dispatch
131
 
132
  ├── env/ # ── Core RL environment ──────────────────────────
133
  │ ├── environment.py # CostAwareToolEnvironment: reset() + step()
 
146
  │ ├── ceramic_search.py # Web search (Ceramic AI API)
147
  │ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
148
  │ ├── calculator.py # Safe AST-based math evaluator (no exec)
149
+ │ ├── code_executor.py # Sandboxed Python exec (blocked imports, dunder attrs)
150
  │ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
151
  │ └── commit.py # Answer pass-through; grading runs in environment
152
 
 
169
 
170
  If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
171
 
172
+ The Python executor is intentionally narrow:
173
+
174
+ - import statements are rejected
175
+ - obvious sandbox-escape names such as `open`, `eval`, `globals`, and `__import__` are blocked
176
+ - dunder attribute access such as `.__class__` and `.__subclasses__()` is blocked
177
+ - only a curated builtin/module surface is exposed
178
+
179
+ That keeps the tool usable for intended coding tasks without turning it into a hidden general-purpose shell.
180
+
181
  ---
182
 
183
  ## Running baselines
RESEARCH.md CHANGED
@@ -147,7 +147,7 @@ action.query = "William Shakespeare"
147
 
148
  ### `calculator` — Cost: **0.1**
149
 
150
- **What it does:** Evaluates a math expression safely using Python's `ast` (Abstract Syntax Tree) module. The expression is parsed into a tree structure, and only pre-approved operations are allowed (addition, subtraction, multiplication, division, power, modulo, and common math functions like `sqrt`, `log`, `sin`, `cos`).
151
 
152
  **Why not just use `eval()`?** Because `eval("__import__('os').system('rm -rf /')")` would delete your hard drive. The AST approach means the code is never executed — it's parsed into a data structure and we only compute what we explicitly allow.
153
 
@@ -168,9 +168,9 @@ action.expression = "import os" # BLOCKED — not a valid math expression
168
 
169
  ### `code_executor` — Cost: **0.3**
170
 
171
- **What it does:** Runs arbitrary Python code in a sandboxed `exec()` environment. Captures whatever is printed to stdout and returns it as the result.
172
 
173
- **Security model:** Blocks imports of dangerous modules (`os`, `sys`, `subprocess`, `socket`, `shutil`, `pathlib`, `importlib`, `ctypes`, `multiprocessing`, `threading`, and more). Uses a custom `__import__` wrapper that raises an `ImportError` before the module loads.
174
 
175
  **Best for:** HumanEval coding tasks where the agent needs to actually run code to verify correctness.
176
 
@@ -197,6 +197,8 @@ print(fibonacci(10))
197
 
198
  **Graceful fallback:** If `TOGETHER_API_KEY` is not set, returns a clear error message instead of crashing. The agent learns to avoid this tool when it's unavailable.
199
 
 
 
200
  ---
201
 
202
  ### `commit` — Cost: **0.0**
@@ -503,9 +505,9 @@ A well-trained agent should exhibit these behaviors:
503
  CostAwareToolEnv/
504
 
505
  ├── app.py
506
- │ The FastAPI web server. Handles /reset, /step, /health, /web.
507
  │ Multi-session: each /reset returns a session_id used in /step.
508
- Loads dataset + tools once at startup for efficiency.
509
 
510
  ├── openenv.yaml
511
  │ Deployment spec for the OpenEnv competition framework.
@@ -552,11 +554,12 @@ CostAwareToolEnv/
552
  │ Returns flat List[Dict] with 'domain' key on each item.
553
 
554
  ├── tools/
555
- │ ├── __init__.py build_tool_registry() returns {tool_id: callable}
 
556
  │ ├── ceramic_search.py make_search_tool() factory wrapping CeramicClient
557
  │ ├── wiki_lookup.py Wikipedia REST API, first paragraph
558
- │ ├── calculator.py Safe AST-based math eval
559
- │ ├── code_executor.py Sandboxed exec with blocked dangerous imports
560
  │ ├── llm_reason.py Together AI API, graceful fallback
561
  │ └── commit.py Pass-through; grading is in environment.py
562
 
 
147
 
148
  ### `calculator` — Cost: **0.1**
149
 
150
+ **What it does:** Evaluates a math expression safely using Python's `ast` (Abstract Syntax Tree) module. The expression is parsed into a tree structure, and only pre-approved operations are allowed (addition, subtraction, multiplication, division, power, modulo, comparisons, and common math functions like `sqrt`, `log`, `sin`, `cos`).
151
 
152
  **Why not just use `eval()`?** Because `eval("__import__('os').system('rm -rf /')")` would delete your hard drive. The AST approach means the code is never executed — it's parsed into a data structure and we only compute what we explicitly allow.
153
 
 
168
 
169
  ### `code_executor` — Cost: **0.3**
170
 
171
+ **What it does:** Runs Python code in a sandboxed `exec()` environment for intended coding tasks. Captures whatever is printed to stdout and returns it as the result.
172
 
173
+ **Security model:** Blocks import statements, dangerous builtin names such as `open`, `eval`, `exec`, `globals`, and obvious object-graph escape paths such as dunder attribute traversal. Only a curated builtin/module surface is exposed.
174
 
175
  **Best for:** HumanEval coding tasks where the agent needs to actually run code to verify correctness.
176
 
 
197
 
198
  **Graceful fallback:** If `TOGETHER_API_KEY` is not set, returns a clear error message instead of crashing. The agent learns to avoid this tool when it's unavailable.
199
 
200
+ **Tool routing note:** The environment exposes the canonical tool manifest at `GET /tools`, and tool dispatch normalizes missing-tool and tool-crash cases into explicit `ToolResult` errors. That keeps the OpenEnv-style contract stable even when a backing service is missing.
201
+
202
  ---
203
 
204
  ### `commit` — Cost: **0.0**
 
505
  CostAwareToolEnv/
506
 
507
  ├── app.py
508
+ │ The FastAPI web server. Handles /reset, /step, /health, /tools, /web.
509
  │ Multi-session: each /reset returns a session_id used in /step.
510
+ Lazily loads the dataset and exposes the canonical tool manifest.
511
 
512
  ├── openenv.yaml
513
  │ Deployment spec for the OpenEnv competition framework.
 
554
  │ Returns flat List[Dict] with 'domain' key on each item.
555
 
556
  ├── tools/
557
+ │ ├── runtime.py Tool catalog, validation, and explicit dispatch
558
+ │ ├── __init__.py build_tool_registry() + tool manifest helpers
559
  │ ├── ceramic_search.py make_search_tool() factory wrapping CeramicClient
560
  │ ├── wiki_lookup.py Wikipedia REST API, first paragraph
561
+ │ ├── calculator.py Safe AST-based math eval with comparisons
562
+ │ ├── code_executor.py Sandboxed exec with blocked imports and dunder escapes
563
  │ ├── llm_reason.py Together AI API, graceful fallback
564
  │ └── commit.py Pass-through; grading is in environment.py
565
 
app.py CHANGED
@@ -1,18 +1,20 @@
1
  """FastAPI server for CostAwareToolEnv.
2
 
3
  Exposes the OpenEnv standard endpoints:
4
- POST /reset OrchestratorObservation + OrchestratorState
5
- POST /step OrchestratorObservation + reward + done + state
6
- GET /health {"status": "ok"}
7
- GET /web → simple demo UI
8
- GET /docs → OpenAPI (automatic)
 
9
  """
10
  from __future__ import annotations
11
 
 
12
  import os
13
  import uuid
14
  from contextlib import asynccontextmanager
15
- from typing import Any, Dict, Optional
16
 
17
  from fastapi import FastAPI, HTTPException
18
  from fastapi.responses import HTMLResponse
@@ -22,13 +24,9 @@ from data.loader import load_all
22
  from env.config import EnvConfig
23
  from env.environment import CostAwareToolEnvironment
24
  from env.models import OrchestratorAction
25
- from tools import build_tool_registry
26
 
27
 
28
- # ---------------------------------------------------------------------------
29
- # Request / response wrappers
30
- # ---------------------------------------------------------------------------
31
-
32
  class ResetRequest(BaseModel):
33
  seed: Optional[int] = None
34
  config_overrides: Optional[Dict[str, Any]] = None
@@ -36,27 +34,140 @@ class ResetRequest(BaseModel):
36
 
37
  class StepRequest(BaseModel):
38
  tool_id: str
39
- query: Optional[str] = None
40
- expression: Optional[str] = None
41
  code_snippet: Optional[str] = None
42
- answer: Optional[str] = None
43
- metadata: Optional[Dict[str, Any]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
45
 
46
- # ---------------------------------------------------------------------------
47
- # App factory
48
- # ---------------------------------------------------------------------------
49
 
50
- def create_app() -> FastAPI:
51
- config = EnvConfig()
52
- tools = build_tool_registry(config)
53
- dataset = load_all(split=config.data_split, max_per_domain=200)
 
 
 
54
 
55
- # Multi-session state: session_id → CostAwareToolEnvironment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  sessions: Dict[str, CostAwareToolEnvironment] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Default shared environment for single-session usage (no session_id)
59
- default_env = CostAwareToolEnvironment(config=config, tools=tools, dataset=dataset)
 
 
 
60
 
61
  @asynccontextmanager
62
  async def lifespan(app: FastAPI):
@@ -74,28 +185,38 @@ def create_app() -> FastAPI:
74
  def health():
75
  return {"status": "ok"}
76
 
 
 
 
 
77
  @app.post("/reset")
78
  def reset(req: ResetRequest):
79
- cfg = EnvConfig()
80
- if req.config_overrides:
81
- for k, v in req.config_overrides.items():
82
- if hasattr(cfg, k):
83
- setattr(cfg, k, v)
84
- env = CostAwareToolEnvironment(config=cfg, tools=tools, dataset=dataset)
85
  obs, state = env.reset(seed=req.seed)
86
 
87
  session_id = str(uuid.uuid4())
88
  sessions[session_id] = env
89
 
90
  return {
91
- "session_id": session_id,
92
  "observation": obs.model_dump(),
93
- "state": state.model_dump(),
94
  }
95
 
96
  @app.post("/step")
97
  def step(req: StepRequest, session_id: Optional[str] = None):
98
- env = sessions.get(session_id or "", default_env)
 
 
 
 
 
 
99
  action = OrchestratorAction(
100
  tool_id=req.tool_id,
101
  query=req.query or "",
@@ -111,98 +232,21 @@ def create_app() -> FastAPI:
111
  except ValueError as exc:
112
  raise HTTPException(status_code=422, detail=str(exc))
113
 
114
- # Clean up finished sessions
115
  if done and session_id and session_id in sessions:
116
  del sessions[session_id]
117
 
118
  return {
119
  "observation": obs.model_dump(),
120
- "reward": reward,
121
- "done": done,
122
- "state": state.model_dump(),
123
  }
124
 
125
  @app.get("/web", response_class=HTMLResponse)
126
  def web_ui():
127
- return _DEMO_HTML
128
 
129
  return app
130
 
131
 
132
  app = create_app()
133
-
134
-
135
- # ---------------------------------------------------------------------------
136
- # Demo UI
137
- # ---------------------------------------------------------------------------
138
-
139
- _DEMO_HTML = """<!DOCTYPE html>
140
- <html lang="en">
141
- <head>
142
- <meta charset="UTF-8">
143
- <title>CostAwareToolEnv</title>
144
- <style>
145
- body { font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }
146
- h1 { color: #333; }
147
- pre { background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }
148
- button { padding: 8px 16px; margin: 4px; cursor: pointer; }
149
- input, select, textarea { width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }
150
- label { font-weight: bold; }
151
- .tool-btn { background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }
152
- .tool-btn:hover { background: #cfe1ff; }
153
- #log { max-height: 480px; overflow-y: auto; }
154
- </style>
155
- </head>
156
- <body>
157
- <h1>CostAwareToolEnv</h1>
158
- <p>Multi-tool cost-aware RL environment — AgentX / OpenEnv</p>
159
-
160
- <button onclick="doReset()">Reset Episode</button>
161
- <hr>
162
- <label>Tool:</label>
163
- <select id="tool">
164
- <option value="ceramic_search">ceramic_search (cost 1.0) — Web retrieval</option>
165
- <option value="wiki_lookup">wiki_lookup (cost 0.5) — Wikipedia</option>
166
- <option value="calculator">calculator (cost 0.1) — Arithmetic / math</option>
167
- <option value="code_executor">code_executor (cost 0.3) — Python execution</option>
168
- <option value="llm_reason">llm_reason (cost 2.0) — LLM chain-of-thought</option>
169
- <option value="commit">commit (cost 0.0) — Submit answer</option>
170
- </select>
171
- <label>Query / Expression / Code / Answer:</label>
172
- <textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
173
- <button class="tool-btn" onclick="doStep()">Step</button>
174
- <hr>
175
- <pre id="log">Click "Reset Episode" to start.</pre>
176
-
177
- <script>
178
- const log = document.getElementById('log');
179
- let sessionId = null;
180
-
181
- function append(text) { log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }
182
-
183
- async function doReset() {
184
- log.textContent = '';
185
- const res = await fetch('/reset', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({seed: 42}) });
186
- const data = await res.json();
187
- sessionId = data.session_id || null;
188
- append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
189
- }
190
-
191
- async function doStep() {
192
- const tool_id = document.getElementById('tool').value;
193
- const input = document.getElementById('query').value;
194
- const body = { tool_id };
195
- if (tool_id === 'commit') body.answer = input;
196
- else if (tool_id === 'calculator') body.expression = input;
197
- else if (tool_id === 'code_executor') body.code_snippet = input;
198
- else body.query = input;
199
-
200
- const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
201
- const res = await fetch(url, { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify(body) });
202
- const data = await res.json();
203
- append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
204
- }
205
- </script>
206
- </body>
207
- </html>
208
- """
 
1
  """FastAPI server for CostAwareToolEnv.
2
 
3
  Exposes the OpenEnv standard endpoints:
4
+ POST /reset -> OrchestratorObservation + OrchestratorState
5
+ POST /step -> OrchestratorObservation + reward + done + state
6
+ GET /health -> {"status": "ok"}
7
+ GET /tools -> canonical tool manifest
8
+ GET /web -> simple demo UI
9
+ GET /docs -> OpenAPI (automatic)
10
  """
11
  from __future__ import annotations
12
 
13
+ import copy
14
  import os
15
  import uuid
16
  from contextlib import asynccontextmanager
17
+ from typing import Any, Callable, Dict, List, Optional
18
 
19
  from fastapi import FastAPI, HTTPException
20
  from fastapi.responses import HTMLResponse
 
24
  from env.config import EnvConfig
25
  from env.environment import CostAwareToolEnvironment
26
  from env.models import OrchestratorAction
27
+ from tools import build_tool_catalog, build_tool_registry, catalog_as_dicts, validate_tool_costs
28
 
29
 
 
 
 
 
30
  class ResetRequest(BaseModel):
31
  seed: Optional[int] = None
32
  config_overrides: Optional[Dict[str, Any]] = None
 
34
 
35
  class StepRequest(BaseModel):
36
  tool_id: str
37
+ query: Optional[str] = None
38
+ expression: Optional[str] = None
39
  code_snippet: Optional[str] = None
40
+ answer: Optional[str] = None
41
+ metadata: Optional[Dict[str, Any]] = None
42
+
43
+
44
+ def _merge_config(base: EnvConfig, overrides: Optional[Dict[str, Any]]) -> EnvConfig:
45
+ cfg = copy.deepcopy(base)
46
+ if not overrides:
47
+ return cfg
48
+
49
+ for key, value in overrides.items():
50
+ if not hasattr(cfg, key):
51
+ raise ValueError(f"Unknown config override: {key}")
52
+
53
+ current = getattr(cfg, key)
54
+ if isinstance(current, dict) and isinstance(value, dict):
55
+ merged = copy.deepcopy(current)
56
+ merged.update(value)
57
+ setattr(cfg, key, merged)
58
+ else:
59
+ setattr(cfg, key, value)
60
+ return cfg
61
+
62
+
63
+ def _build_demo_html(tool_catalog: List[Any]) -> str:
64
+ tool_options = "\n".join(
65
+ f' <option value="{spec.tool_id}">{spec.label} (cost {spec.cost}) — {spec.purpose}</option>'
66
+ for spec in tool_catalog
67
+ )
68
+
69
+ return f"""<!DOCTYPE html>
70
+ <html lang="en">
71
+ <head>
72
+ <meta charset="UTF-8">
73
+ <title>CostAwareToolEnv</title>
74
+ <style>
75
+ body {{ font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }}
76
+ h1 {{ color: #333; }}
77
+ pre {{ background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }}
78
+ button {{ padding: 8px 16px; margin: 4px; cursor: pointer; }}
79
+ input, select, textarea {{ width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }}
80
+ label {{ font-weight: bold; }}
81
+ .tool-btn {{ background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }}
82
+ .tool-btn:hover {{ background: #cfe1ff; }}
83
+ #log {{ max-height: 480px; overflow-y: auto; }}
84
+ </style>
85
+ </head>
86
+ <body>
87
+ <h1>CostAwareToolEnv</h1>
88
+ <p>Multi-tool cost-aware RL environment with explicit tool routing and sandboxed execution.</p>
89
+
90
+ <button onclick="doReset()">Reset Episode</button>
91
+ <hr>
92
+ <label>Tool:</label>
93
+ <select id="tool">
94
+ {tool_options}
95
+ </select>
96
+ <label>Query / Expression / Code / Answer:</label>
97
+ <textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
98
+ <button class="tool-btn" onclick="doStep()">Step</button>
99
+ <hr>
100
+ <pre id="log">Click "Reset Episode" to start.</pre>
101
 
102
+ <script>
103
+ const log = document.getElementById('log');
104
+ let sessionId = null;
105
 
106
+ function append(text) {{ log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }}
 
 
107
 
108
+ async function doReset() {{
109
+ log.textContent = '';
110
+ const res = await fetch('/reset', {{ method: 'POST', headers: {{'Content-Type':'application/json'}}, body: JSON.stringify({{seed: 42}}) }});
111
+ const data = await res.json();
112
+ sessionId = data.session_id || null;
113
+ append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
114
+ }}
115
 
116
+ async function doStep() {{
117
+ const tool_id = document.getElementById('tool').value;
118
+ const input = document.getElementById('query').value;
119
+ const body = {{ tool_id }};
120
+ if (tool_id === 'commit') body.answer = input;
121
+ else if (tool_id === 'calculator') body.expression = input;
122
+ else if (tool_id === 'code_executor') body.code_snippet = input;
123
+ else body.query = input;
124
+
125
+ const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
126
+ const res = await fetch(url, {{ method: 'POST', headers: {{'Content-Type':'application/json'}}, body: JSON.stringify(body) }});
127
+ const data = await res.json();
128
+ append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
129
+ }}
130
+ </script>
131
+ </body>
132
+ </html>
133
+ """
134
+
135
+
136
+ def create_app(
137
+ config: Optional[EnvConfig] = None,
138
+ tools: Optional[Dict[str, Any]] = None,
139
+ dataset: Optional[List[Dict[str, Any]]] = None,
140
+ load_dataset_fn: Callable[..., List[Dict[str, Any]]] = load_all,
141
+ build_registry_fn: Callable[[EnvConfig | None], Dict[str, Any]] = build_tool_registry,
142
+ ) -> FastAPI:
143
+ base_config = config or EnvConfig()
144
+ validate_tool_costs(base_config)
145
+
146
+ dataset_cache = dataset
147
+ default_env: Optional[CostAwareToolEnvironment] = None
148
  sessions: Dict[str, CostAwareToolEnvironment] = {}
149
+ tool_catalog = build_tool_catalog(base_config)
150
+ demo_html = _build_demo_html(tool_catalog)
151
+
152
+ def get_dataset() -> List[Dict[str, Any]]:
153
+ nonlocal dataset_cache
154
+ if dataset_cache is None:
155
+ dataset_cache = load_dataset_fn(split=base_config.data_split, max_per_domain=200)
156
+ return dataset_cache
157
+
158
+ def make_env(effective_config: EnvConfig) -> CostAwareToolEnvironment:
159
+ registry = tools if tools is not None else build_registry_fn(effective_config)
160
+ return CostAwareToolEnvironment(
161
+ config=effective_config,
162
+ tools=registry,
163
+ dataset=get_dataset(),
164
+ )
165
 
166
+ def get_default_env() -> CostAwareToolEnvironment:
167
+ nonlocal default_env
168
+ if default_env is None:
169
+ default_env = make_env(base_config)
170
+ return default_env
171
 
172
  @asynccontextmanager
173
  async def lifespan(app: FastAPI):
 
185
  def health():
186
  return {"status": "ok"}
187
 
188
+ @app.get("/tools")
189
+ def tools_manifest():
190
+ return catalog_as_dicts(base_config)
191
+
192
  @app.post("/reset")
193
  def reset(req: ResetRequest):
194
+ try:
195
+ cfg = _merge_config(base_config, req.config_overrides)
196
+ except ValueError as exc:
197
+ raise HTTPException(status_code=422, detail=str(exc))
198
+
199
+ env = make_env(cfg)
200
  obs, state = env.reset(seed=req.seed)
201
 
202
  session_id = str(uuid.uuid4())
203
  sessions[session_id] = env
204
 
205
  return {
206
+ "session_id": session_id,
207
  "observation": obs.model_dump(),
208
+ "state": state.model_dump(),
209
  }
210
 
211
  @app.post("/step")
212
  def step(req: StepRequest, session_id: Optional[str] = None):
213
+ if session_id is None:
214
+ env = get_default_env()
215
+ else:
216
+ env = sessions.get(session_id)
217
+ if env is None:
218
+ raise HTTPException(status_code=404, detail="Unknown session_id")
219
+
220
  action = OrchestratorAction(
221
  tool_id=req.tool_id,
222
  query=req.query or "",
 
232
  except ValueError as exc:
233
  raise HTTPException(status_code=422, detail=str(exc))
234
 
 
235
  if done and session_id and session_id in sessions:
236
  del sessions[session_id]
237
 
238
  return {
239
  "observation": obs.model_dump(),
240
+ "reward": reward,
241
+ "done": done,
242
+ "state": state.model_dump(),
243
  }
244
 
245
  @app.get("/web", response_class=HTMLResponse)
246
  def web_ui():
247
+ return demo_html
248
 
249
  return app
250
 
251
 
252
  app = create_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
client.py CHANGED
@@ -1,123 +1,4 @@
1
- """Compatibility shim real code lives in ceramic/client.py.
2
-
3
- Mirrors the SearchEconomicsEnv CeramicClient interface so the two
4
- environments share the same retrieval backend.
5
-
6
- Priority for the API key:
7
- 1. CERAMIC_API_KEY env var
8
- 2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility)
9
- 3. Falls back to FallbackCeramicClient (offline, deterministic)
10
- """
11
  from __future__ import annotations
12
 
13
- import hashlib
14
- import os
15
- import time
16
- from dataclasses import dataclass, field
17
- from typing import List, Optional
18
-
19
- import httpx
20
-
21
-
22
- # ---------------------------------------------------------------------------
23
- # Result model
24
- # ---------------------------------------------------------------------------
25
-
26
- @dataclass
27
- class SearchResult:
28
- title: str
29
- url: str
30
- description: str
31
- score: float = 0.0
32
-
33
-
34
- # ---------------------------------------------------------------------------
35
- # Live client
36
- # ---------------------------------------------------------------------------
37
-
38
- class CeramicClient:
39
- """Thin wrapper around the Ceramic search API."""
40
-
41
- BASE_URL = "https://api.ceramic.ai/v1"
42
-
43
- def __init__(self, api_key: str):
44
- self._key = api_key
45
- self._client = httpx.Client(
46
- headers={"Authorization": f"Bearer {api_key}"},
47
- timeout=10.0,
48
- )
49
-
50
- def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
51
- resp = self._client.post(
52
- f"{self.BASE_URL}/search",
53
- json={"query": query, "top_k": top_k},
54
- )
55
- resp.raise_for_status()
56
- data = resp.json()
57
- results = []
58
- for item in data.get("results", []):
59
- results.append(SearchResult(
60
- title=item.get("title", ""),
61
- url=item.get("url", ""),
62
- description=item.get("description", ""),
63
- score=float(item.get("score", 0.0)),
64
- ))
65
- return results
66
-
67
- def close(self):
68
- self._client.close()
69
-
70
- def __enter__(self):
71
- return self
72
-
73
- def __exit__(self, *args):
74
- self.close()
75
-
76
-
77
- # ---------------------------------------------------------------------------
78
- # Offline fallback
79
- # ---------------------------------------------------------------------------
80
-
81
- class FallbackCeramicClient:
82
- """Deterministic offline client — used when no API key is set."""
83
-
84
- def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
85
- # Stable hash → reproducible fake results per query
86
- h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
87
- results = []
88
- for i in range(min(top_k, 3)):
89
- seed = (h + i) % 10_000
90
- results.append(SearchResult(
91
- title=f"Result {seed}: {query[:40]}",
92
- url=f"https://fallback.example.com/doc/{seed}",
93
- description=f"Offline fallback result #{i+1} for query: {query}",
94
- score=round(0.9 - i * 0.15, 3),
95
- ))
96
- return results
97
-
98
- def close(self):
99
- pass
100
-
101
- def __enter__(self):
102
- return self
103
-
104
- def __exit__(self, *args):
105
- pass
106
-
107
-
108
- # ---------------------------------------------------------------------------
109
- # Factory
110
- # ---------------------------------------------------------------------------
111
-
112
- _DEFAULT_KEY = "cer_sk_live_543fe74e79df_eyJvcmdfaWQiOiJvcmdfMDFLTlpINkU5RVNDTUowUUoyREpINFZWWEYiLCJrZXlfaWQiOiI1NDNmZTc0ZTc5ZGYifQ.k8I4Aljsk29y4Uki37Wxfd7QZHs40XSJVNBNnfksCtM"
113
-
114
-
115
- def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
116
- key = (
117
- os.environ.get("CERAMIC_API_KEY")
118
- or os.environ.get("SEE_CERAMIC_API_KEY")
119
- or _DEFAULT_KEY
120
- )
121
- if key:
122
- return CeramicClient(api_key=key)
123
- return FallbackCeramicClient()
 
1
+ """Compatibility shim for the legacy top-level Ceramic client import path."""
 
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
 
4
+ from ceramic.client import * # noqa: F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
env/environment.py CHANGED
@@ -24,6 +24,7 @@ from .models import (
24
  TOOL_IDS,
25
  )
26
  from .reward import commit_reward, step_reward
 
27
 
28
 
29
  class CostAwareToolEnvironment:
@@ -95,6 +96,8 @@ class CostAwareToolEnvironment:
95
  ) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
96
  if self._episode_done:
97
  raise RuntimeError("Episode is done. Call reset() first.")
 
 
98
 
99
  tool_id = action.tool_id
100
  if tool_id not in TOOL_IDS:
@@ -160,30 +163,17 @@ class CostAwareToolEnvironment:
160
  if budget_after > state.total_budget:
161
  r = config.incorrect_reward
162
  self._episode_done = True
163
- obs = self._make_obs(reward=r, question_done=True, done=True)
 
 
 
 
164
  return obs, r, True, state
165
 
166
  t0 = time.perf_counter()
167
- tool_fn = self.tools.get(tool_id)
168
- if tool_fn is None:
169
- tool_result = ToolResult(
170
- tool_id=tool_id, cost=cost,
171
- output="[Tool not available in this environment]",
172
- latency_s=0.0,
173
- error="not_available",
174
- )
175
- else:
176
- try:
177
- tool_result = tool_fn(action)
178
- tool_result.cost = cost
179
- except Exception as exc:
180
- tool_result = ToolResult(
181
- tool_id=tool_id, cost=cost,
182
- output=f"[Error: {exc}]",
183
- latency_s=time.perf_counter() - t0,
184
- error=str(exc),
185
- )
186
  tool_result.latency_s = time.perf_counter() - t0
 
187
 
188
  state.budget_spent = budget_after
189
  state.budget_remaining_ratio = max(
@@ -234,6 +224,8 @@ class CostAwareToolEnvironment:
234
  ) -> OrchestratorObservation:
235
  state = self._state
236
  cfg = self.config
 
 
237
 
238
  if 0 <= self._current_q_idx < len(self._questions):
239
  q_entry = self._questions[self._current_q_idx]
 
24
  TOOL_IDS,
25
  )
26
  from .reward import commit_reward, step_reward
27
+ from tools import dispatch_tool
28
 
29
 
30
  class CostAwareToolEnvironment:
 
96
  ) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
97
  if self._episode_done:
98
  raise RuntimeError("Episode is done. Call reset() first.")
99
+ if self._state is None:
100
+ raise RuntimeError("Call reset() first.")
101
 
102
  tool_id = action.tool_id
103
  if tool_id not in TOOL_IDS:
 
163
  if budget_after > state.total_budget:
164
  r = config.incorrect_reward
165
  self._episode_done = True
166
+ obs = self._make_obs(
167
+ reward=r,
168
+ question_done=True,
169
+ done=True,
170
+ )
171
  return obs, r, True, state
172
 
173
  t0 = time.perf_counter()
174
+ tool_result = dispatch_tool(tool_id, action, self.tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  tool_result.latency_s = time.perf_counter() - t0
176
+ tool_result.cost = cost
177
 
178
  state.budget_spent = budget_after
179
  state.budget_remaining_ratio = max(
 
224
  ) -> OrchestratorObservation:
225
  state = self._state
226
  cfg = self.config
227
+ if state is None:
228
+ raise RuntimeError("Call reset() first.")
229
 
230
  if 0 <= self._current_q_idx < len(self._questions):
231
  q_entry = self._questions[self._current_q_idx]
environment.py CHANGED
@@ -1,307 +1,9 @@
1
- """Compatibility shim real code lives in env/environment.py.
2
 
3
- Step logic:
4
- - Agent receives an OrchestratorObservation with the current question,
5
- budget, context, and available tools.
6
- - Agent picks a tool_id and optional query / code_snippet / answer.
7
- - Environment dispatches to the appropriate tool, charges cost, appends
8
- result to context_window, and returns the next observation + reward.
9
- - Episode ends when budget is exhausted OR all questions are answered.
10
  """
11
  from __future__ import annotations
12
 
13
- import time
14
- import uuid
15
- from typing import Any, Dict, List, Optional, Tuple
16
-
17
- from env.answer_grading import grade
18
- from env.config import EnvConfig
19
- from env.models import (
20
- OrchestratorAction,
21
- OrchestratorObservation,
22
- OrchestratorState,
23
- ToolResult,
24
- TOOL_IDS,
25
- )
26
- from env.reward import commit_reward, step_reward
27
-
28
-
29
- class CostAwareToolEnvironment:
30
- """
31
- OpenEnv-compatible RL environment for multi-tool cost-aware QA.
32
-
33
- Supports external tool injection so the server can wire in live
34
- Ceramic, code executor, etc. Tools are callables with signature:
35
- tool_fn(action: OrchestratorAction) -> ToolResult
36
- """
37
-
38
- def __init__(
39
- self,
40
- config: Optional[EnvConfig] = None,
41
- tools: Optional[Dict[str, Any]] = None,
42
- dataset: Optional[List[Dict[str, Any]]] = None,
43
- ):
44
- self.config = config or EnvConfig()
45
- self.tools = tools or {} # tool_id -> callable
46
- self.dataset = dataset or [] # List of {question, answer, domain}
47
- self._state: Optional[OrchestratorState] = None
48
- self._questions: List[Dict[str, Any]] = []
49
- self._current_q_idx: int = 0
50
- self._context_window: List[str] = []
51
- self._tools_used_this_q: List[str] = []
52
- self._steps_this_q: int = 0
53
- self._episode_done: bool = False
54
-
55
- # -----------------------------------------------------------------------
56
- # Reset
57
- # -----------------------------------------------------------------------
58
-
59
- def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
60
- import random
61
- rng = random.Random(seed if seed is not None else self.config.seed)
62
-
63
- # Sample questions according to domain_mix
64
- questions = _sample_questions(self.dataset, self.config, rng)
65
- self._questions = questions
66
- self._current_q_idx = 0
67
- self._episode_done = False
68
-
69
- self._state = OrchestratorState(
70
- episode_id=str(uuid.uuid4()),
71
- total_budget=self.config.total_budget,
72
- budget_spent=0,
73
- questions_answered=0,
74
- total_correct=0,
75
- current_accuracy=0.0,
76
- budget_remaining_ratio=1.0,
77
- current_question_idx=0,
78
- current_question_steps=0,
79
- )
80
-
81
- self._context_window = []
82
- self._tools_used_this_q = []
83
- self._steps_this_q = 0
84
-
85
- obs = self._make_obs(reward=None, question_done=False, done=False)
86
- return obs, self._state
87
-
88
- # -----------------------------------------------------------------------
89
- # Step
90
- # -----------------------------------------------------------------------
91
-
92
- def step(
93
- self, action: OrchestratorAction
94
- ) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
95
- if self._episode_done:
96
- raise RuntimeError("Episode is done. Call reset() first.")
97
-
98
- tool_id = action.tool_id
99
- if tool_id not in TOOL_IDS:
100
- raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
101
-
102
- state = self._state
103
- config = self.config
104
-
105
- # Guard against exhausted question list (can happen after last commit)
106
- if self._current_q_idx >= len(self._questions):
107
- self._episode_done = True
108
- raise RuntimeError("Episode is done. Call reset() first.")
109
-
110
- q_entry = self._questions[self._current_q_idx]
111
- gold = q_entry["answer"]
112
-
113
- # ---- Commit ---------------------------------------------------
114
- if tool_id == "commit":
115
- raw_pred = action.answer or ""
116
- em, f1, quality = grade(raw_pred, gold)
117
-
118
- r = commit_reward(
119
- quality=quality,
120
- budget_remaining_ratio=state.budget_remaining_ratio,
121
- config=config,
122
- )
123
-
124
- # Count correct
125
- is_correct = (
126
- em if config.grade_count_correct_mode == "em_only"
127
- else (em or f1 >= config.f1_count_threshold)
128
- )
129
- state.questions_answered += 1
130
- state.total_correct += int(is_correct)
131
- state.current_accuracy = state.total_correct / state.questions_answered
132
-
133
- # Advance to next question or end episode
134
- self._current_q_idx += 1
135
- self._context_window = []
136
- self._tools_used_this_q = []
137
- self._steps_this_q = 0
138
-
139
- episode_done = (
140
- self._current_q_idx >= len(self._questions)
141
- or state.budget_spent >= state.total_budget
142
- )
143
- self._episode_done = episode_done
144
- state.current_question_idx = self._current_q_idx
145
- state.current_question_steps = 0
146
-
147
- obs = self._make_obs(
148
- reward=r,
149
- question_done=True,
150
- done=episode_done,
151
- last_tool_result=ToolResult(
152
- tool_id="commit", cost=0,
153
- output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
154
- ),
155
- )
156
- return obs, r, episode_done, state
157
-
158
- # ---- Tool call ------------------------------------------------
159
- cost = config.tool_costs.get(tool_id, 0)
160
- budget_after = state.budget_spent + cost
161
-
162
- # If over budget, force commit penalty
163
- if budget_after > state.total_budget:
164
- r = config.incorrect_reward
165
- self._episode_done = True
166
- obs = self._make_obs(reward=r, question_done=True, done=True)
167
- return obs, r, True, state
168
-
169
- # Dispatch tool
170
- t0 = time.perf_counter()
171
- tool_fn = self.tools.get(tool_id)
172
- if tool_fn is None:
173
- tool_result = ToolResult(
174
- tool_id=tool_id, cost=cost,
175
- output="[Tool not available in this environment]",
176
- latency_s=0.0,
177
- error="not_available",
178
- )
179
- else:
180
- try:
181
- tool_result = tool_fn(action)
182
- tool_result.cost = cost
183
- except Exception as exc:
184
- tool_result = ToolResult(
185
- tool_id=tool_id, cost=cost,
186
- output=f"[Error: {exc}]",
187
- latency_s=time.perf_counter() - t0,
188
- error=str(exc),
189
- )
190
- tool_result.latency_s = time.perf_counter() - t0
191
-
192
- # Charge cost and update state
193
- state.budget_spent = budget_after
194
- state.budget_remaining_ratio = max(
195
- 0.0, (state.total_budget - state.budget_spent) / state.total_budget
196
- )
197
- state.step_count += 1
198
- state.current_question_steps += 1
199
- self._steps_this_q += 1
200
- self._tools_used_this_q.append(tool_id)
201
- self._context_window.append(f"[{tool_id}] {tool_result.output}")
202
-
203
- r = step_reward(tool_id, config)
204
-
205
- # Auto-commit if max steps reached
206
- question_done = self._steps_this_q >= config.max_steps_per_question
207
- episode_done = (
208
- state.budget_spent >= state.total_budget
209
- or (question_done and self._current_q_idx + 1 >= len(self._questions))
210
- )
211
- if question_done and not episode_done:
212
- self._current_q_idx += 1
213
- state.questions_answered += 1
214
- self._context_window = []
215
- self._tools_used_this_q = []
216
- self._steps_this_q = 0
217
- state.current_question_idx = self._current_q_idx
218
- state.current_question_steps = 0
219
-
220
- self._episode_done = episode_done
221
-
222
- obs = self._make_obs(
223
- reward=r,
224
- question_done=question_done,
225
- done=episode_done,
226
- last_tool_result=tool_result,
227
- )
228
- return obs, r, episode_done, state
229
-
230
- # -----------------------------------------------------------------------
231
- # Internal helpers
232
- # -----------------------------------------------------------------------
233
-
234
- def _make_obs(
235
- self,
236
- reward: Optional[float],
237
- question_done: bool,
238
- done: bool,
239
- last_tool_result: Optional[ToolResult] = None,
240
- ) -> OrchestratorObservation:
241
- state = self._state
242
- cfg = self.config
243
-
244
- if 0 <= self._current_q_idx < len(self._questions):
245
- q_entry = self._questions[self._current_q_idx]
246
- elif self._questions:
247
- # Episode finished — repeat last question info (obs is terminal anyway)
248
- q_entry = self._questions[-1]
249
- else:
250
- q_entry = {"question": "", "answer": "", "domain": ""}
251
-
252
- return OrchestratorObservation(
253
- question=q_entry.get("question", ""),
254
- question_idx=self._current_q_idx,
255
- domain=q_entry.get("domain", ""),
256
- question_embedding=[], # populated by server if needed
257
- total_budget=cfg.total_budget,
258
- budget_spent=state.budget_spent,
259
- budget_remaining=state.total_budget - state.budget_spent,
260
- budget_remaining_ratio=state.budget_remaining_ratio,
261
- tools_used_this_question=list(self._tools_used_this_q),
262
- steps_used_this_question=self._steps_this_q,
263
- max_steps_per_question=cfg.max_steps_per_question,
264
- last_tool_result=last_tool_result,
265
- context_window=list(self._context_window),
266
- step_idx=state.step_count,
267
- questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
268
- questions_answered=state.questions_answered,
269
- accuracy_so_far=state.current_accuracy,
270
- question_done=question_done,
271
- done=done,
272
- reward=reward,
273
- )
274
-
275
-
276
- # ---------------------------------------------------------------------------
277
- # Dataset sampling helper
278
- # ---------------------------------------------------------------------------
279
-
280
- def _sample_questions(
281
- dataset: List[Dict[str, Any]],
282
- config: EnvConfig,
283
- rng: Any,
284
- ) -> List[Dict[str, Any]]:
285
- """Sample `config.num_questions` questions according to domain_mix."""
286
- by_domain: Dict[str, List[Dict]] = {}
287
- for item in dataset:
288
- d = item.get("domain", "hotpotqa")
289
- by_domain.setdefault(d, []).append(item)
290
-
291
- selected = []
292
- for domain, frac in config.domain_mix.items():
293
- n = round(config.num_questions * frac)
294
- pool = by_domain.get(domain, [])
295
- if pool and n > 0:
296
- selected.extend(rng.sample(pool, min(n, len(pool))))
297
-
298
- # Guarantee at least num_questions items by filling from the full dataset
299
- if len(selected) < config.num_questions and dataset:
300
- remaining = [d for d in dataset if d not in selected]
301
- rng.shuffle(remaining)
302
- selected.extend(remaining[: config.num_questions - len(selected)])
303
-
304
- if config.shuffle_questions:
305
- rng.shuffle(selected)
306
-
307
- return selected[: config.num_questions]
 
1
+ """Compatibility shim for the legacy top-level import path.
2
 
3
+ The real environment implementation lives in :mod:`env.environment`.
4
+ This module stays intentionally thin so the two orchestrator entrypoints
5
+ cannot drift apart again.
 
 
 
 
6
  """
7
  from __future__ import annotations
8
 
9
+ from env.environment import * # noqa: F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -6,3 +6,4 @@ datasets>=2.18.0
6
  httpx>=0.27.0
7
  requests>=2.31.0
8
  together>=1.2.0
 
 
6
  httpx>=0.27.0
7
  requests>=2.31.0
8
  together>=1.2.0
9
+ pytest>=8.0.0
tests/conftest.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+ from fastapi.testclient import TestClient
8
+
9
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
10
+
11
+ from app import create_app
12
+ from env.config import EnvConfig
13
+ from env.models import TOOL_IDS, OrchestratorAction, ToolResult
14
+
15
+
16
+ def make_stub_registry():
17
+ def make_tool(tool_id: str):
18
+ def _tool(action: OrchestratorAction) -> ToolResult:
19
+ payload = action.query or action.expression or action.code_snippet or action.answer or ""
20
+ return ToolResult(tool_id=tool_id, output=f"{tool_id}:{payload}".rstrip(":"))
21
+
22
+ return _tool
23
+
24
+ return {tool_id: make_tool(tool_id) for tool_id in TOOL_IDS}
25
+
26
+
27
+ @pytest.fixture()
28
+ def sample_dataset():
29
+ return [
30
+ {"question": "What is 2 + 2?", "answer": "4", "domain": "math"},
31
+ {"question": "What is 3 + 1?", "answer": "4", "domain": "math"},
32
+ ]
33
+
34
+
35
+ @pytest.fixture()
36
+ def app_client(sample_dataset):
37
+ cfg = EnvConfig(
38
+ num_questions=2,
39
+ total_budget=5.0,
40
+ max_steps_per_question=4,
41
+ shuffle_questions=False,
42
+ domain_mix={"math": 1.0},
43
+ )
44
+ app = create_app(config=cfg, tools=make_stub_registry(), dataset=sample_dataset)
45
+ return TestClient(app)
tests/test_app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def test_reset_returns_clean_session(app_client):
5
+ response = app_client.post("/reset", json={"seed": 42})
6
+ assert response.status_code == 200
7
+
8
+ payload = response.json()
9
+ assert payload["session_id"]
10
+
11
+ observation = payload["observation"]
12
+ state = payload["state"]
13
+ assert observation["budget_spent"] == 0
14
+ assert observation["question_idx"] == 0
15
+ assert observation["done"] is False
16
+ assert observation["tools_used_this_question"] == []
17
+ assert state["budget_spent"] == 0
18
+ assert state["current_question_idx"] == 0
19
+ assert state["step_count"] == 0
20
+
21
+
22
+ def test_step_charges_the_right_cost(app_client):
23
+ reset = app_client.post("/reset", json={"seed": 42})
24
+ session_id = reset.json()["session_id"]
25
+
26
+ response = app_client.post(
27
+ "/step",
28
+ params={"session_id": session_id},
29
+ json={"tool_id": "calculator", "expression": "2 + 2"},
30
+ )
31
+ assert response.status_code == 200
32
+
33
+ payload = response.json()
34
+ assert payload["reward"] == -0.1
35
+ assert payload["state"]["budget_spent"] == 0.1
36
+ assert payload["observation"]["last_tool_result"]["tool_id"] == "calculator"
37
+ assert payload["observation"]["last_tool_result"]["cost"] == 0.1
38
+
39
+
40
+ def test_commit_advances_question_state_correctly(app_client):
41
+ reset = app_client.post("/reset", json={"seed": 42})
42
+ session_id = reset.json()["session_id"]
43
+
44
+ response = app_client.post(
45
+ "/step",
46
+ params={"session_id": session_id},
47
+ json={"tool_id": "commit", "answer": "4"},
48
+ )
49
+ assert response.status_code == 200
50
+
51
+ payload = response.json()
52
+ assert payload["done"] is False
53
+ assert payload["state"]["questions_answered"] == 1
54
+ assert payload["state"]["current_question_idx"] == 1
55
+ assert payload["observation"]["question_done"] is True
56
+
57
+
58
+ def test_episode_termination_and_cleanup_behave_as_expected(app_client):
59
+ reset = app_client.post("/reset", json={"seed": 42})
60
+ session_id = reset.json()["session_id"]
61
+
62
+ first = app_client.post(
63
+ "/step",
64
+ params={"session_id": session_id},
65
+ json={"tool_id": "commit", "answer": "4"},
66
+ )
67
+ assert first.status_code == 200
68
+
69
+ second = app_client.post(
70
+ "/step",
71
+ params={"session_id": session_id},
72
+ json={"tool_id": "commit", "answer": "4"},
73
+ )
74
+ assert second.status_code == 200
75
+ assert second.json()["done"] is True
76
+
77
+ follow_up = app_client.post(
78
+ "/step",
79
+ params={"session_id": session_id},
80
+ json={"tool_id": "commit", "answer": "4"},
81
+ )
82
+ assert follow_up.status_code == 404
tests/test_code_executor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from env.models import OrchestratorAction
4
+ from tools.code_executor import code_executor_tool
5
+
6
+
7
+ def test_code_executor_empty_input():
8
+ result = code_executor_tool(OrchestratorAction(tool_id="code_executor"))
9
+ assert result.error == "empty"
10
+ assert "No code provided" in result.output
11
+
12
+
13
+ def test_code_executor_runtime_error():
14
+ result = code_executor_tool(
15
+ OrchestratorAction(tool_id="code_executor", code_snippet="print(1 / 0)")
16
+ )
17
+ assert result.error is not None
18
+ assert "division by zero" in result.output.lower()
19
+
20
+
21
+ def test_code_executor_blocks_imports():
22
+ result = code_executor_tool(
23
+ OrchestratorAction(tool_id="code_executor", code_snippet="import os\nprint('hi')")
24
+ )
25
+ assert result.error == "sandbox_violation"
26
+ assert "import statements are blocked" in result.output
27
+
28
+
29
+ def test_code_executor_blocks_unsafe_builtins():
30
+ result = code_executor_tool(
31
+ OrchestratorAction(tool_id="code_executor", code_snippet="open('tmp.txt', 'w')")
32
+ )
33
+ assert result.error == "sandbox_violation"
34
+ assert "name 'open' is blocked" in result.output
35
+
36
+
37
+ def test_code_executor_blocks_escape_attempts():
38
+ result = code_executor_tool(
39
+ OrchestratorAction(
40
+ tool_id="code_executor",
41
+ code_snippet="().__class__.__mro__[1].__subclasses__()",
42
+ )
43
+ )
44
+ assert result.error == "sandbox_violation"
45
+ assert "__class__" in result.output or "__subclasses__" in result.output
tests/test_tools.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from urllib.error import HTTPError
4
+
5
+ from env.models import OrchestratorAction
6
+ from tools.calculator import calculator_tool
7
+ from tools.code_executor import code_executor_tool
8
+ from tools.commit import commit_tool
9
+ from tools.llm_reason import llm_reason_tool
10
+ from tools.ceramic_search import make_search_tool
11
+ from tools.wiki_lookup import wiki_lookup_tool
12
+
13
+
14
+ def test_calculator_happy_path():
15
+ result = calculator_tool(OrchestratorAction(tool_id="calculator", expression="2 + 2 * 3"))
16
+ assert result.output == "8"
17
+ assert result.error is None
18
+
19
+
20
+ def test_calculator_invalid_input():
21
+ result = calculator_tool(OrchestratorAction(tool_id="calculator", expression="open('x')"))
22
+ assert result.error is not None
23
+ assert result.output.startswith("[Calc error:")
24
+
25
+
26
+ def test_search_fallback_is_deterministic(monkeypatch):
27
+ monkeypatch.delenv("CERAMIC_API_KEY", raising=False)
28
+ monkeypatch.delenv("SEE_CERAMIC_API_KEY", raising=False)
29
+
30
+ tool = make_search_tool()
31
+ action = OrchestratorAction(tool_id="ceramic_search", query="Eiffel Tower")
32
+
33
+ first = tool(action)
34
+ second = tool(action)
35
+
36
+ assert first.error is None
37
+ assert first.output == second.output
38
+ assert "Eiffel Tower" in first.output
39
+
40
+
41
+ def test_wiki_lookup_not_found(monkeypatch):
42
+ def fake_urlopen(*args, **kwargs):
43
+ raise HTTPError(url="https://example.com", code=404, msg="not found", hdrs=None, fp=None)
44
+
45
+ monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
46
+
47
+ result = wiki_lookup_tool(OrchestratorAction(tool_id="wiki_lookup", query="Definitely Not A Real Page"))
48
+ assert result.error == "not_found"
49
+ assert "no article found" in result.output.lower()
50
+
51
+
52
+ def test_llm_reason_no_api_key(monkeypatch):
53
+ monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
54
+ monkeypatch.delenv("TOGETHER_KEY", raising=False)
55
+
56
+ result = llm_reason_tool(OrchestratorAction(tool_id="llm_reason", query="Explain gravity"))
57
+ assert result.error == "no_api_key"
58
+ assert "not configured" in result.output.lower()
59
+
60
+
61
+ def test_commit_passthrough_behavior():
62
+ result = commit_tool(OrchestratorAction(tool_id="commit", answer=" 1889 "))
63
+ assert result.tool_id == "commit"
64
+ assert result.output == "Committed answer: 1889"
tools/__init__.py CHANGED
@@ -1,29 +1,20 @@
1
- """Tool registry for CostAwareToolEnv.
2
-
3
- Each tool is a callable: (action: OrchestratorAction) -> ToolResult
4
- """
5
  from __future__ import annotations
6
 
7
- from typing import Callable, Dict
8
-
9
- from env.config import EnvConfig
10
- from env.models import OrchestratorAction, ToolResult
11
-
12
- from .calculator import calculator_tool
13
- from .ceramic_search import make_search_tool
14
- from .code_executor import code_executor_tool
15
- from .commit import commit_tool
16
- from .llm_reason import llm_reason_tool
17
- from .wiki_lookup import wiki_lookup_tool
18
-
19
 
20
- def build_tool_registry(config: EnvConfig | None = None) -> Dict[str, Callable]:
21
- """Return a mapping of tool_id → tool function."""
22
- return {
23
- "ceramic_search": make_search_tool(),
24
- "calculator": calculator_tool,
25
- "wiki_lookup": wiki_lookup_tool,
26
- "code_executor": code_executor_tool,
27
- "llm_reason": llm_reason_tool,
28
- "commit": commit_tool,
29
- }
 
1
+ """Public tool layer for CostAwareToolEnv."""
 
 
 
2
  from __future__ import annotations
3
 
4
+ from .runtime import (
5
+ ToolSpec,
6
+ build_tool_catalog,
7
+ build_tool_registry,
8
+ catalog_as_dicts,
9
+ dispatch_tool,
10
+ validate_tool_costs,
11
+ )
 
 
 
 
12
 
13
+ __all__ = [
14
+ "ToolSpec",
15
+ "build_tool_catalog",
16
+ "build_tool_registry",
17
+ "catalog_as_dicts",
18
+ "dispatch_tool",
19
+ "validate_tool_costs",
20
+ ]
 
 
tools/calculator.py CHANGED
@@ -24,6 +24,15 @@ _SAFE_OPS = {
24
  ast.UAdd: operator.pos,
25
  }
26
 
 
 
 
 
 
 
 
 
 
27
  _SAFE_FUNCS: dict[str, Any] = {
28
  "abs": abs, "round": round, "min": min, "max": max,
29
  "sqrt": math.sqrt, "log": math.log, "log2": math.log2,
@@ -53,6 +62,17 @@ def _safe_eval(node: ast.AST) -> Any:
53
  if op_type not in _SAFE_OPS:
54
  raise ValueError(f"Unsupported unary: {op_type.__name__}")
55
  return _SAFE_OPS[op_type](_safe_eval(node.operand))
 
 
 
 
 
 
 
 
 
 
 
56
  if isinstance(node, ast.Call):
57
  func = _safe_eval(node.func)
58
  if not callable(func):
 
24
  ast.UAdd: operator.pos,
25
  }
26
 
27
+ _SAFE_COMPARISONS = {
28
+ ast.Eq: operator.eq,
29
+ ast.NotEq: operator.ne,
30
+ ast.Lt: operator.lt,
31
+ ast.LtE: operator.le,
32
+ ast.Gt: operator.gt,
33
+ ast.GtE: operator.ge,
34
+ }
35
+
36
  _SAFE_FUNCS: dict[str, Any] = {
37
  "abs": abs, "round": round, "min": min, "max": max,
38
  "sqrt": math.sqrt, "log": math.log, "log2": math.log2,
 
62
  if op_type not in _SAFE_OPS:
63
  raise ValueError(f"Unsupported unary: {op_type.__name__}")
64
  return _SAFE_OPS[op_type](_safe_eval(node.operand))
65
+ if isinstance(node, ast.Compare):
66
+ left = _safe_eval(node.left)
67
+ for op, comparator in zip(node.ops, node.comparators):
68
+ op_type = type(op)
69
+ if op_type not in _SAFE_COMPARISONS:
70
+ raise ValueError(f"Unsupported comparison: {op_type.__name__}")
71
+ right = _safe_eval(comparator)
72
+ if not _SAFE_COMPARISONS[op_type](left, right):
73
+ return False
74
+ left = right
75
+ return True
76
  if isinstance(node, ast.Call):
77
  func = _safe_eval(node.func)
78
  if not callable(func):
tools/code_executor.py CHANGED
@@ -1,35 +1,169 @@
1
  """Restricted Python code executor.
2
 
3
- Runs code in a sandboxed namespace — blocks os/sys/subprocess imports
4
- and captures stdout. Intended for math / algorithmic tasks.
 
 
 
 
 
 
5
  """
6
  from __future__ import annotations
7
 
8
- import io
9
- import sys
10
  import contextlib
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from env.models import OrchestratorAction, ToolResult
13
 
14
- _BLOCKED_MODULES = frozenset({
15
- "os", "sys", "subprocess", "socket", "shutil", "pathlib",
16
- "importlib", "builtins", "ctypes", "multiprocessing", "threading",
17
- "signal", "pty", "fcntl", "resource", "gc", "inspect",
18
- })
19
-
20
  _MAX_OUTPUT_CHARS = 2000
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- class _BlockedImport:
24
- """Raise on any import of blocked modules."""
25
- def __init__(self, original_import):
26
- self._orig = original_import
27
 
28
- def __call__(self, name, *args, **kwargs):
29
- base = name.split(".")[0]
30
- if base in _BLOCKED_MODULES:
31
- raise ImportError(f"Module '{name}' is not allowed in code_executor")
32
- return self._orig(name, *args, **kwargs)
33
 
34
 
35
  def code_executor_tool(action: OrchestratorAction) -> ToolResult:
@@ -38,26 +172,21 @@ def code_executor_tool(action: OrchestratorAction) -> ToolResult:
38
  return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
39
 
40
  stdout_buf = io.StringIO()
41
- safe_globals = {
42
- "__builtins__": {
43
- k: v for k, v in __builtins__.items() # type: ignore[union-attr]
44
- if k not in ("open", "exec", "eval", "compile", "__import__")
45
- } if isinstance(__builtins__, dict) else {
46
- k: getattr(__builtins__, k) for k in dir(__builtins__)
47
- if k not in ("open", "exec", "eval", "compile", "__import__")
48
- },
49
- "__import__": _BlockedImport(__import__),
50
- "print": lambda *a, **kw: print(*a, **kw, file=stdout_buf),
51
  }
52
 
53
  try:
 
 
54
  with contextlib.redirect_stdout(stdout_buf):
55
- exec(compile(code, "<code_executor>", "exec"), safe_globals) # noqa: S102
56
  output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
57
  return ToolResult(tool_id="code_executor", output=output)
 
 
58
  except Exception as exc:
59
- return ToolResult(
60
- tool_id="code_executor",
61
- output=f"[Execution error: {exc}]",
62
- error=str(exc),
63
- )
 
1
  """Restricted Python code executor.
2
 
3
+ The executor is intentionally narrow:
4
+ - import statements are rejected before execution
5
+ - dangerous builtin names are blocked
6
+ - dunder attribute access is blocked to prevent object graph escapes
7
+ - only a curated builtin/module surface is exposed
8
+
9
+ This keeps the tool useful for HumanEval-style tasks while making the
10
+ security boundaries explicit and testable.
11
  """
12
  from __future__ import annotations
13
 
14
+ import ast
15
+ import builtins
16
  import contextlib
17
+ import io
18
+ import math
19
+ import operator
20
+ import collections
21
+ import functools
22
+ import itertools
23
+ import statistics
24
+ import heapq
25
+ import bisect
26
+ import fractions
27
+ import decimal
28
+ import re
29
+ from typing import Any, Dict
30
 
31
  from env.models import OrchestratorAction, ToolResult
32
 
 
 
 
 
 
 
33
  _MAX_OUTPUT_CHARS = 2000
34
 
35
+ _BLOCKED_NAMES = {
36
+ "__builtins__",
37
+ "__import__",
38
+ "open",
39
+ "exec",
40
+ "eval",
41
+ "compile",
42
+ "globals",
43
+ "locals",
44
+ "vars",
45
+ "dir",
46
+ "getattr",
47
+ "setattr",
48
+ "delattr",
49
+ "input",
50
+ "help",
51
+ "type",
52
+ "object",
53
+ "super",
54
+ "memoryview",
55
+ "breakpoint",
56
+ "exit",
57
+ "quit",
58
+ }
59
+
60
+ _BLOCKED_ATTRS = {
61
+ "__class__",
62
+ "__base__",
63
+ "__bases__",
64
+ "__subclasses__",
65
+ "__mro__",
66
+ "__globals__",
67
+ "__code__",
68
+ "__closure__",
69
+ "__dict__",
70
+ "__getattribute__",
71
+ "__getattr__",
72
+ "__setattr__",
73
+ "__delattr__",
74
+ "__reduce__",
75
+ "__reduce_ex__",
76
+ "__func__",
77
+ "__self__",
78
+ "__module__",
79
+ }
80
+
81
+ _SAFE_BUILTIN_NAMES = {
82
+ "abs",
83
+ "all",
84
+ "any",
85
+ "bool",
86
+ "chr",
87
+ "dict",
88
+ "enumerate",
89
+ "float",
90
+ "int",
91
+ "isinstance",
92
+ "issubclass",
93
+ "len",
94
+ "list",
95
+ "map",
96
+ "max",
97
+ "min",
98
+ "pow",
99
+ "range",
100
+ "repr",
101
+ "reversed",
102
+ "round",
103
+ "set",
104
+ "slice",
105
+ "sorted",
106
+ "str",
107
+ "sum",
108
+ "tuple",
109
+ "zip",
110
+ "divmod",
111
+ "ord",
112
+ "Exception",
113
+ "ValueError",
114
+ "RuntimeError",
115
+ "TypeError",
116
+ "KeyError",
117
+ "IndexError",
118
+ "AssertionError",
119
+ "ZeroDivisionError",
120
+ "object",
121
+ }
122
+
123
+ _SAFE_MODULES: Dict[str, Any] = {
124
+ "math": math,
125
+ "collections": collections,
126
+ "functools": functools,
127
+ "itertools": itertools,
128
+ "statistics": statistics,
129
+ "heapq": heapq,
130
+ "bisect": bisect,
131
+ "fractions": fractions,
132
+ "decimal": decimal,
133
+ "re": re,
134
+ "operator": operator,
135
+ }
136
+
137
+
138
+ class SandboxViolation(ValueError):
139
+ """Raised when code tries to cross the sandbox boundary."""
140
+
141
+
142
+ def _validate_tree(tree: ast.AST) -> None:
143
+ for node in ast.walk(tree):
144
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
145
+ raise SandboxViolation("import statements are blocked")
146
+ if isinstance(node, (ast.Global, ast.Nonlocal)):
147
+ raise SandboxViolation("global and nonlocal are blocked")
148
+ if isinstance(node, ast.Attribute):
149
+ if node.attr.startswith("__") or node.attr in _BLOCKED_ATTRS:
150
+ raise SandboxViolation(f"attribute access '{node.attr}' is blocked")
151
+ if isinstance(node, ast.Name):
152
+ if node.id.startswith("__") or node.id in _BLOCKED_NAMES:
153
+ raise SandboxViolation(f"name '{node.id}' is blocked")
154
+
155
+
156
+ def _safe_builtins(stdout_buf: io.StringIO) -> Dict[str, Any]:
157
+ safe: Dict[str, Any] = {name: getattr(builtins, name) for name in _SAFE_BUILTIN_NAMES}
158
 
159
+ def safe_print(*args: Any, **kwargs: Any) -> None:
160
+ kwargs = dict(kwargs)
161
+ kwargs.pop("file", None)
162
+ builtins.print(*args, **kwargs, file=stdout_buf)
163
 
164
+ safe["print"] = safe_print
165
+ safe["__build_class__"] = builtins.__build_class__
166
+ return safe
 
 
167
 
168
 
169
  def code_executor_tool(action: OrchestratorAction) -> ToolResult:
 
172
  return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
173
 
174
  stdout_buf = io.StringIO()
175
+ safe_globals: Dict[str, Any] = {
176
+ "__builtins__": _safe_builtins(stdout_buf),
177
+ "__name__": "__code_executor__",
178
+ "__package__": None,
179
+ **_SAFE_MODULES,
 
 
 
 
 
180
  }
181
 
182
  try:
183
+ tree = ast.parse(code, mode="exec")
184
+ _validate_tree(tree)
185
  with contextlib.redirect_stdout(stdout_buf):
186
+ exec(compile(tree, "<code_executor>", "exec"), safe_globals) # noqa: S102
187
  output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
188
  return ToolResult(tool_id="code_executor", output=output)
189
+ except SandboxViolation as exc:
190
+ return ToolResult(tool_id="code_executor", output=f"[Sandbox blocked: {exc}]", error="sandbox_violation")
191
  except Exception as exc:
192
+ return ToolResult(tool_id="code_executor", output=f"[Execution error: {exc}]", error=str(exc))
 
 
 
 
tools/runtime.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared tool catalog and dispatch helpers for CostAwareToolEnv.
2
+
3
+ This module keeps the tool contract explicit:
4
+ - the catalog describes every tool, its purpose, and its input field
5
+ - registry validation catches config drift early
6
+ - dispatch normalizes failures into ToolResult objects instead of
7
+ letting exceptions leak through the environment loop
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import asdict, dataclass
12
+ from typing import Any, Callable, Dict, List, Mapping
13
+
14
+ from env.config import EnvConfig
15
+ from env.models import OrchestratorAction, TOOL_IDS, ToolResult
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class ToolSpec:
20
+ tool_id: str
21
+ label: str
22
+ purpose: str
23
+ input_field: str
24
+ cost: float
25
+ notes: str
26
+
27
+
28
+ _TOOL_SPEC_TEMPLATES: Dict[str, Dict[str, str]] = {
29
+ "ceramic_search": {
30
+ "label": "Ceramic web search",
31
+ "purpose": "Web retrieval for multi-hop factual QA",
32
+ "input_field": "query",
33
+ "notes": "Falls back to deterministic offline search when Ceramic credentials are unavailable.",
34
+ },
35
+ "wiki_lookup": {
36
+ "label": "Wikipedia lookup",
37
+ "purpose": "Entity facts, definitions, and short summaries",
38
+ "input_field": "query",
39
+ "notes": "Returns an explicit not-found or HTTP error result instead of crashing.",
40
+ },
41
+ "calculator": {
42
+ "label": "Calculator",
43
+ "purpose": "Arithmetic and symbolic math",
44
+ "input_field": "expression",
45
+ "notes": "Uses a restricted AST evaluator with comparisons and common math functions.",
46
+ },
47
+ "code_executor": {
48
+ "label": "Python executor",
49
+ "purpose": "HumanEval-style coding tasks",
50
+ "input_field": "code_snippet",
51
+ "notes": "Sandboxed exec with blocked imports, dunder attribute access, and unsafe builtins.",
52
+ },
53
+ "llm_reason": {
54
+ "label": "LLM reasoning",
55
+ "purpose": "Costly model-backed reasoning on hard problems",
56
+ "input_field": "query",
57
+ "notes": "Returns a clear no_api_key error when Together is unavailable.",
58
+ },
59
+ "commit": {
60
+ "label": "Commit answer",
61
+ "purpose": "Submit the final answer and advance the episode",
62
+ "input_field": "answer",
63
+ "notes": "Pass-through only; grading happens inside the environment.",
64
+ },
65
+ }
66
+
67
+
68
+ def validate_tool_costs(config: EnvConfig) -> None:
69
+ """Fail fast if the configured cost map drifts from the canonical tool set."""
70
+ missing = [tool_id for tool_id in TOOL_IDS if tool_id not in config.tool_costs]
71
+ if missing:
72
+ raise ValueError(f"EnvConfig.tool_costs is missing required tools: {missing}")
73
+
74
+ negative = {tool_id: cost for tool_id, cost in config.tool_costs.items() if cost < 0}
75
+ if negative:
76
+ raise ValueError(f"Tool costs must be non-negative: {negative}")
77
+
78
+
79
+ def build_tool_catalog(config: EnvConfig | None = None) -> List[ToolSpec]:
80
+ """Return the canonical ordered catalog used by the UI and docs."""
81
+ cfg = config or EnvConfig()
82
+ validate_tool_costs(cfg)
83
+
84
+ catalog: List[ToolSpec] = []
85
+ for tool_id in TOOL_IDS:
86
+ template = _TOOL_SPEC_TEMPLATES[tool_id]
87
+ catalog.append(
88
+ ToolSpec(
89
+ tool_id=tool_id,
90
+ label=template["label"],
91
+ purpose=template["purpose"],
92
+ input_field=template["input_field"],
93
+ cost=cfg.tool_costs[tool_id],
94
+ notes=template["notes"],
95
+ )
96
+ )
97
+ return catalog
98
+
99
+
100
+ def build_tool_registry(config: EnvConfig | None = None) -> Dict[str, Callable[[OrchestratorAction], ToolResult]]:
101
+ """Return the canonical mapping of tool_id -> tool callable."""
102
+ from .calculator import calculator_tool
103
+ from .ceramic_search import make_search_tool
104
+ from .code_executor import code_executor_tool
105
+ from .commit import commit_tool
106
+ from .llm_reason import llm_reason_tool
107
+ from .wiki_lookup import wiki_lookup_tool
108
+
109
+ cfg = config or EnvConfig()
110
+ validate_tool_costs(cfg)
111
+
112
+ return {
113
+ "ceramic_search": make_search_tool(),
114
+ "calculator": calculator_tool,
115
+ "wiki_lookup": wiki_lookup_tool,
116
+ "code_executor": code_executor_tool,
117
+ "llm_reason": llm_reason_tool,
118
+ "commit": commit_tool,
119
+ }
120
+
121
+
122
+ def dispatch_tool(
123
+ tool_id: str,
124
+ action: OrchestratorAction,
125
+ registry: Mapping[str, Callable[[OrchestratorAction], ToolResult]],
126
+ ) -> ToolResult:
127
+ """Call a tool and normalize missing-tool and crash cases into ToolResult."""
128
+ tool_fn = registry.get(tool_id)
129
+ if tool_fn is None:
130
+ return ToolResult(
131
+ tool_id=tool_id,
132
+ output="[Tool not available in this environment]",
133
+ error="not_available",
134
+ )
135
+
136
+ try:
137
+ result = tool_fn(action)
138
+ except Exception as exc: # pragma: no cover - defensive wrapper
139
+ return ToolResult(
140
+ tool_id=tool_id,
141
+ output=f"[Error: {exc}]",
142
+ error=str(exc),
143
+ )
144
+
145
+ if not isinstance(result, ToolResult):
146
+ return ToolResult(
147
+ tool_id=tool_id,
148
+ output=f"[Tool error: unexpected return type {type(result).__name__}]",
149
+ error="invalid_return_type",
150
+ )
151
+
152
+ if result.tool_id != tool_id:
153
+ result = result.model_copy(update={"tool_id": tool_id})
154
+ return result
155
+
156
+
157
+ def catalog_as_dicts(config: EnvConfig | None = None) -> List[dict[str, Any]]:
158
+ """Convenience helper for JSON serialization."""
159
+ return [asdict(spec) for spec in build_tool_catalog(config)]
tools/wiki_lookup.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
 
4
  import urllib.parse
5
  import urllib.request
 
6
  import json
7
 
8
  from env.models import OrchestratorAction, ToolResult
 
3
 
4
  import urllib.parse
5
  import urllib.request
6
+ import urllib.error
7
  import json
8
 
9
  from env.models import OrchestratorAction, ToolResult