Spaces:
Running
Running
File size: 7,574 Bytes
68025ee 2f5bdbc 68025ee 2f5bdbc 68025ee 2f5bdbc 68025ee 2f5bdbc 68025ee 2f5bdbc 68025ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | import logging
import sys
import os
from contextlib import asynccontextmanager
from datetime import date
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from fastapi import FastAPI, BackgroundTasks, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional
from config import BENCHMARKS, FREE_MODELS, AVAILABLE_MODELS, ASSETS, OPENROUTER_API_KEY
from db.store import init_db, create_run, complete_run, fail_run, get_run, get_leaderboard, get_decisions
from backtest.runner import run_backtest
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s β %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
logger.info("DB initialised")
yield
app = FastAPI(
title="CryptoAgentBench API",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ββ Schemas ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class BacktestRequest(BaseModel):
benchmark: str = Field(..., description="A, B, or C")
model: str = Field(default="google/gemma-4-31b-it:free")
asset: str = Field(default="BTC/USDT")
start_date: str = Field(default="2024-01-01", description="YYYY-MM-DD")
end_date: str = Field(default="2024-06-30", description="YYYY-MM-DD")
def validate_fields(self):
if self.benchmark not in BENCHMARKS:
raise ValueError(f"benchmark must be one of {BENCHMARKS}")
if self.asset not in ASSETS:
raise ValueError(f"asset must be one of {ASSETS}")
# ββ Background task βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _run_backtest_task(run_id: str, req: BacktestRequest):
try:
result = run_backtest(
benchmark=req.benchmark,
model=req.model,
asset=req.asset,
start_date=req.start_date,
end_date=req.end_date,
)
complete_run(run_id, result)
logger.info(f"Run {run_id} completed. CR={result['metrics'].get('cumulative_return')}")
except Exception as e:
logger.error(f"Run {run_id} failed: {e}", exc_info=True)
fail_run(run_id, str(e))
# ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/")
def health():
return {
"status": "ok",
"service": "CryptoAgentBench API",
"version": "1.0.0",
"date": date.today().isoformat(),
}
@app.get("/health/llm")
def health_llm():
import requests as req
key = OPENROUTER_API_KEY
if not key:
return {"llm_ok": False, "error": "OPENROUTER_API_KEY not set", "key_prefix": None}
key_prefix = key[:6] + "..." if len(key) > 6 else "(short)"
try:
resp = req.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"},
json={"model": FREE_MODELS[0], "messages": [{"role": "user", "content": "Reply OK"}], "max_tokens": 5},
timeout=20,
)
if resp.status_code == 200:
return {"llm_ok": True, "key_prefix": key_prefix, "status": 200}
return {"llm_ok": False, "key_prefix": key_prefix, "status": resp.status_code, "error": resp.text[:200]}
except Exception as e:
return {"llm_ok": False, "key_prefix": key_prefix, "error": str(e)[:200]}
@app.get("/models")
def list_models():
return {
"free_models": FREE_MODELS,
"paid_models": AVAILABLE_MODELS[len(FREE_MODELS):],
"models": AVAILABLE_MODELS,
"note": "Free models via OpenRouter free tier; paid models are affordable open-source.",
}
@app.get("/benchmarks")
def list_benchmarks():
return {
"benchmarks": {
"A": {
"name": "Baseline",
"description": "Single agent: LLM sees price + indicators directly",
"agents": ["Trader"],
"data": ["OHLCV", "Technical Indicators"],
},
"B": {
"name": "Intermediate",
"description": "Technical Analyst + News Analyst -> Trader",
"agents": ["TechnicalAnalyst", "NewsAnalyst", "Trader"],
"data": ["OHLCV", "Technical Indicators", "News"],
},
"C": {
"name": "Full Multi-Agent",
"description": "Technical + Sentiment + News -> Researcher (bull/bear debate) -> Risk Manager -> Trader",
"agents": ["TechnicalAnalyst", "SentimentAnalyst", "NewsAnalyst", "Researcher", "RiskManager", "Trader"],
"data": ["OHLCV", "Technical Indicators", "News", "Fear & Greed", "Funding Rates"],
},
}
}
@app.post("/backtest")
def start_backtest(req: BacktestRequest, background_tasks: BackgroundTasks):
try:
req.validate_fields()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
run_id = create_run(req.benchmark, req.model, req.asset, req.start_date, req.end_date)
background_tasks.add_task(_run_backtest_task, run_id, req)
return {
"run_id": run_id,
"status": "running",
"message": "Backtest started. Poll /runs/{run_id} for results.",
}
@app.get("/runs/{run_id}")
def get_run_detail(run_id: str):
run = get_run(run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
run_out = {k: v for k, v in run.items() if k not in ("equity_curve", "hodl_curve")}
run_out["equity_curve"] = run.get("equity_curve", [])
run_out["hodl_curve"] = run.get("hodl_curve", [])
return run_out
@app.get("/runs/{run_id}/decisions")
def get_run_decisions(run_id: str):
run = get_run(run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
decisions = get_decisions(run_id)
return {"run_id": run_id, "decisions": decisions}
@app.get("/leaderboard")
def leaderboard():
runs = get_leaderboard()
board = []
for r in runs:
metrics = r.get("metrics", {}) or {}
board.append({
"run_id": r["id"],
"benchmark": r["benchmark"],
"model": r["model"],
"asset": r["asset"],
"start_date": r.get("start_date"),
"end_date": r.get("end_date"),
"cumulative_return": metrics.get("cumulative_return"),
"sharpe_ratio": metrics.get("sharpe_ratio"),
"sortino_ratio": metrics.get("sortino_ratio"),
"max_drawdown": metrics.get("max_drawdown"),
"win_rate": metrics.get("win_rate"),
"num_trades": metrics.get("num_trades"),
"hodl_return": metrics.get("hodl_return"),
"alpha": metrics.get("alpha"),
"final_value": metrics.get("final_value"),
"completed_at": r.get("completed_at"),
})
return {"leaderboard": board, "total": len(board)}
|