Andrew Lara Claude Sonnet 4.6 commited on
Commit
8ca3a35
·
0 Parent(s):

Initial implementation of ToolOrchestratorEnv

Browse files

Multi-tool cost-aware RL environment built on SearchEconomicsEnv.
Agent selects from 6 tools (ceramic_search, wiki_lookup, calculator,
code_executor, llm_reason, commit) across 4 QA domains under a shared
budget constraint. Weitzman-style composite reward (quality + efficiency
bonus). OpenEnv-compatible FastAPI server, Dockerfile for HF Spaces.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy this file to .env and fill in your values.
2
+ # On HuggingFace Spaces, set these as Space Secrets (Settings → Variables and secrets).
3
+
4
+ # Required: Ceramic AI search key (sign up at https://ceramic.ai)
5
+ CERAMIC_API_KEY=cer_sk_live_...
6
+
7
+ # Optional: Together AI key for the llm_reason tool (https://api.together.xyz)
8
+ TOGETHER_API_KEY=
9
+
10
+ # HuggingFace token — only needed if you load gated datasets (e.g. GPQA)
11
+ HF_TOKEN=
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Never commit real secrets
2
+ .env
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ .venv/
9
+ dist/
10
+ *.egg-info/
11
+
12
+ # Data / cache
13
+ data/*.jsonl
14
+ data/*.npy
15
+ data/*.json
16
+
17
+ # Editor
18
+ .DS_Store
19
+ .idea/
20
+ .vscode/
BLOG_PROMPT.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blog Writing Prompt
2
+
3
+ Drop this entire block into a fresh Claude conversation to generate the research blog post.
4
+
5
+ ---
6
+
7
+ ## PROMPT
8
+
9
+ You are writing a research blog post about a reinforcement learning environment called **ToolOrchestratorEnv**. The authors are Andrew and Yash Sharma (USC), building on Yash's prior work on SearchEconomicsEnv. The target audience is ML researchers and practitioners who read papers like those at NeurIPS, ICLR, and the Hugging Face blog — people who understand RL and LLMs but are not experts in tool-use or search economics.
10
+
11
+ The post should be **1,500–2,000 words**, well-structured with section headers, and written in a direct, confident academic-blog tone (think: The Gradient, Hugging Face blog, or a good arXiv blog post). Avoid hype. Let the ideas do the work.
12
+
13
+ ---
14
+
15
+ ### What this research is and why it matters
16
+
17
+ **The core problem:** Large language model agents are increasingly given access to tools — search engines, calculators, code interpreters, databases. In real deployments, every tool call costs something: API fees, latency, rate limits, or compute. Current agent frameworks treat tools as free: call whatever you want, as many times as you want. This is unrealistic and economically wasteful.
18
+
19
+ **The research gap:** Most RL environments for tool-using agents either (a) focus on a single tool (e.g. search-only retrieval agents), or (b) ignore cost entirely and measure only answer quality. There is no standard RL training ground where the agent must *choose between tools with different price/quality tradeoffs* under a shared budget constraint.
20
+
21
+ **What we built:** ToolOrchestratorEnv — an OpenEnv-compatible RL environment that puts cost-aware tool selection at the center of the learning objective. The agent picks from six tools per step (web search, Wikipedia, calculator, Python executor, LLM reasoning, or commit) across four question domains (HotpotQA, MATH, GPQA, HumanEval), with a shared budget that depletes as tools are called.
22
+
23
+ **Why this is novel:**
24
+ - It extends the "search economics" framing from a single tool to a heterogeneous tool portfolio
25
+ - It tests transfer: can an agent learn that calculators are cheap and LLMs expensive, and route accordingly?
26
+ - The multi-domain setup forces the agent to learn *domain → tool* mappings (search for factual QA, calculator for math) rather than one-size-fits-all policies
27
+ - The Weitzman-style reward (efficiency bonus only on correct + frugal commits) creates a richer credit assignment problem than binary success/failure
28
+
29
+ ---
30
+
31
+ ### Background sections to write (with sources to find and cite)
32
+
33
+ **1. The tool-use agent landscape**
34
+
35
+ Explain why tool use is now central to LLM agents. Cite and discuss:
36
+ - The ReAct paper (Yao et al., 2022) — introduced interleaving reasoning and tool calls
37
+ - Toolformer (Schick et al., 2023) — self-supervised tool learning
38
+ - ToolBench / API-Bank — benchmarks for tool-using LLMs
39
+ - Find at least one recent paper (2024 or 2025) showing that tool-calling agents outperform tool-free baselines on knowledge-intensive tasks. Look at arXiv, ACL Anthology, or the Hugging Face papers page.
40
+
41
+ **2. Search economics and the budget constraint**
42
+
43
+ Explain the economic analogy: information has a cost, and rational agents should not search more than their expected marginal gain from search. Cite:
44
+ - Weitzman (1979) "Optimal Search for the Best Alternative" — the foundational search economics paper
45
+ - SearchEconomicsEnv by Yash Sharma / USC (https://github.com/sharma-yash01/SearchEconomicsEnv, https://huggingface.co/spaces/yashu2000/search-economics-env) — the direct predecessor that built this RL environment for search-budget-constrained HotpotQA
46
+ - Look for any recent work on "budgeted retrieval" or "adaptive retrieval" in RAG systems (2024-2025) that shows that unconstrained retrieval hurts performance or cost-effectiveness. Papers like FLARE, IterRetGen, or similar might be relevant.
47
+
48
+ **3. The multi-domain challenge**
49
+
50
+ Explain why testing across HotpotQA, MATH, GPQA, and HumanEval matters — these domains need fundamentally different tools (search for factual, calculator for symbolic, code for algorithmic, LLM for graduate-level). Find and cite:
51
+ - The MATH benchmark paper (Hendrycks et al., 2021)
52
+ - HotpotQA paper (Yang et al., 2018)
53
+ - GPQA paper (Rein et al., 2023)
54
+ - HumanEval paper (Chen et al., 2021)
55
+ - Any paper showing that tool specialisation helps across domains (e.g., PAL, PoT, or similar)
56
+
57
+ **4. Reinforcement learning for tool selection**
58
+
59
+ Explain why RL (not just prompting or supervised learning) is the right frame for this problem: the agent must explore, face delayed rewards (only know if an answer was right after commit), and learn multi-step strategies. Cite:
60
+ - Any recent paper using RL for LLM agent training (e.g., RLHF extensions, agent-specific RL work, or OpenEnv/AgentBench)
61
+ - The OpenEnv competition framework (Berkeley RDI, AgentX) — explain what OpenEnv is and why standardised RL environments matter for reproducibility
62
+ - Look for "process reward models" or "step-level reward" papers in the agent RL space
63
+
64
+ ---
65
+
66
+ ### Key sections for the post
67
+
68
+ 1. **The problem with free tools** — hook paragraph. Real API calls cost money. Agents don't know this. Set up the gap.
69
+
70
+ 2. **Search economics, briefly** — one paragraph on Weitzman, one on SearchEconomicsEnv. The framing: information retrieval as a market with prices.
71
+
72
+ 3. **ToolOrchestratorEnv: the environment** — describe the setup clearly:
73
+ - 6 tools, 4 datasets, shared budget
74
+ - The action-observation loop (what the agent sees, what it decides)
75
+ - The reward formula (explain it intuitively: you pay for every call, you earn back on correct commits, and get a bonus for answering correctly without blowing your budget)
76
+ - The Ceramic AI integration for live web retrieval
77
+
78
+ 4. **Why this is hard** — explain the credit assignment problem (you don't know a tool call was wasted until you commit), the domain-routing challenge, and the exploration-exploitation tradeoff under budget pressure.
79
+
80
+ 5. **Baselines and what they tell us** — describe the three baselines (random, cheapest-first, domain-oracle) and what their expected performance reveals about the structure of the problem.
81
+
82
+ 6. **What we're building toward** — the research agenda: train a PPO or DQN agent on this environment, show it beats baselines, and study what routing policies it learns. Can it learn that LLM reasoning is worth 2x the cost for GPQA but wasteful for simple arithmetic?
83
+
84
+ 7. **Conclusion** — the broader point: as AI systems become more agentic, cost-aware tool selection will be as important as answer quality. We need RL environments that take this seriously. This is one.
85
+
86
+ ---
87
+
88
+ ### Tone and style guidelines
89
+
90
+ - **Cite real papers** — do not make up citations. For any claim about related work, search arXiv, Semantic Scholar, or ACL Anthology and use the actual paper. Format citations inline as (Author et al., Year) with a references section at the end.
91
+ - **Be specific** — don't say "researchers have shown" without naming the paper.
92
+ - **Write for skeptics** — assume your reader will ask "why does this matter" and "what's actually new." Answer those questions directly in the text.
93
+ - **Avoid marketing language** — no "revolutionary," "groundbreaking," or "state-of-the-art." Just describe what was built and why it's useful.
94
+ - **Include the reward formula** — write it out mathematically and then explain it in plain English. Researchers appreciate seeing the actual math.
95
+ - **Link to the HF Space** — mention that the environment is live at https://huggingface.co/spaces/yashu2000/search-economics-env (SearchEconomicsEnv, the predecessor) and that ToolOrchestratorEnv will be deployed alongside it.
96
+
97
+ ---
98
+
99
+ ### What NOT to do
100
+
101
+ - Do not fabricate benchmark numbers — we don't have trained agent results yet, only baseline results. Say so honestly.
102
+ - Do not claim this is the first RL environment for tool use — be accurate about prior work.
103
+ - Do not skip the related work — proving the gap is real requires engaging with existing papers.
104
+ - Do not make the reward formula paragraph too short — this is a key technical contribution; spend time on it.
105
+
106
+ ---
107
+
108
+ ### Final checklist before finishing the post
109
+
110
+ - [ ] Every citation is real and can be found on arXiv or a peer-reviewed venue
111
+ - [ ] The reward formula is written out and explained in plain English
112
+ - [ ] The post explains what OpenEnv is and why deploying on HF Spaces matters
113
+ - [ ] The post mentions Ceramic AI and explains why live web retrieval matters (vs. static knowledge)
114
+ - [ ] The baseline section sets up what "winning" looks like for a trained RL agent
115
+ - [ ] A references section is included at the end with full citations
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Spaces — ToolOrchestratorEnv
2
+ # Builds a FastAPI server that exposes the OpenEnv endpoints.
3
+ #
4
+ # Required secret (set in Space Settings → Variables and secrets):
5
+ # CERAMIC_API_KEY=cer_sk_live_...
6
+ #
7
+ # Optional:
8
+ # TOGETHER_API_KEY=... (enables the llm_reason tool)
9
+ # HF_TOKEN=... (enables gated datasets like GPQA)
10
+
11
+ FROM python:3.11-slim
12
+
13
+ WORKDIR /app
14
+
15
+ # Install dependencies first so Docker layer-caches them
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy source
20
+ COPY . .
21
+
22
+ # HuggingFace Spaces runs as a non-root user
23
+ RUN useradd -m -u 1000 appuser && chown -R appuser /app
24
+ USER appuser
25
+
26
+ EXPOSE 8000
27
+
28
+ # The Space README sets base_path: /web so the demo UI loads on open
29
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tool Orchestrator Environment
3
+ emoji: 🔧
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - reinforcement-learning
13
+ - tool-use
14
+ - cost-aware
15
+ ---
16
+
17
+ # ToolOrchestratorEnv
18
+
19
+ **An OpenEnv-compatible reinforcement learning environment for multi-tool, cost-aware question answering.**
20
+
21
+ Built on top of [SearchEconomicsEnv](https://huggingface.co/spaces/yashu2000/search-economics-env) (Yash Sharma, USC / Ceramic AI), this environment generalises the single-tool (search-only) formulation to a full **tool-selection problem**: the agent must choose *which* of six tools to call at each step, managing a shared cost budget across a multi-domain question set (HotpotQA, MATH, GPQA, HumanEval).
22
+
23
+ The core research question: **can an RL agent learn a cost-aware tool routing policy that outperforms simple heuristics like "always search" or "always use the cheapest tool"?**
24
+
25
+ ---
26
+
27
+ ## What the agent learns
28
+
29
+ Each episode the agent receives 10 questions sampled across four domains. At every step it sees:
30
+
31
+ - The current **question** and its **domain** tag
32
+ - Its **remaining budget** (shared across all questions)
33
+ - The **context window** — concatenated outputs from prior tool calls on this question
34
+
35
+ It picks one action from six tools:
36
+
37
+ | Tool | `tool_id` | Cost | Best for |
38
+ |---|---|---|---|
39
+ | Ceramic web search | `ceramic_search` | 1.0 | Multi-hop factual QA |
40
+ | Wikipedia lookup | `wiki_lookup` | 0.5 | Entity facts, definitions |
41
+ | Calculator | `calculator` | 0.1 | Arithmetic, symbolic math |
42
+ | Python executor | `code_executor` | 0.3 | HumanEval code tasks |
43
+ | LLM reasoning | `llm_reason` | 2.0 | Graduate-level GPQA problems |
44
+ | Commit answer | `commit` | 0.0 | Submit and move to next question |
45
+
46
+ **The RL objective:** maximise accuracy across all questions while staying within the total budget — learning *which tool to call*, in *which order*, and *when to stop and commit*.
47
+
48
+ ---
49
+
50
+ ## Reward formula
51
+
52
+ ```
53
+ On tool call: R = -tool_cost
54
+
55
+ On commit: R = base + η · γ · budget_remaining_ratio
56
+
57
+ base = incorrect_reward + quality · (correct_reward − incorrect_reward)
58
+ quality = max(ExactMatch, TokenF1)
59
+ η = 1 if quality ≥ efficiency_bonus_threshold, else 0
60
+ γ = efficiency_bonus_weight
61
+ ```
62
+
63
+ The efficiency bonus is only awarded when the agent answers correctly **and** still has budget remaining — directly incentivising both accuracy and frugality.
64
+
65
+ ---
66
+
67
+ ## Quickstart (local)
68
+
69
+ ```bash
70
+ # 1. Clone and install
71
+ git clone <this-repo>
72
+ cd claude_toolOrchestrator
73
+ pip install -r requirements.txt
74
+
75
+ # 2. Configure keys (copy the example and fill in values)
76
+ cp .env.example .env
77
+ # Set CERAMIC_API_KEY — sign up free at https://ceramic.ai
78
+
79
+ # 3. Start the server
80
+ uvicorn app:app --port 8000
81
+
82
+ # 4. Try the interactive demo UI
83
+ open http://localhost:8000/web
84
+ # or browse the full OpenAPI spec at
85
+ open http://localhost:8000/docs
86
+ ```
87
+
88
+ ---
89
+
90
+ ## HTTP API
91
+
92
+ ### `POST /reset`
93
+
94
+ Start a new episode. Returns `session_id`, initial `observation`, and `state`.
95
+
96
+ ```json
97
+ { "seed": 42, "config_overrides": { "total_budget": 30.0, "num_questions": 5 } }
98
+ ```
99
+
100
+ ### `POST /step?session_id=<id>`
101
+
102
+ Execute one tool call. Pass `session_id` (from `/reset`) as a query param to support parallel agents.
103
+
104
+ ```json
105
+ { "tool_id": "ceramic_search", "query": "When was the Eiffel Tower built?" }
106
+ { "tool_id": "calculator", "expression": "sqrt(144) + 3" }
107
+ { "tool_id": "code_executor", "code_snippet": "print(2 ** 10)" }
108
+ { "tool_id": "commit", "answer": "1889" }
109
+ ```
110
+
111
+ ### `GET /health`
112
+
113
+ Returns `{"status": "ok"}`.
114
+
115
+ ---
116
+
117
+ ## Project layout
118
+
119
+ ```
120
+ claude_toolOrchestrator/
121
+
122
+ ├── app.py # FastAPI server — multi-session, OpenAPI, demo UI
123
+ ├── openenv.yaml # OpenEnv deployment spec
124
+ ├── requirements.txt # Python dependencies
125
+ ├── .env.example # Key template (copy → .env, never commit .env)
126
+
127
+ ├── env/ # ── Core RL environment ──────────────────────────
128
+ │ ├── environment.py # ToolOrchestratorEnvironment: reset() + step()
129
+ │ ├── models.py # Pydantic types: Action, Observation, State, ToolResult
130
+ │ ├── config.py # EnvConfig dataclass: budget, costs, reward weights
131
+ │ ├── answer_grading.py # grade() → (exact_match, f1, quality)
132
+ │ └── reward.py # step_reward() + commit_reward()
133
+
134
+ ├── ceramic/ # ── Retrieval backend ────────────────────────────
135
+ │ └── client.py # CeramicClient (live) + FallbackCeramicClient (offline)
136
+
137
+ ├── data/ # ── Dataset loading ──────────────────────────────
138
+ │ └── loader.py # load_all() → flat list from 4 HF datasets
139
+
140
+ ├── tools/ # ── Six tool implementations ─────────────────────
141
+ │ ├── ceramic_search.py # Web search (Ceramic AI API)
142
+ │ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
143
+ │ ├── calculator.py # Safe AST-based math evaluator (no exec)
144
+ │ ├── code_executor.py # Sandboxed Python exec (blocks os/sys/subprocess)
145
+ │ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
146
+ │ └── commit.py # Answer pass-through; grading runs in environment
147
+
148
+ └── baselines/ # ── Reference policies ───────────────────────────
149
+ ├── random_tool.py # Uniform random tool selection
150
+ ├── cheapest_first.py # Always picks cheapest non-commit tool first
151
+ └── oracle.py # Domain-aware heuristic (search for QA, calc for math)
152
+ ```
153
+
154
+ ---
155
+
156
+ ## Environment variables
157
+
158
+ | Variable | Required | Description |
159
+ |---|---|---|
160
+ | `CERAMIC_API_KEY` | Yes (for live search) | Ceramic AI key — `POST /search` endpoint |
161
+ | `SEE_CERAMIC_API_KEY` | Alternative | HF Spaces alias used by SearchEconomicsEnv |
162
+ | `TOGETHER_API_KEY` | Optional | Enables the `llm_reason` tool via Together AI |
163
+ | `HF_TOKEN` | Optional | Required only to load gated datasets (GPQA) |
164
+
165
+ If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
166
+
167
+ ---
168
+
169
+ ## Running baselines
170
+
171
+ ```bash
172
+ # From inside claude_toolOrchestrator/
173
+ python -m baselines.random_tool
174
+ python -m baselines.cheapest_first
175
+ python -m baselines.oracle
176
+ ```
177
+
178
+ ---
179
+
180
+ ## Relation to SearchEconomicsEnv
181
+
182
+ | | [SearchEconomicsEnv](https://github.com/sharma-yash01/SearchEconomicsEnv) | ToolOrchestratorEnv |
183
+ |---|---|---|
184
+ | Tools available | 1 (search only) | 6 (search, wiki, calc, code, LLM, commit) |
185
+ | Datasets | HotpotQA | HotpotQA + MATH + GPQA + HumanEval |
186
+ | Budget unit | # of search calls | cost units per tool (tool-specific) |
187
+ | Reward shape | Weitzman search penalty | Same formula, extended to tool costs |
188
+ | Core RL challenge | *How many* searches to do | *Which* tool to call, in which order |
189
+ | Retrieval backend | Ceramic AI | Ceramic AI (shared) |
190
+
191
+ ---
192
+
193
+ ## Docker (HuggingFace Spaces)
194
+
195
+ ```bash
196
+ docker build -t tool-orchestrator-env:latest .
197
+ docker run -p 8000:8000 -e CERAMIC_API_KEY=cer_sk_live_... tool-orchestrator-env:latest
198
+ ```
199
+
200
+ ---
201
+
202
+ ## Datasets
203
+
204
+ - **HotpotQA** — Yang et al., 2018. Multi-hop reasoning over Wikipedia.
205
+ - **MATH** — Hendrycks et al., 2021. Competition math levels 3–5.
206
+ - **GPQA** — Rein et al., 2023. Graduate-level science QA.
207
+ - **HumanEval** — Chen et al., 2021. Python programming tasks.
208
+
209
+ ---
210
+
211
+ ## About
212
+
213
+ ToolOrchestratorEnv extends SearchEconomicsEnv to a multi-tool setting, framing cost-aware tool selection as the core RL objective. Built for the OpenEnv competition track at AgentX (Berkeley RDI). Ceramic AI search API powers live web retrieval.
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ToolOrchestratorEnv.
2
+
3
+ Exposes the OpenEnv standard endpoints:
4
+ POST /reset → OrchestratorObservation + OrchestratorState
5
+ POST /step → OrchestratorObservation + reward + done + state
6
+ GET /health → {"status": "ok"}
7
+ GET /web → simple demo UI
8
+ GET /docs → OpenAPI (automatic)
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ import uuid
14
+ from contextlib import asynccontextmanager
15
+ from typing import Any, Dict, Optional
16
+
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.responses import HTMLResponse
19
+ from pydantic import BaseModel
20
+
21
+ from data.loader import load_all
22
+ from env.config import EnvConfig
23
+ from env.environment import ToolOrchestratorEnvironment
24
+ from env.models import OrchestratorAction
25
+ from tools import build_tool_registry
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Request / response wrappers
30
+ # ---------------------------------------------------------------------------
31
+
32
+ class ResetRequest(BaseModel):
33
+ seed: Optional[int] = None
34
+ config_overrides: Optional[Dict[str, Any]] = None
35
+
36
+
37
+ class StepRequest(BaseModel):
38
+ tool_id: str
39
+ query: Optional[str] = None
40
+ expression: Optional[str] = None
41
+ code_snippet: Optional[str] = None
42
+ answer: Optional[str] = None
43
+ metadata: Optional[Dict[str, Any]] = None
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # App factory
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def create_app() -> FastAPI:
51
+ config = EnvConfig()
52
+ tools = build_tool_registry(config)
53
+ dataset = load_all(split=config.data_split, max_per_domain=200)
54
+
55
+ # Multi-session state: session_id → ToolOrchestratorEnvironment
56
+ sessions: Dict[str, ToolOrchestratorEnvironment] = {}
57
+
58
+ # Default shared environment for single-session usage (no session_id)
59
+ default_env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
60
+
61
+ @asynccontextmanager
62
+ async def lifespan(app: FastAPI):
63
+ yield
64
+
65
+ app = FastAPI(
66
+ title="ToolOrchestratorEnv",
67
+ description="Multi-tool cost-aware RL environment (OpenEnv / AgentX)",
68
+ version="0.1.0",
69
+ lifespan=lifespan,
70
+ root_path=os.environ.get("ROOT_PATH", ""),
71
+ )
72
+
73
+ @app.get("/health")
74
+ def health():
75
+ return {"status": "ok"}
76
+
77
+ @app.post("/reset")
78
+ def reset(req: ResetRequest):
79
+ cfg = EnvConfig()
80
+ if req.config_overrides:
81
+ for k, v in req.config_overrides.items():
82
+ if hasattr(cfg, k):
83
+ setattr(cfg, k, v)
84
+ env = ToolOrchestratorEnvironment(config=cfg, tools=tools, dataset=dataset)
85
+ obs, state = env.reset(seed=req.seed)
86
+
87
+ session_id = str(uuid.uuid4())
88
+ sessions[session_id] = env
89
+
90
+ return {
91
+ "session_id": session_id,
92
+ "observation": obs.model_dump(),
93
+ "state": state.model_dump(),
94
+ }
95
+
96
+ @app.post("/step")
97
+ def step(req: StepRequest, session_id: Optional[str] = None):
98
+ env = sessions.get(session_id or "", default_env)
99
+ action = OrchestratorAction(
100
+ tool_id=req.tool_id,
101
+ query=req.query or "",
102
+ expression=req.expression or "",
103
+ code_snippet=req.code_snippet or "",
104
+ answer=req.answer or "",
105
+ metadata=req.metadata,
106
+ )
107
+ try:
108
+ obs, reward, done, state = env.step(action)
109
+ except RuntimeError as exc:
110
+ raise HTTPException(status_code=400, detail=str(exc))
111
+ except ValueError as exc:
112
+ raise HTTPException(status_code=422, detail=str(exc))
113
+
114
+ # Clean up finished sessions
115
+ if done and session_id and session_id in sessions:
116
+ del sessions[session_id]
117
+
118
+ return {
119
+ "observation": obs.model_dump(),
120
+ "reward": reward,
121
+ "done": done,
122
+ "state": state.model_dump(),
123
+ }
124
+
125
+ @app.get("/web", response_class=HTMLResponse)
126
+ def web_ui():
127
+ return _DEMO_HTML
128
+
129
+ return app
130
+
131
+
132
+ app = create_app()
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Demo UI
137
+ # ---------------------------------------------------------------------------
138
+
139
+ _DEMO_HTML = """<!DOCTYPE html>
140
+ <html lang="en">
141
+ <head>
142
+ <meta charset="UTF-8">
143
+ <title>ToolOrchestratorEnv</title>
144
+ <style>
145
+ body { font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }
146
+ h1 { color: #333; }
147
+ pre { background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }
148
+ button { padding: 8px 16px; margin: 4px; cursor: pointer; }
149
+ input, select, textarea { width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }
150
+ label { font-weight: bold; }
151
+ .tool-btn { background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }
152
+ .tool-btn:hover { background: #cfe1ff; }
153
+ #log { max-height: 480px; overflow-y: auto; }
154
+ </style>
155
+ </head>
156
+ <body>
157
+ <h1>ToolOrchestratorEnv</h1>
158
+ <p>Multi-tool cost-aware RL environment — AgentX / OpenEnv</p>
159
+
160
+ <button onclick="doReset()">Reset Episode</button>
161
+ <hr>
162
+ <label>Tool:</label>
163
+ <select id="tool">
164
+ <option value="ceramic_search">ceramic_search (cost 1.0) — Web retrieval</option>
165
+ <option value="wiki_lookup">wiki_lookup (cost 0.5) — Wikipedia</option>
166
+ <option value="calculator">calculator (cost 0.1) — Arithmetic / math</option>
167
+ <option value="code_executor">code_executor (cost 0.3) — Python execution</option>
168
+ <option value="llm_reason">llm_reason (cost 2.0) — LLM chain-of-thought</option>
169
+ <option value="commit">commit (cost 0.0) — Submit answer</option>
170
+ </select>
171
+ <label>Query / Expression / Code / Answer:</label>
172
+ <textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
173
+ <button class="tool-btn" onclick="doStep()">Step</button>
174
+ <hr>
175
+ <pre id="log">Click "Reset Episode" to start.</pre>
176
+
177
+ <script>
178
+ const log = document.getElementById('log');
179
+ let sessionId = null;
180
+
181
+ function append(text) { log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }
182
+
183
+ async function doReset() {
184
+ log.textContent = '';
185
+ const res = await fetch('/reset', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({seed: 42}) });
186
+ const data = await res.json();
187
+ sessionId = data.session_id || null;
188
+ append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
189
+ }
190
+
191
+ async function doStep() {
192
+ const tool_id = document.getElementById('tool').value;
193
+ const input = document.getElementById('query').value;
194
+ const body = { tool_id };
195
+ if (tool_id === 'commit') body.answer = input;
196
+ else if (tool_id === 'calculator') body.expression = input;
197
+ else if (tool_id === 'code_executor') body.code_snippet = input;
198
+ else body.query = input;
199
+
200
+ const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
201
+ const res = await fetch(url, { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify(body) });
202
+ const data = await res.json();
203
+ append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
204
+ }
205
+ </script>
206
+ </body>
207
+ </html>
208
+ """
baselines/__init__.py ADDED
File without changes
baselines/cheapest_first.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cheapest-first baseline — always calls the cheapest available tool first."""
2
+ from __future__ import annotations
3
+
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
8
+
9
+ from data.loader import load_all
10
+ from env.config import EnvConfig
11
+ from env.environment import ToolOrchestratorEnvironment
12
+ from env.models import OrchestratorAction
13
+ from tools import build_tool_registry
14
+
15
+
16
+ class CheapestFirstBaseline:
17
+ """Calls tools in ascending cost order, commits after exhausting budget."""
18
+
19
+ def __init__(self, config: EnvConfig):
20
+ # Sort non-commit tools by cost
21
+ self._order = sorted(
22
+ [t for t in config.tool_costs if t != "commit"],
23
+ key=lambda t: config.tool_costs[t],
24
+ )
25
+ self._commit_after = config.max_steps_per_question - 1
26
+ self._steps_on_q = 0
27
+
28
+ def get_action(self, obs) -> OrchestratorAction:
29
+ if self._steps_on_q >= self._commit_after:
30
+ self._steps_on_q = 0
31
+ return OrchestratorAction(tool_id="commit", answer="I don't know")
32
+ tool = self._order[self._steps_on_q % len(self._order)]
33
+ self._steps_on_q += 1
34
+ query = obs.question[:100] if hasattr(obs, "question") else ""
35
+ return OrchestratorAction(tool_id=tool, query=query)
36
+
37
+ def reset(self):
38
+ self._steps_on_q = 0
39
+
40
+
41
+ def run_episode(seed: int = 0) -> dict:
42
+ config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
43
+ tools = build_tool_registry(config)
44
+ dataset = load_all(max_per_domain=20)
45
+ env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
46
+ agent = CheapestFirstBaseline(config)
47
+
48
+ obs, state = env.reset(seed=seed)
49
+ agent.reset()
50
+ total_reward = 0.0
51
+ done = False
52
+ while not done:
53
+ action = agent.get_action(obs)
54
+ if action.tool_id == "commit":
55
+ agent.reset()
56
+ obs, reward, done, state = env.step(action)
57
+ total_reward += reward
58
+
59
+ return {
60
+ "total_reward": total_reward,
61
+ "accuracy": state.current_accuracy,
62
+ "budget_used": state.budget_spent,
63
+ "questions_answered": state.questions_answered,
64
+ }
65
+
66
+
67
+ if __name__ == "__main__":
68
+ result = run_episode(seed=42)
69
+ print("CheapestFirstBaseline:", result)
baselines/oracle.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Domain-aware oracle baseline — picks the best tool per domain heuristically."""
2
+ from __future__ import annotations
3
+
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
8
+
9
+ from data.loader import load_all
10
+ from env.config import EnvConfig
11
+ from env.environment import ToolOrchestratorEnvironment
12
+ from env.models import OrchestratorAction
13
+ from tools import build_tool_registry
14
+
15
+ # Heuristic: which tool to try at each step index for each domain
16
+ _DOMAIN_STRATEGY = {
17
+ "hotpotqa": ["ceramic_search", "wiki_lookup", "ceramic_search"],
18
+ "math": ["calculator", "llm_reason", "calculator"],
19
+ "gpqa": ["llm_reason", "ceramic_search", "wiki_lookup"],
20
+ "humaneval": ["code_executor", "llm_reason", "code_executor"],
21
+ }
22
+ _DEFAULT_STRATEGY = ["ceramic_search", "wiki_lookup", "llm_reason"]
23
+
24
+
25
+ class OracleBaseline:
26
+ """Uses domain knowledge to pick the best tool per step."""
27
+
28
+ def __init__(self, config: EnvConfig):
29
+ self._commit_after = config.max_steps_per_question - 1
30
+ self._steps_on_q = 0
31
+
32
+ def get_action(self, obs) -> OrchestratorAction:
33
+ if self._steps_on_q >= self._commit_after:
34
+ self._steps_on_q = 0
35
+ return OrchestratorAction(tool_id="commit", answer="I don't know")
36
+
37
+ domain = obs.domain if hasattr(obs, "domain") else "hotpotqa"
38
+ strategy = _DOMAIN_STRATEGY.get(domain, _DEFAULT_STRATEGY)
39
+ tool = strategy[self._steps_on_q % len(strategy)]
40
+ self._steps_on_q += 1
41
+ query = obs.question[:100] if hasattr(obs, "question") else ""
42
+ return OrchestratorAction(tool_id=tool, query=query)
43
+
44
+ def reset(self):
45
+ self._steps_on_q = 0
46
+
47
+
48
+ def run_episode(seed: int = 0) -> dict:
49
+ config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
50
+ tools = build_tool_registry(config)
51
+ dataset = load_all(max_per_domain=20)
52
+ env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
53
+ agent = OracleBaseline(config)
54
+
55
+ obs, state = env.reset(seed=seed)
56
+ agent.reset()
57
+ total_reward = 0.0
58
+ done = False
59
+ while not done:
60
+ action = agent.get_action(obs)
61
+ if action.tool_id == "commit":
62
+ agent.reset()
63
+ obs, reward, done, state = env.step(action)
64
+ total_reward += reward
65
+
66
+ return {
67
+ "total_reward": total_reward,
68
+ "accuracy": state.current_accuracy,
69
+ "budget_used": state.budget_spent,
70
+ "questions_answered": state.questions_answered,
71
+ }
72
+
73
+
74
+ if __name__ == "__main__":
75
+ result = run_episode(seed=42)
76
+ print("OracleBaseline:", result)
baselines/random_tool.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Random tool baseline — picks uniformly from available tools each step."""
2
+ from __future__ import annotations
3
+
4
+ import random
5
+ import sys
6
+ import os
7
+
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
9
+
10
+ from data.loader import load_all
11
+ from env.config import EnvConfig
12
+ from env.environment import ToolOrchestratorEnvironment
13
+ from env.models import TOOL_IDS, OrchestratorAction
14
+ from tools import build_tool_registry
15
+
16
+ _NON_COMMIT = [t for t in TOOL_IDS if t != "commit"]
17
+
18
+
19
+ class RandomToolBaseline:
20
+ """Selects a random tool each step; commits after max_steps_per_question - 1 steps."""
21
+
22
+ def __init__(self, commit_after: int = 3):
23
+ self.commit_after = commit_after
24
+ self._steps_on_q = 0
25
+
26
+ def get_action(self, obs) -> OrchestratorAction:
27
+ if self._steps_on_q >= self.commit_after:
28
+ self._steps_on_q = 0
29
+ return OrchestratorAction(tool_id="commit", answer="I don't know")
30
+ self._steps_on_q += 1
31
+ tool = random.choice(_NON_COMMIT)
32
+ query = obs.question[:100] if hasattr(obs, "question") else ""
33
+ return OrchestratorAction(tool_id=tool, query=query)
34
+
35
+ def reset(self):
36
+ self._steps_on_q = 0
37
+
38
+
39
+ def run_episode(seed: int = 0) -> dict:
40
+ config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
41
+ tools = build_tool_registry(config)
42
+ dataset = load_all(max_per_domain=20)
43
+ env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
44
+ agent = RandomToolBaseline(commit_after=config.max_steps_per_question - 1)
45
+
46
+ obs, state = env.reset(seed=seed)
47
+ agent.reset()
48
+ total_reward = 0.0
49
+ done = False
50
+ while not done:
51
+ action = agent.get_action(obs)
52
+ if action.tool_id == "commit":
53
+ agent.reset()
54
+ obs, reward, done, state = env.step(action)
55
+ total_reward += reward
56
+
57
+ return {
58
+ "total_reward": total_reward,
59
+ "accuracy": state.current_accuracy,
60
+ "budget_used": state.budget_spent,
61
+ "questions_answered": state.questions_answered,
62
+ }
63
+
64
+
65
+ if __name__ == "__main__":
66
+ result = run_episode(seed=42)
67
+ print("RandomToolBaseline:", result)
ceramic/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .client import CeramicClient, FallbackCeramicClient, SearchResult, get_ceramic_client
2
+
3
+ __all__ = ["CeramicClient", "FallbackCeramicClient", "SearchResult", "get_ceramic_client"]
ceramic/client.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ceramic AI search client.
2
+
3
+ Matches the interface used by SearchEconomicsEnv so both environments
4
+ share the same retrieval backend.
5
+
6
+ API key priority:
7
+ 1. CERAMIC_API_KEY env var
8
+ 2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility with SearchEcon)
9
+ 3. Falls back to FallbackCeramicClient (offline / CI, fully deterministic)
10
+
11
+ Ceramic API notes (verified 2025):
12
+ - Endpoint : POST https://api.ceramic.ai/search
13
+ - Body : {"query": "<string>"} (no pagination params supported)
14
+ - Response : {"requestId": "...", "result": {"results": [...], "totalResults": N}}
15
+ - Each result has: title, url, description, score
16
+ - Always returns up to 10 results per call
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import os
22
+ from dataclasses import dataclass
23
+ from typing import List
24
+
25
+ import httpx
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Result model
30
+ # ---------------------------------------------------------------------------
31
+
32
+ @dataclass
33
+ class SearchResult:
34
+ title: str
35
+ url: str
36
+ description: str
37
+ score: float = 0.0
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Live client
42
+ # ---------------------------------------------------------------------------
43
+
44
+ class CeramicClient:
45
+ """Thin wrapper around the Ceramic search API."""
46
+
47
+ BASE_URL = "https://api.ceramic.ai"
48
+
49
+ def __init__(self, api_key: str):
50
+ self._key = api_key
51
+ self._client = httpx.Client(
52
+ headers={"Authorization": f"Bearer {api_key}"},
53
+ timeout=10.0,
54
+ )
55
+
56
+ def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
57
+ """Search Ceramic and return up to top_k results (max 10)."""
58
+ if not query.strip():
59
+ return []
60
+ resp = self._client.post(
61
+ f"{self.BASE_URL}/search",
62
+ json={"query": query},
63
+ )
64
+ resp.raise_for_status()
65
+ data = resp.json()
66
+ raw = data.get("result", {}).get("results", [])
67
+ results = []
68
+ for item in raw[:top_k]:
69
+ results.append(SearchResult(
70
+ title=item.get("title", ""),
71
+ url=item.get("url", ""),
72
+ description=item.get("description", ""),
73
+ score=float(item.get("score", 0.0)),
74
+ ))
75
+ return results
76
+
77
+ def close(self):
78
+ self._client.close()
79
+
80
+ def __enter__(self):
81
+ return self
82
+
83
+ def __exit__(self, *args):
84
+ self.close()
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Offline fallback
89
+ # ---------------------------------------------------------------------------
90
+
91
+ class FallbackCeramicClient:
92
+ """Deterministic offline client used when no API key is available.
93
+
94
+ Generates reproducible fake results via SHA-256 hashing so tests
95
+ and CI runs are stable without network access.
96
+ """
97
+
98
+ def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
99
+ h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
100
+ results = []
101
+ for i in range(min(top_k, 3)):
102
+ seed = (h + i) % 10_000
103
+ results.append(SearchResult(
104
+ title=f"Result {seed}: {query[:40]}",
105
+ url=f"https://fallback.example.com/doc/{seed}",
106
+ description=f"Offline fallback result #{i+1} for query: {query}",
107
+ score=round(0.9 - i * 0.15, 3),
108
+ ))
109
+ return results
110
+
111
+ def close(self):
112
+ pass
113
+
114
+ def __enter__(self):
115
+ return self
116
+
117
+ def __exit__(self, *args):
118
+ pass
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Factory
123
+ # ---------------------------------------------------------------------------
124
+
125
+ def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
126
+ """Return a live CeramicClient if a key is set, otherwise FallbackCeramicClient."""
127
+ key = (
128
+ os.environ.get("CERAMIC_API_KEY")
129
+ or os.environ.get("SEE_CERAMIC_API_KEY")
130
+ )
131
+ if key:
132
+ return CeramicClient(api_key=key)
133
+ return FallbackCeramicClient()
client.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility shim — real code lives in ceramic/client.py.
2
+
3
+ Mirrors the SearchEconomicsEnv CeramicClient interface so the two
4
+ environments share the same retrieval backend.
5
+
6
+ Priority for the API key:
7
+ 1. CERAMIC_API_KEY env var
8
+ 2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility)
9
+ 3. Falls back to FallbackCeramicClient (offline, deterministic)
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import hashlib
14
+ import os
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Optional
18
+
19
+ import httpx
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Result model
24
+ # ---------------------------------------------------------------------------
25
+
26
+ @dataclass
27
+ class SearchResult:
28
+ title: str
29
+ url: str
30
+ description: str
31
+ score: float = 0.0
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Live client
36
+ # ---------------------------------------------------------------------------
37
+
38
+ class CeramicClient:
39
+ """Thin wrapper around the Ceramic search API."""
40
+
41
+ BASE_URL = "https://api.ceramic.ai/v1"
42
+
43
+ def __init__(self, api_key: str):
44
+ self._key = api_key
45
+ self._client = httpx.Client(
46
+ headers={"Authorization": f"Bearer {api_key}"},
47
+ timeout=10.0,
48
+ )
49
+
50
+ def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
51
+ resp = self._client.post(
52
+ f"{self.BASE_URL}/search",
53
+ json={"query": query, "top_k": top_k},
54
+ )
55
+ resp.raise_for_status()
56
+ data = resp.json()
57
+ results = []
58
+ for item in data.get("results", []):
59
+ results.append(SearchResult(
60
+ title=item.get("title", ""),
61
+ url=item.get("url", ""),
62
+ description=item.get("description", ""),
63
+ score=float(item.get("score", 0.0)),
64
+ ))
65
+ return results
66
+
67
+ def close(self):
68
+ self._client.close()
69
+
70
+ def __enter__(self):
71
+ return self
72
+
73
+ def __exit__(self, *args):
74
+ self.close()
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Offline fallback
79
+ # ---------------------------------------------------------------------------
80
+
81
+ class FallbackCeramicClient:
82
+ """Deterministic offline client — used when no API key is set."""
83
+
84
+ def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
85
+ # Stable hash → reproducible fake results per query
86
+ h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
87
+ results = []
88
+ for i in range(min(top_k, 3)):
89
+ seed = (h + i) % 10_000
90
+ results.append(SearchResult(
91
+ title=f"Result {seed}: {query[:40]}",
92
+ url=f"https://fallback.example.com/doc/{seed}",
93
+ description=f"Offline fallback result #{i+1} for query: {query}",
94
+ score=round(0.9 - i * 0.15, 3),
95
+ ))
96
+ return results
97
+
98
+ def close(self):
99
+ pass
100
+
101
+ def __enter__(self):
102
+ return self
103
+
104
+ def __exit__(self, *args):
105
+ pass
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Factory
110
+ # ---------------------------------------------------------------------------
111
+
112
+ _DEFAULT_KEY = "cer_sk_live_543fe74e79df_eyJvcmdfaWQiOiJvcmdfMDFLTlpINkU5RVNDTUowUUoyREpINFZWWEYiLCJrZXlfaWQiOiI1NDNmZTc0ZTc5ZGYifQ.k8I4Aljsk29y4Uki37Wxfd7QZHs40XSJVNBNnfksCtM"
113
+
114
+
115
+ def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
116
+ key = (
117
+ os.environ.get("CERAMIC_API_KEY")
118
+ or os.environ.get("SEE_CERAMIC_API_KEY")
119
+ or _DEFAULT_KEY
120
+ )
121
+ if key:
122
+ return CeramicClient(api_key=key)
123
+ return FallbackCeramicClient()
data/__init__.py ADDED
File without changes
data/loader.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-domain dataset loader for ToolOrchestratorEnv.
2
+
3
+ Returns a flat list of question dicts, each with a 'domain' key.
4
+ Adapted from CostAwareToolEnv/scripts/process_datasets.py.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ import re
10
+ import string
11
+ from typing import Any, Dict, List, Optional
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # HuggingFace loader helper
16
+ # ---------------------------------------------------------------------------
17
+
18
+ def _hf_load(repo_id: str, config: Optional[str], split: str):
19
+ import datasets as hf
20
+ kwargs: Dict[str, Any] = {"split": split, "trust_remote_code": True}
21
+ if config:
22
+ kwargs["name"] = config
23
+ return hf.load_dataset(repo_id, **kwargs)
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # MATH (levels 3-5)
28
+ # ---------------------------------------------------------------------------
29
+
30
+ def _extract_boxed(solution: str):
31
+ for cmd in ("boxed", "fbox"):
32
+ marker = f"\\{cmd}" + "{"
33
+ start = solution.rfind(marker)
34
+ if start == -1:
35
+ continue
36
+ idx = start + len(marker) - 1
37
+ depth = 0
38
+ for i in range(idx, len(solution)):
39
+ if solution[i] == "{":
40
+ depth += 1
41
+ elif solution[i] == "}":
42
+ depth -= 1
43
+ if depth == 0:
44
+ return solution[i + 1 - (i - idx):i].strip()
45
+ # fallback: last non-empty line
46
+ lines = [l.strip() for l in solution.splitlines() if l.strip()]
47
+ return lines[-1] if lines else ""
48
+
49
+
50
+ def _load_math(split: str, max_rows: int) -> List[Dict]:
51
+ candidates = [
52
+ ("DigitalLearningGmbH/MATH-lighteval", "default", "train"),
53
+ ("lighteval/MATH-Hard", "default", "train"),
54
+ ("hendrycks/competition_math", None, "train"),
55
+ ]
56
+ dataset = None
57
+ for repo_id, cfg, spl in candidates:
58
+ try:
59
+ dataset = _hf_load(repo_id, cfg, spl)
60
+ break
61
+ except Exception:
62
+ continue
63
+ if dataset is None:
64
+ return []
65
+
66
+ rows = []
67
+ for ex in dataset:
68
+ level_text = str(ex.get("level", ""))
69
+ m = re.search(r"(\d+)", level_text)
70
+ if not m or int(m.group(1)) not in (3, 4, 5):
71
+ continue
72
+ answer = _extract_boxed(str(ex.get("solution", "")))
73
+ rows.append({
74
+ "question": str(ex.get("problem", "")).strip(),
75
+ "answer": answer,
76
+ "domain": "math",
77
+ "difficulty": m.group(1),
78
+ "subject": str(ex.get("type", "")),
79
+ "source": "math",
80
+ })
81
+ if len(rows) >= max_rows:
82
+ break
83
+ return rows
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # HotpotQA
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def _load_hotpotqa(split: str, max_rows: int) -> List[Dict]:
91
+ hf_split = "train" if split in ("train", "validation") else split
92
+ dataset = None
93
+ for cfg in ("distractor", "fullwiki"):
94
+ try:
95
+ dataset = _hf_load("hotpotqa/hotpot_qa", cfg, hf_split)
96
+ break
97
+ except Exception:
98
+ continue
99
+ if dataset is None:
100
+ return []
101
+
102
+ subset = dataset.shuffle(seed=42).select(range(min(max_rows, len(dataset))))
103
+ rows = []
104
+ for ex in subset:
105
+ rows.append({
106
+ "question": str(ex.get("question", "")).strip(),
107
+ "answer": str(ex.get("answer", "")).strip(),
108
+ "domain": "hotpotqa",
109
+ "difficulty": str(ex.get("level", "")),
110
+ "type": str(ex.get("type", "")),
111
+ "source": "hotpotqa",
112
+ })
113
+ return rows
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # GPQA
118
+ # ---------------------------------------------------------------------------
119
+
120
+ def _resolve_gpqa_answer(ex: Dict) -> str:
121
+ val = str(ex.get("Correct Answer", "")).strip()
122
+ if val.upper() in {"A", "B", "C", "D"}:
123
+ mapping = {
124
+ "A": str(ex.get("Answer A", "")),
125
+ "B": str(ex.get("Answer B", "")),
126
+ "C": str(ex.get("Answer C", "")),
127
+ "D": str(ex.get("Answer D", "")),
128
+ }
129
+ return mapping.get(val.upper(), val).strip()
130
+ return val
131
+
132
+
133
+ def _load_gpqa(split: str, max_rows: int) -> List[Dict]:
134
+ dataset = None
135
+ for repo in ("Idavidrein/gpqa", "Wanfq/gpqa"):
136
+ for cfg in ("gpqa_diamond", "gpqa_main"):
137
+ try:
138
+ dataset = _hf_load(repo, cfg, "train")
139
+ break
140
+ except Exception:
141
+ continue
142
+ if dataset is not None:
143
+ break
144
+ if dataset is None:
145
+ return []
146
+
147
+ rows = []
148
+ for ex in dataset:
149
+ answer = _resolve_gpqa_answer(ex)
150
+ rows.append({
151
+ "question": str(ex.get("Question", "")).strip(),
152
+ "answer": answer,
153
+ "domain": "gpqa",
154
+ "difficulty": "graduate",
155
+ "source": "gpqa",
156
+ })
157
+ if len(rows) >= max_rows:
158
+ break
159
+ return rows
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # HumanEval
164
+ # ---------------------------------------------------------------------------
165
+
166
+ def _load_humaneval(split: str, max_rows: int) -> List[Dict]:
167
+ dataset = None
168
+ for repo in ("openai/openai_humaneval", "openai/human-eval"):
169
+ try:
170
+ dataset = _hf_load(repo, None, "test")
171
+ break
172
+ except Exception:
173
+ continue
174
+ if dataset is None:
175
+ return []
176
+
177
+ rows = []
178
+ for ex in dataset:
179
+ rows.append({
180
+ "question": str(ex.get("prompt", "")).strip(),
181
+ "answer": str(ex.get("canonical_solution", "")).strip(),
182
+ "domain": "humaneval",
183
+ "difficulty": "code",
184
+ "task_id": str(ex.get("task_id", "")),
185
+ "test": str(ex.get("test", "")),
186
+ "entry_point": str(ex.get("entry_point", "")),
187
+ "source": "humaneval",
188
+ })
189
+ if len(rows) >= max_rows:
190
+ break
191
+ return rows
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Synthetic fallback (offline / CI)
196
+ # ---------------------------------------------------------------------------
197
+
198
+ _SYNTHETIC_TEMPLATES = [
199
+ ("What is {a} + {b}?", "{c}", "math"),
200
+ ("Who wrote {work}?", "{author}", "hotpotqa"),
201
+ ("Solve for x: {a}x + {b} = {c}", "{x}", "math"),
202
+ ("What is the capital of {country}?", "{capital}", "hotpotqa"),
203
+ ]
204
+
205
+ _SYNTHETIC_DATA = [
206
+ {"a": 12, "b": 7, "c": 19, "work": "Hamlet", "author": "Shakespeare",
207
+ "country": "France", "capital": "Paris", "x": 3},
208
+ {"a": 25, "b": 13, "c": 38, "work": "1984", "author": "George Orwell",
209
+ "country": "Germany", "capital": "Berlin", "x": 5},
210
+ {"a": 100, "b": 44, "c": 144, "work": "The Odyssey", "author": "Homer",
211
+ "country": "Japan", "capital": "Tokyo", "x": 7},
212
+ ]
213
+
214
+
215
+ def _synthetic_questions(n: int) -> List[Dict]:
216
+ rows = []
217
+ for i in range(n):
218
+ tmpl, ans_tmpl, domain = _SYNTHETIC_TEMPLATES[i % len(_SYNTHETIC_TEMPLATES)]
219
+ data = _SYNTHETIC_DATA[i % len(_SYNTHETIC_DATA)]
220
+ try:
221
+ question = tmpl.format(**data)
222
+ answer = ans_tmpl.format(**data)
223
+ except KeyError:
224
+ question = f"Synthetic question {i}"
225
+ answer = f"answer_{i}"
226
+ rows.append({
227
+ "question": question,
228
+ "answer": str(answer),
229
+ "domain": domain,
230
+ "difficulty": "easy",
231
+ "source": "synthetic",
232
+ })
233
+ return rows
234
+
235
+
236
+ # ---------------------------------------------------------------------------
237
+ # Public API
238
+ # ---------------------------------------------------------------------------
239
+
240
+ _LOADERS = {
241
+ "hotpotqa": _load_hotpotqa,
242
+ "math": _load_math,
243
+ "gpqa": _load_gpqa,
244
+ "humaneval": _load_humaneval,
245
+ }
246
+
247
+
248
+ def load_all(split: str = "validation", max_per_domain: int = 200) -> List[Dict]:
249
+ """Load all four domains and return a flat list with 'domain' keys.
250
+
251
+ Falls back to synthetic questions if a domain is unavailable.
252
+ """
253
+ all_questions: List[Dict] = []
254
+ for domain, loader_fn in _LOADERS.items():
255
+ try:
256
+ rows = loader_fn(split, max_per_domain)
257
+ if rows:
258
+ all_questions.extend(rows)
259
+ print(f"[loader] {domain}: {len(rows)} questions")
260
+ else:
261
+ raise ValueError("empty")
262
+ except Exception as exc:
263
+ print(f"[loader] {domain} unavailable ({exc}), using synthetic fallback")
264
+ synth = _synthetic_questions(max(5, max_per_domain // 10))
265
+ for q in synth:
266
+ q["domain"] = domain
267
+ all_questions.extend(synth)
268
+
269
+ random.shuffle(all_questions)
270
+ return all_questions
env/__init__.py ADDED
File without changes
env/answer_grading.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Answer grading utilities: exact match + token F1.
2
+
3
+ Ported from SearchEconomicsEnv/env/answer_grading.py and adapted for
4
+ multi-domain use (HotpotQA-style EM/F1 + code/math fallback).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import re
10
+ import string
11
+ from collections import Counter
12
+ from typing import Tuple
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Normalisation
17
+ # ---------------------------------------------------------------------------
18
+
19
+ def normalize_answer(text: str) -> list[str]:
20
+ """Lowercase, strip articles/punctuation, tokenise."""
21
+ text = text.lower().strip()
22
+ # Remove articles
23
+ text = re.sub(r"\b(a|an|the)\b", " ", text)
24
+ # Remove punctuation
25
+ text = text.translate(str.maketrans("", "", string.punctuation))
26
+ return text.split()
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Metrics
31
+ # ---------------------------------------------------------------------------
32
+
33
+ def exact_match(pred: str, gold: str) -> bool:
34
+ return normalize_answer(pred) == normalize_answer(gold)
35
+
36
+
37
+ def token_f1(pred: str, gold: str) -> float:
38
+ pred_tokens = normalize_answer(pred)
39
+ gold_tokens = normalize_answer(gold)
40
+ if not pred_tokens or not gold_tokens:
41
+ return float(pred_tokens == gold_tokens)
42
+ common = Counter(pred_tokens) & Counter(gold_tokens)
43
+ num_common = sum(common.values())
44
+ if num_common == 0:
45
+ return 0.0
46
+ precision = num_common / len(pred_tokens)
47
+ recall = num_common / len(gold_tokens)
48
+ return 2 * precision * recall / (precision + recall)
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Answer extraction
53
+ # ---------------------------------------------------------------------------
54
+
55
+ def extract_answer(raw: str) -> str:
56
+ """Pull the answer string out of various agent output formats."""
57
+ # Strip markdown fences
58
+ raw = re.sub(r"```[a-z]*\n?", "", raw).strip()
59
+
60
+ # Try JSON {"answer": ...}
61
+ try:
62
+ parsed = json.loads(raw)
63
+ if isinstance(parsed, dict):
64
+ for key in ("answer", "Answer", "result", "Result"):
65
+ if key in parsed:
66
+ return str(parsed[key]).strip()
67
+ except (json.JSONDecodeError, ValueError):
68
+ pass
69
+
70
+ # Prefix patterns
71
+ for prefix in ("Answer:", "Final answer:", "Result:", "Output:"):
72
+ idx = raw.lower().find(prefix.lower())
73
+ if idx != -1:
74
+ return raw[idx + len(prefix):].strip().split("\n")[0].strip()
75
+
76
+ # Last non-empty line
77
+ lines = [line.strip() for line in raw.splitlines() if line.strip()]
78
+ return lines[-1] if lines else raw.strip()
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Public entry point
83
+ # ---------------------------------------------------------------------------
84
+
85
+ def grade(predicted: str, ground_truth: str) -> Tuple[bool, float, float]:
86
+ """Return (exact_match, f1, quality) where quality ∈ [0, 1]."""
87
+ extracted = extract_answer(predicted)
88
+ em = exact_match(extracted, ground_truth)
89
+ f1 = token_f1(extracted, ground_truth)
90
+ quality = 1.0 if em else f1
91
+ return em, f1, quality
env/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for ToolOrchestratorEnv.
2
+
3
+ All tuneable parameters live here so training scripts, the server, and
4
+ baselines all read from a single source of truth. Override individual
5
+ fields in /reset via config_overrides, or subclass for experiment sweeps.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, Optional
11
+
12
+
13
+ @dataclass
14
+ class EnvConfig:
15
+ # ── Episode structure ────────────────────────────────────────────────────
16
+ total_budget: float = 50.0 # Total cost units for the whole episode
17
+ num_questions: int = 10 # Questions drawn per episode
18
+ max_steps_per_question: int = 8 # Auto-commit after this many tool calls
19
+ data_split: str = "validation" # HuggingFace dataset split to load
20
+ seed: Optional[int] = None # Global RNG seed (None = random)
21
+ shuffle_questions: bool = True # Shuffle sampled questions each episode
22
+
23
+ # ── Domain mix ──────────────────────────────────────────────────────────
24
+ # Fraction of questions drawn from each dataset. Must sum to ~1.0.
25
+ domain_mix: Dict[str, float] = field(default_factory=lambda: {
26
+ "hotpotqa": 0.4, # Multi-hop factual QA
27
+ "math": 0.3, # Competition math (levels 3-5)
28
+ "gpqa": 0.2, # Graduate-level science
29
+ "humaneval": 0.1, # Python programming tasks
30
+ })
31
+
32
+ # ── Tool costs ──────────────────────────────────────────────────────────
33
+ # Budget units consumed per tool call. Commit is always free.
34
+ tool_costs: Dict[str, float] = field(default_factory=lambda: {
35
+ "ceramic_search": 1.0,
36
+ "wiki_lookup": 0.5,
37
+ "calculator": 0.1,
38
+ "code_executor": 0.3,
39
+ "llm_reason": 2.0,
40
+ "commit": 0.0,
41
+ })
42
+
43
+ # ── Reward shaping ───────────────────────────────────────────────────────
44
+ correct_reward: float = 1.0 # Base reward for a correct commit
45
+ incorrect_reward: float = -0.5 # Base reward for a wrong commit
46
+ efficiency_bonus_weight: float = 0.1 # γ: scales the efficiency bonus
47
+ efficiency_bonus_threshold: float = 0.5 # Minimum quality to earn the bonus
48
+
49
+ # ── Grading ─────────────────────────────────────────────────────────────
50
+ # "em_only" → only exact match counts as correct
51
+ # "em_or_f1" → token F1 ≥ f1_count_threshold also counts as correct
52
+ grade_count_correct_mode: str = "em_or_f1"
53
+ f1_count_threshold: float = 0.5
env/environment.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core RL environment: ToolOrchestratorEnvironment.
2
+
3
+ Step logic:
4
+ - Agent receives an OrchestratorObservation with the current question,
5
+ budget, context, and available tools.
6
+ - Agent picks a tool_id and optional query / code_snippet / answer.
7
+ - Environment dispatches to the appropriate tool, charges cost, appends
8
+ result to context_window, and returns the next observation + reward.
9
+ - Episode ends when budget is exhausted OR all questions are answered.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import time
14
+ import uuid
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ from .answer_grading import grade
18
+ from .config import EnvConfig
19
+ from .models import (
20
+ OrchestratorAction,
21
+ OrchestratorObservation,
22
+ OrchestratorState,
23
+ ToolResult,
24
+ TOOL_IDS,
25
+ )
26
+ from .reward import commit_reward, step_reward
27
+
28
+
29
+ class ToolOrchestratorEnvironment:
30
+ """
31
+ OpenEnv-compatible RL environment for multi-tool cost-aware QA.
32
+
33
+ Supports external tool injection so the server can wire in live
34
+ Ceramic, code executor, etc. Tools are callables with signature:
35
+ tool_fn(action: OrchestratorAction) -> ToolResult
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ config: Optional[EnvConfig] = None,
41
+ tools: Optional[Dict[str, Any]] = None,
42
+ dataset: Optional[List[Dict[str, Any]]] = None,
43
+ ):
44
+ self.config = config or EnvConfig()
45
+ self.tools = tools or {} # tool_id -> callable
46
+ self.dataset = dataset or [] # List of {question, answer, domain, ...}
47
+ self._state: Optional[OrchestratorState] = None
48
+ self._questions: List[Dict[str, Any]] = []
49
+ self._current_q_idx: int = 0
50
+ self._context_window: List[str] = []
51
+ self._tools_used_this_q: List[str] = []
52
+ self._steps_this_q: int = 0
53
+ self._episode_done: bool = False
54
+
55
+ # -----------------------------------------------------------------------
56
+ # Reset
57
+ # -----------------------------------------------------------------------
58
+
59
+ def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
60
+ import random
61
+ effective_seed = seed if seed is not None else self.config.seed
62
+ rng = random.Random(effective_seed)
63
+
64
+ questions = _sample_questions(self.dataset, self.config, rng)
65
+ self._questions = questions
66
+ self._current_q_idx = 0
67
+ self._episode_done = False
68
+
69
+ self._state = OrchestratorState(
70
+ episode_id=str(uuid.uuid4()),
71
+ total_budget=self.config.total_budget,
72
+ budget_spent=0,
73
+ questions_answered=0,
74
+ total_correct=0,
75
+ current_accuracy=0.0,
76
+ budget_remaining_ratio=1.0,
77
+ current_question_idx=0,
78
+ current_question_steps=0,
79
+ step_count=0,
80
+ )
81
+
82
+ self._context_window = []
83
+ self._tools_used_this_q = []
84
+ self._steps_this_q = 0
85
+
86
+ obs = self._make_obs(reward=None, question_done=False, done=False)
87
+ return obs, self._state
88
+
89
+ # -----------------------------------------------------------------------
90
+ # Step
91
+ # -----------------------------------------------------------------------
92
+
93
+ def step(
94
+ self, action: OrchestratorAction
95
+ ) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
96
+ if self._episode_done:
97
+ raise RuntimeError("Episode is done. Call reset() first.")
98
+
99
+ tool_id = action.tool_id
100
+ if tool_id not in TOOL_IDS:
101
+ raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
102
+
103
+ state = self._state
104
+ config = self.config
105
+
106
+ if self._current_q_idx >= len(self._questions):
107
+ self._episode_done = True
108
+ raise RuntimeError("Episode is done. Call reset() first.")
109
+
110
+ q_entry = self._questions[self._current_q_idx]
111
+ gold = q_entry["answer"]
112
+
113
+ # ---- Commit ---------------------------------------------------
114
+ if tool_id == "commit":
115
+ raw_pred = action.answer or ""
116
+ em, f1, quality = grade(raw_pred, gold)
117
+
118
+ r = commit_reward(
119
+ quality=quality,
120
+ budget_remaining_ratio=state.budget_remaining_ratio,
121
+ config=config,
122
+ )
123
+
124
+ is_correct = (
125
+ em if config.grade_count_correct_mode == "em_only"
126
+ else (em or f1 >= config.f1_count_threshold)
127
+ )
128
+ state.questions_answered += 1
129
+ state.total_correct += int(is_correct)
130
+ state.current_accuracy = state.total_correct / state.questions_answered
131
+
132
+ self._current_q_idx += 1
133
+ self._context_window = []
134
+ self._tools_used_this_q = []
135
+ self._steps_this_q = 0
136
+
137
+ episode_done = (
138
+ self._current_q_idx >= len(self._questions)
139
+ or state.budget_spent >= state.total_budget
140
+ )
141
+ self._episode_done = episode_done
142
+ state.current_question_idx = self._current_q_idx
143
+ state.current_question_steps = 0
144
+
145
+ obs = self._make_obs(
146
+ reward=r,
147
+ question_done=True,
148
+ done=episode_done,
149
+ last_tool_result=ToolResult(
150
+ tool_id="commit", cost=0,
151
+ output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
152
+ ),
153
+ )
154
+ return obs, r, episode_done, state
155
+
156
+ # ---- Tool call ------------------------------------------------
157
+ cost = config.tool_costs.get(tool_id, 0)
158
+ budget_after = state.budget_spent + cost
159
+
160
+ if budget_after > state.total_budget:
161
+ r = config.incorrect_reward
162
+ self._episode_done = True
163
+ obs = self._make_obs(reward=r, question_done=True, done=True)
164
+ return obs, r, True, state
165
+
166
+ t0 = time.perf_counter()
167
+ tool_fn = self.tools.get(tool_id)
168
+ if tool_fn is None:
169
+ tool_result = ToolResult(
170
+ tool_id=tool_id, cost=cost,
171
+ output="[Tool not available in this environment]",
172
+ latency_s=0.0,
173
+ error="not_available",
174
+ )
175
+ else:
176
+ try:
177
+ tool_result = tool_fn(action)
178
+ tool_result.cost = cost
179
+ except Exception as exc:
180
+ tool_result = ToolResult(
181
+ tool_id=tool_id, cost=cost,
182
+ output=f"[Error: {exc}]",
183
+ latency_s=time.perf_counter() - t0,
184
+ error=str(exc),
185
+ )
186
+ tool_result.latency_s = time.perf_counter() - t0
187
+
188
+ state.budget_spent = budget_after
189
+ state.budget_remaining_ratio = max(
190
+ 0.0, (state.total_budget - state.budget_spent) / state.total_budget
191
+ )
192
+ state.step_count += 1
193
+ state.current_question_steps += 1
194
+ self._steps_this_q += 1
195
+ self._tools_used_this_q.append(tool_id)
196
+ self._context_window.append(f"[{tool_id}] {tool_result.output}")
197
+
198
+ r = step_reward(tool_id, config)
199
+
200
+ question_done = self._steps_this_q >= config.max_steps_per_question
201
+ episode_done = (
202
+ state.budget_spent >= state.total_budget
203
+ or (question_done and self._current_q_idx + 1 >= len(self._questions))
204
+ )
205
+ if question_done and not episode_done:
206
+ self._current_q_idx += 1
207
+ state.questions_answered += 1
208
+ self._context_window = []
209
+ self._tools_used_this_q = []
210
+ self._steps_this_q = 0
211
+ state.current_question_idx = self._current_q_idx
212
+ state.current_question_steps = 0
213
+
214
+ self._episode_done = episode_done
215
+
216
+ obs = self._make_obs(
217
+ reward=r,
218
+ question_done=question_done,
219
+ done=episode_done,
220
+ last_tool_result=tool_result,
221
+ )
222
+ return obs, r, episode_done, state
223
+
224
+ # -----------------------------------------------------------------------
225
+ # Internal helpers
226
+ # -----------------------------------------------------------------------
227
+
228
+ def _make_obs(
229
+ self,
230
+ reward: Optional[float],
231
+ question_done: bool,
232
+ done: bool,
233
+ last_tool_result: Optional[ToolResult] = None,
234
+ ) -> OrchestratorObservation:
235
+ state = self._state
236
+ cfg = self.config
237
+
238
+ if 0 <= self._current_q_idx < len(self._questions):
239
+ q_entry = self._questions[self._current_q_idx]
240
+ elif self._questions:
241
+ q_entry = self._questions[-1]
242
+ else:
243
+ q_entry = {"question": "", "answer": "", "domain": ""}
244
+
245
+ return OrchestratorObservation(
246
+ question=q_entry.get("question", ""),
247
+ question_idx=self._current_q_idx,
248
+ domain=q_entry.get("domain", ""),
249
+ question_embedding=[],
250
+ total_budget=cfg.total_budget,
251
+ budget_spent=state.budget_spent,
252
+ budget_remaining=state.total_budget - state.budget_spent,
253
+ budget_remaining_ratio=state.budget_remaining_ratio,
254
+ tools_used_this_question=list(self._tools_used_this_q),
255
+ steps_used_this_question=self._steps_this_q,
256
+ max_steps_per_question=cfg.max_steps_per_question,
257
+ last_tool_result=last_tool_result,
258
+ context_window=list(self._context_window),
259
+ step_idx=state.step_count,
260
+ questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
261
+ questions_answered=state.questions_answered,
262
+ accuracy_so_far=state.current_accuracy,
263
+ question_done=question_done,
264
+ done=done,
265
+ reward=reward,
266
+ )
267
+
268
+
269
+ # ---------------------------------------------------------------------------
270
+ # Dataset sampling helper
271
+ # ---------------------------------------------------------------------------
272
+
273
+ def _sample_questions(
274
+ dataset: List[Dict[str, Any]],
275
+ config: EnvConfig,
276
+ rng: Any,
277
+ ) -> List[Dict[str, Any]]:
278
+ """Sample `config.num_questions` questions according to domain_mix."""
279
+ by_domain: Dict[str, List[Dict]] = {}
280
+ for item in dataset:
281
+ d = item.get("domain", "hotpotqa")
282
+ by_domain.setdefault(d, []).append(item)
283
+
284
+ selected = []
285
+ for domain, frac in config.domain_mix.items():
286
+ n = round(config.num_questions * frac)
287
+ pool = by_domain.get(domain, [])
288
+ if pool and n > 0:
289
+ selected.extend(rng.sample(pool, min(n, len(pool))))
290
+
291
+ if len(selected) < config.num_questions and dataset:
292
+ remaining = [d for d in dataset if d not in selected]
293
+ rng.shuffle(remaining)
294
+ selected.extend(remaining[: config.num_questions - len(selected)])
295
+
296
+ if config.shuffle_questions:
297
+ rng.shuffle(selected)
298
+
299
+ return selected[: config.num_questions]
env/models.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic data models for ToolOrchestratorEnv.
2
+
3
+ Three main types flow through the environment:
4
+ OrchestratorAction — agent → env (what tool to call)
5
+ OrchestratorObservation — env → agent (what the agent sees)
6
+ OrchestratorState — env → server (full bookkeeping snapshot)
7
+
8
+ ToolResult is returned by every tool and attached to the observation.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ from pydantic import BaseModel, Field
15
+
16
+ # Canonical tool IDs — order matters for UI display; must match config.tool_costs keys.
17
+ TOOL_IDS = [
18
+ "ceramic_search", # web retrieval (most useful for HotpotQA)
19
+ "wiki_lookup", # Wikipedia summary (good for entity facts)
20
+ "calculator", # safe AST math eval (essential for MATH)
21
+ "code_executor", # sandboxed Python (HumanEval)
22
+ "llm_reason", # LLM chain-of-thought (GPQA)
23
+ "commit", # submit answer — always free
24
+ ]
25
+
26
+
27
+ class OrchestratorAction(BaseModel):
28
+ """One step of the agent's interaction with the environment.
29
+
30
+ Fields used per tool:
31
+ ceramic_search → query
32
+ wiki_lookup → query
33
+ calculator → expression (falls back to query if blank)
34
+ code_executor → code_snippet (falls back to query if blank)
35
+ llm_reason → query
36
+ commit → answer
37
+ """
38
+ tool_id: str
39
+ query: str = ""
40
+ expression: str = ""
41
+ code_snippet: str = ""
42
+ answer: str = ""
43
+ metadata: Optional[Dict[str, Any]] = None
44
+
45
+
46
+ class ToolResult(BaseModel):
47
+ """Output produced by one tool call.
48
+
49
+ Attached to OrchestratorObservation.last_tool_result and also
50
+ appended (as a string) to the context_window.
51
+ """
52
+ tool_id: str
53
+ output: str = "" # Human-readable result text
54
+ cost: float = 0.0 # Budget units charged (set by environment)
55
+ latency_s: float = 0.0 # Wall-clock seconds (set by environment)
56
+ error: Optional[str] = None # Non-None if the tool call failed
57
+
58
+
59
+ class OrchestratorObservation(BaseModel):
60
+ """Everything the agent sees at the start of each step.
61
+
62
+ Designed to be complete: the agent should be able to make an
63
+ informed tool-selection decision using only this observation.
64
+ """
65
+ # ── Current question ────────────────────────────────────────────────────
66
+ question: str # Full question text
67
+ question_idx: int # Position in the episode (0-indexed)
68
+ domain: str # "hotpotqa" | "math" | "gpqa" | "humaneval"
69
+ question_embedding: List[float] = Field(default_factory=list) # Optional embedding vector
70
+
71
+ # ── Budget ──────────────────────────────────────────────────────────────
72
+ total_budget: float
73
+ budget_spent: float
74
+ budget_remaining: float
75
+ budget_remaining_ratio: float # budget_remaining / total_budget ∈ [0, 1]
76
+
77
+ # ── Progress on the current question ────────────────────────────────────
78
+ tools_used_this_question: List[str] = Field(default_factory=list)
79
+ steps_used_this_question: int = 0
80
+ max_steps_per_question: int = 8
81
+ last_tool_result: Optional[ToolResult] = None
82
+ context_window: List[str] = Field(default_factory=list) # "[tool_id] output" strings
83
+
84
+ # ── Episode-level progress ───────────────────────────────────────────────
85
+ step_idx: int = 0 # Global step counter
86
+ questions_remaining: int = 0 # Questions not yet started
87
+ questions_answered: int = 0 # Questions that received a commit
88
+ accuracy_so_far: float = 0.0 # Running correctness rate
89
+
90
+ # ── Terminal signals ─────────────────────────────────────────────────────
91
+ question_done: bool = False # This question just ended (commit or max_steps)
92
+ done: bool = False # Episode is over
93
+ reward: Optional[float] = None # Reward from the *previous* step
94
+
95
+
96
+ class OrchestratorState(BaseModel):
97
+ """Full bookkeeping snapshot — returned alongside observation for logging.
98
+
99
+ Contains all fields needed to reconstruct the episode history without
100
+ digging into the environment's internal attributes.
101
+ """
102
+ episode_id: str
103
+ total_budget: float
104
+ budget_spent: float = 0.0
105
+ questions_answered: int = 0
106
+ total_correct: int = 0
107
+ current_accuracy: float = 0.0
108
+ budget_remaining_ratio: float = 1.0
109
+ current_question_idx: int = 0
110
+ current_question_steps: int = 0
111
+ step_count: int = 0
env/reward.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward functions for ToolOrchestratorEnv.
2
+
3
+ Adapted from SearchEconomicsEnv/env/reward.py, generalised to a multi-tool
4
+ action space where each tool has its own cost.
5
+
6
+ Two reward signals:
7
+
8
+ step_reward — small negative penalty charged every time the agent calls
9
+ a tool (including commit = 0 cost). Discourages wasted
10
+ calls without forbidding exploration.
11
+
12
+ commit_reward — composite reward awarded when the agent submits an answer.
13
+ Balances answer quality against remaining budget (Weitzman
14
+ style: you earn a bonus for being both correct *and* frugal).
15
+ """
16
+ from __future__ import annotations
17
+
18
+ from .config import EnvConfig
19
+
20
+
21
+ def step_reward(tool_id: str, config: EnvConfig) -> float:
22
+ """Return the (negative) cost of calling tool_id.
23
+
24
+ Example: calculator → -0.1, llm_reason → -2.0, commit → 0.0
25
+ """
26
+ return -config.tool_costs.get(tool_id, 0.0)
27
+
28
+
29
+ def commit_reward(
30
+ quality: float,
31
+ budget_remaining_ratio: float,
32
+ config: EnvConfig,
33
+ ) -> float:
34
+ """Composite reward on commit.
35
+
36
+ Formula
37
+ -------
38
+ base = incorrect_reward + quality × (correct_reward − incorrect_reward)
39
+ η = 1 if quality ≥ efficiency_bonus_threshold, else 0
40
+ bonus = η × efficiency_bonus_weight × budget_remaining_ratio
41
+ R = base + bonus
42
+
43
+ The efficiency bonus (bonus) is only non-zero when the agent both answers
44
+ correctly (quality above threshold) *and* conserves budget. This creates
45
+ a soft incentive to use cheaper tools and commit early when confident.
46
+
47
+ Parameters
48
+ ----------
49
+ quality : float in [0, 1] — max(ExactMatch, TokenF1)
50
+ budget_remaining_ratio: float in [0, 1] — fraction of budget still unspent
51
+ config : EnvConfig
52
+ """
53
+ q = max(0.0, min(1.0, quality))
54
+ base = config.incorrect_reward + q * (config.correct_reward - config.incorrect_reward)
55
+ eta = 1.0 if q >= config.efficiency_bonus_threshold else 0.0
56
+ bonus = eta * config.efficiency_bonus_weight * budget_remaining_ratio
57
+ return base + bonus
environment.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility shim — real code lives in env/environment.py.
2
+
3
+ Step logic:
4
+ - Agent receives an OrchestratorObservation with the current question,
5
+ budget, context, and available tools.
6
+ - Agent picks a tool_id and optional query / code_snippet / answer.
7
+ - Environment dispatches to the appropriate tool, charges cost, appends
8
+ result to context_window, and returns the next observation + reward.
9
+ - Episode ends when budget is exhausted OR all questions are answered.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import time
14
+ import uuid
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ from env.answer_grading import grade
18
+ from env.config import EnvConfig
19
+ from env.models import (
20
+ OrchestratorAction,
21
+ OrchestratorObservation,
22
+ OrchestratorState,
23
+ ToolResult,
24
+ TOOL_IDS,
25
+ )
26
+ from env.reward import commit_reward, step_reward
27
+
28
+
29
+ class ToolOrchestratorEnvironment:
30
+ """
31
+ OpenEnv-compatible RL environment for multi-tool cost-aware QA.
32
+
33
+ Supports external tool injection so the server can wire in live
34
+ Ceramic, code executor, etc. Tools are callables with signature:
35
+ tool_fn(action: OrchestratorAction) -> ToolResult
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ config: Optional[EnvConfig] = None,
41
+ tools: Optional[Dict[str, Any]] = None,
42
+ dataset: Optional[List[Dict[str, Any]]] = None,
43
+ ):
44
+ self.config = config or EnvConfig()
45
+ self.tools = tools or {} # tool_id -> callable
46
+ self.dataset = dataset or [] # List of {question, answer, domain}
47
+ self._state: Optional[OrchestratorState] = None
48
+ self._questions: List[Dict[str, Any]] = []
49
+ self._current_q_idx: int = 0
50
+ self._context_window: List[str] = []
51
+ self._tools_used_this_q: List[str] = []
52
+ self._steps_this_q: int = 0
53
+ self._episode_done: bool = False
54
+
55
+ # -----------------------------------------------------------------------
56
+ # Reset
57
+ # -----------------------------------------------------------------------
58
+
59
+ def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
60
+ import random
61
+ rng = random.Random(seed if seed is not None else self.config.seed)
62
+
63
+ # Sample questions according to domain_mix
64
+ questions = _sample_questions(self.dataset, self.config, rng)
65
+ self._questions = questions
66
+ self._current_q_idx = 0
67
+ self._episode_done = False
68
+
69
+ self._state = OrchestratorState(
70
+ episode_id=str(uuid.uuid4()),
71
+ total_budget=self.config.total_budget,
72
+ budget_spent=0,
73
+ questions_answered=0,
74
+ total_correct=0,
75
+ current_accuracy=0.0,
76
+ budget_remaining_ratio=1.0,
77
+ current_question_idx=0,
78
+ current_question_steps=0,
79
+ )
80
+
81
+ self._context_window = []
82
+ self._tools_used_this_q = []
83
+ self._steps_this_q = 0
84
+
85
+ obs = self._make_obs(reward=None, question_done=False, done=False)
86
+ return obs, self._state
87
+
88
+ # -----------------------------------------------------------------------
89
+ # Step
90
+ # -----------------------------------------------------------------------
91
+
92
+ def step(
93
+ self, action: OrchestratorAction
94
+ ) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
95
+ if self._episode_done:
96
+ raise RuntimeError("Episode is done. Call reset() first.")
97
+
98
+ tool_id = action.tool_id
99
+ if tool_id not in TOOL_IDS:
100
+ raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
101
+
102
+ state = self._state
103
+ config = self.config
104
+
105
+ # Guard against exhausted question list (can happen after last commit)
106
+ if self._current_q_idx >= len(self._questions):
107
+ self._episode_done = True
108
+ raise RuntimeError("Episode is done. Call reset() first.")
109
+
110
+ q_entry = self._questions[self._current_q_idx]
111
+ gold = q_entry["answer"]
112
+
113
+ # ---- Commit ---------------------------------------------------
114
+ if tool_id == "commit":
115
+ raw_pred = action.answer or ""
116
+ em, f1, quality = grade(raw_pred, gold)
117
+
118
+ r = commit_reward(
119
+ quality=quality,
120
+ budget_remaining_ratio=state.budget_remaining_ratio,
121
+ config=config,
122
+ )
123
+
124
+ # Count correct
125
+ is_correct = (
126
+ em if config.grade_count_correct_mode == "em_only"
127
+ else (em or f1 >= config.f1_count_threshold)
128
+ )
129
+ state.questions_answered += 1
130
+ state.total_correct += int(is_correct)
131
+ state.current_accuracy = state.total_correct / state.questions_answered
132
+
133
+ # Advance to next question or end episode
134
+ self._current_q_idx += 1
135
+ self._context_window = []
136
+ self._tools_used_this_q = []
137
+ self._steps_this_q = 0
138
+
139
+ episode_done = (
140
+ self._current_q_idx >= len(self._questions)
141
+ or state.budget_spent >= state.total_budget
142
+ )
143
+ self._episode_done = episode_done
144
+ state.current_question_idx = self._current_q_idx
145
+ state.current_question_steps = 0
146
+
147
+ obs = self._make_obs(
148
+ reward=r,
149
+ question_done=True,
150
+ done=episode_done,
151
+ last_tool_result=ToolResult(
152
+ tool_id="commit", cost=0,
153
+ output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
154
+ ),
155
+ )
156
+ return obs, r, episode_done, state
157
+
158
+ # ---- Tool call ------------------------------------------------
159
+ cost = config.tool_costs.get(tool_id, 0)
160
+ budget_after = state.budget_spent + cost
161
+
162
+ # If over budget, force commit penalty
163
+ if budget_after > state.total_budget:
164
+ r = config.incorrect_reward
165
+ self._episode_done = True
166
+ obs = self._make_obs(reward=r, question_done=True, done=True)
167
+ return obs, r, True, state
168
+
169
+ # Dispatch tool
170
+ t0 = time.perf_counter()
171
+ tool_fn = self.tools.get(tool_id)
172
+ if tool_fn is None:
173
+ tool_result = ToolResult(
174
+ tool_id=tool_id, cost=cost,
175
+ output="[Tool not available in this environment]",
176
+ latency_s=0.0,
177
+ error="not_available",
178
+ )
179
+ else:
180
+ try:
181
+ tool_result = tool_fn(action)
182
+ tool_result.cost = cost
183
+ except Exception as exc:
184
+ tool_result = ToolResult(
185
+ tool_id=tool_id, cost=cost,
186
+ output=f"[Error: {exc}]",
187
+ latency_s=time.perf_counter() - t0,
188
+ error=str(exc),
189
+ )
190
+ tool_result.latency_s = time.perf_counter() - t0
191
+
192
+ # Charge cost and update state
193
+ state.budget_spent = budget_after
194
+ state.budget_remaining_ratio = max(
195
+ 0.0, (state.total_budget - state.budget_spent) / state.total_budget
196
+ )
197
+ state.step_count += 1
198
+ state.current_question_steps += 1
199
+ self._steps_this_q += 1
200
+ self._tools_used_this_q.append(tool_id)
201
+ self._context_window.append(f"[{tool_id}] {tool_result.output}")
202
+
203
+ r = step_reward(tool_id, config)
204
+
205
+ # Auto-commit if max steps reached
206
+ question_done = self._steps_this_q >= config.max_steps_per_question
207
+ episode_done = (
208
+ state.budget_spent >= state.total_budget
209
+ or (question_done and self._current_q_idx + 1 >= len(self._questions))
210
+ )
211
+ if question_done and not episode_done:
212
+ self._current_q_idx += 1
213
+ state.questions_answered += 1
214
+ self._context_window = []
215
+ self._tools_used_this_q = []
216
+ self._steps_this_q = 0
217
+ state.current_question_idx = self._current_q_idx
218
+ state.current_question_steps = 0
219
+
220
+ self._episode_done = episode_done
221
+
222
+ obs = self._make_obs(
223
+ reward=r,
224
+ question_done=question_done,
225
+ done=episode_done,
226
+ last_tool_result=tool_result,
227
+ )
228
+ return obs, r, episode_done, state
229
+
230
+ # -----------------------------------------------------------------------
231
+ # Internal helpers
232
+ # -----------------------------------------------------------------------
233
+
234
+ def _make_obs(
235
+ self,
236
+ reward: Optional[float],
237
+ question_done: bool,
238
+ done: bool,
239
+ last_tool_result: Optional[ToolResult] = None,
240
+ ) -> OrchestratorObservation:
241
+ state = self._state
242
+ cfg = self.config
243
+
244
+ if 0 <= self._current_q_idx < len(self._questions):
245
+ q_entry = self._questions[self._current_q_idx]
246
+ elif self._questions:
247
+ # Episode finished — repeat last question info (obs is terminal anyway)
248
+ q_entry = self._questions[-1]
249
+ else:
250
+ q_entry = {"question": "", "answer": "", "domain": ""}
251
+
252
+ return OrchestratorObservation(
253
+ question=q_entry.get("question", ""),
254
+ question_idx=self._current_q_idx,
255
+ domain=q_entry.get("domain", ""),
256
+ question_embedding=[], # populated by server if needed
257
+ total_budget=cfg.total_budget,
258
+ budget_spent=state.budget_spent,
259
+ budget_remaining=state.total_budget - state.budget_spent,
260
+ budget_remaining_ratio=state.budget_remaining_ratio,
261
+ tools_used_this_question=list(self._tools_used_this_q),
262
+ steps_used_this_question=self._steps_this_q,
263
+ max_steps_per_question=cfg.max_steps_per_question,
264
+ last_tool_result=last_tool_result,
265
+ context_window=list(self._context_window),
266
+ step_idx=state.step_count,
267
+ questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
268
+ questions_answered=state.questions_answered,
269
+ accuracy_so_far=state.current_accuracy,
270
+ question_done=question_done,
271
+ done=done,
272
+ reward=reward,
273
+ )
274
+
275
+
276
+ # ---------------------------------------------------------------------------
277
+ # Dataset sampling helper
278
+ # ---------------------------------------------------------------------------
279
+
280
+ def _sample_questions(
281
+ dataset: List[Dict[str, Any]],
282
+ config: EnvConfig,
283
+ rng: Any,
284
+ ) -> List[Dict[str, Any]]:
285
+ """Sample `config.num_questions` questions according to domain_mix."""
286
+ by_domain: Dict[str, List[Dict]] = {}
287
+ for item in dataset:
288
+ d = item.get("domain", "hotpotqa")
289
+ by_domain.setdefault(d, []).append(item)
290
+
291
+ selected = []
292
+ for domain, frac in config.domain_mix.items():
293
+ n = round(config.num_questions * frac)
294
+ pool = by_domain.get(domain, [])
295
+ if pool and n > 0:
296
+ selected.extend(rng.sample(pool, min(n, len(pool))))
297
+
298
+ # Guarantee at least num_questions items by filling from the full dataset
299
+ if len(selected) < config.num_questions and dataset:
300
+ remaining = [d for d in dataset if d not in selected]
301
+ rng.shuffle(remaining)
302
+ selected.extend(remaining[: config.num_questions - len(selected)])
303
+
304
+ if config.shuffle_questions:
305
+ rng.shuffle(selected)
306
+
307
+ return selected[: config.num_questions]
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec: 1
2
+ app: tool-orchestrator-env
3
+ type: space
4
+ runtime: fastapi
5
+ entrypoint: app:app
6
+ port: 8000
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.0
4
+ numpy>=1.24
5
+ datasets>=2.18.0
6
+ httpx>=0.27.0
7
+ requests>=2.31.0
8
+ together>=1.2.0
tools/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool registry for ToolOrchestratorEnv.
2
+
3
+ Each tool is a callable: (action: OrchestratorAction) -> ToolResult
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Dict
8
+
9
+ from env.config import EnvConfig
10
+ from env.models import OrchestratorAction, ToolResult
11
+
12
+ from .calculator import calculator_tool
13
+ from .ceramic_search import make_search_tool
14
+ from .code_executor import code_executor_tool
15
+ from .commit import commit_tool
16
+ from .llm_reason import llm_reason_tool
17
+ from .wiki_lookup import wiki_lookup_tool
18
+
19
+
20
+ def build_tool_registry(config: EnvConfig | None = None) -> Dict[str, Callable]:
21
+ """Return a mapping of tool_id → tool function."""
22
+ return {
23
+ "ceramic_search": make_search_tool(),
24
+ "calculator": calculator_tool,
25
+ "wiki_lookup": wiki_lookup_tool,
26
+ "code_executor": code_executor_tool,
27
+ "llm_reason": llm_reason_tool,
28
+ "commit": commit_tool,
29
+ }
tools/calculator.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Safe AST-based calculator tool.
2
+
3
+ Supports arithmetic, comparisons, and basic math functions.
4
+ No exec/eval with arbitrary code — uses ast.literal_eval-style restricted eval.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import ast
9
+ import math
10
+ import operator
11
+ from typing import Any
12
+
13
+ from env.models import OrchestratorAction, ToolResult
14
+
15
+ _SAFE_OPS = {
16
+ ast.Add: operator.add,
17
+ ast.Sub: operator.sub,
18
+ ast.Mult: operator.mul,
19
+ ast.Div: operator.truediv,
20
+ ast.Pow: operator.pow,
21
+ ast.Mod: operator.mod,
22
+ ast.FloorDiv: operator.floordiv,
23
+ ast.USub: operator.neg,
24
+ ast.UAdd: operator.pos,
25
+ }
26
+
27
+ _SAFE_FUNCS: dict[str, Any] = {
28
+ "abs": abs, "round": round, "min": min, "max": max,
29
+ "sqrt": math.sqrt, "log": math.log, "log2": math.log2,
30
+ "log10": math.log10, "exp": math.exp,
31
+ "sin": math.sin, "cos": math.cos, "tan": math.tan,
32
+ "floor": math.floor, "ceil": math.ceil,
33
+ "pi": math.pi, "e": math.e,
34
+ }
35
+
36
+
37
+ def _safe_eval(node: ast.AST) -> Any:
38
+ if isinstance(node, ast.Expression):
39
+ return _safe_eval(node.body)
40
+ if isinstance(node, ast.Constant):
41
+ return node.value
42
+ if isinstance(node, ast.Name):
43
+ if node.id in _SAFE_FUNCS:
44
+ return _SAFE_FUNCS[node.id]
45
+ raise ValueError(f"Unknown name: {node.id!r}")
46
+ if isinstance(node, ast.BinOp):
47
+ op_type = type(node.op)
48
+ if op_type not in _SAFE_OPS:
49
+ raise ValueError(f"Unsupported operator: {op_type.__name__}")
50
+ return _SAFE_OPS[op_type](_safe_eval(node.left), _safe_eval(node.right))
51
+ if isinstance(node, ast.UnaryOp):
52
+ op_type = type(node.op)
53
+ if op_type not in _SAFE_OPS:
54
+ raise ValueError(f"Unsupported unary: {op_type.__name__}")
55
+ return _SAFE_OPS[op_type](_safe_eval(node.operand))
56
+ if isinstance(node, ast.Call):
57
+ func = _safe_eval(node.func)
58
+ if not callable(func):
59
+ raise ValueError("Not callable")
60
+ args = [_safe_eval(a) for a in node.args]
61
+ return func(*args)
62
+ raise ValueError(f"Unsupported AST node: {type(node).__name__}")
63
+
64
+
65
+ def calculator_tool(action: OrchestratorAction) -> ToolResult:
66
+ expr = (action.expression or action.query or "").strip()
67
+ if not expr:
68
+ return ToolResult(tool_id="calculator", output="[No expression provided]", error="empty")
69
+ try:
70
+ tree = ast.parse(expr, mode="eval")
71
+ result = _safe_eval(tree)
72
+ return ToolResult(tool_id="calculator", output=str(result))
73
+ except Exception as exc:
74
+ return ToolResult(tool_id="calculator", output=f"[Calc error: {exc}]", error=str(exc))
tools/ceramic_search.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ceramic search tool — wraps CeramicClient."""
2
+ from __future__ import annotations
3
+
4
+ from ceramic.client import get_ceramic_client
5
+ from env.models import OrchestratorAction, ToolResult
6
+
7
+
8
+ def make_search_tool(top_k: int = 3):
9
+ """Factory: creates a search tool with a shared Ceramic client."""
10
+ client = get_ceramic_client()
11
+
12
+ def _search(action: OrchestratorAction) -> ToolResult:
13
+ query = (action.query or "").strip()
14
+ if not query:
15
+ return ToolResult(
16
+ tool_id="ceramic_search",
17
+ output="[No query provided]",
18
+ error="empty_query",
19
+ )
20
+ try:
21
+ results = client.search(query, top_k=top_k)
22
+ snippets = []
23
+ for r in results:
24
+ snippets.append(f"**{r.title}** ({r.score:.2f})\n{r.description}")
25
+ output = "\n\n".join(snippets) if snippets else "[No results found]"
26
+ return ToolResult(tool_id="ceramic_search", output=output)
27
+ except Exception as exc:
28
+ return ToolResult(
29
+ tool_id="ceramic_search",
30
+ output=f"[Search error: {exc}]",
31
+ error=str(exc),
32
+ )
33
+
34
+ return _search
tools/code_executor.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Restricted Python code executor.
2
+
3
+ Runs code in a sandboxed namespace — blocks os/sys/subprocess imports
4
+ and captures stdout. Intended for math / algorithmic tasks.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import io
9
+ import sys
10
+ import contextlib
11
+
12
+ from env.models import OrchestratorAction, ToolResult
13
+
14
+ _BLOCKED_MODULES = frozenset({
15
+ "os", "sys", "subprocess", "socket", "shutil", "pathlib",
16
+ "importlib", "builtins", "ctypes", "multiprocessing", "threading",
17
+ "signal", "pty", "fcntl", "resource", "gc", "inspect",
18
+ })
19
+
20
+ _MAX_OUTPUT_CHARS = 2000
21
+
22
+
23
+ class _BlockedImport:
24
+ """Raise on any import of blocked modules."""
25
+ def __init__(self, original_import):
26
+ self._orig = original_import
27
+
28
+ def __call__(self, name, *args, **kwargs):
29
+ base = name.split(".")[0]
30
+ if base in _BLOCKED_MODULES:
31
+ raise ImportError(f"Module '{name}' is not allowed in code_executor")
32
+ return self._orig(name, *args, **kwargs)
33
+
34
+
35
+ def code_executor_tool(action: OrchestratorAction) -> ToolResult:
36
+ code = (action.code_snippet or action.query or "").strip()
37
+ if not code:
38
+ return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
39
+
40
+ stdout_buf = io.StringIO()
41
+ safe_globals = {
42
+ "__builtins__": {
43
+ k: v for k, v in __builtins__.items() # type: ignore[union-attr]
44
+ if k not in ("open", "exec", "eval", "compile", "__import__")
45
+ } if isinstance(__builtins__, dict) else {
46
+ k: getattr(__builtins__, k) for k in dir(__builtins__)
47
+ if k not in ("open", "exec", "eval", "compile", "__import__")
48
+ },
49
+ "__import__": _BlockedImport(__import__),
50
+ "print": lambda *a, **kw: print(*a, **kw, file=stdout_buf),
51
+ }
52
+
53
+ try:
54
+ with contextlib.redirect_stdout(stdout_buf):
55
+ exec(compile(code, "<code_executor>", "exec"), safe_globals) # noqa: S102
56
+ output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
57
+ return ToolResult(tool_id="code_executor", output=output)
58
+ except Exception as exc:
59
+ return ToolResult(
60
+ tool_id="code_executor",
61
+ output=f"[Execution error: {exc}]",
62
+ error=str(exc),
63
+ )
tools/commit.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Commit tool — passes the answer through; grading happens in environment.py."""
2
+ from __future__ import annotations
3
+
4
+ from env.models import OrchestratorAction, ToolResult
5
+
6
+
7
+ def commit_tool(action: OrchestratorAction) -> ToolResult:
8
+ answer = (action.answer or "").strip()
9
+ return ToolResult(
10
+ tool_id="commit",
11
+ output=f"Committed answer: {answer[:200]}",
12
+ )
tools/llm_reason.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM reasoning tool — calls Together AI (or falls back gracefully)."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+
6
+ from env.models import OrchestratorAction, ToolResult
7
+
8
+ _DEFAULT_MODEL = "meta-llama/Llama-3-8b-chat-hf"
9
+ _MAX_TOKENS = 512
10
+
11
+
12
+ def llm_reason_tool(action: OrchestratorAction) -> ToolResult:
13
+ prompt = (action.query or "").strip()
14
+ if not prompt:
15
+ return ToolResult(tool_id="llm_reason", output="[No prompt provided]", error="empty")
16
+
17
+ api_key = os.environ.get("TOGETHER_API_KEY") or os.environ.get("TOGETHER_KEY")
18
+ if not api_key:
19
+ return ToolResult(
20
+ tool_id="llm_reason",
21
+ output="[LLM reasoning not configured — set TOGETHER_API_KEY]",
22
+ error="no_api_key",
23
+ )
24
+
25
+ try:
26
+ import together # type: ignore
27
+ client = together.Together(api_key=api_key)
28
+ resp = client.chat.completions.create(
29
+ model=_DEFAULT_MODEL,
30
+ messages=[{"role": "user", "content": prompt}],
31
+ max_tokens=_MAX_TOKENS,
32
+ temperature=0.0,
33
+ )
34
+ text = resp.choices[0].message.content or ""
35
+ return ToolResult(tool_id="llm_reason", output=text.strip()[:2000])
36
+ except ImportError:
37
+ return ToolResult(
38
+ tool_id="llm_reason",
39
+ output="[together package not installed — pip install together]",
40
+ error="import_error",
41
+ )
42
+ except Exception as exc:
43
+ return ToolResult(tool_id="llm_reason", output=f"[LLM error: {exc}]", error=str(exc))
tools/wiki_lookup.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wikipedia lookup tool — returns the intro paragraph of an article."""
2
+ from __future__ import annotations
3
+
4
+ import urllib.parse
5
+ import urllib.request
6
+ import json
7
+
8
+ from env.models import OrchestratorAction, ToolResult
9
+
10
+ _WIKI_API = "https://en.wikipedia.org/api/rest_v1/page/summary/{}"
11
+
12
+
13
+ def wiki_lookup_tool(action: OrchestratorAction) -> ToolResult:
14
+ query = (action.query or "").strip()
15
+ if not query:
16
+ return ToolResult(tool_id="wiki_lookup", output="[No query provided]", error="empty_query")
17
+
18
+ title = urllib.parse.quote(query.replace(" ", "_"))
19
+ url = _WIKI_API.format(title)
20
+ try:
21
+ req = urllib.request.Request(url, headers={"User-Agent": "ToolOrchestratorEnv/0.1"})
22
+ with urllib.request.urlopen(req, timeout=8) as resp:
23
+ data = json.loads(resp.read().decode())
24
+ extract = data.get("extract", "").strip()
25
+ page_title = data.get("title", query)
26
+ if not extract:
27
+ return ToolResult(tool_id="wiki_lookup", output=f"[No summary found for '{query}']")
28
+ return ToolResult(tool_id="wiki_lookup", output=f"**{page_title}**\n{extract[:800]}")
29
+ except urllib.error.HTTPError as exc:
30
+ if exc.code == 404:
31
+ return ToolResult(
32
+ tool_id="wiki_lookup",
33
+ output=f"[Wikipedia: no article found for '{query}']",
34
+ error="not_found",
35
+ )
36
+ return ToolResult(tool_id="wiki_lookup", output=f"[Wiki HTTP error {exc.code}]", error=str(exc))
37
+ except Exception as exc:
38
+ return ToolResult(tool_id="wiki_lookup", output=f"[Wiki error: {exc}]", error=str(exc))