Spaces:
Sleeping
Sleeping
Andrew Lara commited on
Commit ·
fb0cc18
1
Parent(s): 98c3ce1
Tighten tool routing and executor safety
Browse files- README.md +15 -1
- RESEARCH.md +11 -8
- app.py +159 -115
- client.py +2 -121
- env/environment.py +12 -20
- environment.py +5 -303
- requirements.txt +1 -0
- tests/conftest.py +45 -0
- tests/test_app.py +82 -0
- tests/test_code_executor.py +45 -0
- tests/test_tools.py +64 -0
- tools/__init__.py +17 -26
- tools/calculator.py +20 -0
- tools/code_executor.py +164 -35
- tools/runtime.py +159 -0
- tools/wiki_lookup.py +1 -0
README.md
CHANGED
|
@@ -108,6 +108,10 @@ Execute one tool call. Pass `session_id` (from `/reset`) as a query param to sup
|
|
| 108 |
{ "tool_id": "commit", "answer": "1889" }
|
| 109 |
```
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
### `GET /health`
|
| 112 |
|
| 113 |
Returns `{"status": "ok"}`.
|
|
@@ -123,6 +127,7 @@ claude_toolOrchestrator/
|
|
| 123 |
├── openenv.yaml # OpenEnv deployment spec
|
| 124 |
├── requirements.txt # Python dependencies
|
| 125 |
├── .env.example # Key template (copy → .env, never commit .env)
|
|
|
|
| 126 |
│
|
| 127 |
├── env/ # ── Core RL environment ──────────────────────────
|
| 128 |
│ ├── environment.py # CostAwareToolEnvironment: reset() + step()
|
|
@@ -141,7 +146,7 @@ claude_toolOrchestrator/
|
|
| 141 |
│ ├── ceramic_search.py # Web search (Ceramic AI API)
|
| 142 |
│ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
|
| 143 |
│ ├── calculator.py # Safe AST-based math evaluator (no exec)
|
| 144 |
-
│ ├── code_executor.py # Sandboxed Python exec (
|
| 145 |
│ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
|
| 146 |
│ └── commit.py # Answer pass-through; grading runs in environment
|
| 147 |
│
|
|
@@ -164,6 +169,15 @@ claude_toolOrchestrator/
|
|
| 164 |
|
| 165 |
If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
---
|
| 168 |
|
| 169 |
## Running baselines
|
|
|
|
| 108 |
{ "tool_id": "commit", "answer": "1889" }
|
| 109 |
```
|
| 110 |
|
| 111 |
+
### `GET /tools`
|
| 112 |
+
|
| 113 |
+
Returns the canonical tool manifest with each tool's label, purpose, input field, cost, and safety notes.
|
| 114 |
+
|
| 115 |
### `GET /health`
|
| 116 |
|
| 117 |
Returns `{"status": "ok"}`.
|
|
|
|
| 127 |
├── openenv.yaml # OpenEnv deployment spec
|
| 128 |
├── requirements.txt # Python dependencies
|
| 129 |
├── .env.example # Key template (copy → .env, never commit .env)
|
| 130 |
+
├── tools/runtime.py # Tool catalog, validation, and explicit dispatch
|
| 131 |
│
|
| 132 |
├── env/ # ── Core RL environment ──────────────────────────
|
| 133 |
│ ├── environment.py # CostAwareToolEnvironment: reset() + step()
|
|
|
|
| 146 |
│ ├── ceramic_search.py # Web search (Ceramic AI API)
|
| 147 |
│ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
|
| 148 |
│ ├── calculator.py # Safe AST-based math evaluator (no exec)
|
| 149 |
+
│ ├── code_executor.py # Sandboxed Python exec (blocked imports, dunder attrs)
|
| 150 |
│ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
|
| 151 |
│ └── commit.py # Answer pass-through; grading runs in environment
|
| 152 |
│
|
|
|
|
| 169 |
|
| 170 |
If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
|
| 171 |
|
| 172 |
+
The Python executor is intentionally narrow:
|
| 173 |
+
|
| 174 |
+
- import statements are rejected
|
| 175 |
+
- obvious sandbox-escape names such as `open`, `eval`, `globals`, and `__import__` are blocked
|
| 176 |
+
- dunder attribute access such as `.__class__` and `.__subclasses__()` is blocked
|
| 177 |
+
- only a curated builtin/module surface is exposed
|
| 178 |
+
|
| 179 |
+
That keeps the tool usable for intended coding tasks without turning it into a hidden general-purpose shell.
|
| 180 |
+
|
| 181 |
---
|
| 182 |
|
| 183 |
## Running baselines
|
RESEARCH.md
CHANGED
|
@@ -147,7 +147,7 @@ action.query = "William Shakespeare"
|
|
| 147 |
|
| 148 |
### `calculator` — Cost: **0.1**
|
| 149 |
|
| 150 |
-
**What it does:** Evaluates a math expression safely using Python's `ast` (Abstract Syntax Tree) module. The expression is parsed into a tree structure, and only pre-approved operations are allowed (addition, subtraction, multiplication, division, power, modulo, and common math functions like `sqrt`, `log`, `sin`, `cos`).
|
| 151 |
|
| 152 |
**Why not just use `eval()`?** Because `eval("__import__('os').system('rm -rf /')")` would delete your hard drive. The AST approach means the code is never executed — it's parsed into a data structure and we only compute what we explicitly allow.
|
| 153 |
|
|
@@ -168,9 +168,9 @@ action.expression = "import os" # BLOCKED — not a valid math expression
|
|
| 168 |
|
| 169 |
### `code_executor` — Cost: **0.3**
|
| 170 |
|
| 171 |
-
**What it does:** Runs
|
| 172 |
|
| 173 |
-
**Security model:** Blocks
|
| 174 |
|
| 175 |
**Best for:** HumanEval coding tasks where the agent needs to actually run code to verify correctness.
|
| 176 |
|
|
@@ -197,6 +197,8 @@ print(fibonacci(10))
|
|
| 197 |
|
| 198 |
**Graceful fallback:** If `TOGETHER_API_KEY` is not set, returns a clear error message instead of crashing. The agent learns to avoid this tool when it's unavailable.
|
| 199 |
|
|
|
|
|
|
|
| 200 |
---
|
| 201 |
|
| 202 |
### `commit` — Cost: **0.0**
|
|
@@ -503,9 +505,9 @@ A well-trained agent should exhibit these behaviors:
|
|
| 503 |
CostAwareToolEnv/
|
| 504 |
│
|
| 505 |
├── app.py
|
| 506 |
-
│ The FastAPI web server. Handles /reset, /step, /health, /web.
|
| 507 |
│ Multi-session: each /reset returns a session_id used in /step.
|
| 508 |
-
│
|
| 509 |
│
|
| 510 |
├── openenv.yaml
|
| 511 |
│ Deployment spec for the OpenEnv competition framework.
|
|
@@ -552,11 +554,12 @@ CostAwareToolEnv/
|
|
| 552 |
│ Returns flat List[Dict] with 'domain' key on each item.
|
| 553 |
│
|
| 554 |
├── tools/
|
| 555 |
-
│ ├──
|
|
|
|
| 556 |
│ ├── ceramic_search.py make_search_tool() factory wrapping CeramicClient
|
| 557 |
│ ├── wiki_lookup.py Wikipedia REST API, first paragraph
|
| 558 |
-
│ ├── calculator.py Safe AST-based math eval
|
| 559 |
-
│ ├── code_executor.py Sandboxed exec with blocked
|
| 560 |
│ ├── llm_reason.py Together AI API, graceful fallback
|
| 561 |
│ └── commit.py Pass-through; grading is in environment.py
|
| 562 |
│
|
|
|
|
| 147 |
|
| 148 |
### `calculator` — Cost: **0.1**
|
| 149 |
|
| 150 |
+
**What it does:** Evaluates a math expression safely using Python's `ast` (Abstract Syntax Tree) module. The expression is parsed into a tree structure, and only pre-approved operations are allowed (addition, subtraction, multiplication, division, power, modulo, comparisons, and common math functions like `sqrt`, `log`, `sin`, `cos`).
|
| 151 |
|
| 152 |
**Why not just use `eval()`?** Because `eval("__import__('os').system('rm -rf /')")` would delete your hard drive. The AST approach means the code is never executed — it's parsed into a data structure and we only compute what we explicitly allow.
|
| 153 |
|
|
|
|
| 168 |
|
| 169 |
### `code_executor` — Cost: **0.3**
|
| 170 |
|
| 171 |
+
**What it does:** Runs Python code in a sandboxed `exec()` environment for intended coding tasks. Captures whatever is printed to stdout and returns it as the result.
|
| 172 |
|
| 173 |
+
**Security model:** Blocks import statements, dangerous builtin names such as `open`, `eval`, `exec`, `globals`, and obvious object-graph escape paths such as dunder attribute traversal. Only a curated builtin/module surface is exposed.
|
| 174 |
|
| 175 |
**Best for:** HumanEval coding tasks where the agent needs to actually run code to verify correctness.
|
| 176 |
|
|
|
|
| 197 |
|
| 198 |
**Graceful fallback:** If `TOGETHER_API_KEY` is not set, returns a clear error message instead of crashing. The agent learns to avoid this tool when it's unavailable.
|
| 199 |
|
| 200 |
+
**Tool routing note:** The environment exposes the canonical tool manifest at `GET /tools`, and tool dispatch normalizes missing-tool and tool-crash cases into explicit `ToolResult` errors. That keeps the OpenEnv-style contract stable even when a backing service is missing.
|
| 201 |
+
|
| 202 |
---
|
| 203 |
|
| 204 |
### `commit` — Cost: **0.0**
|
|
|
|
| 505 |
CostAwareToolEnv/
|
| 506 |
│
|
| 507 |
├── app.py
|
| 508 |
+
│ The FastAPI web server. Handles /reset, /step, /health, /tools, /web.
|
| 509 |
│ Multi-session: each /reset returns a session_id used in /step.
|
| 510 |
+
│ Lazily loads the dataset and exposes the canonical tool manifest.
|
| 511 |
│
|
| 512 |
├── openenv.yaml
|
| 513 |
│ Deployment spec for the OpenEnv competition framework.
|
|
|
|
| 554 |
│ Returns flat List[Dict] with 'domain' key on each item.
|
| 555 |
│
|
| 556 |
├── tools/
|
| 557 |
+
│ ├── runtime.py Tool catalog, validation, and explicit dispatch
|
| 558 |
+
│ ├── __init__.py build_tool_registry() + tool manifest helpers
|
| 559 |
│ ├── ceramic_search.py make_search_tool() factory wrapping CeramicClient
|
| 560 |
│ ├── wiki_lookup.py Wikipedia REST API, first paragraph
|
| 561 |
+
│ ├── calculator.py Safe AST-based math eval with comparisons
|
| 562 |
+
│ ├── code_executor.py Sandboxed exec with blocked imports and dunder escapes
|
| 563 |
│ ├── llm_reason.py Together AI API, graceful fallback
|
| 564 |
│ └── commit.py Pass-through; grading is in environment.py
|
| 565 |
│
|
app.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
"""FastAPI server for CostAwareToolEnv.
|
| 2 |
|
| 3 |
Exposes the OpenEnv standard endpoints:
|
| 4 |
-
POST /reset
|
| 5 |
-
POST /step
|
| 6 |
-
GET /health
|
| 7 |
-
GET /
|
| 8 |
-
GET /
|
|
|
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
|
|
|
|
| 12 |
import os
|
| 13 |
import uuid
|
| 14 |
from contextlib import asynccontextmanager
|
| 15 |
-
from typing import Any, Dict, Optional
|
| 16 |
|
| 17 |
from fastapi import FastAPI, HTTPException
|
| 18 |
from fastapi.responses import HTMLResponse
|
|
@@ -22,13 +24,9 @@ from data.loader import load_all
|
|
| 22 |
from env.config import EnvConfig
|
| 23 |
from env.environment import CostAwareToolEnvironment
|
| 24 |
from env.models import OrchestratorAction
|
| 25 |
-
from tools import build_tool_registry
|
| 26 |
|
| 27 |
|
| 28 |
-
# ---------------------------------------------------------------------------
|
| 29 |
-
# Request / response wrappers
|
| 30 |
-
# ---------------------------------------------------------------------------
|
| 31 |
-
|
| 32 |
class ResetRequest(BaseModel):
|
| 33 |
seed: Optional[int] = None
|
| 34 |
config_overrides: Optional[Dict[str, Any]] = None
|
|
@@ -36,27 +34,140 @@ class ResetRequest(BaseModel):
|
|
| 36 |
|
| 37 |
class StepRequest(BaseModel):
|
| 38 |
tool_id: str
|
| 39 |
-
query:
|
| 40 |
-
expression:
|
| 41 |
code_snippet: Optional[str] = None
|
| 42 |
-
answer:
|
| 43 |
-
metadata:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
# App factory
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
sessions: Dict[str, CostAwareToolEnvironment] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@asynccontextmanager
|
| 62 |
async def lifespan(app: FastAPI):
|
|
@@ -74,28 +185,38 @@ def create_app() -> FastAPI:
|
|
| 74 |
def health():
|
| 75 |
return {"status": "ok"}
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
@app.post("/reset")
|
| 78 |
def reset(req: ResetRequest):
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
env =
|
| 85 |
obs, state = env.reset(seed=req.seed)
|
| 86 |
|
| 87 |
session_id = str(uuid.uuid4())
|
| 88 |
sessions[session_id] = env
|
| 89 |
|
| 90 |
return {
|
| 91 |
-
"session_id":
|
| 92 |
"observation": obs.model_dump(),
|
| 93 |
-
"state":
|
| 94 |
}
|
| 95 |
|
| 96 |
@app.post("/step")
|
| 97 |
def step(req: StepRequest, session_id: Optional[str] = None):
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
action = OrchestratorAction(
|
| 100 |
tool_id=req.tool_id,
|
| 101 |
query=req.query or "",
|
|
@@ -111,98 +232,21 @@ def create_app() -> FastAPI:
|
|
| 111 |
except ValueError as exc:
|
| 112 |
raise HTTPException(status_code=422, detail=str(exc))
|
| 113 |
|
| 114 |
-
# Clean up finished sessions
|
| 115 |
if done and session_id and session_id in sessions:
|
| 116 |
del sessions[session_id]
|
| 117 |
|
| 118 |
return {
|
| 119 |
"observation": obs.model_dump(),
|
| 120 |
-
"reward":
|
| 121 |
-
"done":
|
| 122 |
-
"state":
|
| 123 |
}
|
| 124 |
|
| 125 |
@app.get("/web", response_class=HTMLResponse)
|
| 126 |
def web_ui():
|
| 127 |
-
return
|
| 128 |
|
| 129 |
return app
|
| 130 |
|
| 131 |
|
| 132 |
app = create_app()
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# ---------------------------------------------------------------------------
|
| 136 |
-
# Demo UI
|
| 137 |
-
# ---------------------------------------------------------------------------
|
| 138 |
-
|
| 139 |
-
_DEMO_HTML = """<!DOCTYPE html>
|
| 140 |
-
<html lang="en">
|
| 141 |
-
<head>
|
| 142 |
-
<meta charset="UTF-8">
|
| 143 |
-
<title>CostAwareToolEnv</title>
|
| 144 |
-
<style>
|
| 145 |
-
body { font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }
|
| 146 |
-
h1 { color: #333; }
|
| 147 |
-
pre { background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }
|
| 148 |
-
button { padding: 8px 16px; margin: 4px; cursor: pointer; }
|
| 149 |
-
input, select, textarea { width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }
|
| 150 |
-
label { font-weight: bold; }
|
| 151 |
-
.tool-btn { background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }
|
| 152 |
-
.tool-btn:hover { background: #cfe1ff; }
|
| 153 |
-
#log { max-height: 480px; overflow-y: auto; }
|
| 154 |
-
</style>
|
| 155 |
-
</head>
|
| 156 |
-
<body>
|
| 157 |
-
<h1>CostAwareToolEnv</h1>
|
| 158 |
-
<p>Multi-tool cost-aware RL environment — AgentX / OpenEnv</p>
|
| 159 |
-
|
| 160 |
-
<button onclick="doReset()">Reset Episode</button>
|
| 161 |
-
<hr>
|
| 162 |
-
<label>Tool:</label>
|
| 163 |
-
<select id="tool">
|
| 164 |
-
<option value="ceramic_search">ceramic_search (cost 1.0) — Web retrieval</option>
|
| 165 |
-
<option value="wiki_lookup">wiki_lookup (cost 0.5) — Wikipedia</option>
|
| 166 |
-
<option value="calculator">calculator (cost 0.1) — Arithmetic / math</option>
|
| 167 |
-
<option value="code_executor">code_executor (cost 0.3) — Python execution</option>
|
| 168 |
-
<option value="llm_reason">llm_reason (cost 2.0) — LLM chain-of-thought</option>
|
| 169 |
-
<option value="commit">commit (cost 0.0) — Submit answer</option>
|
| 170 |
-
</select>
|
| 171 |
-
<label>Query / Expression / Code / Answer:</label>
|
| 172 |
-
<textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
|
| 173 |
-
<button class="tool-btn" onclick="doStep()">Step</button>
|
| 174 |
-
<hr>
|
| 175 |
-
<pre id="log">Click "Reset Episode" to start.</pre>
|
| 176 |
-
|
| 177 |
-
<script>
|
| 178 |
-
const log = document.getElementById('log');
|
| 179 |
-
let sessionId = null;
|
| 180 |
-
|
| 181 |
-
function append(text) { log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }
|
| 182 |
-
|
| 183 |
-
async function doReset() {
|
| 184 |
-
log.textContent = '';
|
| 185 |
-
const res = await fetch('/reset', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({seed: 42}) });
|
| 186 |
-
const data = await res.json();
|
| 187 |
-
sessionId = data.session_id || null;
|
| 188 |
-
append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
async function doStep() {
|
| 192 |
-
const tool_id = document.getElementById('tool').value;
|
| 193 |
-
const input = document.getElementById('query').value;
|
| 194 |
-
const body = { tool_id };
|
| 195 |
-
if (tool_id === 'commit') body.answer = input;
|
| 196 |
-
else if (tool_id === 'calculator') body.expression = input;
|
| 197 |
-
else if (tool_id === 'code_executor') body.code_snippet = input;
|
| 198 |
-
else body.query = input;
|
| 199 |
-
|
| 200 |
-
const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
|
| 201 |
-
const res = await fetch(url, { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify(body) });
|
| 202 |
-
const data = await res.json();
|
| 203 |
-
append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
|
| 204 |
-
}
|
| 205 |
-
</script>
|
| 206 |
-
</body>
|
| 207 |
-
</html>
|
| 208 |
-
"""
|
|
|
|
| 1 |
"""FastAPI server for CostAwareToolEnv.
|
| 2 |
|
| 3 |
Exposes the OpenEnv standard endpoints:
|
| 4 |
+
POST /reset -> OrchestratorObservation + OrchestratorState
|
| 5 |
+
POST /step -> OrchestratorObservation + reward + done + state
|
| 6 |
+
GET /health -> {"status": "ok"}
|
| 7 |
+
GET /tools -> canonical tool manifest
|
| 8 |
+
GET /web -> simple demo UI
|
| 9 |
+
GET /docs -> OpenAPI (automatic)
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import copy
|
| 14 |
import os
|
| 15 |
import uuid
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 18 |
|
| 19 |
from fastapi import FastAPI, HTTPException
|
| 20 |
from fastapi.responses import HTMLResponse
|
|
|
|
| 24 |
from env.config import EnvConfig
|
| 25 |
from env.environment import CostAwareToolEnvironment
|
| 26 |
from env.models import OrchestratorAction
|
| 27 |
+
from tools import build_tool_catalog, build_tool_registry, catalog_as_dicts, validate_tool_costs
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class ResetRequest(BaseModel):
|
| 31 |
seed: Optional[int] = None
|
| 32 |
config_overrides: Optional[Dict[str, Any]] = None
|
|
|
|
| 34 |
|
| 35 |
class StepRequest(BaseModel):
|
| 36 |
tool_id: str
|
| 37 |
+
query: Optional[str] = None
|
| 38 |
+
expression: Optional[str] = None
|
| 39 |
code_snippet: Optional[str] = None
|
| 40 |
+
answer: Optional[str] = None
|
| 41 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _merge_config(base: EnvConfig, overrides: Optional[Dict[str, Any]]) -> EnvConfig:
|
| 45 |
+
cfg = copy.deepcopy(base)
|
| 46 |
+
if not overrides:
|
| 47 |
+
return cfg
|
| 48 |
+
|
| 49 |
+
for key, value in overrides.items():
|
| 50 |
+
if not hasattr(cfg, key):
|
| 51 |
+
raise ValueError(f"Unknown config override: {key}")
|
| 52 |
+
|
| 53 |
+
current = getattr(cfg, key)
|
| 54 |
+
if isinstance(current, dict) and isinstance(value, dict):
|
| 55 |
+
merged = copy.deepcopy(current)
|
| 56 |
+
merged.update(value)
|
| 57 |
+
setattr(cfg, key, merged)
|
| 58 |
+
else:
|
| 59 |
+
setattr(cfg, key, value)
|
| 60 |
+
return cfg
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _build_demo_html(tool_catalog: List[Any]) -> str:
|
| 64 |
+
tool_options = "\n".join(
|
| 65 |
+
f' <option value="{spec.tool_id}">{spec.label} (cost {spec.cost}) — {spec.purpose}</option>'
|
| 66 |
+
for spec in tool_catalog
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return f"""<!DOCTYPE html>
|
| 70 |
+
<html lang="en">
|
| 71 |
+
<head>
|
| 72 |
+
<meta charset="UTF-8">
|
| 73 |
+
<title>CostAwareToolEnv</title>
|
| 74 |
+
<style>
|
| 75 |
+
body {{ font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }}
|
| 76 |
+
h1 {{ color: #333; }}
|
| 77 |
+
pre {{ background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }}
|
| 78 |
+
button {{ padding: 8px 16px; margin: 4px; cursor: pointer; }}
|
| 79 |
+
input, select, textarea {{ width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }}
|
| 80 |
+
label {{ font-weight: bold; }}
|
| 81 |
+
.tool-btn {{ background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }}
|
| 82 |
+
.tool-btn:hover {{ background: #cfe1ff; }}
|
| 83 |
+
#log {{ max-height: 480px; overflow-y: auto; }}
|
| 84 |
+
</style>
|
| 85 |
+
</head>
|
| 86 |
+
<body>
|
| 87 |
+
<h1>CostAwareToolEnv</h1>
|
| 88 |
+
<p>Multi-tool cost-aware RL environment with explicit tool routing and sandboxed execution.</p>
|
| 89 |
+
|
| 90 |
+
<button onclick="doReset()">Reset Episode</button>
|
| 91 |
+
<hr>
|
| 92 |
+
<label>Tool:</label>
|
| 93 |
+
<select id="tool">
|
| 94 |
+
{tool_options}
|
| 95 |
+
</select>
|
| 96 |
+
<label>Query / Expression / Code / Answer:</label>
|
| 97 |
+
<textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
|
| 98 |
+
<button class="tool-btn" onclick="doStep()">Step</button>
|
| 99 |
+
<hr>
|
| 100 |
+
<pre id="log">Click "Reset Episode" to start.</pre>
|
| 101 |
|
| 102 |
+
<script>
|
| 103 |
+
const log = document.getElementById('log');
|
| 104 |
+
let sessionId = null;
|
| 105 |
|
| 106 |
+
function append(text) {{ log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }}
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
async function doReset() {{
|
| 109 |
+
log.textContent = '';
|
| 110 |
+
const res = await fetch('/reset', {{ method: 'POST', headers: {{'Content-Type':'application/json'}}, body: JSON.stringify({{seed: 42}}) }});
|
| 111 |
+
const data = await res.json();
|
| 112 |
+
sessionId = data.session_id || null;
|
| 113 |
+
append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
|
| 114 |
+
}}
|
| 115 |
|
| 116 |
+
async function doStep() {{
|
| 117 |
+
const tool_id = document.getElementById('tool').value;
|
| 118 |
+
const input = document.getElementById('query').value;
|
| 119 |
+
const body = {{ tool_id }};
|
| 120 |
+
if (tool_id === 'commit') body.answer = input;
|
| 121 |
+
else if (tool_id === 'calculator') body.expression = input;
|
| 122 |
+
else if (tool_id === 'code_executor') body.code_snippet = input;
|
| 123 |
+
else body.query = input;
|
| 124 |
+
|
| 125 |
+
const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
|
| 126 |
+
const res = await fetch(url, {{ method: 'POST', headers: {{'Content-Type':'application/json'}}, body: JSON.stringify(body) }});
|
| 127 |
+
const data = await res.json();
|
| 128 |
+
append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
|
| 129 |
+
}}
|
| 130 |
+
</script>
|
| 131 |
+
</body>
|
| 132 |
+
</html>
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def create_app(
|
| 137 |
+
config: Optional[EnvConfig] = None,
|
| 138 |
+
tools: Optional[Dict[str, Any]] = None,
|
| 139 |
+
dataset: Optional[List[Dict[str, Any]]] = None,
|
| 140 |
+
load_dataset_fn: Callable[..., List[Dict[str, Any]]] = load_all,
|
| 141 |
+
build_registry_fn: Callable[[EnvConfig | None], Dict[str, Any]] = build_tool_registry,
|
| 142 |
+
) -> FastAPI:
|
| 143 |
+
base_config = config or EnvConfig()
|
| 144 |
+
validate_tool_costs(base_config)
|
| 145 |
+
|
| 146 |
+
dataset_cache = dataset
|
| 147 |
+
default_env: Optional[CostAwareToolEnvironment] = None
|
| 148 |
sessions: Dict[str, CostAwareToolEnvironment] = {}
|
| 149 |
+
tool_catalog = build_tool_catalog(base_config)
|
| 150 |
+
demo_html = _build_demo_html(tool_catalog)
|
| 151 |
+
|
| 152 |
+
def get_dataset() -> List[Dict[str, Any]]:
|
| 153 |
+
nonlocal dataset_cache
|
| 154 |
+
if dataset_cache is None:
|
| 155 |
+
dataset_cache = load_dataset_fn(split=base_config.data_split, max_per_domain=200)
|
| 156 |
+
return dataset_cache
|
| 157 |
+
|
| 158 |
+
def make_env(effective_config: EnvConfig) -> CostAwareToolEnvironment:
|
| 159 |
+
registry = tools if tools is not None else build_registry_fn(effective_config)
|
| 160 |
+
return CostAwareToolEnvironment(
|
| 161 |
+
config=effective_config,
|
| 162 |
+
tools=registry,
|
| 163 |
+
dataset=get_dataset(),
|
| 164 |
+
)
|
| 165 |
|
| 166 |
+
def get_default_env() -> CostAwareToolEnvironment:
|
| 167 |
+
nonlocal default_env
|
| 168 |
+
if default_env is None:
|
| 169 |
+
default_env = make_env(base_config)
|
| 170 |
+
return default_env
|
| 171 |
|
| 172 |
@asynccontextmanager
|
| 173 |
async def lifespan(app: FastAPI):
|
|
|
|
| 185 |
def health():
|
| 186 |
return {"status": "ok"}
|
| 187 |
|
| 188 |
+
@app.get("/tools")
|
| 189 |
+
def tools_manifest():
|
| 190 |
+
return catalog_as_dicts(base_config)
|
| 191 |
+
|
| 192 |
@app.post("/reset")
|
| 193 |
def reset(req: ResetRequest):
|
| 194 |
+
try:
|
| 195 |
+
cfg = _merge_config(base_config, req.config_overrides)
|
| 196 |
+
except ValueError as exc:
|
| 197 |
+
raise HTTPException(status_code=422, detail=str(exc))
|
| 198 |
+
|
| 199 |
+
env = make_env(cfg)
|
| 200 |
obs, state = env.reset(seed=req.seed)
|
| 201 |
|
| 202 |
session_id = str(uuid.uuid4())
|
| 203 |
sessions[session_id] = env
|
| 204 |
|
| 205 |
return {
|
| 206 |
+
"session_id": session_id,
|
| 207 |
"observation": obs.model_dump(),
|
| 208 |
+
"state": state.model_dump(),
|
| 209 |
}
|
| 210 |
|
| 211 |
@app.post("/step")
|
| 212 |
def step(req: StepRequest, session_id: Optional[str] = None):
|
| 213 |
+
if session_id is None:
|
| 214 |
+
env = get_default_env()
|
| 215 |
+
else:
|
| 216 |
+
env = sessions.get(session_id)
|
| 217 |
+
if env is None:
|
| 218 |
+
raise HTTPException(status_code=404, detail="Unknown session_id")
|
| 219 |
+
|
| 220 |
action = OrchestratorAction(
|
| 221 |
tool_id=req.tool_id,
|
| 222 |
query=req.query or "",
|
|
|
|
| 232 |
except ValueError as exc:
|
| 233 |
raise HTTPException(status_code=422, detail=str(exc))
|
| 234 |
|
|
|
|
| 235 |
if done and session_id and session_id in sessions:
|
| 236 |
del sessions[session_id]
|
| 237 |
|
| 238 |
return {
|
| 239 |
"observation": obs.model_dump(),
|
| 240 |
+
"reward": reward,
|
| 241 |
+
"done": done,
|
| 242 |
+
"state": state.model_dump(),
|
| 243 |
}
|
| 244 |
|
| 245 |
@app.get("/web", response_class=HTMLResponse)
|
| 246 |
def web_ui():
|
| 247 |
+
return demo_html
|
| 248 |
|
| 249 |
return app
|
| 250 |
|
| 251 |
|
| 252 |
app = create_app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client.py
CHANGED
|
@@ -1,123 +1,4 @@
|
|
| 1 |
-
"""Compatibility shim
|
| 2 |
-
|
| 3 |
-
Mirrors the SearchEconomicsEnv CeramicClient interface so the two
|
| 4 |
-
environments share the same retrieval backend.
|
| 5 |
-
|
| 6 |
-
Priority for the API key:
|
| 7 |
-
1. CERAMIC_API_KEY env var
|
| 8 |
-
2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility)
|
| 9 |
-
3. Falls back to FallbackCeramicClient (offline, deterministic)
|
| 10 |
-
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
-
import
|
| 14 |
-
import os
|
| 15 |
-
import time
|
| 16 |
-
from dataclasses import dataclass, field
|
| 17 |
-
from typing import List, Optional
|
| 18 |
-
|
| 19 |
-
import httpx
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# ---------------------------------------------------------------------------
|
| 23 |
-
# Result model
|
| 24 |
-
# ---------------------------------------------------------------------------
|
| 25 |
-
|
| 26 |
-
@dataclass
|
| 27 |
-
class SearchResult:
|
| 28 |
-
title: str
|
| 29 |
-
url: str
|
| 30 |
-
description: str
|
| 31 |
-
score: float = 0.0
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
# Live client
|
| 36 |
-
# ---------------------------------------------------------------------------
|
| 37 |
-
|
| 38 |
-
class CeramicClient:
|
| 39 |
-
"""Thin wrapper around the Ceramic search API."""
|
| 40 |
-
|
| 41 |
-
BASE_URL = "https://api.ceramic.ai/v1"
|
| 42 |
-
|
| 43 |
-
def __init__(self, api_key: str):
|
| 44 |
-
self._key = api_key
|
| 45 |
-
self._client = httpx.Client(
|
| 46 |
-
headers={"Authorization": f"Bearer {api_key}"},
|
| 47 |
-
timeout=10.0,
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 51 |
-
resp = self._client.post(
|
| 52 |
-
f"{self.BASE_URL}/search",
|
| 53 |
-
json={"query": query, "top_k": top_k},
|
| 54 |
-
)
|
| 55 |
-
resp.raise_for_status()
|
| 56 |
-
data = resp.json()
|
| 57 |
-
results = []
|
| 58 |
-
for item in data.get("results", []):
|
| 59 |
-
results.append(SearchResult(
|
| 60 |
-
title=item.get("title", ""),
|
| 61 |
-
url=item.get("url", ""),
|
| 62 |
-
description=item.get("description", ""),
|
| 63 |
-
score=float(item.get("score", 0.0)),
|
| 64 |
-
))
|
| 65 |
-
return results
|
| 66 |
-
|
| 67 |
-
def close(self):
|
| 68 |
-
self._client.close()
|
| 69 |
-
|
| 70 |
-
def __enter__(self):
|
| 71 |
-
return self
|
| 72 |
-
|
| 73 |
-
def __exit__(self, *args):
|
| 74 |
-
self.close()
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
# ---------------------------------------------------------------------------
|
| 78 |
-
# Offline fallback
|
| 79 |
-
# ---------------------------------------------------------------------------
|
| 80 |
-
|
| 81 |
-
class FallbackCeramicClient:
|
| 82 |
-
"""Deterministic offline client — used when no API key is set."""
|
| 83 |
-
|
| 84 |
-
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 85 |
-
# Stable hash → reproducible fake results per query
|
| 86 |
-
h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
|
| 87 |
-
results = []
|
| 88 |
-
for i in range(min(top_k, 3)):
|
| 89 |
-
seed = (h + i) % 10_000
|
| 90 |
-
results.append(SearchResult(
|
| 91 |
-
title=f"Result {seed}: {query[:40]}",
|
| 92 |
-
url=f"https://fallback.example.com/doc/{seed}",
|
| 93 |
-
description=f"Offline fallback result #{i+1} for query: {query}",
|
| 94 |
-
score=round(0.9 - i * 0.15, 3),
|
| 95 |
-
))
|
| 96 |
-
return results
|
| 97 |
-
|
| 98 |
-
def close(self):
|
| 99 |
-
pass
|
| 100 |
-
|
| 101 |
-
def __enter__(self):
|
| 102 |
-
return self
|
| 103 |
-
|
| 104 |
-
def __exit__(self, *args):
|
| 105 |
-
pass
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ---------------------------------------------------------------------------
|
| 109 |
-
# Factory
|
| 110 |
-
# ---------------------------------------------------------------------------
|
| 111 |
-
|
| 112 |
-
_DEFAULT_KEY = "cer_sk_live_543fe74e79df_eyJvcmdfaWQiOiJvcmdfMDFLTlpINkU5RVNDTUowUUoyREpINFZWWEYiLCJrZXlfaWQiOiI1NDNmZTc0ZTc5ZGYifQ.k8I4Aljsk29y4Uki37Wxfd7QZHs40XSJVNBNnfksCtM"
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
|
| 116 |
-
key = (
|
| 117 |
-
os.environ.get("CERAMIC_API_KEY")
|
| 118 |
-
or os.environ.get("SEE_CERAMIC_API_KEY")
|
| 119 |
-
or _DEFAULT_KEY
|
| 120 |
-
)
|
| 121 |
-
if key:
|
| 122 |
-
return CeramicClient(api_key=key)
|
| 123 |
-
return FallbackCeramicClient()
|
|
|
|
| 1 |
+
"""Compatibility shim for the legacy top-level Ceramic client import path."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
from ceramic.client import * # noqa: F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env/environment.py
CHANGED
|
@@ -24,6 +24,7 @@ from .models import (
|
|
| 24 |
TOOL_IDS,
|
| 25 |
)
|
| 26 |
from .reward import commit_reward, step_reward
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class CostAwareToolEnvironment:
|
|
@@ -95,6 +96,8 @@ class CostAwareToolEnvironment:
|
|
| 95 |
) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
|
| 96 |
if self._episode_done:
|
| 97 |
raise RuntimeError("Episode is done. Call reset() first.")
|
|
|
|
|
|
|
| 98 |
|
| 99 |
tool_id = action.tool_id
|
| 100 |
if tool_id not in TOOL_IDS:
|
|
@@ -160,30 +163,17 @@ class CostAwareToolEnvironment:
|
|
| 160 |
if budget_after > state.total_budget:
|
| 161 |
r = config.incorrect_reward
|
| 162 |
self._episode_done = True
|
| 163 |
-
obs = self._make_obs(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
return obs, r, True, state
|
| 165 |
|
| 166 |
t0 = time.perf_counter()
|
| 167 |
-
|
| 168 |
-
if tool_fn is None:
|
| 169 |
-
tool_result = ToolResult(
|
| 170 |
-
tool_id=tool_id, cost=cost,
|
| 171 |
-
output="[Tool not available in this environment]",
|
| 172 |
-
latency_s=0.0,
|
| 173 |
-
error="not_available",
|
| 174 |
-
)
|
| 175 |
-
else:
|
| 176 |
-
try:
|
| 177 |
-
tool_result = tool_fn(action)
|
| 178 |
-
tool_result.cost = cost
|
| 179 |
-
except Exception as exc:
|
| 180 |
-
tool_result = ToolResult(
|
| 181 |
-
tool_id=tool_id, cost=cost,
|
| 182 |
-
output=f"[Error: {exc}]",
|
| 183 |
-
latency_s=time.perf_counter() - t0,
|
| 184 |
-
error=str(exc),
|
| 185 |
-
)
|
| 186 |
tool_result.latency_s = time.perf_counter() - t0
|
|
|
|
| 187 |
|
| 188 |
state.budget_spent = budget_after
|
| 189 |
state.budget_remaining_ratio = max(
|
|
@@ -234,6 +224,8 @@ class CostAwareToolEnvironment:
|
|
| 234 |
) -> OrchestratorObservation:
|
| 235 |
state = self._state
|
| 236 |
cfg = self.config
|
|
|
|
|
|
|
| 237 |
|
| 238 |
if 0 <= self._current_q_idx < len(self._questions):
|
| 239 |
q_entry = self._questions[self._current_q_idx]
|
|
|
|
| 24 |
TOOL_IDS,
|
| 25 |
)
|
| 26 |
from .reward import commit_reward, step_reward
|
| 27 |
+
from tools import dispatch_tool
|
| 28 |
|
| 29 |
|
| 30 |
class CostAwareToolEnvironment:
|
|
|
|
| 96 |
) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
|
| 97 |
if self._episode_done:
|
| 98 |
raise RuntimeError("Episode is done. Call reset() first.")
|
| 99 |
+
if self._state is None:
|
| 100 |
+
raise RuntimeError("Call reset() first.")
|
| 101 |
|
| 102 |
tool_id = action.tool_id
|
| 103 |
if tool_id not in TOOL_IDS:
|
|
|
|
| 163 |
if budget_after > state.total_budget:
|
| 164 |
r = config.incorrect_reward
|
| 165 |
self._episode_done = True
|
| 166 |
+
obs = self._make_obs(
|
| 167 |
+
reward=r,
|
| 168 |
+
question_done=True,
|
| 169 |
+
done=True,
|
| 170 |
+
)
|
| 171 |
return obs, r, True, state
|
| 172 |
|
| 173 |
t0 = time.perf_counter()
|
| 174 |
+
tool_result = dispatch_tool(tool_id, action, self.tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
tool_result.latency_s = time.perf_counter() - t0
|
| 176 |
+
tool_result.cost = cost
|
| 177 |
|
| 178 |
state.budget_spent = budget_after
|
| 179 |
state.budget_remaining_ratio = max(
|
|
|
|
| 224 |
) -> OrchestratorObservation:
|
| 225 |
state = self._state
|
| 226 |
cfg = self.config
|
| 227 |
+
if state is None:
|
| 228 |
+
raise RuntimeError("Call reset() first.")
|
| 229 |
|
| 230 |
if 0 <= self._current_q_idx < len(self._questions):
|
| 231 |
q_entry = self._questions[self._current_q_idx]
|
environment.py
CHANGED
|
@@ -1,307 +1,9 @@
|
|
| 1 |
-
"""Compatibility shim
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
- Agent picks a tool_id and optional query / code_snippet / answer.
|
| 7 |
-
- Environment dispatches to the appropriate tool, charges cost, appends
|
| 8 |
-
result to context_window, and returns the next observation + reward.
|
| 9 |
-
- Episode ends when budget is exhausted OR all questions are answered.
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
-
import
|
| 14 |
-
import uuid
|
| 15 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 16 |
-
|
| 17 |
-
from env.answer_grading import grade
|
| 18 |
-
from env.config import EnvConfig
|
| 19 |
-
from env.models import (
|
| 20 |
-
OrchestratorAction,
|
| 21 |
-
OrchestratorObservation,
|
| 22 |
-
OrchestratorState,
|
| 23 |
-
ToolResult,
|
| 24 |
-
TOOL_IDS,
|
| 25 |
-
)
|
| 26 |
-
from env.reward import commit_reward, step_reward
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class CostAwareToolEnvironment:
|
| 30 |
-
"""
|
| 31 |
-
OpenEnv-compatible RL environment for multi-tool cost-aware QA.
|
| 32 |
-
|
| 33 |
-
Supports external tool injection so the server can wire in live
|
| 34 |
-
Ceramic, code executor, etc. Tools are callables with signature:
|
| 35 |
-
tool_fn(action: OrchestratorAction) -> ToolResult
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
def __init__(
|
| 39 |
-
self,
|
| 40 |
-
config: Optional[EnvConfig] = None,
|
| 41 |
-
tools: Optional[Dict[str, Any]] = None,
|
| 42 |
-
dataset: Optional[List[Dict[str, Any]]] = None,
|
| 43 |
-
):
|
| 44 |
-
self.config = config or EnvConfig()
|
| 45 |
-
self.tools = tools or {} # tool_id -> callable
|
| 46 |
-
self.dataset = dataset or [] # List of {question, answer, domain}
|
| 47 |
-
self._state: Optional[OrchestratorState] = None
|
| 48 |
-
self._questions: List[Dict[str, Any]] = []
|
| 49 |
-
self._current_q_idx: int = 0
|
| 50 |
-
self._context_window: List[str] = []
|
| 51 |
-
self._tools_used_this_q: List[str] = []
|
| 52 |
-
self._steps_this_q: int = 0
|
| 53 |
-
self._episode_done: bool = False
|
| 54 |
-
|
| 55 |
-
# -----------------------------------------------------------------------
|
| 56 |
-
# Reset
|
| 57 |
-
# -----------------------------------------------------------------------
|
| 58 |
-
|
| 59 |
-
def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
|
| 60 |
-
import random
|
| 61 |
-
rng = random.Random(seed if seed is not None else self.config.seed)
|
| 62 |
-
|
| 63 |
-
# Sample questions according to domain_mix
|
| 64 |
-
questions = _sample_questions(self.dataset, self.config, rng)
|
| 65 |
-
self._questions = questions
|
| 66 |
-
self._current_q_idx = 0
|
| 67 |
-
self._episode_done = False
|
| 68 |
-
|
| 69 |
-
self._state = OrchestratorState(
|
| 70 |
-
episode_id=str(uuid.uuid4()),
|
| 71 |
-
total_budget=self.config.total_budget,
|
| 72 |
-
budget_spent=0,
|
| 73 |
-
questions_answered=0,
|
| 74 |
-
total_correct=0,
|
| 75 |
-
current_accuracy=0.0,
|
| 76 |
-
budget_remaining_ratio=1.0,
|
| 77 |
-
current_question_idx=0,
|
| 78 |
-
current_question_steps=0,
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
self._context_window = []
|
| 82 |
-
self._tools_used_this_q = []
|
| 83 |
-
self._steps_this_q = 0
|
| 84 |
-
|
| 85 |
-
obs = self._make_obs(reward=None, question_done=False, done=False)
|
| 86 |
-
return obs, self._state
|
| 87 |
-
|
| 88 |
-
# -----------------------------------------------------------------------
|
| 89 |
-
# Step
|
| 90 |
-
# -----------------------------------------------------------------------
|
| 91 |
-
|
| 92 |
-
def step(
|
| 93 |
-
self, action: OrchestratorAction
|
| 94 |
-
) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
|
| 95 |
-
if self._episode_done:
|
| 96 |
-
raise RuntimeError("Episode is done. Call reset() first.")
|
| 97 |
-
|
| 98 |
-
tool_id = action.tool_id
|
| 99 |
-
if tool_id not in TOOL_IDS:
|
| 100 |
-
raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
|
| 101 |
-
|
| 102 |
-
state = self._state
|
| 103 |
-
config = self.config
|
| 104 |
-
|
| 105 |
-
# Guard against exhausted question list (can happen after last commit)
|
| 106 |
-
if self._current_q_idx >= len(self._questions):
|
| 107 |
-
self._episode_done = True
|
| 108 |
-
raise RuntimeError("Episode is done. Call reset() first.")
|
| 109 |
-
|
| 110 |
-
q_entry = self._questions[self._current_q_idx]
|
| 111 |
-
gold = q_entry["answer"]
|
| 112 |
-
|
| 113 |
-
# ---- Commit ---------------------------------------------------
|
| 114 |
-
if tool_id == "commit":
|
| 115 |
-
raw_pred = action.answer or ""
|
| 116 |
-
em, f1, quality = grade(raw_pred, gold)
|
| 117 |
-
|
| 118 |
-
r = commit_reward(
|
| 119 |
-
quality=quality,
|
| 120 |
-
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 121 |
-
config=config,
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
# Count correct
|
| 125 |
-
is_correct = (
|
| 126 |
-
em if config.grade_count_correct_mode == "em_only"
|
| 127 |
-
else (em or f1 >= config.f1_count_threshold)
|
| 128 |
-
)
|
| 129 |
-
state.questions_answered += 1
|
| 130 |
-
state.total_correct += int(is_correct)
|
| 131 |
-
state.current_accuracy = state.total_correct / state.questions_answered
|
| 132 |
-
|
| 133 |
-
# Advance to next question or end episode
|
| 134 |
-
self._current_q_idx += 1
|
| 135 |
-
self._context_window = []
|
| 136 |
-
self._tools_used_this_q = []
|
| 137 |
-
self._steps_this_q = 0
|
| 138 |
-
|
| 139 |
-
episode_done = (
|
| 140 |
-
self._current_q_idx >= len(self._questions)
|
| 141 |
-
or state.budget_spent >= state.total_budget
|
| 142 |
-
)
|
| 143 |
-
self._episode_done = episode_done
|
| 144 |
-
state.current_question_idx = self._current_q_idx
|
| 145 |
-
state.current_question_steps = 0
|
| 146 |
-
|
| 147 |
-
obs = self._make_obs(
|
| 148 |
-
reward=r,
|
| 149 |
-
question_done=True,
|
| 150 |
-
done=episode_done,
|
| 151 |
-
last_tool_result=ToolResult(
|
| 152 |
-
tool_id="commit", cost=0,
|
| 153 |
-
output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
|
| 154 |
-
),
|
| 155 |
-
)
|
| 156 |
-
return obs, r, episode_done, state
|
| 157 |
-
|
| 158 |
-
# ---- Tool call ------------------------------------------------
|
| 159 |
-
cost = config.tool_costs.get(tool_id, 0)
|
| 160 |
-
budget_after = state.budget_spent + cost
|
| 161 |
-
|
| 162 |
-
# If over budget, force commit penalty
|
| 163 |
-
if budget_after > state.total_budget:
|
| 164 |
-
r = config.incorrect_reward
|
| 165 |
-
self._episode_done = True
|
| 166 |
-
obs = self._make_obs(reward=r, question_done=True, done=True)
|
| 167 |
-
return obs, r, True, state
|
| 168 |
-
|
| 169 |
-
# Dispatch tool
|
| 170 |
-
t0 = time.perf_counter()
|
| 171 |
-
tool_fn = self.tools.get(tool_id)
|
| 172 |
-
if tool_fn is None:
|
| 173 |
-
tool_result = ToolResult(
|
| 174 |
-
tool_id=tool_id, cost=cost,
|
| 175 |
-
output="[Tool not available in this environment]",
|
| 176 |
-
latency_s=0.0,
|
| 177 |
-
error="not_available",
|
| 178 |
-
)
|
| 179 |
-
else:
|
| 180 |
-
try:
|
| 181 |
-
tool_result = tool_fn(action)
|
| 182 |
-
tool_result.cost = cost
|
| 183 |
-
except Exception as exc:
|
| 184 |
-
tool_result = ToolResult(
|
| 185 |
-
tool_id=tool_id, cost=cost,
|
| 186 |
-
output=f"[Error: {exc}]",
|
| 187 |
-
latency_s=time.perf_counter() - t0,
|
| 188 |
-
error=str(exc),
|
| 189 |
-
)
|
| 190 |
-
tool_result.latency_s = time.perf_counter() - t0
|
| 191 |
-
|
| 192 |
-
# Charge cost and update state
|
| 193 |
-
state.budget_spent = budget_after
|
| 194 |
-
state.budget_remaining_ratio = max(
|
| 195 |
-
0.0, (state.total_budget - state.budget_spent) / state.total_budget
|
| 196 |
-
)
|
| 197 |
-
state.step_count += 1
|
| 198 |
-
state.current_question_steps += 1
|
| 199 |
-
self._steps_this_q += 1
|
| 200 |
-
self._tools_used_this_q.append(tool_id)
|
| 201 |
-
self._context_window.append(f"[{tool_id}] {tool_result.output}")
|
| 202 |
-
|
| 203 |
-
r = step_reward(tool_id, config)
|
| 204 |
-
|
| 205 |
-
# Auto-commit if max steps reached
|
| 206 |
-
question_done = self._steps_this_q >= config.max_steps_per_question
|
| 207 |
-
episode_done = (
|
| 208 |
-
state.budget_spent >= state.total_budget
|
| 209 |
-
or (question_done and self._current_q_idx + 1 >= len(self._questions))
|
| 210 |
-
)
|
| 211 |
-
if question_done and not episode_done:
|
| 212 |
-
self._current_q_idx += 1
|
| 213 |
-
state.questions_answered += 1
|
| 214 |
-
self._context_window = []
|
| 215 |
-
self._tools_used_this_q = []
|
| 216 |
-
self._steps_this_q = 0
|
| 217 |
-
state.current_question_idx = self._current_q_idx
|
| 218 |
-
state.current_question_steps = 0
|
| 219 |
-
|
| 220 |
-
self._episode_done = episode_done
|
| 221 |
-
|
| 222 |
-
obs = self._make_obs(
|
| 223 |
-
reward=r,
|
| 224 |
-
question_done=question_done,
|
| 225 |
-
done=episode_done,
|
| 226 |
-
last_tool_result=tool_result,
|
| 227 |
-
)
|
| 228 |
-
return obs, r, episode_done, state
|
| 229 |
-
|
| 230 |
-
# -----------------------------------------------------------------------
|
| 231 |
-
# Internal helpers
|
| 232 |
-
# -----------------------------------------------------------------------
|
| 233 |
-
|
| 234 |
-
def _make_obs(
|
| 235 |
-
self,
|
| 236 |
-
reward: Optional[float],
|
| 237 |
-
question_done: bool,
|
| 238 |
-
done: bool,
|
| 239 |
-
last_tool_result: Optional[ToolResult] = None,
|
| 240 |
-
) -> OrchestratorObservation:
|
| 241 |
-
state = self._state
|
| 242 |
-
cfg = self.config
|
| 243 |
-
|
| 244 |
-
if 0 <= self._current_q_idx < len(self._questions):
|
| 245 |
-
q_entry = self._questions[self._current_q_idx]
|
| 246 |
-
elif self._questions:
|
| 247 |
-
# Episode finished — repeat last question info (obs is terminal anyway)
|
| 248 |
-
q_entry = self._questions[-1]
|
| 249 |
-
else:
|
| 250 |
-
q_entry = {"question": "", "answer": "", "domain": ""}
|
| 251 |
-
|
| 252 |
-
return OrchestratorObservation(
|
| 253 |
-
question=q_entry.get("question", ""),
|
| 254 |
-
question_idx=self._current_q_idx,
|
| 255 |
-
domain=q_entry.get("domain", ""),
|
| 256 |
-
question_embedding=[], # populated by server if needed
|
| 257 |
-
total_budget=cfg.total_budget,
|
| 258 |
-
budget_spent=state.budget_spent,
|
| 259 |
-
budget_remaining=state.total_budget - state.budget_spent,
|
| 260 |
-
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 261 |
-
tools_used_this_question=list(self._tools_used_this_q),
|
| 262 |
-
steps_used_this_question=self._steps_this_q,
|
| 263 |
-
max_steps_per_question=cfg.max_steps_per_question,
|
| 264 |
-
last_tool_result=last_tool_result,
|
| 265 |
-
context_window=list(self._context_window),
|
| 266 |
-
step_idx=state.step_count,
|
| 267 |
-
questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
|
| 268 |
-
questions_answered=state.questions_answered,
|
| 269 |
-
accuracy_so_far=state.current_accuracy,
|
| 270 |
-
question_done=question_done,
|
| 271 |
-
done=done,
|
| 272 |
-
reward=reward,
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
# ---------------------------------------------------------------------------
|
| 277 |
-
# Dataset sampling helper
|
| 278 |
-
# ---------------------------------------------------------------------------
|
| 279 |
-
|
| 280 |
-
def _sample_questions(
|
| 281 |
-
dataset: List[Dict[str, Any]],
|
| 282 |
-
config: EnvConfig,
|
| 283 |
-
rng: Any,
|
| 284 |
-
) -> List[Dict[str, Any]]:
|
| 285 |
-
"""Sample `config.num_questions` questions according to domain_mix."""
|
| 286 |
-
by_domain: Dict[str, List[Dict]] = {}
|
| 287 |
-
for item in dataset:
|
| 288 |
-
d = item.get("domain", "hotpotqa")
|
| 289 |
-
by_domain.setdefault(d, []).append(item)
|
| 290 |
-
|
| 291 |
-
selected = []
|
| 292 |
-
for domain, frac in config.domain_mix.items():
|
| 293 |
-
n = round(config.num_questions * frac)
|
| 294 |
-
pool = by_domain.get(domain, [])
|
| 295 |
-
if pool and n > 0:
|
| 296 |
-
selected.extend(rng.sample(pool, min(n, len(pool))))
|
| 297 |
-
|
| 298 |
-
# Guarantee at least num_questions items by filling from the full dataset
|
| 299 |
-
if len(selected) < config.num_questions and dataset:
|
| 300 |
-
remaining = [d for d in dataset if d not in selected]
|
| 301 |
-
rng.shuffle(remaining)
|
| 302 |
-
selected.extend(remaining[: config.num_questions - len(selected)])
|
| 303 |
-
|
| 304 |
-
if config.shuffle_questions:
|
| 305 |
-
rng.shuffle(selected)
|
| 306 |
-
|
| 307 |
-
return selected[: config.num_questions]
|
|
|
|
| 1 |
+
"""Compatibility shim for the legacy top-level import path.
|
| 2 |
|
| 3 |
+
The real environment implementation lives in :mod:`env.environment`.
|
| 4 |
+
This module stays intentionally thin so the two orchestrator entrypoints
|
| 5 |
+
cannot drift apart again.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
from env.environment import * # noqa: F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -6,3 +6,4 @@ datasets>=2.18.0
|
|
| 6 |
httpx>=0.27.0
|
| 7 |
requests>=2.31.0
|
| 8 |
together>=1.2.0
|
|
|
|
|
|
| 6 |
httpx>=0.27.0
|
| 7 |
requests>=2.31.0
|
| 8 |
together>=1.2.0
|
| 9 |
+
pytest>=8.0.0
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from fastapi.testclient import TestClient
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 10 |
+
|
| 11 |
+
from app import create_app
|
| 12 |
+
from env.config import EnvConfig
|
| 13 |
+
from env.models import TOOL_IDS, OrchestratorAction, ToolResult
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_stub_registry():
|
| 17 |
+
def make_tool(tool_id: str):
|
| 18 |
+
def _tool(action: OrchestratorAction) -> ToolResult:
|
| 19 |
+
payload = action.query or action.expression or action.code_snippet or action.answer or ""
|
| 20 |
+
return ToolResult(tool_id=tool_id, output=f"{tool_id}:{payload}".rstrip(":"))
|
| 21 |
+
|
| 22 |
+
return _tool
|
| 23 |
+
|
| 24 |
+
return {tool_id: make_tool(tool_id) for tool_id in TOOL_IDS}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture()
|
| 28 |
+
def sample_dataset():
|
| 29 |
+
return [
|
| 30 |
+
{"question": "What is 2 + 2?", "answer": "4", "domain": "math"},
|
| 31 |
+
{"question": "What is 3 + 1?", "answer": "4", "domain": "math"},
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture()
|
| 36 |
+
def app_client(sample_dataset):
|
| 37 |
+
cfg = EnvConfig(
|
| 38 |
+
num_questions=2,
|
| 39 |
+
total_budget=5.0,
|
| 40 |
+
max_steps_per_question=4,
|
| 41 |
+
shuffle_questions=False,
|
| 42 |
+
domain_mix={"math": 1.0},
|
| 43 |
+
)
|
| 44 |
+
app = create_app(config=cfg, tools=make_stub_registry(), dataset=sample_dataset)
|
| 45 |
+
return TestClient(app)
|
tests/test_app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_reset_returns_clean_session(app_client):
|
| 5 |
+
response = app_client.post("/reset", json={"seed": 42})
|
| 6 |
+
assert response.status_code == 200
|
| 7 |
+
|
| 8 |
+
payload = response.json()
|
| 9 |
+
assert payload["session_id"]
|
| 10 |
+
|
| 11 |
+
observation = payload["observation"]
|
| 12 |
+
state = payload["state"]
|
| 13 |
+
assert observation["budget_spent"] == 0
|
| 14 |
+
assert observation["question_idx"] == 0
|
| 15 |
+
assert observation["done"] is False
|
| 16 |
+
assert observation["tools_used_this_question"] == []
|
| 17 |
+
assert state["budget_spent"] == 0
|
| 18 |
+
assert state["current_question_idx"] == 0
|
| 19 |
+
assert state["step_count"] == 0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_step_charges_the_right_cost(app_client):
|
| 23 |
+
reset = app_client.post("/reset", json={"seed": 42})
|
| 24 |
+
session_id = reset.json()["session_id"]
|
| 25 |
+
|
| 26 |
+
response = app_client.post(
|
| 27 |
+
"/step",
|
| 28 |
+
params={"session_id": session_id},
|
| 29 |
+
json={"tool_id": "calculator", "expression": "2 + 2"},
|
| 30 |
+
)
|
| 31 |
+
assert response.status_code == 200
|
| 32 |
+
|
| 33 |
+
payload = response.json()
|
| 34 |
+
assert payload["reward"] == -0.1
|
| 35 |
+
assert payload["state"]["budget_spent"] == 0.1
|
| 36 |
+
assert payload["observation"]["last_tool_result"]["tool_id"] == "calculator"
|
| 37 |
+
assert payload["observation"]["last_tool_result"]["cost"] == 0.1
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_commit_advances_question_state_correctly(app_client):
|
| 41 |
+
reset = app_client.post("/reset", json={"seed": 42})
|
| 42 |
+
session_id = reset.json()["session_id"]
|
| 43 |
+
|
| 44 |
+
response = app_client.post(
|
| 45 |
+
"/step",
|
| 46 |
+
params={"session_id": session_id},
|
| 47 |
+
json={"tool_id": "commit", "answer": "4"},
|
| 48 |
+
)
|
| 49 |
+
assert response.status_code == 200
|
| 50 |
+
|
| 51 |
+
payload = response.json()
|
| 52 |
+
assert payload["done"] is False
|
| 53 |
+
assert payload["state"]["questions_answered"] == 1
|
| 54 |
+
assert payload["state"]["current_question_idx"] == 1
|
| 55 |
+
assert payload["observation"]["question_done"] is True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_episode_termination_and_cleanup_behave_as_expected(app_client):
|
| 59 |
+
reset = app_client.post("/reset", json={"seed": 42})
|
| 60 |
+
session_id = reset.json()["session_id"]
|
| 61 |
+
|
| 62 |
+
first = app_client.post(
|
| 63 |
+
"/step",
|
| 64 |
+
params={"session_id": session_id},
|
| 65 |
+
json={"tool_id": "commit", "answer": "4"},
|
| 66 |
+
)
|
| 67 |
+
assert first.status_code == 200
|
| 68 |
+
|
| 69 |
+
second = app_client.post(
|
| 70 |
+
"/step",
|
| 71 |
+
params={"session_id": session_id},
|
| 72 |
+
json={"tool_id": "commit", "answer": "4"},
|
| 73 |
+
)
|
| 74 |
+
assert second.status_code == 200
|
| 75 |
+
assert second.json()["done"] is True
|
| 76 |
+
|
| 77 |
+
follow_up = app_client.post(
|
| 78 |
+
"/step",
|
| 79 |
+
params={"session_id": session_id},
|
| 80 |
+
json={"tool_id": "commit", "answer": "4"},
|
| 81 |
+
)
|
| 82 |
+
assert follow_up.status_code == 404
|
tests/test_code_executor.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from env.models import OrchestratorAction
|
| 4 |
+
from tools.code_executor import code_executor_tool
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_code_executor_empty_input():
|
| 8 |
+
result = code_executor_tool(OrchestratorAction(tool_id="code_executor"))
|
| 9 |
+
assert result.error == "empty"
|
| 10 |
+
assert "No code provided" in result.output
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_code_executor_runtime_error():
|
| 14 |
+
result = code_executor_tool(
|
| 15 |
+
OrchestratorAction(tool_id="code_executor", code_snippet="print(1 / 0)")
|
| 16 |
+
)
|
| 17 |
+
assert result.error is not None
|
| 18 |
+
assert "division by zero" in result.output.lower()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_code_executor_blocks_imports():
|
| 22 |
+
result = code_executor_tool(
|
| 23 |
+
OrchestratorAction(tool_id="code_executor", code_snippet="import os\nprint('hi')")
|
| 24 |
+
)
|
| 25 |
+
assert result.error == "sandbox_violation"
|
| 26 |
+
assert "import statements are blocked" in result.output
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_code_executor_blocks_unsafe_builtins():
|
| 30 |
+
result = code_executor_tool(
|
| 31 |
+
OrchestratorAction(tool_id="code_executor", code_snippet="open('tmp.txt', 'w')")
|
| 32 |
+
)
|
| 33 |
+
assert result.error == "sandbox_violation"
|
| 34 |
+
assert "name 'open' is blocked" in result.output
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_code_executor_blocks_escape_attempts():
|
| 38 |
+
result = code_executor_tool(
|
| 39 |
+
OrchestratorAction(
|
| 40 |
+
tool_id="code_executor",
|
| 41 |
+
code_snippet="().__class__.__mro__[1].__subclasses__()",
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
assert result.error == "sandbox_violation"
|
| 45 |
+
assert "__class__" in result.output or "__subclasses__" in result.output
|
tests/test_tools.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from urllib.error import HTTPError
|
| 4 |
+
|
| 5 |
+
from env.models import OrchestratorAction
|
| 6 |
+
from tools.calculator import calculator_tool
|
| 7 |
+
from tools.code_executor import code_executor_tool
|
| 8 |
+
from tools.commit import commit_tool
|
| 9 |
+
from tools.llm_reason import llm_reason_tool
|
| 10 |
+
from tools.ceramic_search import make_search_tool
|
| 11 |
+
from tools.wiki_lookup import wiki_lookup_tool
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_calculator_happy_path():
|
| 15 |
+
result = calculator_tool(OrchestratorAction(tool_id="calculator", expression="2 + 2 * 3"))
|
| 16 |
+
assert result.output == "8"
|
| 17 |
+
assert result.error is None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_calculator_invalid_input():
|
| 21 |
+
result = calculator_tool(OrchestratorAction(tool_id="calculator", expression="open('x')"))
|
| 22 |
+
assert result.error is not None
|
| 23 |
+
assert result.output.startswith("[Calc error:")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_search_fallback_is_deterministic(monkeypatch):
|
| 27 |
+
monkeypatch.delenv("CERAMIC_API_KEY", raising=False)
|
| 28 |
+
monkeypatch.delenv("SEE_CERAMIC_API_KEY", raising=False)
|
| 29 |
+
|
| 30 |
+
tool = make_search_tool()
|
| 31 |
+
action = OrchestratorAction(tool_id="ceramic_search", query="Eiffel Tower")
|
| 32 |
+
|
| 33 |
+
first = tool(action)
|
| 34 |
+
second = tool(action)
|
| 35 |
+
|
| 36 |
+
assert first.error is None
|
| 37 |
+
assert first.output == second.output
|
| 38 |
+
assert "Eiffel Tower" in first.output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_wiki_lookup_not_found(monkeypatch):
|
| 42 |
+
def fake_urlopen(*args, **kwargs):
|
| 43 |
+
raise HTTPError(url="https://example.com", code=404, msg="not found", hdrs=None, fp=None)
|
| 44 |
+
|
| 45 |
+
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
| 46 |
+
|
| 47 |
+
result = wiki_lookup_tool(OrchestratorAction(tool_id="wiki_lookup", query="Definitely Not A Real Page"))
|
| 48 |
+
assert result.error == "not_found"
|
| 49 |
+
assert "no article found" in result.output.lower()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_llm_reason_no_api_key(monkeypatch):
|
| 53 |
+
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
|
| 54 |
+
monkeypatch.delenv("TOGETHER_KEY", raising=False)
|
| 55 |
+
|
| 56 |
+
result = llm_reason_tool(OrchestratorAction(tool_id="llm_reason", query="Explain gravity"))
|
| 57 |
+
assert result.error == "no_api_key"
|
| 58 |
+
assert "not configured" in result.output.lower()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_commit_passthrough_behavior():
|
| 62 |
+
result = commit_tool(OrchestratorAction(tool_id="commit", answer=" 1889 "))
|
| 63 |
+
assert result.tool_id == "commit"
|
| 64 |
+
assert result.output == "Committed answer: 1889"
|
tools/__init__.py
CHANGED
|
@@ -1,29 +1,20 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
Each tool is a callable: (action: OrchestratorAction) -> ToolResult
|
| 4 |
-
"""
|
| 5 |
from __future__ import annotations
|
| 6 |
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
from .commit import commit_tool
|
| 16 |
-
from .llm_reason import llm_reason_tool
|
| 17 |
-
from .wiki_lookup import wiki_lookup_tool
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
-
""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
"commit": commit_tool,
|
| 29 |
-
}
|
|
|
|
| 1 |
+
"""Public tool layer for CostAwareToolEnv."""
|
|
|
|
|
|
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
from .runtime import (
|
| 5 |
+
ToolSpec,
|
| 6 |
+
build_tool_catalog,
|
| 7 |
+
build_tool_registry,
|
| 8 |
+
catalog_as_dicts,
|
| 9 |
+
dispatch_tool,
|
| 10 |
+
validate_tool_costs,
|
| 11 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
__all__ = [
|
| 14 |
+
"ToolSpec",
|
| 15 |
+
"build_tool_catalog",
|
| 16 |
+
"build_tool_registry",
|
| 17 |
+
"catalog_as_dicts",
|
| 18 |
+
"dispatch_tool",
|
| 19 |
+
"validate_tool_costs",
|
| 20 |
+
]
|
|
|
|
|
|
tools/calculator.py
CHANGED
|
@@ -24,6 +24,15 @@ _SAFE_OPS = {
|
|
| 24 |
ast.UAdd: operator.pos,
|
| 25 |
}
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
_SAFE_FUNCS: dict[str, Any] = {
|
| 28 |
"abs": abs, "round": round, "min": min, "max": max,
|
| 29 |
"sqrt": math.sqrt, "log": math.log, "log2": math.log2,
|
|
@@ -53,6 +62,17 @@ def _safe_eval(node: ast.AST) -> Any:
|
|
| 53 |
if op_type not in _SAFE_OPS:
|
| 54 |
raise ValueError(f"Unsupported unary: {op_type.__name__}")
|
| 55 |
return _SAFE_OPS[op_type](_safe_eval(node.operand))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if isinstance(node, ast.Call):
|
| 57 |
func = _safe_eval(node.func)
|
| 58 |
if not callable(func):
|
|
|
|
| 24 |
ast.UAdd: operator.pos,
|
| 25 |
}
|
| 26 |
|
| 27 |
+
_SAFE_COMPARISONS = {
|
| 28 |
+
ast.Eq: operator.eq,
|
| 29 |
+
ast.NotEq: operator.ne,
|
| 30 |
+
ast.Lt: operator.lt,
|
| 31 |
+
ast.LtE: operator.le,
|
| 32 |
+
ast.Gt: operator.gt,
|
| 33 |
+
ast.GtE: operator.ge,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
_SAFE_FUNCS: dict[str, Any] = {
|
| 37 |
"abs": abs, "round": round, "min": min, "max": max,
|
| 38 |
"sqrt": math.sqrt, "log": math.log, "log2": math.log2,
|
|
|
|
| 62 |
if op_type not in _SAFE_OPS:
|
| 63 |
raise ValueError(f"Unsupported unary: {op_type.__name__}")
|
| 64 |
return _SAFE_OPS[op_type](_safe_eval(node.operand))
|
| 65 |
+
if isinstance(node, ast.Compare):
|
| 66 |
+
left = _safe_eval(node.left)
|
| 67 |
+
for op, comparator in zip(node.ops, node.comparators):
|
| 68 |
+
op_type = type(op)
|
| 69 |
+
if op_type not in _SAFE_COMPARISONS:
|
| 70 |
+
raise ValueError(f"Unsupported comparison: {op_type.__name__}")
|
| 71 |
+
right = _safe_eval(comparator)
|
| 72 |
+
if not _SAFE_COMPARISONS[op_type](left, right):
|
| 73 |
+
return False
|
| 74 |
+
left = right
|
| 75 |
+
return True
|
| 76 |
if isinstance(node, ast.Call):
|
| 77 |
func = _safe_eval(node.func)
|
| 78 |
if not callable(func):
|
tools/code_executor.py
CHANGED
|
@@ -1,35 +1,169 @@
|
|
| 1 |
"""Restricted Python code executor.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
import contextlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from env.models import OrchestratorAction, ToolResult
|
| 13 |
|
| 14 |
-
_BLOCKED_MODULES = frozenset({
|
| 15 |
-
"os", "sys", "subprocess", "socket", "shutil", "pathlib",
|
| 16 |
-
"importlib", "builtins", "ctypes", "multiprocessing", "threading",
|
| 17 |
-
"signal", "pty", "fcntl", "resource", "gc", "inspect",
|
| 18 |
-
})
|
| 19 |
-
|
| 20 |
_MAX_OUTPUT_CHARS = 2000
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
raise ImportError(f"Module '{name}' is not allowed in code_executor")
|
| 32 |
-
return self._orig(name, *args, **kwargs)
|
| 33 |
|
| 34 |
|
| 35 |
def code_executor_tool(action: OrchestratorAction) -> ToolResult:
|
|
@@ -38,26 +172,21 @@ def code_executor_tool(action: OrchestratorAction) -> ToolResult:
|
|
| 38 |
return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
|
| 39 |
|
| 40 |
stdout_buf = io.StringIO()
|
| 41 |
-
safe_globals = {
|
| 42 |
-
"__builtins__":
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
k: getattr(__builtins__, k) for k in dir(__builtins__)
|
| 47 |
-
if k not in ("open", "exec", "eval", "compile", "__import__")
|
| 48 |
-
},
|
| 49 |
-
"__import__": _BlockedImport(__import__),
|
| 50 |
-
"print": lambda *a, **kw: print(*a, **kw, file=stdout_buf),
|
| 51 |
}
|
| 52 |
|
| 53 |
try:
|
|
|
|
|
|
|
| 54 |
with contextlib.redirect_stdout(stdout_buf):
|
| 55 |
-
exec(compile(
|
| 56 |
output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
|
| 57 |
return ToolResult(tool_id="code_executor", output=output)
|
|
|
|
|
|
|
| 58 |
except Exception as exc:
|
| 59 |
-
return ToolResult(
|
| 60 |
-
tool_id="code_executor",
|
| 61 |
-
output=f"[Execution error: {exc}]",
|
| 62 |
-
error=str(exc),
|
| 63 |
-
)
|
|
|
|
| 1 |
"""Restricted Python code executor.
|
| 2 |
|
| 3 |
+
The executor is intentionally narrow:
|
| 4 |
+
- import statements are rejected before execution
|
| 5 |
+
- dangerous builtin names are blocked
|
| 6 |
+
- dunder attribute access is blocked to prevent object graph escapes
|
| 7 |
+
- only a curated builtin/module surface is exposed
|
| 8 |
+
|
| 9 |
+
This keeps the tool useful for HumanEval-style tasks while making the
|
| 10 |
+
security boundaries explicit and testable.
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import ast
|
| 15 |
+
import builtins
|
| 16 |
import contextlib
|
| 17 |
+
import io
|
| 18 |
+
import math
|
| 19 |
+
import operator
|
| 20 |
+
import collections
|
| 21 |
+
import functools
|
| 22 |
+
import itertools
|
| 23 |
+
import statistics
|
| 24 |
+
import heapq
|
| 25 |
+
import bisect
|
| 26 |
+
import fractions
|
| 27 |
+
import decimal
|
| 28 |
+
import re
|
| 29 |
+
from typing import Any, Dict
|
| 30 |
|
| 31 |
from env.models import OrchestratorAction, ToolResult
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
_MAX_OUTPUT_CHARS = 2000
|
| 34 |
|
| 35 |
+
_BLOCKED_NAMES = {
|
| 36 |
+
"__builtins__",
|
| 37 |
+
"__import__",
|
| 38 |
+
"open",
|
| 39 |
+
"exec",
|
| 40 |
+
"eval",
|
| 41 |
+
"compile",
|
| 42 |
+
"globals",
|
| 43 |
+
"locals",
|
| 44 |
+
"vars",
|
| 45 |
+
"dir",
|
| 46 |
+
"getattr",
|
| 47 |
+
"setattr",
|
| 48 |
+
"delattr",
|
| 49 |
+
"input",
|
| 50 |
+
"help",
|
| 51 |
+
"type",
|
| 52 |
+
"object",
|
| 53 |
+
"super",
|
| 54 |
+
"memoryview",
|
| 55 |
+
"breakpoint",
|
| 56 |
+
"exit",
|
| 57 |
+
"quit",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
_BLOCKED_ATTRS = {
|
| 61 |
+
"__class__",
|
| 62 |
+
"__base__",
|
| 63 |
+
"__bases__",
|
| 64 |
+
"__subclasses__",
|
| 65 |
+
"__mro__",
|
| 66 |
+
"__globals__",
|
| 67 |
+
"__code__",
|
| 68 |
+
"__closure__",
|
| 69 |
+
"__dict__",
|
| 70 |
+
"__getattribute__",
|
| 71 |
+
"__getattr__",
|
| 72 |
+
"__setattr__",
|
| 73 |
+
"__delattr__",
|
| 74 |
+
"__reduce__",
|
| 75 |
+
"__reduce_ex__",
|
| 76 |
+
"__func__",
|
| 77 |
+
"__self__",
|
| 78 |
+
"__module__",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
_SAFE_BUILTIN_NAMES = {
|
| 82 |
+
"abs",
|
| 83 |
+
"all",
|
| 84 |
+
"any",
|
| 85 |
+
"bool",
|
| 86 |
+
"chr",
|
| 87 |
+
"dict",
|
| 88 |
+
"enumerate",
|
| 89 |
+
"float",
|
| 90 |
+
"int",
|
| 91 |
+
"isinstance",
|
| 92 |
+
"issubclass",
|
| 93 |
+
"len",
|
| 94 |
+
"list",
|
| 95 |
+
"map",
|
| 96 |
+
"max",
|
| 97 |
+
"min",
|
| 98 |
+
"pow",
|
| 99 |
+
"range",
|
| 100 |
+
"repr",
|
| 101 |
+
"reversed",
|
| 102 |
+
"round",
|
| 103 |
+
"set",
|
| 104 |
+
"slice",
|
| 105 |
+
"sorted",
|
| 106 |
+
"str",
|
| 107 |
+
"sum",
|
| 108 |
+
"tuple",
|
| 109 |
+
"zip",
|
| 110 |
+
"divmod",
|
| 111 |
+
"ord",
|
| 112 |
+
"Exception",
|
| 113 |
+
"ValueError",
|
| 114 |
+
"RuntimeError",
|
| 115 |
+
"TypeError",
|
| 116 |
+
"KeyError",
|
| 117 |
+
"IndexError",
|
| 118 |
+
"AssertionError",
|
| 119 |
+
"ZeroDivisionError",
|
| 120 |
+
"object",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
_SAFE_MODULES: Dict[str, Any] = {
|
| 124 |
+
"math": math,
|
| 125 |
+
"collections": collections,
|
| 126 |
+
"functools": functools,
|
| 127 |
+
"itertools": itertools,
|
| 128 |
+
"statistics": statistics,
|
| 129 |
+
"heapq": heapq,
|
| 130 |
+
"bisect": bisect,
|
| 131 |
+
"fractions": fractions,
|
| 132 |
+
"decimal": decimal,
|
| 133 |
+
"re": re,
|
| 134 |
+
"operator": operator,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class SandboxViolation(ValueError):
|
| 139 |
+
"""Raised when code tries to cross the sandbox boundary."""
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _validate_tree(tree: ast.AST) -> None:
|
| 143 |
+
for node in ast.walk(tree):
|
| 144 |
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
| 145 |
+
raise SandboxViolation("import statements are blocked")
|
| 146 |
+
if isinstance(node, (ast.Global, ast.Nonlocal)):
|
| 147 |
+
raise SandboxViolation("global and nonlocal are blocked")
|
| 148 |
+
if isinstance(node, ast.Attribute):
|
| 149 |
+
if node.attr.startswith("__") or node.attr in _BLOCKED_ATTRS:
|
| 150 |
+
raise SandboxViolation(f"attribute access '{node.attr}' is blocked")
|
| 151 |
+
if isinstance(node, ast.Name):
|
| 152 |
+
if node.id.startswith("__") or node.id in _BLOCKED_NAMES:
|
| 153 |
+
raise SandboxViolation(f"name '{node.id}' is blocked")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _safe_builtins(stdout_buf: io.StringIO) -> Dict[str, Any]:
|
| 157 |
+
safe: Dict[str, Any] = {name: getattr(builtins, name) for name in _SAFE_BUILTIN_NAMES}
|
| 158 |
|
| 159 |
+
def safe_print(*args: Any, **kwargs: Any) -> None:
|
| 160 |
+
kwargs = dict(kwargs)
|
| 161 |
+
kwargs.pop("file", None)
|
| 162 |
+
builtins.print(*args, **kwargs, file=stdout_buf)
|
| 163 |
|
| 164 |
+
safe["print"] = safe_print
|
| 165 |
+
safe["__build_class__"] = builtins.__build_class__
|
| 166 |
+
return safe
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def code_executor_tool(action: OrchestratorAction) -> ToolResult:
|
|
|
|
| 172 |
return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
|
| 173 |
|
| 174 |
stdout_buf = io.StringIO()
|
| 175 |
+
safe_globals: Dict[str, Any] = {
|
| 176 |
+
"__builtins__": _safe_builtins(stdout_buf),
|
| 177 |
+
"__name__": "__code_executor__",
|
| 178 |
+
"__package__": None,
|
| 179 |
+
**_SAFE_MODULES,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
|
| 182 |
try:
|
| 183 |
+
tree = ast.parse(code, mode="exec")
|
| 184 |
+
_validate_tree(tree)
|
| 185 |
with contextlib.redirect_stdout(stdout_buf):
|
| 186 |
+
exec(compile(tree, "<code_executor>", "exec"), safe_globals) # noqa: S102
|
| 187 |
output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
|
| 188 |
return ToolResult(tool_id="code_executor", output=output)
|
| 189 |
+
except SandboxViolation as exc:
|
| 190 |
+
return ToolResult(tool_id="code_executor", output=f"[Sandbox blocked: {exc}]", error="sandbox_violation")
|
| 191 |
except Exception as exc:
|
| 192 |
+
return ToolResult(tool_id="code_executor", output=f"[Execution error: {exc}]", error=str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
tools/runtime.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared tool catalog and dispatch helpers for CostAwareToolEnv.
|
| 2 |
+
|
| 3 |
+
This module keeps the tool contract explicit:
|
| 4 |
+
- the catalog describes every tool, its purpose, and its input field
|
| 5 |
+
- registry validation catches config drift early
|
| 6 |
+
- dispatch normalizes failures into ToolResult objects instead of
|
| 7 |
+
letting exceptions leak through the environment loop
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import asdict, dataclass
|
| 12 |
+
from typing import Any, Callable, Dict, List, Mapping
|
| 13 |
+
|
| 14 |
+
from env.config import EnvConfig
|
| 15 |
+
from env.models import OrchestratorAction, TOOL_IDS, ToolResult
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class ToolSpec:
|
| 20 |
+
tool_id: str
|
| 21 |
+
label: str
|
| 22 |
+
purpose: str
|
| 23 |
+
input_field: str
|
| 24 |
+
cost: float
|
| 25 |
+
notes: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_TOOL_SPEC_TEMPLATES: Dict[str, Dict[str, str]] = {
|
| 29 |
+
"ceramic_search": {
|
| 30 |
+
"label": "Ceramic web search",
|
| 31 |
+
"purpose": "Web retrieval for multi-hop factual QA",
|
| 32 |
+
"input_field": "query",
|
| 33 |
+
"notes": "Falls back to deterministic offline search when Ceramic credentials are unavailable.",
|
| 34 |
+
},
|
| 35 |
+
"wiki_lookup": {
|
| 36 |
+
"label": "Wikipedia lookup",
|
| 37 |
+
"purpose": "Entity facts, definitions, and short summaries",
|
| 38 |
+
"input_field": "query",
|
| 39 |
+
"notes": "Returns an explicit not-found or HTTP error result instead of crashing.",
|
| 40 |
+
},
|
| 41 |
+
"calculator": {
|
| 42 |
+
"label": "Calculator",
|
| 43 |
+
"purpose": "Arithmetic and symbolic math",
|
| 44 |
+
"input_field": "expression",
|
| 45 |
+
"notes": "Uses a restricted AST evaluator with comparisons and common math functions.",
|
| 46 |
+
},
|
| 47 |
+
"code_executor": {
|
| 48 |
+
"label": "Python executor",
|
| 49 |
+
"purpose": "HumanEval-style coding tasks",
|
| 50 |
+
"input_field": "code_snippet",
|
| 51 |
+
"notes": "Sandboxed exec with blocked imports, dunder attribute access, and unsafe builtins.",
|
| 52 |
+
},
|
| 53 |
+
"llm_reason": {
|
| 54 |
+
"label": "LLM reasoning",
|
| 55 |
+
"purpose": "Costly model-backed reasoning on hard problems",
|
| 56 |
+
"input_field": "query",
|
| 57 |
+
"notes": "Returns a clear no_api_key error when Together is unavailable.",
|
| 58 |
+
},
|
| 59 |
+
"commit": {
|
| 60 |
+
"label": "Commit answer",
|
| 61 |
+
"purpose": "Submit the final answer and advance the episode",
|
| 62 |
+
"input_field": "answer",
|
| 63 |
+
"notes": "Pass-through only; grading happens inside the environment.",
|
| 64 |
+
},
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def validate_tool_costs(config: EnvConfig) -> None:
|
| 69 |
+
"""Fail fast if the configured cost map drifts from the canonical tool set."""
|
| 70 |
+
missing = [tool_id for tool_id in TOOL_IDS if tool_id not in config.tool_costs]
|
| 71 |
+
if missing:
|
| 72 |
+
raise ValueError(f"EnvConfig.tool_costs is missing required tools: {missing}")
|
| 73 |
+
|
| 74 |
+
negative = {tool_id: cost for tool_id, cost in config.tool_costs.items() if cost < 0}
|
| 75 |
+
if negative:
|
| 76 |
+
raise ValueError(f"Tool costs must be non-negative: {negative}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_tool_catalog(config: EnvConfig | None = None) -> List[ToolSpec]:
|
| 80 |
+
"""Return the canonical ordered catalog used by the UI and docs."""
|
| 81 |
+
cfg = config or EnvConfig()
|
| 82 |
+
validate_tool_costs(cfg)
|
| 83 |
+
|
| 84 |
+
catalog: List[ToolSpec] = []
|
| 85 |
+
for tool_id in TOOL_IDS:
|
| 86 |
+
template = _TOOL_SPEC_TEMPLATES[tool_id]
|
| 87 |
+
catalog.append(
|
| 88 |
+
ToolSpec(
|
| 89 |
+
tool_id=tool_id,
|
| 90 |
+
label=template["label"],
|
| 91 |
+
purpose=template["purpose"],
|
| 92 |
+
input_field=template["input_field"],
|
| 93 |
+
cost=cfg.tool_costs[tool_id],
|
| 94 |
+
notes=template["notes"],
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
return catalog
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_tool_registry(config: EnvConfig | None = None) -> Dict[str, Callable[[OrchestratorAction], ToolResult]]:
|
| 101 |
+
"""Return the canonical mapping of tool_id -> tool callable."""
|
| 102 |
+
from .calculator import calculator_tool
|
| 103 |
+
from .ceramic_search import make_search_tool
|
| 104 |
+
from .code_executor import code_executor_tool
|
| 105 |
+
from .commit import commit_tool
|
| 106 |
+
from .llm_reason import llm_reason_tool
|
| 107 |
+
from .wiki_lookup import wiki_lookup_tool
|
| 108 |
+
|
| 109 |
+
cfg = config or EnvConfig()
|
| 110 |
+
validate_tool_costs(cfg)
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"ceramic_search": make_search_tool(),
|
| 114 |
+
"calculator": calculator_tool,
|
| 115 |
+
"wiki_lookup": wiki_lookup_tool,
|
| 116 |
+
"code_executor": code_executor_tool,
|
| 117 |
+
"llm_reason": llm_reason_tool,
|
| 118 |
+
"commit": commit_tool,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def dispatch_tool(
|
| 123 |
+
tool_id: str,
|
| 124 |
+
action: OrchestratorAction,
|
| 125 |
+
registry: Mapping[str, Callable[[OrchestratorAction], ToolResult]],
|
| 126 |
+
) -> ToolResult:
|
| 127 |
+
"""Call a tool and normalize missing-tool and crash cases into ToolResult."""
|
| 128 |
+
tool_fn = registry.get(tool_id)
|
| 129 |
+
if tool_fn is None:
|
| 130 |
+
return ToolResult(
|
| 131 |
+
tool_id=tool_id,
|
| 132 |
+
output="[Tool not available in this environment]",
|
| 133 |
+
error="not_available",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
result = tool_fn(action)
|
| 138 |
+
except Exception as exc: # pragma: no cover - defensive wrapper
|
| 139 |
+
return ToolResult(
|
| 140 |
+
tool_id=tool_id,
|
| 141 |
+
output=f"[Error: {exc}]",
|
| 142 |
+
error=str(exc),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if not isinstance(result, ToolResult):
|
| 146 |
+
return ToolResult(
|
| 147 |
+
tool_id=tool_id,
|
| 148 |
+
output=f"[Tool error: unexpected return type {type(result).__name__}]",
|
| 149 |
+
error="invalid_return_type",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if result.tool_id != tool_id:
|
| 153 |
+
result = result.model_copy(update={"tool_id": tool_id})
|
| 154 |
+
return result
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def catalog_as_dicts(config: EnvConfig | None = None) -> List[dict[str, Any]]:
|
| 158 |
+
"""Convenience helper for JSON serialization."""
|
| 159 |
+
return [asdict(spec) for spec in build_tool_catalog(config)]
|
tools/wiki_lookup.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
|
| 4 |
import urllib.parse
|
| 5 |
import urllib.request
|
|
|
|
| 6 |
import json
|
| 7 |
|
| 8 |
from env.models import OrchestratorAction, ToolResult
|
|
|
|
| 3 |
|
| 4 |
import urllib.parse
|
| 5 |
import urllib.request
|
| 6 |
+
import urllib.error
|
| 7 |
import json
|
| 8 |
|
| 9 |
from env.models import OrchestratorAction, ToolResult
|