Spaces:
Sleeping
Sleeping
Andrew Lara Claude Sonnet 4.6 commited on
Commit ·
8ca3a35
0
Parent(s):
Initial implementation of ToolOrchestratorEnv
Browse filesMulti-tool cost-aware RL environment built on SearchEconomicsEnv.
Agent selects from 6 tools (ceramic_search, wiki_lookup, calculator,
code_executor, llm_reason, commit) across 4 QA domains under a shared
budget constraint. Weitzman-style composite reward (quality + efficiency
bonus). OpenEnv-compatible FastAPI server, Dockerfile for HF Spaces.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .env.example +11 -0
- .gitignore +20 -0
- BLOG_PROMPT.md +115 -0
- Dockerfile +29 -0
- README.md +213 -0
- app.py +208 -0
- baselines/__init__.py +0 -0
- baselines/cheapest_first.py +69 -0
- baselines/oracle.py +76 -0
- baselines/random_tool.py +67 -0
- ceramic/__init__.py +3 -0
- ceramic/client.py +133 -0
- client.py +123 -0
- data/__init__.py +0 -0
- data/loader.py +270 -0
- env/__init__.py +0 -0
- env/answer_grading.py +91 -0
- env/config.py +53 -0
- env/environment.py +299 -0
- env/models.py +111 -0
- env/reward.py +57 -0
- environment.py +307 -0
- openenv.yaml +6 -0
- requirements.txt +8 -0
- tools/__init__.py +29 -0
- tools/calculator.py +74 -0
- tools/ceramic_search.py +34 -0
- tools/code_executor.py +63 -0
- tools/commit.py +12 -0
- tools/llm_reason.py +43 -0
- tools/wiki_lookup.py +38 -0
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to .env and fill in your values.
|
| 2 |
+
# On HuggingFace Spaces, set these as Space Secrets (Settings → Variables and secrets).
|
| 3 |
+
|
| 4 |
+
# Required: Ceramic AI search key (sign up at https://ceramic.ai)
|
| 5 |
+
CERAMIC_API_KEY=cer_sk_live_...
|
| 6 |
+
|
| 7 |
+
# Optional: Together AI key for the llm_reason tool (https://api.together.xyz)
|
| 8 |
+
TOGETHER_API_KEY=
|
| 9 |
+
|
| 10 |
+
# HuggingFace token — only needed if you load gated datasets (e.g. GPQA)
|
| 11 |
+
HF_TOKEN=
|
.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Never commit real secrets
|
| 2 |
+
.env
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
.venv/
|
| 9 |
+
dist/
|
| 10 |
+
*.egg-info/
|
| 11 |
+
|
| 12 |
+
# Data / cache
|
| 13 |
+
data/*.jsonl
|
| 14 |
+
data/*.npy
|
| 15 |
+
data/*.json
|
| 16 |
+
|
| 17 |
+
# Editor
|
| 18 |
+
.DS_Store
|
| 19 |
+
.idea/
|
| 20 |
+
.vscode/
|
BLOG_PROMPT.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blog Writing Prompt
|
| 2 |
+
|
| 3 |
+
Drop this entire block into a fresh Claude conversation to generate the research blog post.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## PROMPT
|
| 8 |
+
|
| 9 |
+
You are writing a research blog post about a reinforcement learning environment called **ToolOrchestratorEnv**. The authors are Andrew and Yash Sharma (USC), building on Yash's prior work on SearchEconomicsEnv. The target audience is ML researchers and practitioners who read papers like those at NeurIPS, ICLR, and the Hugging Face blog — people who understand RL and LLMs but are not experts in tool-use or search economics.
|
| 10 |
+
|
| 11 |
+
The post should be **1,500–2,000 words**, well-structured with section headers, and written in a direct, confident academic-blog tone (think: The Gradient, Hugging Face blog, or a good arXiv blog post). Avoid hype. Let the ideas do the work.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
### What this research is and why it matters
|
| 16 |
+
|
| 17 |
+
**The core problem:** Large language model agents are increasingly given access to tools — search engines, calculators, code interpreters, databases. In real deployments, every tool call costs something: API fees, latency, rate limits, or compute. Current agent frameworks treat tools as free: call whatever you want, as many times as you want. This is unrealistic and economically wasteful.
|
| 18 |
+
|
| 19 |
+
**The research gap:** Most RL environments for tool-using agents either (a) focus on a single tool (e.g. search-only retrieval agents), or (b) ignore cost entirely and measure only answer quality. There is no standard RL training ground where the agent must *choose between tools with different price/quality tradeoffs* under a shared budget constraint.
|
| 20 |
+
|
| 21 |
+
**What we built:** ToolOrchestratorEnv — an OpenEnv-compatible RL environment that puts cost-aware tool selection at the center of the learning objective. The agent picks from six tools per step (web search, Wikipedia, calculator, Python executor, LLM reasoning, or commit) across four question domains (HotpotQA, MATH, GPQA, HumanEval), with a shared budget that depletes as tools are called.
|
| 22 |
+
|
| 23 |
+
**Why this is novel:**
|
| 24 |
+
- It extends the "search economics" framing from a single tool to a heterogeneous tool portfolio
|
| 25 |
+
- It tests transfer: can an agent learn that calculators are cheap and LLMs expensive, and route accordingly?
|
| 26 |
+
- The multi-domain setup forces the agent to learn *domain → tool* mappings (search for factual QA, calculator for math) rather than one-size-fits-all policies
|
| 27 |
+
- The Weitzman-style reward (efficiency bonus only on correct + frugal commits) creates a richer credit assignment problem than binary success/failure
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
### Background sections to write (with sources to find and cite)
|
| 32 |
+
|
| 33 |
+
**1. The tool-use agent landscape**
|
| 34 |
+
|
| 35 |
+
Explain why tool use is now central to LLM agents. Cite and discuss:
|
| 36 |
+
- The ReAct paper (Yao et al., 2022) — introduced interleaving reasoning and tool calls
|
| 37 |
+
- Toolformer (Schick et al., 2023) — self-supervised tool learning
|
| 38 |
+
- ToolBench / API-Bank — benchmarks for tool-using LLMs
|
| 39 |
+
- Find at least one recent paper (2024 or 2025) showing that tool-calling agents outperform tool-free baselines on knowledge-intensive tasks. Look at arXiv, ACL Anthology, or the Hugging Face papers page.
|
| 40 |
+
|
| 41 |
+
**2. Search economics and the budget constraint**
|
| 42 |
+
|
| 43 |
+
Explain the economic analogy: information has a cost, and rational agents should not search more than their expected marginal gain from search. Cite:
|
| 44 |
+
- Weitzman (1979) "Optimal Search for the Best Alternative" — the foundational search economics paper
|
| 45 |
+
- SearchEconomicsEnv by Yash Sharma / USC (https://github.com/sharma-yash01/SearchEconomicsEnv, https://huggingface.co/spaces/yashu2000/search-economics-env) — the direct predecessor that built this RL environment for search-budget-constrained HotpotQA
|
| 46 |
+
- Look for any recent work on "budgeted retrieval" or "adaptive retrieval" in RAG systems (2024-2025) that shows that unconstrained retrieval hurts performance or cost-effectiveness. Papers like FLARE, IterRetGen, or similar might be relevant.
|
| 47 |
+
|
| 48 |
+
**3. The multi-domain challenge**
|
| 49 |
+
|
| 50 |
+
Explain why testing across HotpotQA, MATH, GPQA, and HumanEval matters — these domains need fundamentally different tools (search for factual, calculator for symbolic, code for algorithmic, LLM for graduate-level). Find and cite:
|
| 51 |
+
- The MATH benchmark paper (Hendrycks et al., 2021)
|
| 52 |
+
- HotpotQA paper (Yang et al., 2018)
|
| 53 |
+
- GPQA paper (Rein et al., 2023)
|
| 54 |
+
- HumanEval paper (Chen et al., 2021)
|
| 55 |
+
- Any paper showing that tool specialisation helps across domains (e.g., PAL, PoT, or similar)
|
| 56 |
+
|
| 57 |
+
**4. Reinforcement learning for tool selection**
|
| 58 |
+
|
| 59 |
+
Explain why RL (not just prompting or supervised learning) is the right frame for this problem: the agent must explore, face delayed rewards (only know if an answer was right after commit), and learn multi-step strategies. Cite:
|
| 60 |
+
- Any recent paper using RL for LLM agent training (e.g., RLHF extensions, agent-specific RL work, or OpenEnv/AgentBench)
|
| 61 |
+
- The OpenEnv competition framework (Berkeley RDI, AgentX) — explain what OpenEnv is and why standardised RL environments matter for reproducibility
|
| 62 |
+
- Look for "process reward models" or "step-level reward" papers in the agent RL space
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
### Key sections for the post
|
| 67 |
+
|
| 68 |
+
1. **The problem with free tools** — hook paragraph. Real API calls cost money. Agents don't know this. Set up the gap.
|
| 69 |
+
|
| 70 |
+
2. **Search economics, briefly** — one paragraph on Weitzman, one on SearchEconomicsEnv. The framing: information retrieval as a market with prices.
|
| 71 |
+
|
| 72 |
+
3. **ToolOrchestratorEnv: the environment** — describe the setup clearly:
|
| 73 |
+
- 6 tools, 4 datasets, shared budget
|
| 74 |
+
- The action-observation loop (what the agent sees, what it decides)
|
| 75 |
+
- The reward formula (explain it intuitively: you pay for every call, you earn back on correct commits, and get a bonus for answering correctly without blowing your budget)
|
| 76 |
+
- The Ceramic AI integration for live web retrieval
|
| 77 |
+
|
| 78 |
+
4. **Why this is hard** — explain the credit assignment problem (you don't know a tool call was wasted until you commit), the domain-routing challenge, and the exploration-exploitation tradeoff under budget pressure.
|
| 79 |
+
|
| 80 |
+
5. **Baselines and what they tell us** — describe the three baselines (random, cheapest-first, domain-oracle) and what their expected performance reveals about the structure of the problem.
|
| 81 |
+
|
| 82 |
+
6. **What we're building toward** — the research agenda: train a PPO or DQN agent on this environment, show it beats baselines, and study what routing policies it learns. Can it learn that LLM reasoning is worth 2x the cost for GPQA but wasteful for simple arithmetic?
|
| 83 |
+
|
| 84 |
+
7. **Conclusion** — the broader point: as AI systems become more agentic, cost-aware tool selection will be as important as answer quality. We need RL environments that take this seriously. This is one.
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
### Tone and style guidelines
|
| 89 |
+
|
| 90 |
+
- **Cite real papers** — do not make up citations. For any claim about related work, search arXiv, Semantic Scholar, or ACL Anthology and use the actual paper. Format citations inline as (Author et al., Year) with a references section at the end.
|
| 91 |
+
- **Be specific** — don't say "researchers have shown" without naming the paper.
|
| 92 |
+
- **Write for skeptics** — assume your reader will ask "why does this matter" and "what's actually new." Answer those questions directly in the text.
|
| 93 |
+
- **Avoid marketing language** — no "revolutionary," "groundbreaking," or "state-of-the-art." Just describe what was built and why it's useful.
|
| 94 |
+
- **Include the reward formula** — write it out mathematically and then explain it in plain English. Researchers appreciate seeing the actual math.
|
| 95 |
+
- **Link to the HF Space** — mention that the environment is live at https://huggingface.co/spaces/yashu2000/search-economics-env (SearchEconomicsEnv, the predecessor) and that ToolOrchestratorEnv will be deployed alongside it.
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
### What NOT to do
|
| 100 |
+
|
| 101 |
+
- Do not fabricate benchmark numbers — we don't have trained agent results yet, only baseline results. Say so honestly.
|
| 102 |
+
- Do not claim this is the first RL environment for tool use — be accurate about prior work.
|
| 103 |
+
- Do not skip the related work — proving the gap is real requires engaging with existing papers.
|
| 104 |
+
- Do not make the reward formula paragraph too short — this is a key technical contribution; spend time on it.
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
### Final checklist before finishing the post
|
| 109 |
+
|
| 110 |
+
- [ ] Every citation is real and can be found on arXiv or a peer-reviewed venue
|
| 111 |
+
- [ ] The reward formula is written out and explained in plain English
|
| 112 |
+
- [ ] The post explains what OpenEnv is and why deploying on HF Spaces matters
|
| 113 |
+
- [ ] The post mentions Ceramic AI and explains why live web retrieval matters (vs. static knowledge)
|
| 114 |
+
- [ ] The baseline section sets up what "winning" looks like for a trained RL agent
|
| 115 |
+
- [ ] A references section is included at the end with full citations
|
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace Spaces — ToolOrchestratorEnv
|
| 2 |
+
# Builds a FastAPI server that exposes the OpenEnv endpoints.
|
| 3 |
+
#
|
| 4 |
+
# Required secret (set in Space Settings → Variables and secrets):
|
| 5 |
+
# CERAMIC_API_KEY=cer_sk_live_...
|
| 6 |
+
#
|
| 7 |
+
# Optional:
|
| 8 |
+
# TOGETHER_API_KEY=... (enables the llm_reason tool)
|
| 9 |
+
# HF_TOKEN=... (enables gated datasets like GPQA)
|
| 10 |
+
|
| 11 |
+
FROM python:3.11-slim
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# Install dependencies first so Docker layer-caches them
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy source
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# HuggingFace Spaces runs as a non-root user
|
| 23 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser /app
|
| 24 |
+
USER appuser
|
| 25 |
+
|
| 26 |
+
EXPOSE 8000
|
| 27 |
+
|
| 28 |
+
# The Space README sets base_path: /web so the demo UI loads on open
|
| 29 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Tool Orchestrator Environment
|
| 3 |
+
emoji: 🔧
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- reinforcement-learning
|
| 13 |
+
- tool-use
|
| 14 |
+
- cost-aware
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# ToolOrchestratorEnv
|
| 18 |
+
|
| 19 |
+
**An OpenEnv-compatible reinforcement learning environment for multi-tool, cost-aware question answering.**
|
| 20 |
+
|
| 21 |
+
Built on top of [SearchEconomicsEnv](https://huggingface.co/spaces/yashu2000/search-economics-env) (Yash Sharma, USC / Ceramic AI), this environment generalises the single-tool (search-only) formulation to a full **tool-selection problem**: the agent must choose *which* of six tools to call at each step, managing a shared cost budget across a multi-domain question set (HotpotQA, MATH, GPQA, HumanEval).
|
| 22 |
+
|
| 23 |
+
The core research question: **can an RL agent learn a cost-aware tool routing policy that outperforms simple heuristics like "always search" or "always use the cheapest tool"?**
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## What the agent learns
|
| 28 |
+
|
| 29 |
+
Each episode the agent receives 10 questions sampled across four domains. At every step it sees:
|
| 30 |
+
|
| 31 |
+
- The current **question** and its **domain** tag
|
| 32 |
+
- Its **remaining budget** (shared across all questions)
|
| 33 |
+
- The **context window** — concatenated outputs from prior tool calls on this question
|
| 34 |
+
|
| 35 |
+
It picks one action from six tools:
|
| 36 |
+
|
| 37 |
+
| Tool | `tool_id` | Cost | Best for |
|
| 38 |
+
|---|---|---|---|
|
| 39 |
+
| Ceramic web search | `ceramic_search` | 1.0 | Multi-hop factual QA |
|
| 40 |
+
| Wikipedia lookup | `wiki_lookup` | 0.5 | Entity facts, definitions |
|
| 41 |
+
| Calculator | `calculator` | 0.1 | Arithmetic, symbolic math |
|
| 42 |
+
| Python executor | `code_executor` | 0.3 | HumanEval code tasks |
|
| 43 |
+
| LLM reasoning | `llm_reason` | 2.0 | Graduate-level GPQA problems |
|
| 44 |
+
| Commit answer | `commit` | 0.0 | Submit and move to next question |
|
| 45 |
+
|
| 46 |
+
**The RL objective:** maximise accuracy across all questions while staying within the total budget — learning *which tool to call*, in *which order*, and *when to stop and commit*.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## Reward formula
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
On tool call: R = -tool_cost
|
| 54 |
+
|
| 55 |
+
On commit: R = base + η · γ · budget_remaining_ratio
|
| 56 |
+
|
| 57 |
+
base = incorrect_reward + quality · (correct_reward − incorrect_reward)
|
| 58 |
+
quality = max(ExactMatch, TokenF1)
|
| 59 |
+
η = 1 if quality ≥ efficiency_bonus_threshold, else 0
|
| 60 |
+
γ = efficiency_bonus_weight
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
The efficiency bonus is only awarded when the agent answers correctly **and** still has budget remaining — directly incentivising both accuracy and frugality.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## Quickstart (local)
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# 1. Clone and install
|
| 71 |
+
git clone <this-repo>
|
| 72 |
+
cd claude_toolOrchestrator
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
|
| 75 |
+
# 2. Configure keys (copy the example and fill in values)
|
| 76 |
+
cp .env.example .env
|
| 77 |
+
# Set CERAMIC_API_KEY — sign up free at https://ceramic.ai
|
| 78 |
+
|
| 79 |
+
# 3. Start the server
|
| 80 |
+
uvicorn app:app --port 8000
|
| 81 |
+
|
| 82 |
+
# 4. Try the interactive demo UI
|
| 83 |
+
open http://localhost:8000/web
|
| 84 |
+
# or browse the full OpenAPI spec at
|
| 85 |
+
open http://localhost:8000/docs
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## HTTP API
|
| 91 |
+
|
| 92 |
+
### `POST /reset`
|
| 93 |
+
|
| 94 |
+
Start a new episode. Returns `session_id`, initial `observation`, and `state`.
|
| 95 |
+
|
| 96 |
+
```json
|
| 97 |
+
{ "seed": 42, "config_overrides": { "total_budget": 30.0, "num_questions": 5 } }
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### `POST /step?session_id=<id>`
|
| 101 |
+
|
| 102 |
+
Execute one tool call. Pass `session_id` (from `/reset`) as a query param to support parallel agents.
|
| 103 |
+
|
| 104 |
+
```json
|
| 105 |
+
{ "tool_id": "ceramic_search", "query": "When was the Eiffel Tower built?" }
|
| 106 |
+
{ "tool_id": "calculator", "expression": "sqrt(144) + 3" }
|
| 107 |
+
{ "tool_id": "code_executor", "code_snippet": "print(2 ** 10)" }
|
| 108 |
+
{ "tool_id": "commit", "answer": "1889" }
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### `GET /health`
|
| 112 |
+
|
| 113 |
+
Returns `{"status": "ok"}`.
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## Project layout
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
claude_toolOrchestrator/
|
| 121 |
+
│
|
| 122 |
+
├── app.py # FastAPI server — multi-session, OpenAPI, demo UI
|
| 123 |
+
├── openenv.yaml # OpenEnv deployment spec
|
| 124 |
+
├── requirements.txt # Python dependencies
|
| 125 |
+
├── .env.example # Key template (copy → .env, never commit .env)
|
| 126 |
+
│
|
| 127 |
+
├── env/ # ── Core RL environment ──────────────────────────
|
| 128 |
+
│ ├── environment.py # ToolOrchestratorEnvironment: reset() + step()
|
| 129 |
+
│ ├── models.py # Pydantic types: Action, Observation, State, ToolResult
|
| 130 |
+
│ ├── config.py # EnvConfig dataclass: budget, costs, reward weights
|
| 131 |
+
│ ├── answer_grading.py # grade() → (exact_match, f1, quality)
|
| 132 |
+
│ └── reward.py # step_reward() + commit_reward()
|
| 133 |
+
│
|
| 134 |
+
├── ceramic/ # ── Retrieval backend ────────────────────────────
|
| 135 |
+
│ └── client.py # CeramicClient (live) + FallbackCeramicClient (offline)
|
| 136 |
+
│
|
| 137 |
+
├── data/ # ── Dataset loading ──────────────────────────────
|
| 138 |
+
│ └── loader.py # load_all() → flat list from 4 HF datasets
|
| 139 |
+
│
|
| 140 |
+
├── tools/ # ── Six tool implementations ─────────────────────
|
| 141 |
+
│ ├── ceramic_search.py # Web search (Ceramic AI API)
|
| 142 |
+
│ ├── wiki_lookup.py # Wikipedia REST API, first paragraph
|
| 143 |
+
│ ├── calculator.py # Safe AST-based math evaluator (no exec)
|
| 144 |
+
│ ├── code_executor.py # Sandboxed Python exec (blocks os/sys/subprocess)
|
| 145 |
+
│ ├── llm_reason.py # Together AI chain-of-thought (graceful fallback)
|
| 146 |
+
│ └── commit.py # Answer pass-through; grading runs in environment
|
| 147 |
+
│
|
| 148 |
+
└── baselines/ # ── Reference policies ───────────────────────────
|
| 149 |
+
├── random_tool.py # Uniform random tool selection
|
| 150 |
+
├── cheapest_first.py # Always picks cheapest non-commit tool first
|
| 151 |
+
└── oracle.py # Domain-aware heuristic (search for QA, calc for math)
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## Environment variables
|
| 157 |
+
|
| 158 |
+
| Variable | Required | Description |
|
| 159 |
+
|---|---|---|
|
| 160 |
+
| `CERAMIC_API_KEY` | Yes (for live search) | Ceramic AI key — `POST /search` endpoint |
|
| 161 |
+
| `SEE_CERAMIC_API_KEY` | Alternative | HF Spaces alias used by SearchEconomicsEnv |
|
| 162 |
+
| `TOGETHER_API_KEY` | Optional | Enables the `llm_reason` tool via Together AI |
|
| 163 |
+
| `HF_TOKEN` | Optional | Required only to load gated datasets (GPQA) |
|
| 164 |
+
|
| 165 |
+
If no Ceramic key is set, `ceramic_search` falls back to deterministic offline results; all other tools work without any key.
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## Running baselines
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
# From inside claude_toolOrchestrator/
|
| 173 |
+
python -m baselines.random_tool
|
| 174 |
+
python -m baselines.cheapest_first
|
| 175 |
+
python -m baselines.oracle
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## Relation to SearchEconomicsEnv
|
| 181 |
+
|
| 182 |
+
| | [SearchEconomicsEnv](https://github.com/sharma-yash01/SearchEconomicsEnv) | ToolOrchestratorEnv |
|
| 183 |
+
|---|---|---|
|
| 184 |
+
| Tools available | 1 (search only) | 6 (search, wiki, calc, code, LLM, commit) |
|
| 185 |
+
| Datasets | HotpotQA | HotpotQA + MATH + GPQA + HumanEval |
|
| 186 |
+
| Budget unit | # of search calls | cost units per tool (tool-specific) |
|
| 187 |
+
| Reward shape | Weitzman search penalty | Same formula, extended to tool costs |
|
| 188 |
+
| Core RL challenge | *How many* searches to do | *Which* tool to call, in which order |
|
| 189 |
+
| Retrieval backend | Ceramic AI | Ceramic AI (shared) |
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## Docker (HuggingFace Spaces)
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
docker build -t tool-orchestrator-env:latest .
|
| 197 |
+
docker run -p 8000:8000 -e CERAMIC_API_KEY=cer_sk_live_... tool-orchestrator-env:latest
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Datasets
|
| 203 |
+
|
| 204 |
+
- **HotpotQA** — Yang et al., 2018. Multi-hop reasoning over Wikipedia.
|
| 205 |
+
- **MATH** — Hendrycks et al., 2021. Competition math levels 3–5.
|
| 206 |
+
- **GPQA** — Rein et al., 2023. Graduate-level science QA.
|
| 207 |
+
- **HumanEval** — Chen et al., 2021. Python programming tasks.
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## About
|
| 212 |
+
|
| 213 |
+
ToolOrchestratorEnv extends SearchEconomicsEnv to a multi-tool setting, framing cost-aware tool selection as the core RL objective. Built for the OpenEnv competition track at AgentX (Berkeley RDI). Ceramic AI search API powers live web retrieval.
|
app.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
Exposes the OpenEnv standard endpoints:
|
| 4 |
+
POST /reset → OrchestratorObservation + OrchestratorState
|
| 5 |
+
POST /step → OrchestratorObservation + reward + done + state
|
| 6 |
+
GET /health → {"status": "ok"}
|
| 7 |
+
GET /web → simple demo UI
|
| 8 |
+
GET /docs → OpenAPI (automatic)
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import uuid
|
| 14 |
+
from contextlib import asynccontextmanager
|
| 15 |
+
from typing import Any, Dict, Optional
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI, HTTPException
|
| 18 |
+
from fastapi.responses import HTMLResponse
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
|
| 21 |
+
from data.loader import load_all
|
| 22 |
+
from env.config import EnvConfig
|
| 23 |
+
from env.environment import ToolOrchestratorEnvironment
|
| 24 |
+
from env.models import OrchestratorAction
|
| 25 |
+
from tools import build_tool_registry
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Request / response wrappers
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
class ResetRequest(BaseModel):
|
| 33 |
+
seed: Optional[int] = None
|
| 34 |
+
config_overrides: Optional[Dict[str, Any]] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class StepRequest(BaseModel):
|
| 38 |
+
tool_id: str
|
| 39 |
+
query: Optional[str] = None
|
| 40 |
+
expression: Optional[str] = None
|
| 41 |
+
code_snippet: Optional[str] = None
|
| 42 |
+
answer: Optional[str] = None
|
| 43 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# App factory
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def create_app() -> FastAPI:
|
| 51 |
+
config = EnvConfig()
|
| 52 |
+
tools = build_tool_registry(config)
|
| 53 |
+
dataset = load_all(split=config.data_split, max_per_domain=200)
|
| 54 |
+
|
| 55 |
+
# Multi-session state: session_id → ToolOrchestratorEnvironment
|
| 56 |
+
sessions: Dict[str, ToolOrchestratorEnvironment] = {}
|
| 57 |
+
|
| 58 |
+
# Default shared environment for single-session usage (no session_id)
|
| 59 |
+
default_env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
|
| 60 |
+
|
| 61 |
+
@asynccontextmanager
|
| 62 |
+
async def lifespan(app: FastAPI):
|
| 63 |
+
yield
|
| 64 |
+
|
| 65 |
+
app = FastAPI(
|
| 66 |
+
title="ToolOrchestratorEnv",
|
| 67 |
+
description="Multi-tool cost-aware RL environment (OpenEnv / AgentX)",
|
| 68 |
+
version="0.1.0",
|
| 69 |
+
lifespan=lifespan,
|
| 70 |
+
root_path=os.environ.get("ROOT_PATH", ""),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@app.get("/health")
|
| 74 |
+
def health():
|
| 75 |
+
return {"status": "ok"}
|
| 76 |
+
|
| 77 |
+
@app.post("/reset")
|
| 78 |
+
def reset(req: ResetRequest):
|
| 79 |
+
cfg = EnvConfig()
|
| 80 |
+
if req.config_overrides:
|
| 81 |
+
for k, v in req.config_overrides.items():
|
| 82 |
+
if hasattr(cfg, k):
|
| 83 |
+
setattr(cfg, k, v)
|
| 84 |
+
env = ToolOrchestratorEnvironment(config=cfg, tools=tools, dataset=dataset)
|
| 85 |
+
obs, state = env.reset(seed=req.seed)
|
| 86 |
+
|
| 87 |
+
session_id = str(uuid.uuid4())
|
| 88 |
+
sessions[session_id] = env
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"session_id": session_id,
|
| 92 |
+
"observation": obs.model_dump(),
|
| 93 |
+
"state": state.model_dump(),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
@app.post("/step")
|
| 97 |
+
def step(req: StepRequest, session_id: Optional[str] = None):
|
| 98 |
+
env = sessions.get(session_id or "", default_env)
|
| 99 |
+
action = OrchestratorAction(
|
| 100 |
+
tool_id=req.tool_id,
|
| 101 |
+
query=req.query or "",
|
| 102 |
+
expression=req.expression or "",
|
| 103 |
+
code_snippet=req.code_snippet or "",
|
| 104 |
+
answer=req.answer or "",
|
| 105 |
+
metadata=req.metadata,
|
| 106 |
+
)
|
| 107 |
+
try:
|
| 108 |
+
obs, reward, done, state = env.step(action)
|
| 109 |
+
except RuntimeError as exc:
|
| 110 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 111 |
+
except ValueError as exc:
|
| 112 |
+
raise HTTPException(status_code=422, detail=str(exc))
|
| 113 |
+
|
| 114 |
+
# Clean up finished sessions
|
| 115 |
+
if done and session_id and session_id in sessions:
|
| 116 |
+
del sessions[session_id]
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"observation": obs.model_dump(),
|
| 120 |
+
"reward": reward,
|
| 121 |
+
"done": done,
|
| 122 |
+
"state": state.model_dump(),
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
@app.get("/web", response_class=HTMLResponse)
|
| 126 |
+
def web_ui():
|
| 127 |
+
return _DEMO_HTML
|
| 128 |
+
|
| 129 |
+
return app
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
app = create_app()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
# Demo UI
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
_DEMO_HTML = """<!DOCTYPE html>
|
| 140 |
+
<html lang="en">
|
| 141 |
+
<head>
|
| 142 |
+
<meta charset="UTF-8">
|
| 143 |
+
<title>ToolOrchestratorEnv</title>
|
| 144 |
+
<style>
|
| 145 |
+
body { font-family: monospace; max-width: 860px; margin: 40px auto; padding: 0 20px; }
|
| 146 |
+
h1 { color: #333; }
|
| 147 |
+
pre { background: #f4f4f4; padding: 12px; border-radius: 6px; overflow-x: auto; }
|
| 148 |
+
button { padding: 8px 16px; margin: 4px; cursor: pointer; }
|
| 149 |
+
input, select, textarea { width: 100%; padding: 6px; margin: 4px 0; box-sizing: border-box; }
|
| 150 |
+
label { font-weight: bold; }
|
| 151 |
+
.tool-btn { background: #e8f0fe; border: 1px solid #4a90e2; border-radius: 4px; }
|
| 152 |
+
.tool-btn:hover { background: #cfe1ff; }
|
| 153 |
+
#log { max-height: 480px; overflow-y: auto; }
|
| 154 |
+
</style>
|
| 155 |
+
</head>
|
| 156 |
+
<body>
|
| 157 |
+
<h1>ToolOrchestratorEnv</h1>
|
| 158 |
+
<p>Multi-tool cost-aware RL environment — AgentX / OpenEnv</p>
|
| 159 |
+
|
| 160 |
+
<button onclick="doReset()">Reset Episode</button>
|
| 161 |
+
<hr>
|
| 162 |
+
<label>Tool:</label>
|
| 163 |
+
<select id="tool">
|
| 164 |
+
<option value="ceramic_search">ceramic_search (cost 1.0) — Web retrieval</option>
|
| 165 |
+
<option value="wiki_lookup">wiki_lookup (cost 0.5) — Wikipedia</option>
|
| 166 |
+
<option value="calculator">calculator (cost 0.1) — Arithmetic / math</option>
|
| 167 |
+
<option value="code_executor">code_executor (cost 0.3) — Python execution</option>
|
| 168 |
+
<option value="llm_reason">llm_reason (cost 2.0) — LLM chain-of-thought</option>
|
| 169 |
+
<option value="commit">commit (cost 0.0) — Submit answer</option>
|
| 170 |
+
</select>
|
| 171 |
+
<label>Query / Expression / Code / Answer:</label>
|
| 172 |
+
<textarea id="query" rows="3" placeholder="Enter query or answer..."></textarea>
|
| 173 |
+
<button class="tool-btn" onclick="doStep()">Step</button>
|
| 174 |
+
<hr>
|
| 175 |
+
<pre id="log">Click "Reset Episode" to start.</pre>
|
| 176 |
+
|
| 177 |
+
<script>
|
| 178 |
+
const log = document.getElementById('log');
|
| 179 |
+
let sessionId = null;
|
| 180 |
+
|
| 181 |
+
function append(text) { log.textContent += text + '\\n---\\n'; log.scrollTop = log.scrollHeight; }
|
| 182 |
+
|
| 183 |
+
async function doReset() {
|
| 184 |
+
log.textContent = '';
|
| 185 |
+
const res = await fetch('/reset', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({seed: 42}) });
|
| 186 |
+
const data = await res.json();
|
| 187 |
+
sessionId = data.session_id || null;
|
| 188 |
+
append('RESET session=' + sessionId + '\\n' + JSON.stringify(data, null, 2));
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
async function doStep() {
|
| 192 |
+
const tool_id = document.getElementById('tool').value;
|
| 193 |
+
const input = document.getElementById('query').value;
|
| 194 |
+
const body = { tool_id };
|
| 195 |
+
if (tool_id === 'commit') body.answer = input;
|
| 196 |
+
else if (tool_id === 'calculator') body.expression = input;
|
| 197 |
+
else if (tool_id === 'code_executor') body.code_snippet = input;
|
| 198 |
+
else body.query = input;
|
| 199 |
+
|
| 200 |
+
const url = sessionId ? '/step?session_id=' + encodeURIComponent(sessionId) : '/step';
|
| 201 |
+
const res = await fetch(url, { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify(body) });
|
| 202 |
+
const data = await res.json();
|
| 203 |
+
append('STEP tool_id=' + tool_id + '\\n' + JSON.stringify(data, null, 2));
|
| 204 |
+
}
|
| 205 |
+
</script>
|
| 206 |
+
</body>
|
| 207 |
+
</html>
|
| 208 |
+
"""
|
baselines/__init__.py
ADDED
|
File without changes
|
baselines/cheapest_first.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cheapest-first baseline — always calls the cheapest available tool first."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 8 |
+
|
| 9 |
+
from data.loader import load_all
|
| 10 |
+
from env.config import EnvConfig
|
| 11 |
+
from env.environment import ToolOrchestratorEnvironment
|
| 12 |
+
from env.models import OrchestratorAction
|
| 13 |
+
from tools import build_tool_registry
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CheapestFirstBaseline:
|
| 17 |
+
"""Calls tools in ascending cost order, commits after exhausting budget."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: EnvConfig):
|
| 20 |
+
# Sort non-commit tools by cost
|
| 21 |
+
self._order = sorted(
|
| 22 |
+
[t for t in config.tool_costs if t != "commit"],
|
| 23 |
+
key=lambda t: config.tool_costs[t],
|
| 24 |
+
)
|
| 25 |
+
self._commit_after = config.max_steps_per_question - 1
|
| 26 |
+
self._steps_on_q = 0
|
| 27 |
+
|
| 28 |
+
def get_action(self, obs) -> OrchestratorAction:
|
| 29 |
+
if self._steps_on_q >= self._commit_after:
|
| 30 |
+
self._steps_on_q = 0
|
| 31 |
+
return OrchestratorAction(tool_id="commit", answer="I don't know")
|
| 32 |
+
tool = self._order[self._steps_on_q % len(self._order)]
|
| 33 |
+
self._steps_on_q += 1
|
| 34 |
+
query = obs.question[:100] if hasattr(obs, "question") else ""
|
| 35 |
+
return OrchestratorAction(tool_id=tool, query=query)
|
| 36 |
+
|
| 37 |
+
def reset(self):
|
| 38 |
+
self._steps_on_q = 0
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_episode(seed: int = 0) -> dict:
|
| 42 |
+
config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
|
| 43 |
+
tools = build_tool_registry(config)
|
| 44 |
+
dataset = load_all(max_per_domain=20)
|
| 45 |
+
env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
|
| 46 |
+
agent = CheapestFirstBaseline(config)
|
| 47 |
+
|
| 48 |
+
obs, state = env.reset(seed=seed)
|
| 49 |
+
agent.reset()
|
| 50 |
+
total_reward = 0.0
|
| 51 |
+
done = False
|
| 52 |
+
while not done:
|
| 53 |
+
action = agent.get_action(obs)
|
| 54 |
+
if action.tool_id == "commit":
|
| 55 |
+
agent.reset()
|
| 56 |
+
obs, reward, done, state = env.step(action)
|
| 57 |
+
total_reward += reward
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"total_reward": total_reward,
|
| 61 |
+
"accuracy": state.current_accuracy,
|
| 62 |
+
"budget_used": state.budget_spent,
|
| 63 |
+
"questions_answered": state.questions_answered,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
result = run_episode(seed=42)
|
| 69 |
+
print("CheapestFirstBaseline:", result)
|
baselines/oracle.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Domain-aware oracle baseline — picks the best tool per domain heuristically."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 8 |
+
|
| 9 |
+
from data.loader import load_all
|
| 10 |
+
from env.config import EnvConfig
|
| 11 |
+
from env.environment import ToolOrchestratorEnvironment
|
| 12 |
+
from env.models import OrchestratorAction
|
| 13 |
+
from tools import build_tool_registry
|
| 14 |
+
|
| 15 |
+
# Heuristic: which tool to try at each step index for each domain
|
| 16 |
+
_DOMAIN_STRATEGY = {
|
| 17 |
+
"hotpotqa": ["ceramic_search", "wiki_lookup", "ceramic_search"],
|
| 18 |
+
"math": ["calculator", "llm_reason", "calculator"],
|
| 19 |
+
"gpqa": ["llm_reason", "ceramic_search", "wiki_lookup"],
|
| 20 |
+
"humaneval": ["code_executor", "llm_reason", "code_executor"],
|
| 21 |
+
}
|
| 22 |
+
_DEFAULT_STRATEGY = ["ceramic_search", "wiki_lookup", "llm_reason"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OracleBaseline:
|
| 26 |
+
"""Uses domain knowledge to pick the best tool per step."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config: EnvConfig):
|
| 29 |
+
self._commit_after = config.max_steps_per_question - 1
|
| 30 |
+
self._steps_on_q = 0
|
| 31 |
+
|
| 32 |
+
def get_action(self, obs) -> OrchestratorAction:
|
| 33 |
+
if self._steps_on_q >= self._commit_after:
|
| 34 |
+
self._steps_on_q = 0
|
| 35 |
+
return OrchestratorAction(tool_id="commit", answer="I don't know")
|
| 36 |
+
|
| 37 |
+
domain = obs.domain if hasattr(obs, "domain") else "hotpotqa"
|
| 38 |
+
strategy = _DOMAIN_STRATEGY.get(domain, _DEFAULT_STRATEGY)
|
| 39 |
+
tool = strategy[self._steps_on_q % len(strategy)]
|
| 40 |
+
self._steps_on_q += 1
|
| 41 |
+
query = obs.question[:100] if hasattr(obs, "question") else ""
|
| 42 |
+
return OrchestratorAction(tool_id=tool, query=query)
|
| 43 |
+
|
| 44 |
+
def reset(self):
|
| 45 |
+
self._steps_on_q = 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_episode(seed: int = 0) -> dict:
|
| 49 |
+
config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
|
| 50 |
+
tools = build_tool_registry(config)
|
| 51 |
+
dataset = load_all(max_per_domain=20)
|
| 52 |
+
env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
|
| 53 |
+
agent = OracleBaseline(config)
|
| 54 |
+
|
| 55 |
+
obs, state = env.reset(seed=seed)
|
| 56 |
+
agent.reset()
|
| 57 |
+
total_reward = 0.0
|
| 58 |
+
done = False
|
| 59 |
+
while not done:
|
| 60 |
+
action = agent.get_action(obs)
|
| 61 |
+
if action.tool_id == "commit":
|
| 62 |
+
agent.reset()
|
| 63 |
+
obs, reward, done, state = env.step(action)
|
| 64 |
+
total_reward += reward
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"total_reward": total_reward,
|
| 68 |
+
"accuracy": state.current_accuracy,
|
| 69 |
+
"budget_used": state.budget_spent,
|
| 70 |
+
"questions_answered": state.questions_answered,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
result = run_episode(seed=42)
|
| 76 |
+
print("OracleBaseline:", result)
|
baselines/random_tool.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Random tool baseline — picks uniformly from available tools each step."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 9 |
+
|
| 10 |
+
from data.loader import load_all
|
| 11 |
+
from env.config import EnvConfig
|
| 12 |
+
from env.environment import ToolOrchestratorEnvironment
|
| 13 |
+
from env.models import TOOL_IDS, OrchestratorAction
|
| 14 |
+
from tools import build_tool_registry
|
| 15 |
+
|
| 16 |
+
_NON_COMMIT = [t for t in TOOL_IDS if t != "commit"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RandomToolBaseline:
|
| 20 |
+
"""Selects a random tool each step; commits after max_steps_per_question - 1 steps."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, commit_after: int = 3):
|
| 23 |
+
self.commit_after = commit_after
|
| 24 |
+
self._steps_on_q = 0
|
| 25 |
+
|
| 26 |
+
def get_action(self, obs) -> OrchestratorAction:
|
| 27 |
+
if self._steps_on_q >= self.commit_after:
|
| 28 |
+
self._steps_on_q = 0
|
| 29 |
+
return OrchestratorAction(tool_id="commit", answer="I don't know")
|
| 30 |
+
self._steps_on_q += 1
|
| 31 |
+
tool = random.choice(_NON_COMMIT)
|
| 32 |
+
query = obs.question[:100] if hasattr(obs, "question") else ""
|
| 33 |
+
return OrchestratorAction(tool_id=tool, query=query)
|
| 34 |
+
|
| 35 |
+
def reset(self):
|
| 36 |
+
self._steps_on_q = 0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def run_episode(seed: int = 0) -> dict:
|
| 40 |
+
config = EnvConfig(num_questions=5, total_budget=30.0, seed=seed)
|
| 41 |
+
tools = build_tool_registry(config)
|
| 42 |
+
dataset = load_all(max_per_domain=20)
|
| 43 |
+
env = ToolOrchestratorEnvironment(config=config, tools=tools, dataset=dataset)
|
| 44 |
+
agent = RandomToolBaseline(commit_after=config.max_steps_per_question - 1)
|
| 45 |
+
|
| 46 |
+
obs, state = env.reset(seed=seed)
|
| 47 |
+
agent.reset()
|
| 48 |
+
total_reward = 0.0
|
| 49 |
+
done = False
|
| 50 |
+
while not done:
|
| 51 |
+
action = agent.get_action(obs)
|
| 52 |
+
if action.tool_id == "commit":
|
| 53 |
+
agent.reset()
|
| 54 |
+
obs, reward, done, state = env.step(action)
|
| 55 |
+
total_reward += reward
|
| 56 |
+
|
| 57 |
+
return {
|
| 58 |
+
"total_reward": total_reward,
|
| 59 |
+
"accuracy": state.current_accuracy,
|
| 60 |
+
"budget_used": state.budget_spent,
|
| 61 |
+
"questions_answered": state.questions_answered,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
result = run_episode(seed=42)
|
| 67 |
+
print("RandomToolBaseline:", result)
|
ceramic/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .client import CeramicClient, FallbackCeramicClient, SearchResult, get_ceramic_client
|
| 2 |
+
|
| 3 |
+
__all__ = ["CeramicClient", "FallbackCeramicClient", "SearchResult", "get_ceramic_client"]
|
ceramic/client.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ceramic AI search client.
|
| 2 |
+
|
| 3 |
+
Matches the interface used by SearchEconomicsEnv so both environments
|
| 4 |
+
share the same retrieval backend.
|
| 5 |
+
|
| 6 |
+
API key priority:
|
| 7 |
+
1. CERAMIC_API_KEY env var
|
| 8 |
+
2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility with SearchEcon)
|
| 9 |
+
3. Falls back to FallbackCeramicClient (offline / CI, fully deterministic)
|
| 10 |
+
|
| 11 |
+
Ceramic API notes (verified 2025):
|
| 12 |
+
- Endpoint : POST https://api.ceramic.ai/search
|
| 13 |
+
- Body : {"query": "<string>"} (no pagination params supported)
|
| 14 |
+
- Response : {"requestId": "...", "result": {"results": [...], "totalResults": N}}
|
| 15 |
+
- Each result has: title, url, description, score
|
| 16 |
+
- Always returns up to 10 results per call
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import hashlib
|
| 21 |
+
import os
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import List
|
| 24 |
+
|
| 25 |
+
import httpx
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Result model
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class SearchResult:
|
| 34 |
+
title: str
|
| 35 |
+
url: str
|
| 36 |
+
description: str
|
| 37 |
+
score: float = 0.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Live client
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
class CeramicClient:
|
| 45 |
+
"""Thin wrapper around the Ceramic search API."""
|
| 46 |
+
|
| 47 |
+
BASE_URL = "https://api.ceramic.ai"
|
| 48 |
+
|
| 49 |
+
def __init__(self, api_key: str):
|
| 50 |
+
self._key = api_key
|
| 51 |
+
self._client = httpx.Client(
|
| 52 |
+
headers={"Authorization": f"Bearer {api_key}"},
|
| 53 |
+
timeout=10.0,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 57 |
+
"""Search Ceramic and return up to top_k results (max 10)."""
|
| 58 |
+
if not query.strip():
|
| 59 |
+
return []
|
| 60 |
+
resp = self._client.post(
|
| 61 |
+
f"{self.BASE_URL}/search",
|
| 62 |
+
json={"query": query},
|
| 63 |
+
)
|
| 64 |
+
resp.raise_for_status()
|
| 65 |
+
data = resp.json()
|
| 66 |
+
raw = data.get("result", {}).get("results", [])
|
| 67 |
+
results = []
|
| 68 |
+
for item in raw[:top_k]:
|
| 69 |
+
results.append(SearchResult(
|
| 70 |
+
title=item.get("title", ""),
|
| 71 |
+
url=item.get("url", ""),
|
| 72 |
+
description=item.get("description", ""),
|
| 73 |
+
score=float(item.get("score", 0.0)),
|
| 74 |
+
))
|
| 75 |
+
return results
|
| 76 |
+
|
| 77 |
+
def close(self):
|
| 78 |
+
self._client.close()
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
return self
|
| 82 |
+
|
| 83 |
+
def __exit__(self, *args):
|
| 84 |
+
self.close()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Offline fallback
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
class FallbackCeramicClient:
|
| 92 |
+
"""Deterministic offline client used when no API key is available.
|
| 93 |
+
|
| 94 |
+
Generates reproducible fake results via SHA-256 hashing so tests
|
| 95 |
+
and CI runs are stable without network access.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 99 |
+
h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
|
| 100 |
+
results = []
|
| 101 |
+
for i in range(min(top_k, 3)):
|
| 102 |
+
seed = (h + i) % 10_000
|
| 103 |
+
results.append(SearchResult(
|
| 104 |
+
title=f"Result {seed}: {query[:40]}",
|
| 105 |
+
url=f"https://fallback.example.com/doc/{seed}",
|
| 106 |
+
description=f"Offline fallback result #{i+1} for query: {query}",
|
| 107 |
+
score=round(0.9 - i * 0.15, 3),
|
| 108 |
+
))
|
| 109 |
+
return results
|
| 110 |
+
|
| 111 |
+
def close(self):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def __enter__(self):
|
| 115 |
+
return self
|
| 116 |
+
|
| 117 |
+
def __exit__(self, *args):
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
# Factory
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
|
| 126 |
+
"""Return a live CeramicClient if a key is set, otherwise FallbackCeramicClient."""
|
| 127 |
+
key = (
|
| 128 |
+
os.environ.get("CERAMIC_API_KEY")
|
| 129 |
+
or os.environ.get("SEE_CERAMIC_API_KEY")
|
| 130 |
+
)
|
| 131 |
+
if key:
|
| 132 |
+
return CeramicClient(api_key=key)
|
| 133 |
+
return FallbackCeramicClient()
|
client.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility shim — real code lives in ceramic/client.py.
|
| 2 |
+
|
| 3 |
+
Mirrors the SearchEconomicsEnv CeramicClient interface so the two
|
| 4 |
+
environments share the same retrieval backend.
|
| 5 |
+
|
| 6 |
+
Priority for the API key:
|
| 7 |
+
1. CERAMIC_API_KEY env var
|
| 8 |
+
2. SEE_CERAMIC_API_KEY env var (HF Spaces compatibility)
|
| 9 |
+
3. Falls back to FallbackCeramicClient (offline, deterministic)
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import hashlib
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import List, Optional
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Result model
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class SearchResult:
|
| 28 |
+
title: str
|
| 29 |
+
url: str
|
| 30 |
+
description: str
|
| 31 |
+
score: float = 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Live client
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
class CeramicClient:
|
| 39 |
+
"""Thin wrapper around the Ceramic search API."""
|
| 40 |
+
|
| 41 |
+
BASE_URL = "https://api.ceramic.ai/v1"
|
| 42 |
+
|
| 43 |
+
def __init__(self, api_key: str):
|
| 44 |
+
self._key = api_key
|
| 45 |
+
self._client = httpx.Client(
|
| 46 |
+
headers={"Authorization": f"Bearer {api_key}"},
|
| 47 |
+
timeout=10.0,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 51 |
+
resp = self._client.post(
|
| 52 |
+
f"{self.BASE_URL}/search",
|
| 53 |
+
json={"query": query, "top_k": top_k},
|
| 54 |
+
)
|
| 55 |
+
resp.raise_for_status()
|
| 56 |
+
data = resp.json()
|
| 57 |
+
results = []
|
| 58 |
+
for item in data.get("results", []):
|
| 59 |
+
results.append(SearchResult(
|
| 60 |
+
title=item.get("title", ""),
|
| 61 |
+
url=item.get("url", ""),
|
| 62 |
+
description=item.get("description", ""),
|
| 63 |
+
score=float(item.get("score", 0.0)),
|
| 64 |
+
))
|
| 65 |
+
return results
|
| 66 |
+
|
| 67 |
+
def close(self):
|
| 68 |
+
self._client.close()
|
| 69 |
+
|
| 70 |
+
def __enter__(self):
|
| 71 |
+
return self
|
| 72 |
+
|
| 73 |
+
def __exit__(self, *args):
|
| 74 |
+
self.close()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
# Offline fallback
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
class FallbackCeramicClient:
|
| 82 |
+
"""Deterministic offline client — used when no API key is set."""
|
| 83 |
+
|
| 84 |
+
def search(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
| 85 |
+
# Stable hash → reproducible fake results per query
|
| 86 |
+
h = int(hashlib.sha256(query.encode()).hexdigest(), 16)
|
| 87 |
+
results = []
|
| 88 |
+
for i in range(min(top_k, 3)):
|
| 89 |
+
seed = (h + i) % 10_000
|
| 90 |
+
results.append(SearchResult(
|
| 91 |
+
title=f"Result {seed}: {query[:40]}",
|
| 92 |
+
url=f"https://fallback.example.com/doc/{seed}",
|
| 93 |
+
description=f"Offline fallback result #{i+1} for query: {query}",
|
| 94 |
+
score=round(0.9 - i * 0.15, 3),
|
| 95 |
+
))
|
| 96 |
+
return results
|
| 97 |
+
|
| 98 |
+
def close(self):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def __enter__(self):
|
| 102 |
+
return self
|
| 103 |
+
|
| 104 |
+
def __exit__(self, *args):
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Factory
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
_DEFAULT_KEY = "cer_sk_live_543fe74e79df_eyJvcmdfaWQiOiJvcmdfMDFLTlpINkU5RVNDTUowUUoyREpINFZWWEYiLCJrZXlfaWQiOiI1NDNmZTc0ZTc5ZGYifQ.k8I4Aljsk29y4Uki37Wxfd7QZHs40XSJVNBNnfksCtM"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_ceramic_client() -> CeramicClient | FallbackCeramicClient:
|
| 116 |
+
key = (
|
| 117 |
+
os.environ.get("CERAMIC_API_KEY")
|
| 118 |
+
or os.environ.get("SEE_CERAMIC_API_KEY")
|
| 119 |
+
or _DEFAULT_KEY
|
| 120 |
+
)
|
| 121 |
+
if key:
|
| 122 |
+
return CeramicClient(api_key=key)
|
| 123 |
+
return FallbackCeramicClient()
|
data/__init__.py
ADDED
|
File without changes
|
data/loader.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-domain dataset loader for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
Returns a flat list of question dicts, each with a 'domain' key.
|
| 4 |
+
Adapted from CostAwareToolEnv/scripts/process_datasets.py.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import re
|
| 10 |
+
import string
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# HuggingFace loader helper
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def _hf_load(repo_id: str, config: Optional[str], split: str):
|
| 19 |
+
import datasets as hf
|
| 20 |
+
kwargs: Dict[str, Any] = {"split": split, "trust_remote_code": True}
|
| 21 |
+
if config:
|
| 22 |
+
kwargs["name"] = config
|
| 23 |
+
return hf.load_dataset(repo_id, **kwargs)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# MATH (levels 3-5)
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def _extract_boxed(solution: str):
|
| 31 |
+
for cmd in ("boxed", "fbox"):
|
| 32 |
+
marker = f"\\{cmd}" + "{"
|
| 33 |
+
start = solution.rfind(marker)
|
| 34 |
+
if start == -1:
|
| 35 |
+
continue
|
| 36 |
+
idx = start + len(marker) - 1
|
| 37 |
+
depth = 0
|
| 38 |
+
for i in range(idx, len(solution)):
|
| 39 |
+
if solution[i] == "{":
|
| 40 |
+
depth += 1
|
| 41 |
+
elif solution[i] == "}":
|
| 42 |
+
depth -= 1
|
| 43 |
+
if depth == 0:
|
| 44 |
+
return solution[i + 1 - (i - idx):i].strip()
|
| 45 |
+
# fallback: last non-empty line
|
| 46 |
+
lines = [l.strip() for l in solution.splitlines() if l.strip()]
|
| 47 |
+
return lines[-1] if lines else ""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_math(split: str, max_rows: int) -> List[Dict]:
|
| 51 |
+
candidates = [
|
| 52 |
+
("DigitalLearningGmbH/MATH-lighteval", "default", "train"),
|
| 53 |
+
("lighteval/MATH-Hard", "default", "train"),
|
| 54 |
+
("hendrycks/competition_math", None, "train"),
|
| 55 |
+
]
|
| 56 |
+
dataset = None
|
| 57 |
+
for repo_id, cfg, spl in candidates:
|
| 58 |
+
try:
|
| 59 |
+
dataset = _hf_load(repo_id, cfg, spl)
|
| 60 |
+
break
|
| 61 |
+
except Exception:
|
| 62 |
+
continue
|
| 63 |
+
if dataset is None:
|
| 64 |
+
return []
|
| 65 |
+
|
| 66 |
+
rows = []
|
| 67 |
+
for ex in dataset:
|
| 68 |
+
level_text = str(ex.get("level", ""))
|
| 69 |
+
m = re.search(r"(\d+)", level_text)
|
| 70 |
+
if not m or int(m.group(1)) not in (3, 4, 5):
|
| 71 |
+
continue
|
| 72 |
+
answer = _extract_boxed(str(ex.get("solution", "")))
|
| 73 |
+
rows.append({
|
| 74 |
+
"question": str(ex.get("problem", "")).strip(),
|
| 75 |
+
"answer": answer,
|
| 76 |
+
"domain": "math",
|
| 77 |
+
"difficulty": m.group(1),
|
| 78 |
+
"subject": str(ex.get("type", "")),
|
| 79 |
+
"source": "math",
|
| 80 |
+
})
|
| 81 |
+
if len(rows) >= max_rows:
|
| 82 |
+
break
|
| 83 |
+
return rows
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# HotpotQA
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
def _load_hotpotqa(split: str, max_rows: int) -> List[Dict]:
|
| 91 |
+
hf_split = "train" if split in ("train", "validation") else split
|
| 92 |
+
dataset = None
|
| 93 |
+
for cfg in ("distractor", "fullwiki"):
|
| 94 |
+
try:
|
| 95 |
+
dataset = _hf_load("hotpotqa/hotpot_qa", cfg, hf_split)
|
| 96 |
+
break
|
| 97 |
+
except Exception:
|
| 98 |
+
continue
|
| 99 |
+
if dataset is None:
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
subset = dataset.shuffle(seed=42).select(range(min(max_rows, len(dataset))))
|
| 103 |
+
rows = []
|
| 104 |
+
for ex in subset:
|
| 105 |
+
rows.append({
|
| 106 |
+
"question": str(ex.get("question", "")).strip(),
|
| 107 |
+
"answer": str(ex.get("answer", "")).strip(),
|
| 108 |
+
"domain": "hotpotqa",
|
| 109 |
+
"difficulty": str(ex.get("level", "")),
|
| 110 |
+
"type": str(ex.get("type", "")),
|
| 111 |
+
"source": "hotpotqa",
|
| 112 |
+
})
|
| 113 |
+
return rows
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
# GPQA
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
def _resolve_gpqa_answer(ex: Dict) -> str:
|
| 121 |
+
val = str(ex.get("Correct Answer", "")).strip()
|
| 122 |
+
if val.upper() in {"A", "B", "C", "D"}:
|
| 123 |
+
mapping = {
|
| 124 |
+
"A": str(ex.get("Answer A", "")),
|
| 125 |
+
"B": str(ex.get("Answer B", "")),
|
| 126 |
+
"C": str(ex.get("Answer C", "")),
|
| 127 |
+
"D": str(ex.get("Answer D", "")),
|
| 128 |
+
}
|
| 129 |
+
return mapping.get(val.upper(), val).strip()
|
| 130 |
+
return val
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _load_gpqa(split: str, max_rows: int) -> List[Dict]:
|
| 134 |
+
dataset = None
|
| 135 |
+
for repo in ("Idavidrein/gpqa", "Wanfq/gpqa"):
|
| 136 |
+
for cfg in ("gpqa_diamond", "gpqa_main"):
|
| 137 |
+
try:
|
| 138 |
+
dataset = _hf_load(repo, cfg, "train")
|
| 139 |
+
break
|
| 140 |
+
except Exception:
|
| 141 |
+
continue
|
| 142 |
+
if dataset is not None:
|
| 143 |
+
break
|
| 144 |
+
if dataset is None:
|
| 145 |
+
return []
|
| 146 |
+
|
| 147 |
+
rows = []
|
| 148 |
+
for ex in dataset:
|
| 149 |
+
answer = _resolve_gpqa_answer(ex)
|
| 150 |
+
rows.append({
|
| 151 |
+
"question": str(ex.get("Question", "")).strip(),
|
| 152 |
+
"answer": answer,
|
| 153 |
+
"domain": "gpqa",
|
| 154 |
+
"difficulty": "graduate",
|
| 155 |
+
"source": "gpqa",
|
| 156 |
+
})
|
| 157 |
+
if len(rows) >= max_rows:
|
| 158 |
+
break
|
| 159 |
+
return rows
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# HumanEval
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def _load_humaneval(split: str, max_rows: int) -> List[Dict]:
|
| 167 |
+
dataset = None
|
| 168 |
+
for repo in ("openai/openai_humaneval", "openai/human-eval"):
|
| 169 |
+
try:
|
| 170 |
+
dataset = _hf_load(repo, None, "test")
|
| 171 |
+
break
|
| 172 |
+
except Exception:
|
| 173 |
+
continue
|
| 174 |
+
if dataset is None:
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
rows = []
|
| 178 |
+
for ex in dataset:
|
| 179 |
+
rows.append({
|
| 180 |
+
"question": str(ex.get("prompt", "")).strip(),
|
| 181 |
+
"answer": str(ex.get("canonical_solution", "")).strip(),
|
| 182 |
+
"domain": "humaneval",
|
| 183 |
+
"difficulty": "code",
|
| 184 |
+
"task_id": str(ex.get("task_id", "")),
|
| 185 |
+
"test": str(ex.get("test", "")),
|
| 186 |
+
"entry_point": str(ex.get("entry_point", "")),
|
| 187 |
+
"source": "humaneval",
|
| 188 |
+
})
|
| 189 |
+
if len(rows) >= max_rows:
|
| 190 |
+
break
|
| 191 |
+
return rows
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
# Synthetic fallback (offline / CI)
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
_SYNTHETIC_TEMPLATES = [
|
| 199 |
+
("What is {a} + {b}?", "{c}", "math"),
|
| 200 |
+
("Who wrote {work}?", "{author}", "hotpotqa"),
|
| 201 |
+
("Solve for x: {a}x + {b} = {c}", "{x}", "math"),
|
| 202 |
+
("What is the capital of {country}?", "{capital}", "hotpotqa"),
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
_SYNTHETIC_DATA = [
|
| 206 |
+
{"a": 12, "b": 7, "c": 19, "work": "Hamlet", "author": "Shakespeare",
|
| 207 |
+
"country": "France", "capital": "Paris", "x": 3},
|
| 208 |
+
{"a": 25, "b": 13, "c": 38, "work": "1984", "author": "George Orwell",
|
| 209 |
+
"country": "Germany", "capital": "Berlin", "x": 5},
|
| 210 |
+
{"a": 100, "b": 44, "c": 144, "work": "The Odyssey", "author": "Homer",
|
| 211 |
+
"country": "Japan", "capital": "Tokyo", "x": 7},
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _synthetic_questions(n: int) -> List[Dict]:
|
| 216 |
+
rows = []
|
| 217 |
+
for i in range(n):
|
| 218 |
+
tmpl, ans_tmpl, domain = _SYNTHETIC_TEMPLATES[i % len(_SYNTHETIC_TEMPLATES)]
|
| 219 |
+
data = _SYNTHETIC_DATA[i % len(_SYNTHETIC_DATA)]
|
| 220 |
+
try:
|
| 221 |
+
question = tmpl.format(**data)
|
| 222 |
+
answer = ans_tmpl.format(**data)
|
| 223 |
+
except KeyError:
|
| 224 |
+
question = f"Synthetic question {i}"
|
| 225 |
+
answer = f"answer_{i}"
|
| 226 |
+
rows.append({
|
| 227 |
+
"question": question,
|
| 228 |
+
"answer": str(answer),
|
| 229 |
+
"domain": domain,
|
| 230 |
+
"difficulty": "easy",
|
| 231 |
+
"source": "synthetic",
|
| 232 |
+
})
|
| 233 |
+
return rows
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
# Public API
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
|
| 240 |
+
_LOADERS = {
|
| 241 |
+
"hotpotqa": _load_hotpotqa,
|
| 242 |
+
"math": _load_math,
|
| 243 |
+
"gpqa": _load_gpqa,
|
| 244 |
+
"humaneval": _load_humaneval,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def load_all(split: str = "validation", max_per_domain: int = 200) -> List[Dict]:
|
| 249 |
+
"""Load all four domains and return a flat list with 'domain' keys.
|
| 250 |
+
|
| 251 |
+
Falls back to synthetic questions if a domain is unavailable.
|
| 252 |
+
"""
|
| 253 |
+
all_questions: List[Dict] = []
|
| 254 |
+
for domain, loader_fn in _LOADERS.items():
|
| 255 |
+
try:
|
| 256 |
+
rows = loader_fn(split, max_per_domain)
|
| 257 |
+
if rows:
|
| 258 |
+
all_questions.extend(rows)
|
| 259 |
+
print(f"[loader] {domain}: {len(rows)} questions")
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError("empty")
|
| 262 |
+
except Exception as exc:
|
| 263 |
+
print(f"[loader] {domain} unavailable ({exc}), using synthetic fallback")
|
| 264 |
+
synth = _synthetic_questions(max(5, max_per_domain // 10))
|
| 265 |
+
for q in synth:
|
| 266 |
+
q["domain"] = domain
|
| 267 |
+
all_questions.extend(synth)
|
| 268 |
+
|
| 269 |
+
random.shuffle(all_questions)
|
| 270 |
+
return all_questions
|
env/__init__.py
ADDED
|
File without changes
|
env/answer_grading.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Answer grading utilities: exact match + token F1.
|
| 2 |
+
|
| 3 |
+
Ported from SearchEconomicsEnv/env/answer_grading.py and adapted for
|
| 4 |
+
multi-domain use (HotpotQA-style EM/F1 + code/math fallback).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
import string
|
| 11 |
+
from collections import Counter
|
| 12 |
+
from typing import Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Normalisation
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
def normalize_answer(text: str) -> list[str]:
|
| 20 |
+
"""Lowercase, strip articles/punctuation, tokenise."""
|
| 21 |
+
text = text.lower().strip()
|
| 22 |
+
# Remove articles
|
| 23 |
+
text = re.sub(r"\b(a|an|the)\b", " ", text)
|
| 24 |
+
# Remove punctuation
|
| 25 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 26 |
+
return text.split()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Metrics
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def exact_match(pred: str, gold: str) -> bool:
|
| 34 |
+
return normalize_answer(pred) == normalize_answer(gold)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def token_f1(pred: str, gold: str) -> float:
|
| 38 |
+
pred_tokens = normalize_answer(pred)
|
| 39 |
+
gold_tokens = normalize_answer(gold)
|
| 40 |
+
if not pred_tokens or not gold_tokens:
|
| 41 |
+
return float(pred_tokens == gold_tokens)
|
| 42 |
+
common = Counter(pred_tokens) & Counter(gold_tokens)
|
| 43 |
+
num_common = sum(common.values())
|
| 44 |
+
if num_common == 0:
|
| 45 |
+
return 0.0
|
| 46 |
+
precision = num_common / len(pred_tokens)
|
| 47 |
+
recall = num_common / len(gold_tokens)
|
| 48 |
+
return 2 * precision * recall / (precision + recall)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Answer extraction
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
def extract_answer(raw: str) -> str:
|
| 56 |
+
"""Pull the answer string out of various agent output formats."""
|
| 57 |
+
# Strip markdown fences
|
| 58 |
+
raw = re.sub(r"```[a-z]*\n?", "", raw).strip()
|
| 59 |
+
|
| 60 |
+
# Try JSON {"answer": ...}
|
| 61 |
+
try:
|
| 62 |
+
parsed = json.loads(raw)
|
| 63 |
+
if isinstance(parsed, dict):
|
| 64 |
+
for key in ("answer", "Answer", "result", "Result"):
|
| 65 |
+
if key in parsed:
|
| 66 |
+
return str(parsed[key]).strip()
|
| 67 |
+
except (json.JSONDecodeError, ValueError):
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
# Prefix patterns
|
| 71 |
+
for prefix in ("Answer:", "Final answer:", "Result:", "Output:"):
|
| 72 |
+
idx = raw.lower().find(prefix.lower())
|
| 73 |
+
if idx != -1:
|
| 74 |
+
return raw[idx + len(prefix):].strip().split("\n")[0].strip()
|
| 75 |
+
|
| 76 |
+
# Last non-empty line
|
| 77 |
+
lines = [line.strip() for line in raw.splitlines() if line.strip()]
|
| 78 |
+
return lines[-1] if lines else raw.strip()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Public entry point
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def grade(predicted: str, ground_truth: str) -> Tuple[bool, float, float]:
|
| 86 |
+
"""Return (exact_match, f1, quality) where quality ∈ [0, 1]."""
|
| 87 |
+
extracted = extract_answer(predicted)
|
| 88 |
+
em = exact_match(extracted, ground_truth)
|
| 89 |
+
f1 = token_f1(extracted, ground_truth)
|
| 90 |
+
quality = 1.0 if em else f1
|
| 91 |
+
return em, f1, quality
|
env/config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
All tuneable parameters live here so training scripts, the server, and
|
| 4 |
+
baselines all read from a single source of truth. Override individual
|
| 5 |
+
fields in /reset via config_overrides, or subclass for experiment sweeps.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class EnvConfig:
|
| 15 |
+
# ── Episode structure ────────────────────────────────────────────────────
|
| 16 |
+
total_budget: float = 50.0 # Total cost units for the whole episode
|
| 17 |
+
num_questions: int = 10 # Questions drawn per episode
|
| 18 |
+
max_steps_per_question: int = 8 # Auto-commit after this many tool calls
|
| 19 |
+
data_split: str = "validation" # HuggingFace dataset split to load
|
| 20 |
+
seed: Optional[int] = None # Global RNG seed (None = random)
|
| 21 |
+
shuffle_questions: bool = True # Shuffle sampled questions each episode
|
| 22 |
+
|
| 23 |
+
# ── Domain mix ──────────────────────────────────────────────────────────
|
| 24 |
+
# Fraction of questions drawn from each dataset. Must sum to ~1.0.
|
| 25 |
+
domain_mix: Dict[str, float] = field(default_factory=lambda: {
|
| 26 |
+
"hotpotqa": 0.4, # Multi-hop factual QA
|
| 27 |
+
"math": 0.3, # Competition math (levels 3-5)
|
| 28 |
+
"gpqa": 0.2, # Graduate-level science
|
| 29 |
+
"humaneval": 0.1, # Python programming tasks
|
| 30 |
+
})
|
| 31 |
+
|
| 32 |
+
# ── Tool costs ──────────────────────────────────────────────────────────
|
| 33 |
+
# Budget units consumed per tool call. Commit is always free.
|
| 34 |
+
tool_costs: Dict[str, float] = field(default_factory=lambda: {
|
| 35 |
+
"ceramic_search": 1.0,
|
| 36 |
+
"wiki_lookup": 0.5,
|
| 37 |
+
"calculator": 0.1,
|
| 38 |
+
"code_executor": 0.3,
|
| 39 |
+
"llm_reason": 2.0,
|
| 40 |
+
"commit": 0.0,
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
+
# ── Reward shaping ───────────────────────────────────────────────────────
|
| 44 |
+
correct_reward: float = 1.0 # Base reward for a correct commit
|
| 45 |
+
incorrect_reward: float = -0.5 # Base reward for a wrong commit
|
| 46 |
+
efficiency_bonus_weight: float = 0.1 # γ: scales the efficiency bonus
|
| 47 |
+
efficiency_bonus_threshold: float = 0.5 # Minimum quality to earn the bonus
|
| 48 |
+
|
| 49 |
+
# ── Grading ─────────────────────────────────────────────────────────────
|
| 50 |
+
# "em_only" → only exact match counts as correct
|
| 51 |
+
# "em_or_f1" → token F1 ≥ f1_count_threshold also counts as correct
|
| 52 |
+
grade_count_correct_mode: str = "em_or_f1"
|
| 53 |
+
f1_count_threshold: float = 0.5
|
env/environment.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core RL environment: ToolOrchestratorEnvironment.
|
| 2 |
+
|
| 3 |
+
Step logic:
|
| 4 |
+
- Agent receives an OrchestratorObservation with the current question,
|
| 5 |
+
budget, context, and available tools.
|
| 6 |
+
- Agent picks a tool_id and optional query / code_snippet / answer.
|
| 7 |
+
- Environment dispatches to the appropriate tool, charges cost, appends
|
| 8 |
+
result to context_window, and returns the next observation + reward.
|
| 9 |
+
- Episode ends when budget is exhausted OR all questions are answered.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
from .answer_grading import grade
|
| 18 |
+
from .config import EnvConfig
|
| 19 |
+
from .models import (
|
| 20 |
+
OrchestratorAction,
|
| 21 |
+
OrchestratorObservation,
|
| 22 |
+
OrchestratorState,
|
| 23 |
+
ToolResult,
|
| 24 |
+
TOOL_IDS,
|
| 25 |
+
)
|
| 26 |
+
from .reward import commit_reward, step_reward
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ToolOrchestratorEnvironment:
|
| 30 |
+
"""
|
| 31 |
+
OpenEnv-compatible RL environment for multi-tool cost-aware QA.
|
| 32 |
+
|
| 33 |
+
Supports external tool injection so the server can wire in live
|
| 34 |
+
Ceramic, code executor, etc. Tools are callables with signature:
|
| 35 |
+
tool_fn(action: OrchestratorAction) -> ToolResult
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
config: Optional[EnvConfig] = None,
|
| 41 |
+
tools: Optional[Dict[str, Any]] = None,
|
| 42 |
+
dataset: Optional[List[Dict[str, Any]]] = None,
|
| 43 |
+
):
|
| 44 |
+
self.config = config or EnvConfig()
|
| 45 |
+
self.tools = tools or {} # tool_id -> callable
|
| 46 |
+
self.dataset = dataset or [] # List of {question, answer, domain, ...}
|
| 47 |
+
self._state: Optional[OrchestratorState] = None
|
| 48 |
+
self._questions: List[Dict[str, Any]] = []
|
| 49 |
+
self._current_q_idx: int = 0
|
| 50 |
+
self._context_window: List[str] = []
|
| 51 |
+
self._tools_used_this_q: List[str] = []
|
| 52 |
+
self._steps_this_q: int = 0
|
| 53 |
+
self._episode_done: bool = False
|
| 54 |
+
|
| 55 |
+
# -----------------------------------------------------------------------
|
| 56 |
+
# Reset
|
| 57 |
+
# -----------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
|
| 60 |
+
import random
|
| 61 |
+
effective_seed = seed if seed is not None else self.config.seed
|
| 62 |
+
rng = random.Random(effective_seed)
|
| 63 |
+
|
| 64 |
+
questions = _sample_questions(self.dataset, self.config, rng)
|
| 65 |
+
self._questions = questions
|
| 66 |
+
self._current_q_idx = 0
|
| 67 |
+
self._episode_done = False
|
| 68 |
+
|
| 69 |
+
self._state = OrchestratorState(
|
| 70 |
+
episode_id=str(uuid.uuid4()),
|
| 71 |
+
total_budget=self.config.total_budget,
|
| 72 |
+
budget_spent=0,
|
| 73 |
+
questions_answered=0,
|
| 74 |
+
total_correct=0,
|
| 75 |
+
current_accuracy=0.0,
|
| 76 |
+
budget_remaining_ratio=1.0,
|
| 77 |
+
current_question_idx=0,
|
| 78 |
+
current_question_steps=0,
|
| 79 |
+
step_count=0,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self._context_window = []
|
| 83 |
+
self._tools_used_this_q = []
|
| 84 |
+
self._steps_this_q = 0
|
| 85 |
+
|
| 86 |
+
obs = self._make_obs(reward=None, question_done=False, done=False)
|
| 87 |
+
return obs, self._state
|
| 88 |
+
|
| 89 |
+
# -----------------------------------------------------------------------
|
| 90 |
+
# Step
|
| 91 |
+
# -----------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
def step(
|
| 94 |
+
self, action: OrchestratorAction
|
| 95 |
+
) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
|
| 96 |
+
if self._episode_done:
|
| 97 |
+
raise RuntimeError("Episode is done. Call reset() first.")
|
| 98 |
+
|
| 99 |
+
tool_id = action.tool_id
|
| 100 |
+
if tool_id not in TOOL_IDS:
|
| 101 |
+
raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
|
| 102 |
+
|
| 103 |
+
state = self._state
|
| 104 |
+
config = self.config
|
| 105 |
+
|
| 106 |
+
if self._current_q_idx >= len(self._questions):
|
| 107 |
+
self._episode_done = True
|
| 108 |
+
raise RuntimeError("Episode is done. Call reset() first.")
|
| 109 |
+
|
| 110 |
+
q_entry = self._questions[self._current_q_idx]
|
| 111 |
+
gold = q_entry["answer"]
|
| 112 |
+
|
| 113 |
+
# ---- Commit ---------------------------------------------------
|
| 114 |
+
if tool_id == "commit":
|
| 115 |
+
raw_pred = action.answer or ""
|
| 116 |
+
em, f1, quality = grade(raw_pred, gold)
|
| 117 |
+
|
| 118 |
+
r = commit_reward(
|
| 119 |
+
quality=quality,
|
| 120 |
+
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 121 |
+
config=config,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
is_correct = (
|
| 125 |
+
em if config.grade_count_correct_mode == "em_only"
|
| 126 |
+
else (em or f1 >= config.f1_count_threshold)
|
| 127 |
+
)
|
| 128 |
+
state.questions_answered += 1
|
| 129 |
+
state.total_correct += int(is_correct)
|
| 130 |
+
state.current_accuracy = state.total_correct / state.questions_answered
|
| 131 |
+
|
| 132 |
+
self._current_q_idx += 1
|
| 133 |
+
self._context_window = []
|
| 134 |
+
self._tools_used_this_q = []
|
| 135 |
+
self._steps_this_q = 0
|
| 136 |
+
|
| 137 |
+
episode_done = (
|
| 138 |
+
self._current_q_idx >= len(self._questions)
|
| 139 |
+
or state.budget_spent >= state.total_budget
|
| 140 |
+
)
|
| 141 |
+
self._episode_done = episode_done
|
| 142 |
+
state.current_question_idx = self._current_q_idx
|
| 143 |
+
state.current_question_steps = 0
|
| 144 |
+
|
| 145 |
+
obs = self._make_obs(
|
| 146 |
+
reward=r,
|
| 147 |
+
question_done=True,
|
| 148 |
+
done=episode_done,
|
| 149 |
+
last_tool_result=ToolResult(
|
| 150 |
+
tool_id="commit", cost=0,
|
| 151 |
+
output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
return obs, r, episode_done, state
|
| 155 |
+
|
| 156 |
+
# ---- Tool call ------------------------------------------------
|
| 157 |
+
cost = config.tool_costs.get(tool_id, 0)
|
| 158 |
+
budget_after = state.budget_spent + cost
|
| 159 |
+
|
| 160 |
+
if budget_after > state.total_budget:
|
| 161 |
+
r = config.incorrect_reward
|
| 162 |
+
self._episode_done = True
|
| 163 |
+
obs = self._make_obs(reward=r, question_done=True, done=True)
|
| 164 |
+
return obs, r, True, state
|
| 165 |
+
|
| 166 |
+
t0 = time.perf_counter()
|
| 167 |
+
tool_fn = self.tools.get(tool_id)
|
| 168 |
+
if tool_fn is None:
|
| 169 |
+
tool_result = ToolResult(
|
| 170 |
+
tool_id=tool_id, cost=cost,
|
| 171 |
+
output="[Tool not available in this environment]",
|
| 172 |
+
latency_s=0.0,
|
| 173 |
+
error="not_available",
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
try:
|
| 177 |
+
tool_result = tool_fn(action)
|
| 178 |
+
tool_result.cost = cost
|
| 179 |
+
except Exception as exc:
|
| 180 |
+
tool_result = ToolResult(
|
| 181 |
+
tool_id=tool_id, cost=cost,
|
| 182 |
+
output=f"[Error: {exc}]",
|
| 183 |
+
latency_s=time.perf_counter() - t0,
|
| 184 |
+
error=str(exc),
|
| 185 |
+
)
|
| 186 |
+
tool_result.latency_s = time.perf_counter() - t0
|
| 187 |
+
|
| 188 |
+
state.budget_spent = budget_after
|
| 189 |
+
state.budget_remaining_ratio = max(
|
| 190 |
+
0.0, (state.total_budget - state.budget_spent) / state.total_budget
|
| 191 |
+
)
|
| 192 |
+
state.step_count += 1
|
| 193 |
+
state.current_question_steps += 1
|
| 194 |
+
self._steps_this_q += 1
|
| 195 |
+
self._tools_used_this_q.append(tool_id)
|
| 196 |
+
self._context_window.append(f"[{tool_id}] {tool_result.output}")
|
| 197 |
+
|
| 198 |
+
r = step_reward(tool_id, config)
|
| 199 |
+
|
| 200 |
+
question_done = self._steps_this_q >= config.max_steps_per_question
|
| 201 |
+
episode_done = (
|
| 202 |
+
state.budget_spent >= state.total_budget
|
| 203 |
+
or (question_done and self._current_q_idx + 1 >= len(self._questions))
|
| 204 |
+
)
|
| 205 |
+
if question_done and not episode_done:
|
| 206 |
+
self._current_q_idx += 1
|
| 207 |
+
state.questions_answered += 1
|
| 208 |
+
self._context_window = []
|
| 209 |
+
self._tools_used_this_q = []
|
| 210 |
+
self._steps_this_q = 0
|
| 211 |
+
state.current_question_idx = self._current_q_idx
|
| 212 |
+
state.current_question_steps = 0
|
| 213 |
+
|
| 214 |
+
self._episode_done = episode_done
|
| 215 |
+
|
| 216 |
+
obs = self._make_obs(
|
| 217 |
+
reward=r,
|
| 218 |
+
question_done=question_done,
|
| 219 |
+
done=episode_done,
|
| 220 |
+
last_tool_result=tool_result,
|
| 221 |
+
)
|
| 222 |
+
return obs, r, episode_done, state
|
| 223 |
+
|
| 224 |
+
# -----------------------------------------------------------------------
|
| 225 |
+
# Internal helpers
|
| 226 |
+
# -----------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
def _make_obs(
|
| 229 |
+
self,
|
| 230 |
+
reward: Optional[float],
|
| 231 |
+
question_done: bool,
|
| 232 |
+
done: bool,
|
| 233 |
+
last_tool_result: Optional[ToolResult] = None,
|
| 234 |
+
) -> OrchestratorObservation:
|
| 235 |
+
state = self._state
|
| 236 |
+
cfg = self.config
|
| 237 |
+
|
| 238 |
+
if 0 <= self._current_q_idx < len(self._questions):
|
| 239 |
+
q_entry = self._questions[self._current_q_idx]
|
| 240 |
+
elif self._questions:
|
| 241 |
+
q_entry = self._questions[-1]
|
| 242 |
+
else:
|
| 243 |
+
q_entry = {"question": "", "answer": "", "domain": ""}
|
| 244 |
+
|
| 245 |
+
return OrchestratorObservation(
|
| 246 |
+
question=q_entry.get("question", ""),
|
| 247 |
+
question_idx=self._current_q_idx,
|
| 248 |
+
domain=q_entry.get("domain", ""),
|
| 249 |
+
question_embedding=[],
|
| 250 |
+
total_budget=cfg.total_budget,
|
| 251 |
+
budget_spent=state.budget_spent,
|
| 252 |
+
budget_remaining=state.total_budget - state.budget_spent,
|
| 253 |
+
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 254 |
+
tools_used_this_question=list(self._tools_used_this_q),
|
| 255 |
+
steps_used_this_question=self._steps_this_q,
|
| 256 |
+
max_steps_per_question=cfg.max_steps_per_question,
|
| 257 |
+
last_tool_result=last_tool_result,
|
| 258 |
+
context_window=list(self._context_window),
|
| 259 |
+
step_idx=state.step_count,
|
| 260 |
+
questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
|
| 261 |
+
questions_answered=state.questions_answered,
|
| 262 |
+
accuracy_so_far=state.current_accuracy,
|
| 263 |
+
question_done=question_done,
|
| 264 |
+
done=done,
|
| 265 |
+
reward=reward,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
# Dataset sampling helper
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
|
| 273 |
+
def _sample_questions(
|
| 274 |
+
dataset: List[Dict[str, Any]],
|
| 275 |
+
config: EnvConfig,
|
| 276 |
+
rng: Any,
|
| 277 |
+
) -> List[Dict[str, Any]]:
|
| 278 |
+
"""Sample `config.num_questions` questions according to domain_mix."""
|
| 279 |
+
by_domain: Dict[str, List[Dict]] = {}
|
| 280 |
+
for item in dataset:
|
| 281 |
+
d = item.get("domain", "hotpotqa")
|
| 282 |
+
by_domain.setdefault(d, []).append(item)
|
| 283 |
+
|
| 284 |
+
selected = []
|
| 285 |
+
for domain, frac in config.domain_mix.items():
|
| 286 |
+
n = round(config.num_questions * frac)
|
| 287 |
+
pool = by_domain.get(domain, [])
|
| 288 |
+
if pool and n > 0:
|
| 289 |
+
selected.extend(rng.sample(pool, min(n, len(pool))))
|
| 290 |
+
|
| 291 |
+
if len(selected) < config.num_questions and dataset:
|
| 292 |
+
remaining = [d for d in dataset if d not in selected]
|
| 293 |
+
rng.shuffle(remaining)
|
| 294 |
+
selected.extend(remaining[: config.num_questions - len(selected)])
|
| 295 |
+
|
| 296 |
+
if config.shuffle_questions:
|
| 297 |
+
rng.shuffle(selected)
|
| 298 |
+
|
| 299 |
+
return selected[: config.num_questions]
|
env/models.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic data models for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
Three main types flow through the environment:
|
| 4 |
+
OrchestratorAction — agent → env (what tool to call)
|
| 5 |
+
OrchestratorObservation — env → agent (what the agent sees)
|
| 6 |
+
OrchestratorState — env → server (full bookkeeping snapshot)
|
| 7 |
+
|
| 8 |
+
ToolResult is returned by every tool and attached to the observation.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
|
| 16 |
+
# Canonical tool IDs — order matters for UI display; must match config.tool_costs keys.
|
| 17 |
+
TOOL_IDS = [
|
| 18 |
+
"ceramic_search", # web retrieval (most useful for HotpotQA)
|
| 19 |
+
"wiki_lookup", # Wikipedia summary (good for entity facts)
|
| 20 |
+
"calculator", # safe AST math eval (essential for MATH)
|
| 21 |
+
"code_executor", # sandboxed Python (HumanEval)
|
| 22 |
+
"llm_reason", # LLM chain-of-thought (GPQA)
|
| 23 |
+
"commit", # submit answer — always free
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class OrchestratorAction(BaseModel):
|
| 28 |
+
"""One step of the agent's interaction with the environment.
|
| 29 |
+
|
| 30 |
+
Fields used per tool:
|
| 31 |
+
ceramic_search → query
|
| 32 |
+
wiki_lookup → query
|
| 33 |
+
calculator → expression (falls back to query if blank)
|
| 34 |
+
code_executor → code_snippet (falls back to query if blank)
|
| 35 |
+
llm_reason → query
|
| 36 |
+
commit → answer
|
| 37 |
+
"""
|
| 38 |
+
tool_id: str
|
| 39 |
+
query: str = ""
|
| 40 |
+
expression: str = ""
|
| 41 |
+
code_snippet: str = ""
|
| 42 |
+
answer: str = ""
|
| 43 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ToolResult(BaseModel):
|
| 47 |
+
"""Output produced by one tool call.
|
| 48 |
+
|
| 49 |
+
Attached to OrchestratorObservation.last_tool_result and also
|
| 50 |
+
appended (as a string) to the context_window.
|
| 51 |
+
"""
|
| 52 |
+
tool_id: str
|
| 53 |
+
output: str = "" # Human-readable result text
|
| 54 |
+
cost: float = 0.0 # Budget units charged (set by environment)
|
| 55 |
+
latency_s: float = 0.0 # Wall-clock seconds (set by environment)
|
| 56 |
+
error: Optional[str] = None # Non-None if the tool call failed
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class OrchestratorObservation(BaseModel):
|
| 60 |
+
"""Everything the agent sees at the start of each step.
|
| 61 |
+
|
| 62 |
+
Designed to be complete: the agent should be able to make an
|
| 63 |
+
informed tool-selection decision using only this observation.
|
| 64 |
+
"""
|
| 65 |
+
# ── Current question ────────────────────────────────────────────────────
|
| 66 |
+
question: str # Full question text
|
| 67 |
+
question_idx: int # Position in the episode (0-indexed)
|
| 68 |
+
domain: str # "hotpotqa" | "math" | "gpqa" | "humaneval"
|
| 69 |
+
question_embedding: List[float] = Field(default_factory=list) # Optional embedding vector
|
| 70 |
+
|
| 71 |
+
# ── Budget ──────────────────────────────────────────────────────────────
|
| 72 |
+
total_budget: float
|
| 73 |
+
budget_spent: float
|
| 74 |
+
budget_remaining: float
|
| 75 |
+
budget_remaining_ratio: float # budget_remaining / total_budget ∈ [0, 1]
|
| 76 |
+
|
| 77 |
+
# ── Progress on the current question ────────────────────────────────────
|
| 78 |
+
tools_used_this_question: List[str] = Field(default_factory=list)
|
| 79 |
+
steps_used_this_question: int = 0
|
| 80 |
+
max_steps_per_question: int = 8
|
| 81 |
+
last_tool_result: Optional[ToolResult] = None
|
| 82 |
+
context_window: List[str] = Field(default_factory=list) # "[tool_id] output" strings
|
| 83 |
+
|
| 84 |
+
# ── Episode-level progress ───────────────────────────────────────────────
|
| 85 |
+
step_idx: int = 0 # Global step counter
|
| 86 |
+
questions_remaining: int = 0 # Questions not yet started
|
| 87 |
+
questions_answered: int = 0 # Questions that received a commit
|
| 88 |
+
accuracy_so_far: float = 0.0 # Running correctness rate
|
| 89 |
+
|
| 90 |
+
# ── Terminal signals ─────────────────────────────────────────────────────
|
| 91 |
+
question_done: bool = False # This question just ended (commit or max_steps)
|
| 92 |
+
done: bool = False # Episode is over
|
| 93 |
+
reward: Optional[float] = None # Reward from the *previous* step
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class OrchestratorState(BaseModel):
|
| 97 |
+
"""Full bookkeeping snapshot — returned alongside observation for logging.
|
| 98 |
+
|
| 99 |
+
Contains all fields needed to reconstruct the episode history without
|
| 100 |
+
digging into the environment's internal attributes.
|
| 101 |
+
"""
|
| 102 |
+
episode_id: str
|
| 103 |
+
total_budget: float
|
| 104 |
+
budget_spent: float = 0.0
|
| 105 |
+
questions_answered: int = 0
|
| 106 |
+
total_correct: int = 0
|
| 107 |
+
current_accuracy: float = 0.0
|
| 108 |
+
budget_remaining_ratio: float = 1.0
|
| 109 |
+
current_question_idx: int = 0
|
| 110 |
+
current_question_steps: int = 0
|
| 111 |
+
step_count: int = 0
|
env/reward.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward functions for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
Adapted from SearchEconomicsEnv/env/reward.py, generalised to a multi-tool
|
| 4 |
+
action space where each tool has its own cost.
|
| 5 |
+
|
| 6 |
+
Two reward signals:
|
| 7 |
+
|
| 8 |
+
step_reward — small negative penalty charged every time the agent calls
|
| 9 |
+
a tool (including commit = 0 cost). Discourages wasted
|
| 10 |
+
calls without forbidding exploration.
|
| 11 |
+
|
| 12 |
+
commit_reward — composite reward awarded when the agent submits an answer.
|
| 13 |
+
Balances answer quality against remaining budget (Weitzman
|
| 14 |
+
style: you earn a bonus for being both correct *and* frugal).
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from .config import EnvConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def step_reward(tool_id: str, config: EnvConfig) -> float:
|
| 22 |
+
"""Return the (negative) cost of calling tool_id.
|
| 23 |
+
|
| 24 |
+
Example: calculator → -0.1, llm_reason → -2.0, commit → 0.0
|
| 25 |
+
"""
|
| 26 |
+
return -config.tool_costs.get(tool_id, 0.0)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def commit_reward(
|
| 30 |
+
quality: float,
|
| 31 |
+
budget_remaining_ratio: float,
|
| 32 |
+
config: EnvConfig,
|
| 33 |
+
) -> float:
|
| 34 |
+
"""Composite reward on commit.
|
| 35 |
+
|
| 36 |
+
Formula
|
| 37 |
+
-------
|
| 38 |
+
base = incorrect_reward + quality × (correct_reward − incorrect_reward)
|
| 39 |
+
η = 1 if quality ≥ efficiency_bonus_threshold, else 0
|
| 40 |
+
bonus = η × efficiency_bonus_weight × budget_remaining_ratio
|
| 41 |
+
R = base + bonus
|
| 42 |
+
|
| 43 |
+
The efficiency bonus (bonus) is only non-zero when the agent both answers
|
| 44 |
+
correctly (quality above threshold) *and* conserves budget. This creates
|
| 45 |
+
a soft incentive to use cheaper tools and commit early when confident.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
quality : float in [0, 1] — max(ExactMatch, TokenF1)
|
| 50 |
+
budget_remaining_ratio: float in [0, 1] — fraction of budget still unspent
|
| 51 |
+
config : EnvConfig
|
| 52 |
+
"""
|
| 53 |
+
q = max(0.0, min(1.0, quality))
|
| 54 |
+
base = config.incorrect_reward + q * (config.correct_reward - config.incorrect_reward)
|
| 55 |
+
eta = 1.0 if q >= config.efficiency_bonus_threshold else 0.0
|
| 56 |
+
bonus = eta * config.efficiency_bonus_weight * budget_remaining_ratio
|
| 57 |
+
return base + bonus
|
environment.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility shim — real code lives in env/environment.py.
|
| 2 |
+
|
| 3 |
+
Step logic:
|
| 4 |
+
- Agent receives an OrchestratorObservation with the current question,
|
| 5 |
+
budget, context, and available tools.
|
| 6 |
+
- Agent picks a tool_id and optional query / code_snippet / answer.
|
| 7 |
+
- Environment dispatches to the appropriate tool, charges cost, appends
|
| 8 |
+
result to context_window, and returns the next observation + reward.
|
| 9 |
+
- Episode ends when budget is exhausted OR all questions are answered.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
from env.answer_grading import grade
|
| 18 |
+
from env.config import EnvConfig
|
| 19 |
+
from env.models import (
|
| 20 |
+
OrchestratorAction,
|
| 21 |
+
OrchestratorObservation,
|
| 22 |
+
OrchestratorState,
|
| 23 |
+
ToolResult,
|
| 24 |
+
TOOL_IDS,
|
| 25 |
+
)
|
| 26 |
+
from env.reward import commit_reward, step_reward
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ToolOrchestratorEnvironment:
|
| 30 |
+
"""
|
| 31 |
+
OpenEnv-compatible RL environment for multi-tool cost-aware QA.
|
| 32 |
+
|
| 33 |
+
Supports external tool injection so the server can wire in live
|
| 34 |
+
Ceramic, code executor, etc. Tools are callables with signature:
|
| 35 |
+
tool_fn(action: OrchestratorAction) -> ToolResult
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
config: Optional[EnvConfig] = None,
|
| 41 |
+
tools: Optional[Dict[str, Any]] = None,
|
| 42 |
+
dataset: Optional[List[Dict[str, Any]]] = None,
|
| 43 |
+
):
|
| 44 |
+
self.config = config or EnvConfig()
|
| 45 |
+
self.tools = tools or {} # tool_id -> callable
|
| 46 |
+
self.dataset = dataset or [] # List of {question, answer, domain}
|
| 47 |
+
self._state: Optional[OrchestratorState] = None
|
| 48 |
+
self._questions: List[Dict[str, Any]] = []
|
| 49 |
+
self._current_q_idx: int = 0
|
| 50 |
+
self._context_window: List[str] = []
|
| 51 |
+
self._tools_used_this_q: List[str] = []
|
| 52 |
+
self._steps_this_q: int = 0
|
| 53 |
+
self._episode_done: bool = False
|
| 54 |
+
|
| 55 |
+
# -----------------------------------------------------------------------
|
| 56 |
+
# Reset
|
| 57 |
+
# -----------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def reset(self, seed: Optional[int] = None) -> Tuple[OrchestratorObservation, OrchestratorState]:
|
| 60 |
+
import random
|
| 61 |
+
rng = random.Random(seed if seed is not None else self.config.seed)
|
| 62 |
+
|
| 63 |
+
# Sample questions according to domain_mix
|
| 64 |
+
questions = _sample_questions(self.dataset, self.config, rng)
|
| 65 |
+
self._questions = questions
|
| 66 |
+
self._current_q_idx = 0
|
| 67 |
+
self._episode_done = False
|
| 68 |
+
|
| 69 |
+
self._state = OrchestratorState(
|
| 70 |
+
episode_id=str(uuid.uuid4()),
|
| 71 |
+
total_budget=self.config.total_budget,
|
| 72 |
+
budget_spent=0,
|
| 73 |
+
questions_answered=0,
|
| 74 |
+
total_correct=0,
|
| 75 |
+
current_accuracy=0.0,
|
| 76 |
+
budget_remaining_ratio=1.0,
|
| 77 |
+
current_question_idx=0,
|
| 78 |
+
current_question_steps=0,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self._context_window = []
|
| 82 |
+
self._tools_used_this_q = []
|
| 83 |
+
self._steps_this_q = 0
|
| 84 |
+
|
| 85 |
+
obs = self._make_obs(reward=None, question_done=False, done=False)
|
| 86 |
+
return obs, self._state
|
| 87 |
+
|
| 88 |
+
# -----------------------------------------------------------------------
|
| 89 |
+
# Step
|
| 90 |
+
# -----------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def step(
|
| 93 |
+
self, action: OrchestratorAction
|
| 94 |
+
) -> Tuple[OrchestratorObservation, float, bool, OrchestratorState]:
|
| 95 |
+
if self._episode_done:
|
| 96 |
+
raise RuntimeError("Episode is done. Call reset() first.")
|
| 97 |
+
|
| 98 |
+
tool_id = action.tool_id
|
| 99 |
+
if tool_id not in TOOL_IDS:
|
| 100 |
+
raise ValueError(f"Unknown tool_id: {tool_id!r}. Valid: {TOOL_IDS}")
|
| 101 |
+
|
| 102 |
+
state = self._state
|
| 103 |
+
config = self.config
|
| 104 |
+
|
| 105 |
+
# Guard against exhausted question list (can happen after last commit)
|
| 106 |
+
if self._current_q_idx >= len(self._questions):
|
| 107 |
+
self._episode_done = True
|
| 108 |
+
raise RuntimeError("Episode is done. Call reset() first.")
|
| 109 |
+
|
| 110 |
+
q_entry = self._questions[self._current_q_idx]
|
| 111 |
+
gold = q_entry["answer"]
|
| 112 |
+
|
| 113 |
+
# ---- Commit ---------------------------------------------------
|
| 114 |
+
if tool_id == "commit":
|
| 115 |
+
raw_pred = action.answer or ""
|
| 116 |
+
em, f1, quality = grade(raw_pred, gold)
|
| 117 |
+
|
| 118 |
+
r = commit_reward(
|
| 119 |
+
quality=quality,
|
| 120 |
+
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 121 |
+
config=config,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Count correct
|
| 125 |
+
is_correct = (
|
| 126 |
+
em if config.grade_count_correct_mode == "em_only"
|
| 127 |
+
else (em or f1 >= config.f1_count_threshold)
|
| 128 |
+
)
|
| 129 |
+
state.questions_answered += 1
|
| 130 |
+
state.total_correct += int(is_correct)
|
| 131 |
+
state.current_accuracy = state.total_correct / state.questions_answered
|
| 132 |
+
|
| 133 |
+
# Advance to next question or end episode
|
| 134 |
+
self._current_q_idx += 1
|
| 135 |
+
self._context_window = []
|
| 136 |
+
self._tools_used_this_q = []
|
| 137 |
+
self._steps_this_q = 0
|
| 138 |
+
|
| 139 |
+
episode_done = (
|
| 140 |
+
self._current_q_idx >= len(self._questions)
|
| 141 |
+
or state.budget_spent >= state.total_budget
|
| 142 |
+
)
|
| 143 |
+
self._episode_done = episode_done
|
| 144 |
+
state.current_question_idx = self._current_q_idx
|
| 145 |
+
state.current_question_steps = 0
|
| 146 |
+
|
| 147 |
+
obs = self._make_obs(
|
| 148 |
+
reward=r,
|
| 149 |
+
question_done=True,
|
| 150 |
+
done=episode_done,
|
| 151 |
+
last_tool_result=ToolResult(
|
| 152 |
+
tool_id="commit", cost=0,
|
| 153 |
+
output=f"EM={em} F1={f1:.3f} quality={quality:.3f}"
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
return obs, r, episode_done, state
|
| 157 |
+
|
| 158 |
+
# ---- Tool call ------------------------------------------------
|
| 159 |
+
cost = config.tool_costs.get(tool_id, 0)
|
| 160 |
+
budget_after = state.budget_spent + cost
|
| 161 |
+
|
| 162 |
+
# If over budget, force commit penalty
|
| 163 |
+
if budget_after > state.total_budget:
|
| 164 |
+
r = config.incorrect_reward
|
| 165 |
+
self._episode_done = True
|
| 166 |
+
obs = self._make_obs(reward=r, question_done=True, done=True)
|
| 167 |
+
return obs, r, True, state
|
| 168 |
+
|
| 169 |
+
# Dispatch tool
|
| 170 |
+
t0 = time.perf_counter()
|
| 171 |
+
tool_fn = self.tools.get(tool_id)
|
| 172 |
+
if tool_fn is None:
|
| 173 |
+
tool_result = ToolResult(
|
| 174 |
+
tool_id=tool_id, cost=cost,
|
| 175 |
+
output="[Tool not available in this environment]",
|
| 176 |
+
latency_s=0.0,
|
| 177 |
+
error="not_available",
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
try:
|
| 181 |
+
tool_result = tool_fn(action)
|
| 182 |
+
tool_result.cost = cost
|
| 183 |
+
except Exception as exc:
|
| 184 |
+
tool_result = ToolResult(
|
| 185 |
+
tool_id=tool_id, cost=cost,
|
| 186 |
+
output=f"[Error: {exc}]",
|
| 187 |
+
latency_s=time.perf_counter() - t0,
|
| 188 |
+
error=str(exc),
|
| 189 |
+
)
|
| 190 |
+
tool_result.latency_s = time.perf_counter() - t0
|
| 191 |
+
|
| 192 |
+
# Charge cost and update state
|
| 193 |
+
state.budget_spent = budget_after
|
| 194 |
+
state.budget_remaining_ratio = max(
|
| 195 |
+
0.0, (state.total_budget - state.budget_spent) / state.total_budget
|
| 196 |
+
)
|
| 197 |
+
state.step_count += 1
|
| 198 |
+
state.current_question_steps += 1
|
| 199 |
+
self._steps_this_q += 1
|
| 200 |
+
self._tools_used_this_q.append(tool_id)
|
| 201 |
+
self._context_window.append(f"[{tool_id}] {tool_result.output}")
|
| 202 |
+
|
| 203 |
+
r = step_reward(tool_id, config)
|
| 204 |
+
|
| 205 |
+
# Auto-commit if max steps reached
|
| 206 |
+
question_done = self._steps_this_q >= config.max_steps_per_question
|
| 207 |
+
episode_done = (
|
| 208 |
+
state.budget_spent >= state.total_budget
|
| 209 |
+
or (question_done and self._current_q_idx + 1 >= len(self._questions))
|
| 210 |
+
)
|
| 211 |
+
if question_done and not episode_done:
|
| 212 |
+
self._current_q_idx += 1
|
| 213 |
+
state.questions_answered += 1
|
| 214 |
+
self._context_window = []
|
| 215 |
+
self._tools_used_this_q = []
|
| 216 |
+
self._steps_this_q = 0
|
| 217 |
+
state.current_question_idx = self._current_q_idx
|
| 218 |
+
state.current_question_steps = 0
|
| 219 |
+
|
| 220 |
+
self._episode_done = episode_done
|
| 221 |
+
|
| 222 |
+
obs = self._make_obs(
|
| 223 |
+
reward=r,
|
| 224 |
+
question_done=question_done,
|
| 225 |
+
done=episode_done,
|
| 226 |
+
last_tool_result=tool_result,
|
| 227 |
+
)
|
| 228 |
+
return obs, r, episode_done, state
|
| 229 |
+
|
| 230 |
+
# -----------------------------------------------------------------------
|
| 231 |
+
# Internal helpers
|
| 232 |
+
# -----------------------------------------------------------------------
|
| 233 |
+
|
| 234 |
+
def _make_obs(
|
| 235 |
+
self,
|
| 236 |
+
reward: Optional[float],
|
| 237 |
+
question_done: bool,
|
| 238 |
+
done: bool,
|
| 239 |
+
last_tool_result: Optional[ToolResult] = None,
|
| 240 |
+
) -> OrchestratorObservation:
|
| 241 |
+
state = self._state
|
| 242 |
+
cfg = self.config
|
| 243 |
+
|
| 244 |
+
if 0 <= self._current_q_idx < len(self._questions):
|
| 245 |
+
q_entry = self._questions[self._current_q_idx]
|
| 246 |
+
elif self._questions:
|
| 247 |
+
# Episode finished — repeat last question info (obs is terminal anyway)
|
| 248 |
+
q_entry = self._questions[-1]
|
| 249 |
+
else:
|
| 250 |
+
q_entry = {"question": "", "answer": "", "domain": ""}
|
| 251 |
+
|
| 252 |
+
return OrchestratorObservation(
|
| 253 |
+
question=q_entry.get("question", ""),
|
| 254 |
+
question_idx=self._current_q_idx,
|
| 255 |
+
domain=q_entry.get("domain", ""),
|
| 256 |
+
question_embedding=[], # populated by server if needed
|
| 257 |
+
total_budget=cfg.total_budget,
|
| 258 |
+
budget_spent=state.budget_spent,
|
| 259 |
+
budget_remaining=state.total_budget - state.budget_spent,
|
| 260 |
+
budget_remaining_ratio=state.budget_remaining_ratio,
|
| 261 |
+
tools_used_this_question=list(self._tools_used_this_q),
|
| 262 |
+
steps_used_this_question=self._steps_this_q,
|
| 263 |
+
max_steps_per_question=cfg.max_steps_per_question,
|
| 264 |
+
last_tool_result=last_tool_result,
|
| 265 |
+
context_window=list(self._context_window),
|
| 266 |
+
step_idx=state.step_count,
|
| 267 |
+
questions_remaining=max(0, len(self._questions) - self._current_q_idx - 1),
|
| 268 |
+
questions_answered=state.questions_answered,
|
| 269 |
+
accuracy_so_far=state.current_accuracy,
|
| 270 |
+
question_done=question_done,
|
| 271 |
+
done=done,
|
| 272 |
+
reward=reward,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
# Dataset sampling helper
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
def _sample_questions(
|
| 281 |
+
dataset: List[Dict[str, Any]],
|
| 282 |
+
config: EnvConfig,
|
| 283 |
+
rng: Any,
|
| 284 |
+
) -> List[Dict[str, Any]]:
|
| 285 |
+
"""Sample `config.num_questions` questions according to domain_mix."""
|
| 286 |
+
by_domain: Dict[str, List[Dict]] = {}
|
| 287 |
+
for item in dataset:
|
| 288 |
+
d = item.get("domain", "hotpotqa")
|
| 289 |
+
by_domain.setdefault(d, []).append(item)
|
| 290 |
+
|
| 291 |
+
selected = []
|
| 292 |
+
for domain, frac in config.domain_mix.items():
|
| 293 |
+
n = round(config.num_questions * frac)
|
| 294 |
+
pool = by_domain.get(domain, [])
|
| 295 |
+
if pool and n > 0:
|
| 296 |
+
selected.extend(rng.sample(pool, min(n, len(pool))))
|
| 297 |
+
|
| 298 |
+
# Guarantee at least num_questions items by filling from the full dataset
|
| 299 |
+
if len(selected) < config.num_questions and dataset:
|
| 300 |
+
remaining = [d for d in dataset if d not in selected]
|
| 301 |
+
rng.shuffle(remaining)
|
| 302 |
+
selected.extend(remaining[: config.num_questions - len(selected)])
|
| 303 |
+
|
| 304 |
+
if config.shuffle_questions:
|
| 305 |
+
rng.shuffle(selected)
|
| 306 |
+
|
| 307 |
+
return selected[: config.num_questions]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec: 1
|
| 2 |
+
app: tool-orchestrator-env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
entrypoint: app:app
|
| 6 |
+
port: 8000
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
numpy>=1.24
|
| 5 |
+
datasets>=2.18.0
|
| 6 |
+
httpx>=0.27.0
|
| 7 |
+
requests>=2.31.0
|
| 8 |
+
together>=1.2.0
|
tools/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool registry for ToolOrchestratorEnv.
|
| 2 |
+
|
| 3 |
+
Each tool is a callable: (action: OrchestratorAction) -> ToolResult
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Dict
|
| 8 |
+
|
| 9 |
+
from env.config import EnvConfig
|
| 10 |
+
from env.models import OrchestratorAction, ToolResult
|
| 11 |
+
|
| 12 |
+
from .calculator import calculator_tool
|
| 13 |
+
from .ceramic_search import make_search_tool
|
| 14 |
+
from .code_executor import code_executor_tool
|
| 15 |
+
from .commit import commit_tool
|
| 16 |
+
from .llm_reason import llm_reason_tool
|
| 17 |
+
from .wiki_lookup import wiki_lookup_tool
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_tool_registry(config: EnvConfig | None = None) -> Dict[str, Callable]:
|
| 21 |
+
"""Return a mapping of tool_id → tool function."""
|
| 22 |
+
return {
|
| 23 |
+
"ceramic_search": make_search_tool(),
|
| 24 |
+
"calculator": calculator_tool,
|
| 25 |
+
"wiki_lookup": wiki_lookup_tool,
|
| 26 |
+
"code_executor": code_executor_tool,
|
| 27 |
+
"llm_reason": llm_reason_tool,
|
| 28 |
+
"commit": commit_tool,
|
| 29 |
+
}
|
tools/calculator.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Safe AST-based calculator tool.
|
| 2 |
+
|
| 3 |
+
Supports arithmetic, comparisons, and basic math functions.
|
| 4 |
+
No exec/eval with arbitrary code — uses ast.literal_eval-style restricted eval.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import ast
|
| 9 |
+
import math
|
| 10 |
+
import operator
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from env.models import OrchestratorAction, ToolResult
|
| 14 |
+
|
| 15 |
+
_SAFE_OPS = {
|
| 16 |
+
ast.Add: operator.add,
|
| 17 |
+
ast.Sub: operator.sub,
|
| 18 |
+
ast.Mult: operator.mul,
|
| 19 |
+
ast.Div: operator.truediv,
|
| 20 |
+
ast.Pow: operator.pow,
|
| 21 |
+
ast.Mod: operator.mod,
|
| 22 |
+
ast.FloorDiv: operator.floordiv,
|
| 23 |
+
ast.USub: operator.neg,
|
| 24 |
+
ast.UAdd: operator.pos,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
_SAFE_FUNCS: dict[str, Any] = {
|
| 28 |
+
"abs": abs, "round": round, "min": min, "max": max,
|
| 29 |
+
"sqrt": math.sqrt, "log": math.log, "log2": math.log2,
|
| 30 |
+
"log10": math.log10, "exp": math.exp,
|
| 31 |
+
"sin": math.sin, "cos": math.cos, "tan": math.tan,
|
| 32 |
+
"floor": math.floor, "ceil": math.ceil,
|
| 33 |
+
"pi": math.pi, "e": math.e,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _safe_eval(node: ast.AST) -> Any:
|
| 38 |
+
if isinstance(node, ast.Expression):
|
| 39 |
+
return _safe_eval(node.body)
|
| 40 |
+
if isinstance(node, ast.Constant):
|
| 41 |
+
return node.value
|
| 42 |
+
if isinstance(node, ast.Name):
|
| 43 |
+
if node.id in _SAFE_FUNCS:
|
| 44 |
+
return _SAFE_FUNCS[node.id]
|
| 45 |
+
raise ValueError(f"Unknown name: {node.id!r}")
|
| 46 |
+
if isinstance(node, ast.BinOp):
|
| 47 |
+
op_type = type(node.op)
|
| 48 |
+
if op_type not in _SAFE_OPS:
|
| 49 |
+
raise ValueError(f"Unsupported operator: {op_type.__name__}")
|
| 50 |
+
return _SAFE_OPS[op_type](_safe_eval(node.left), _safe_eval(node.right))
|
| 51 |
+
if isinstance(node, ast.UnaryOp):
|
| 52 |
+
op_type = type(node.op)
|
| 53 |
+
if op_type not in _SAFE_OPS:
|
| 54 |
+
raise ValueError(f"Unsupported unary: {op_type.__name__}")
|
| 55 |
+
return _SAFE_OPS[op_type](_safe_eval(node.operand))
|
| 56 |
+
if isinstance(node, ast.Call):
|
| 57 |
+
func = _safe_eval(node.func)
|
| 58 |
+
if not callable(func):
|
| 59 |
+
raise ValueError("Not callable")
|
| 60 |
+
args = [_safe_eval(a) for a in node.args]
|
| 61 |
+
return func(*args)
|
| 62 |
+
raise ValueError(f"Unsupported AST node: {type(node).__name__}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def calculator_tool(action: OrchestratorAction) -> ToolResult:
|
| 66 |
+
expr = (action.expression or action.query or "").strip()
|
| 67 |
+
if not expr:
|
| 68 |
+
return ToolResult(tool_id="calculator", output="[No expression provided]", error="empty")
|
| 69 |
+
try:
|
| 70 |
+
tree = ast.parse(expr, mode="eval")
|
| 71 |
+
result = _safe_eval(tree)
|
| 72 |
+
return ToolResult(tool_id="calculator", output=str(result))
|
| 73 |
+
except Exception as exc:
|
| 74 |
+
return ToolResult(tool_id="calculator", output=f"[Calc error: {exc}]", error=str(exc))
|
tools/ceramic_search.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ceramic search tool — wraps CeramicClient."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from ceramic.client import get_ceramic_client
|
| 5 |
+
from env.models import OrchestratorAction, ToolResult
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_search_tool(top_k: int = 3):
|
| 9 |
+
"""Factory: creates a search tool with a shared Ceramic client."""
|
| 10 |
+
client = get_ceramic_client()
|
| 11 |
+
|
| 12 |
+
def _search(action: OrchestratorAction) -> ToolResult:
|
| 13 |
+
query = (action.query or "").strip()
|
| 14 |
+
if not query:
|
| 15 |
+
return ToolResult(
|
| 16 |
+
tool_id="ceramic_search",
|
| 17 |
+
output="[No query provided]",
|
| 18 |
+
error="empty_query",
|
| 19 |
+
)
|
| 20 |
+
try:
|
| 21 |
+
results = client.search(query, top_k=top_k)
|
| 22 |
+
snippets = []
|
| 23 |
+
for r in results:
|
| 24 |
+
snippets.append(f"**{r.title}** ({r.score:.2f})\n{r.description}")
|
| 25 |
+
output = "\n\n".join(snippets) if snippets else "[No results found]"
|
| 26 |
+
return ToolResult(tool_id="ceramic_search", output=output)
|
| 27 |
+
except Exception as exc:
|
| 28 |
+
return ToolResult(
|
| 29 |
+
tool_id="ceramic_search",
|
| 30 |
+
output=f"[Search error: {exc}]",
|
| 31 |
+
error=str(exc),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return _search
|
tools/code_executor.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Restricted Python code executor.
|
| 2 |
+
|
| 3 |
+
Runs code in a sandboxed namespace — blocks os/sys/subprocess imports
|
| 4 |
+
and captures stdout. Intended for math / algorithmic tasks.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import sys
|
| 10 |
+
import contextlib
|
| 11 |
+
|
| 12 |
+
from env.models import OrchestratorAction, ToolResult
|
| 13 |
+
|
| 14 |
+
_BLOCKED_MODULES = frozenset({
|
| 15 |
+
"os", "sys", "subprocess", "socket", "shutil", "pathlib",
|
| 16 |
+
"importlib", "builtins", "ctypes", "multiprocessing", "threading",
|
| 17 |
+
"signal", "pty", "fcntl", "resource", "gc", "inspect",
|
| 18 |
+
})
|
| 19 |
+
|
| 20 |
+
_MAX_OUTPUT_CHARS = 2000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class _BlockedImport:
|
| 24 |
+
"""Raise on any import of blocked modules."""
|
| 25 |
+
def __init__(self, original_import):
|
| 26 |
+
self._orig = original_import
|
| 27 |
+
|
| 28 |
+
def __call__(self, name, *args, **kwargs):
|
| 29 |
+
base = name.split(".")[0]
|
| 30 |
+
if base in _BLOCKED_MODULES:
|
| 31 |
+
raise ImportError(f"Module '{name}' is not allowed in code_executor")
|
| 32 |
+
return self._orig(name, *args, **kwargs)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def code_executor_tool(action: OrchestratorAction) -> ToolResult:
|
| 36 |
+
code = (action.code_snippet or action.query or "").strip()
|
| 37 |
+
if not code:
|
| 38 |
+
return ToolResult(tool_id="code_executor", output="[No code provided]", error="empty")
|
| 39 |
+
|
| 40 |
+
stdout_buf = io.StringIO()
|
| 41 |
+
safe_globals = {
|
| 42 |
+
"__builtins__": {
|
| 43 |
+
k: v for k, v in __builtins__.items() # type: ignore[union-attr]
|
| 44 |
+
if k not in ("open", "exec", "eval", "compile", "__import__")
|
| 45 |
+
} if isinstance(__builtins__, dict) else {
|
| 46 |
+
k: getattr(__builtins__, k) for k in dir(__builtins__)
|
| 47 |
+
if k not in ("open", "exec", "eval", "compile", "__import__")
|
| 48 |
+
},
|
| 49 |
+
"__import__": _BlockedImport(__import__),
|
| 50 |
+
"print": lambda *a, **kw: print(*a, **kw, file=stdout_buf),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
with contextlib.redirect_stdout(stdout_buf):
|
| 55 |
+
exec(compile(code, "<code_executor>", "exec"), safe_globals) # noqa: S102
|
| 56 |
+
output = stdout_buf.getvalue()[:_MAX_OUTPUT_CHARS] or "[Code ran, no output]"
|
| 57 |
+
return ToolResult(tool_id="code_executor", output=output)
|
| 58 |
+
except Exception as exc:
|
| 59 |
+
return ToolResult(
|
| 60 |
+
tool_id="code_executor",
|
| 61 |
+
output=f"[Execution error: {exc}]",
|
| 62 |
+
error=str(exc),
|
| 63 |
+
)
|
tools/commit.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Commit tool — passes the answer through; grading happens in environment.py."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from env.models import OrchestratorAction, ToolResult
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def commit_tool(action: OrchestratorAction) -> ToolResult:
|
| 8 |
+
answer = (action.answer or "").strip()
|
| 9 |
+
return ToolResult(
|
| 10 |
+
tool_id="commit",
|
| 11 |
+
output=f"Committed answer: {answer[:200]}",
|
| 12 |
+
)
|
tools/llm_reason.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM reasoning tool — calls Together AI (or falls back gracefully)."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from env.models import OrchestratorAction, ToolResult
|
| 7 |
+
|
| 8 |
+
_DEFAULT_MODEL = "meta-llama/Llama-3-8b-chat-hf"
|
| 9 |
+
_MAX_TOKENS = 512
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def llm_reason_tool(action: OrchestratorAction) -> ToolResult:
|
| 13 |
+
prompt = (action.query or "").strip()
|
| 14 |
+
if not prompt:
|
| 15 |
+
return ToolResult(tool_id="llm_reason", output="[No prompt provided]", error="empty")
|
| 16 |
+
|
| 17 |
+
api_key = os.environ.get("TOGETHER_API_KEY") or os.environ.get("TOGETHER_KEY")
|
| 18 |
+
if not api_key:
|
| 19 |
+
return ToolResult(
|
| 20 |
+
tool_id="llm_reason",
|
| 21 |
+
output="[LLM reasoning not configured — set TOGETHER_API_KEY]",
|
| 22 |
+
error="no_api_key",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import together # type: ignore
|
| 27 |
+
client = together.Together(api_key=api_key)
|
| 28 |
+
resp = client.chat.completions.create(
|
| 29 |
+
model=_DEFAULT_MODEL,
|
| 30 |
+
messages=[{"role": "user", "content": prompt}],
|
| 31 |
+
max_tokens=_MAX_TOKENS,
|
| 32 |
+
temperature=0.0,
|
| 33 |
+
)
|
| 34 |
+
text = resp.choices[0].message.content or ""
|
| 35 |
+
return ToolResult(tool_id="llm_reason", output=text.strip()[:2000])
|
| 36 |
+
except ImportError:
|
| 37 |
+
return ToolResult(
|
| 38 |
+
tool_id="llm_reason",
|
| 39 |
+
output="[together package not installed — pip install together]",
|
| 40 |
+
error="import_error",
|
| 41 |
+
)
|
| 42 |
+
except Exception as exc:
|
| 43 |
+
return ToolResult(tool_id="llm_reason", output=f"[LLM error: {exc}]", error=str(exc))
|
tools/wiki_lookup.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wikipedia lookup tool — returns the intro paragraph of an article."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import urllib.parse
|
| 5 |
+
import urllib.request
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
from env.models import OrchestratorAction, ToolResult
|
| 9 |
+
|
| 10 |
+
_WIKI_API = "https://en.wikipedia.org/api/rest_v1/page/summary/{}"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def wiki_lookup_tool(action: OrchestratorAction) -> ToolResult:
|
| 14 |
+
query = (action.query or "").strip()
|
| 15 |
+
if not query:
|
| 16 |
+
return ToolResult(tool_id="wiki_lookup", output="[No query provided]", error="empty_query")
|
| 17 |
+
|
| 18 |
+
title = urllib.parse.quote(query.replace(" ", "_"))
|
| 19 |
+
url = _WIKI_API.format(title)
|
| 20 |
+
try:
|
| 21 |
+
req = urllib.request.Request(url, headers={"User-Agent": "ToolOrchestratorEnv/0.1"})
|
| 22 |
+
with urllib.request.urlopen(req, timeout=8) as resp:
|
| 23 |
+
data = json.loads(resp.read().decode())
|
| 24 |
+
extract = data.get("extract", "").strip()
|
| 25 |
+
page_title = data.get("title", query)
|
| 26 |
+
if not extract:
|
| 27 |
+
return ToolResult(tool_id="wiki_lookup", output=f"[No summary found for '{query}']")
|
| 28 |
+
return ToolResult(tool_id="wiki_lookup", output=f"**{page_title}**\n{extract[:800]}")
|
| 29 |
+
except urllib.error.HTTPError as exc:
|
| 30 |
+
if exc.code == 404:
|
| 31 |
+
return ToolResult(
|
| 32 |
+
tool_id="wiki_lookup",
|
| 33 |
+
output=f"[Wikipedia: no article found for '{query}']",
|
| 34 |
+
error="not_found",
|
| 35 |
+
)
|
| 36 |
+
return ToolResult(tool_id="wiki_lookup", output=f"[Wiki HTTP error {exc.code}]", error=str(exc))
|
| 37 |
+
except Exception as exc:
|
| 38 |
+
return ToolResult(tool_id="wiki_lookup", output=f"[Wiki error: {exc}]", error=str(exc))
|