MDIIII's picture
perf: raise rate limit to 200/min for paid-tier models
2f5bdbc
Raw
History Blame Contribute Delete
7.57 kB
import logging
import sys
import os
from contextlib import asynccontextmanager
from datetime import date
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from fastapi import FastAPI, BackgroundTasks, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional
from config import BENCHMARKS, FREE_MODELS, AVAILABLE_MODELS, ASSETS, OPENROUTER_API_KEY
from db.store import init_db, create_run, complete_run, fail_run, get_run, get_leaderboard, get_decisions
from backtest.runner import run_backtest
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s β€” %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
logger.info("DB initialised")
yield
app = FastAPI(
title="CryptoAgentBench API",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Schemas ──────────────────────────────────────────────────────────────────
class BacktestRequest(BaseModel):
benchmark: str = Field(..., description="A, B, or C")
model: str = Field(default="google/gemma-4-31b-it:free")
asset: str = Field(default="BTC/USDT")
start_date: str = Field(default="2024-01-01", description="YYYY-MM-DD")
end_date: str = Field(default="2024-06-30", description="YYYY-MM-DD")
def validate_fields(self):
if self.benchmark not in BENCHMARKS:
raise ValueError(f"benchmark must be one of {BENCHMARKS}")
if self.asset not in ASSETS:
raise ValueError(f"asset must be one of {ASSETS}")
# ── Background task ───────────────────────────────────────────────────────────
def _run_backtest_task(run_id: str, req: BacktestRequest):
try:
result = run_backtest(
benchmark=req.benchmark,
model=req.model,
asset=req.asset,
start_date=req.start_date,
end_date=req.end_date,
)
complete_run(run_id, result)
logger.info(f"Run {run_id} completed. CR={result['metrics'].get('cumulative_return')}")
except Exception as e:
logger.error(f"Run {run_id} failed: {e}", exc_info=True)
fail_run(run_id, str(e))
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/")
def health():
return {
"status": "ok",
"service": "CryptoAgentBench API",
"version": "1.0.0",
"date": date.today().isoformat(),
}
@app.get("/health/llm")
def health_llm():
import requests as req
key = OPENROUTER_API_KEY
if not key:
return {"llm_ok": False, "error": "OPENROUTER_API_KEY not set", "key_prefix": None}
key_prefix = key[:6] + "..." if len(key) > 6 else "(short)"
try:
resp = req.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"},
json={"model": FREE_MODELS[0], "messages": [{"role": "user", "content": "Reply OK"}], "max_tokens": 5},
timeout=20,
)
if resp.status_code == 200:
return {"llm_ok": True, "key_prefix": key_prefix, "status": 200}
return {"llm_ok": False, "key_prefix": key_prefix, "status": resp.status_code, "error": resp.text[:200]}
except Exception as e:
return {"llm_ok": False, "key_prefix": key_prefix, "error": str(e)[:200]}
@app.get("/models")
def list_models():
return {
"free_models": FREE_MODELS,
"paid_models": AVAILABLE_MODELS[len(FREE_MODELS):],
"models": AVAILABLE_MODELS,
"note": "Free models via OpenRouter free tier; paid models are affordable open-source.",
}
@app.get("/benchmarks")
def list_benchmarks():
return {
"benchmarks": {
"A": {
"name": "Baseline",
"description": "Single agent: LLM sees price + indicators directly",
"agents": ["Trader"],
"data": ["OHLCV", "Technical Indicators"],
},
"B": {
"name": "Intermediate",
"description": "Technical Analyst + News Analyst -> Trader",
"agents": ["TechnicalAnalyst", "NewsAnalyst", "Trader"],
"data": ["OHLCV", "Technical Indicators", "News"],
},
"C": {
"name": "Full Multi-Agent",
"description": "Technical + Sentiment + News -> Researcher (bull/bear debate) -> Risk Manager -> Trader",
"agents": ["TechnicalAnalyst", "SentimentAnalyst", "NewsAnalyst", "Researcher", "RiskManager", "Trader"],
"data": ["OHLCV", "Technical Indicators", "News", "Fear & Greed", "Funding Rates"],
},
}
}
@app.post("/backtest")
def start_backtest(req: BacktestRequest, background_tasks: BackgroundTasks):
try:
req.validate_fields()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
run_id = create_run(req.benchmark, req.model, req.asset, req.start_date, req.end_date)
background_tasks.add_task(_run_backtest_task, run_id, req)
return {
"run_id": run_id,
"status": "running",
"message": "Backtest started. Poll /runs/{run_id} for results.",
}
@app.get("/runs/{run_id}")
def get_run_detail(run_id: str):
run = get_run(run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
run_out = {k: v for k, v in run.items() if k not in ("equity_curve", "hodl_curve")}
run_out["equity_curve"] = run.get("equity_curve", [])
run_out["hodl_curve"] = run.get("hodl_curve", [])
return run_out
@app.get("/runs/{run_id}/decisions")
def get_run_decisions(run_id: str):
run = get_run(run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
decisions = get_decisions(run_id)
return {"run_id": run_id, "decisions": decisions}
@app.get("/leaderboard")
def leaderboard():
runs = get_leaderboard()
board = []
for r in runs:
metrics = r.get("metrics", {}) or {}
board.append({
"run_id": r["id"],
"benchmark": r["benchmark"],
"model": r["model"],
"asset": r["asset"],
"start_date": r.get("start_date"),
"end_date": r.get("end_date"),
"cumulative_return": metrics.get("cumulative_return"),
"sharpe_ratio": metrics.get("sharpe_ratio"),
"sortino_ratio": metrics.get("sortino_ratio"),
"max_drawdown": metrics.get("max_drawdown"),
"win_rate": metrics.get("win_rate"),
"num_trades": metrics.get("num_trades"),
"hodl_return": metrics.get("hodl_return"),
"alpha": metrics.get("alpha"),
"final_value": metrics.get("final_value"),
"completed_at": r.get("completed_at"),
})
return {"leaderboard": board, "total": len(board)}