#!/usr/bin/env python3 """ Model benchmark harness. Runs every audio file in `backend/tests/fixtures/{ai,human}/` against the deployed backend's `/analyze/upload` endpoint and reports: * Confusion matrix (TP / FP / TN / FN) * Accuracy, precision, recall, F1 * Per-clip table: expected vs. observed + raw scores * Score distribution histogram (text bar chart) * CSV export for spreadsheet analysis Usage: export DETECTOR_API_URL='https://michal-giza-audio-detector-backend.hf.space' export DETECTOR_API_KEY='...' # 1. Drop AI clips into backend/tests/fixtures/ai/*.{mp3,wav,m4a} # 2. Drop HUMAN clips into backend/tests/fixtures/human/*.{mp3,wav,m4a} # 3. Run: python3 benchmark.py # verbose python3 benchmark.py --csv results.csv # also write CSV python3 benchmark.py --threshold 0.65 # explore other decision thresholds Exit code 0 on benchmark completion (regardless of model quality). Exit 2 if no fixtures are present. """ import argparse import csv import os import sys import time from pathlib import Path from typing import Iterator import requests # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- GREEN = "\033[92m" RED = "\033[91m" YELLOW = "\033[93m" CYAN = "\033[96m" BOLD = "\033[1m" RESET = "\033[0m" BASE_URL = os.environ.get("DETECTOR_API_URL", "").rstrip("/") API_KEY = os.environ.get("DETECTOR_API_KEY", "") FIXTURES_DIR = Path(__file__).parent / "fixtures" AUDIO_EXTS = {".mp3", ".wav", ".m4a", ".aac", ".flac", ".ogg"} MIME_FOR_EXT = { ".mp3": "audio/mpeg", ".wav": "audio/wav", ".m4a": "audio/mp4", ".aac": "audio/aac", ".flac": "audio/flac", ".ogg": "audio/ogg", } # --------------------------------------------------------------------------- # Model-under-test wrapper # --------------------------------------------------------------------------- def analyze_clip(path: Path, timeout: int = 120) -> dict: """POST one clip to /analyze/upload. Returns the JSON body or raises.""" mime = MIME_FOR_EXT.get(path.suffix.lower(), "application/octet-stream") with path.open("rb") as f: resp = requests.post( f"{BASE_URL}/analyze/upload", headers={"X-Api-Key": API_KEY}, files={"file": (path.name, f, mime)}, timeout=timeout, ) if resp.status_code != 200: raise RuntimeError( f"HTTP {resp.status_code}: {resp.text[:200]}" ) return resp.json() def iter_fixtures() -> Iterator[tuple[Path, bool]]: """Yield (path, is_ai_expected) for every clip under fixtures/.""" for subdir, expected_ai in [("ai", True), ("human", False)]: root = FIXTURES_DIR / subdir if not root.exists(): continue for p in sorted(root.iterdir()): if p.is_file() and p.suffix.lower() in AUDIO_EXTS: yield p, expected_ai # --------------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------------- def compute_metrics(rows: list[dict], threshold: float) -> dict: """Compute confusion matrix + derived rates for a given decision threshold.""" tp = fp = tn = fn = 0 for r in rows: predicted_ai = r["confidence"] >= threshold actual_ai = r["expected_ai"] if predicted_ai and actual_ai: tp += 1 elif predicted_ai and not actual_ai: fp += 1 elif not predicted_ai and not actual_ai: tn += 1 else: fn += 1 total = tp + fp + tn + fn accuracy = (tp + tn) / total if total else 0.0 precision = tp / (tp + fp) if (tp + fp) else 0.0 recall = tp / (tp + fn) if (tp + fn) else 0.0 f1 = ( 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 ) return { "tp": tp, "fp": fp, "tn": tn, "fn": fn, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "total": total, } def text_histogram(values: list[float], width: int = 40, buckets: int = 20) -> str: """Tiny ASCII histogram of [0..1] scores.""" if not values: return "(no data)" counts = [0] * buckets for v in values: idx = min(int(v * buckets), buckets - 1) counts[idx] += 1 peak = max(counts) or 1 lines = [] for i, c in enumerate(counts): lo = i / buckets hi = (i + 1) / buckets bar = "█" * int(c / peak * width) lines.append(f" [{lo:.2f}-{hi:.2f}) {c:3d} {bar}") return "\n".join(lines) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold on `confidence` (default 0.5)") parser.add_argument("--csv", type=Path, default=None, help="Optional CSV export path") parser.add_argument("--sweep", action="store_true", help="Also show metrics at 9 thresholds 0.1..0.9") args = parser.parse_args() if not BASE_URL or not API_KEY: print("DETECTOR_API_URL and DETECTOR_API_KEY must be set.", file=sys.stderr) return 2 fixtures = list(iter_fixtures()) if not fixtures: print(f"{YELLOW}No fixtures found in {FIXTURES_DIR}/.{RESET}", file=sys.stderr) print(" Expected layout:", file=sys.stderr) print(f" {FIXTURES_DIR}/ai/*.mp3", file=sys.stderr) print(f" {FIXTURES_DIR}/human/*.mp3", file=sys.stderr) return 2 n_ai = sum(1 for _, is_ai in fixtures if is_ai) n_human = len(fixtures) - n_ai print(f"{BOLD}Benchmark — {BASE_URL}{RESET}") print(f" fixtures: {len(fixtures)} ({n_ai} AI, {n_human} human)") print(f" threshold: {args.threshold}") print() # --- Run --- rows: list[dict] = [] print(f"{BOLD}{'path':<45} {'expect':<7} {'conf':<6} {'wav2vec':<7} {'fp':<6} {'verdict':<7}{RESET}") print("-" * 86) for path, expected_ai in fixtures: rel = path.relative_to(FIXTURES_DIR) try: start = time.time() body = analyze_clip(path) elapsed = time.time() - start conf = body["confidence"] details = body.get("details", {}) wav2vec = details.get("wav2vec2_score", float("nan")) fp_score = details.get("fingerprint_score", float("nan")) predicted = conf >= args.threshold correct = predicted == expected_ai verdict = "AI" if predicted else "HUMAN" color = GREEN if correct else RED exp_label = "AI" if expected_ai else "HUMAN" print( f"{color}{str(rel):<45} {exp_label:<7} {conf:<6.3f} " f"{wav2vec:<7.3f} {fp_score:<6.3f} {verdict:<7}{RESET} " f"({elapsed:.1f}s)" ) rows.append({ "path": str(rel), "expected_ai": expected_ai, "confidence": conf, "wav2vec2_score": wav2vec, "fingerprint_score": fp_score, "elapsed_seconds": elapsed, }) except Exception as e: print(f"{RED}{str(rel):<45} ERROR: {e}{RESET}") rows.append({ "path": str(rel), "expected_ai": expected_ai, "confidence": float("nan"), "wav2vec2_score": float("nan"), "fingerprint_score": float("nan"), "elapsed_seconds": 0.0, "error": str(e), }) # --- Metrics --- clean = [r for r in rows if "error" not in r] if not clean: print(f"\n{RED}No successful runs.{RESET}") return 1 metrics = compute_metrics(clean, args.threshold) print() print(f"{BOLD}Confusion matrix @ threshold={args.threshold}{RESET}") print(f" predicted AI predicted HUMAN") print(f" actual AI {metrics['tp']:>4d} {metrics['fn']:>4d}") print(f" actual HUMAN {metrics['fp']:>4d} {metrics['tn']:>4d}") print() print(f" accuracy {metrics['accuracy']:.3f}") print(f" precision {metrics['precision']:.3f} (of predicted-AI, how many were AI)") print(f" recall {metrics['recall']:.3f} (of actual-AI, how many we caught)") print(f" f1 {metrics['f1']:.3f}") # --- Score distributions (this is what reveals whether the model discriminates) --- ai_scores = [r["wav2vec2_score"] for r in clean if r["expected_ai"]] human_scores = [r["wav2vec2_score"] for r in clean if not r["expected_ai"]] print() print(f"{BOLD}wav2vec2 score distribution — AI clips (n={len(ai_scores)}){RESET}") print(text_histogram(ai_scores)) print() print(f"{BOLD}wav2vec2 score distribution — HUMAN clips (n={len(human_scores)}){RESET}") print(text_histogram(human_scores)) print() # Quick sanity read — means overlap = model doesn't discriminate. if ai_scores and human_scores: mean_ai = sum(ai_scores) / len(ai_scores) mean_human = sum(human_scores) / len(human_scores) separation = abs(mean_ai - mean_human) print(f" mean(AI wav2vec2) = {mean_ai:.3f}") print(f" mean(HUMAN wav2vec2) = {mean_human:.3f}") print(f" separation = {separation:.3f}") if separation < 0.1: print(f" {RED}→ model does not discriminate — replace it.{RESET}") elif separation < 0.3: print(f" {YELLOW}→ weak discrimination — consider alternatives.{RESET}") else: print(f" {GREEN}→ meaningful discrimination.{RESET}") # --- Threshold sweep --- if args.sweep: print() print(f"{BOLD}Threshold sweep{RESET}") print(f" {'t':<6} {'accuracy':<10} {'precision':<11} {'recall':<8} {'f1':<6}") for t in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: m = compute_metrics(clean, t) print(f" {t:<6.2f} {m['accuracy']:<10.3f} " f"{m['precision']:<11.3f} {m['recall']:<8.3f} {m['f1']:<6.3f}") # --- CSV export --- if args.csv: with args.csv.open("w", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "path", "expected_ai", "confidence", "wav2vec2_score", "fingerprint_score", "elapsed_seconds", "error", ], ) writer.writeheader() for r in rows: writer.writerow({k: r.get(k, "") for k in writer.fieldnames}) print() print(f"CSV written to {args.csv}") return 0 if __name__ == "__main__": sys.exit(main())