"""End-to-end evaluation harness for the Prompt Squirrel RAG pipeline. Measures per-stage and overall metrics using ground-truth tagged samples from the e621 evaluation dataset. Metrics computed: - Stage 2 (Retrieval): Recall@k — what fraction of ground-truth tags appear among the retrieved candidates - Stage 3 (Selection): Precision, Recall, F1 — how well the final selected tags match the ground truth Usage: # Full end-to-end (Stage 1 + 2 + 3): python scripts/eval_pipeline.py --n 20 # Skip Stage 1 LLM rewrite, use ground-truth tags as retrieval input: python scripts/eval_pipeline.py --n 20 --skip-rewrite # Use a specific caption field: python scripts/eval_pipeline.py --n 20 --caption-field caption_cogvlm Requires: - OPENROUTER_API_KEY env var (for Stage 1 rewrite and Stage 3 selection) - fluffyrock_3m.csv and other retrieval assets in the project root - data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl """ from __future__ import annotations import argparse import json import os import sys import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) os.chdir(_REPO_ROOT) EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" def _flatten_ground_truth_tags(tags_categorized_str: str) -> Set[str]: """Parse the categorized ground-truth JSON string into a flat set of tags.""" if not tags_categorized_str: return set() try: cats = json.loads(tags_categorized_str) except json.JSONDecodeError: return set() tags = set() for tag_list in cats.values(): if isinstance(tag_list, list): for t in tag_list: tags.add(t.strip()) return tags @dataclass class SampleResult: sample_id: Any caption: str ground_truth_tags: Set[str] # Stage 1 rewrite_phrases: List[str] = field(default_factory=list) # Stage 2 retrieved_tags: Set[str] = field(default_factory=set) retrieval_recall: float = 0.0 # Stage 3 selected_tags: Set[str] = field(default_factory=set) selection_precision: float = 0.0 selection_recall: float = 0.0 selection_f1: float = 0.0 # Timing stage1_time: float = 0.0 stage2_time: float = 0.0 stage3_time: float = 0.0 # Errors error: Optional[str] = None def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float, float, float]: """Compute precision, recall, F1.""" if not predicted and not ground_truth: return 1.0, 1.0, 1.0 if not predicted: return 0.0, 0.0, 0.0 if not ground_truth: return 0.0, 0.0, 0.0 tp = len(predicted & ground_truth) precision = tp / len(predicted) recall = tp / len(ground_truth) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1 def run_eval( n_samples: int = 20, caption_field: str = "caption_cogvlm", skip_rewrite: bool = False, allow_nsfw: bool = False, mode: str = "chunked_map_union", chunk_size: int = 60, per_phrase_k: int = 2, temperature: float = 0.0, max_tokens: int = 512, verbose: bool = False, ) -> List[SampleResult]: from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases from psq_rag.llm.select import llm_select_indices def log(msg: str) -> None: if verbose: print(f" {msg}") # Load eval samples if not EVAL_DATA_PATH.is_file(): print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}") sys.exit(1) samples = [] with EVAL_DATA_PATH.open("r", encoding="utf-8") as f: for line in f: if len(samples) >= n_samples: break row = json.loads(line) caption = row.get(caption_field, "") if not caption or not caption.strip(): continue gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", "")) if not gt_tags: continue samples.append({ "id": row.get("id", row.get("row_id", len(samples))), "caption": caption.strip(), "gt_tags": gt_tags, }) print(f"Loaded {len(samples)} samples (caption_field={caption_field})") print(f"skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}") print() results: List[SampleResult] = [] for i, sample in enumerate(samples): sid = sample["id"] caption = sample["caption"] gt_tags = sample["gt_tags"] result = SampleResult( sample_id=sid, caption=caption[:120] + ("..." if len(caption) > 120 else ""), ground_truth_tags=gt_tags, ) print(f"[{i+1}/{len(samples)}] id={sid} gt_tags={len(gt_tags)}") try: # --- Stage 1: LLM Rewrite --- if skip_rewrite: # Use the caption directly as comma-separated phrases phrases = [p.strip() for p in caption.split(",") if p.strip()] # Also split on periods/sentences for natural language captions if len(phrases) <= 1: phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()] result.rewrite_phrases = phrases result.stage1_time = 0.0 else: t0 = time.time() rewritten = llm_rewrite_prompt(caption, log) result.stage1_time = time.time() - t0 if rewritten: result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()] else: result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()] if len(result.rewrite_phrases) <= 1: result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()] if verbose: log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}") # --- Stage 2: Retrieval --- t0 = time.time() retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=result.rewrite_phrases, allow_nsfw_tags=allow_nsfw, global_k=300, verbose=False, ) result.stage2_time = time.time() - t0 if isinstance(retrieval_result, tuple): candidates, _ = retrieval_result else: candidates = retrieval_result result.retrieved_tags = {c.tag for c in candidates} # Retrieval recall: what fraction of ground truth was retrieved if gt_tags: result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags) if verbose: log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}") # --- Stage 3: LLM Selection --- t0 = time.time() picked_indices = llm_select_indices( query_text=caption, candidates=candidates, max_pick=0, log=log, mode=mode, chunk_size=chunk_size, per_phrase_k=per_phrase_k, temperature=temperature, max_tokens=max_tokens, ) result.stage3_time = time.time() - t0 result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set() # Selection metrics p, r, f1 = _compute_metrics(result.selected_tags, gt_tags) result.selection_precision = p result.selection_recall = r result.selection_f1 = f1 print( f" retrieval_recall={result.retrieval_recall:.3f} " f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} " f"selected={len(result.selected_tags)} " f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s" ) except Exception as e: result.error = str(e) print(f" ERROR: {e}") results.append(result) return results def print_summary(results: List[SampleResult]) -> None: """Print aggregate metrics across all samples.""" valid = [r for r in results if r.error is None] errored = [r for r in results if r.error is not None] if not valid: print("\nNo valid results to summarize.") return n = len(valid) avg_retrieval_recall = sum(r.retrieval_recall for r in valid) / n avg_sel_precision = sum(r.selection_precision for r in valid) / n avg_sel_recall = sum(r.selection_recall for r in valid) / n avg_sel_f1 = sum(r.selection_f1 for r in valid) / n avg_retrieved = sum(len(r.retrieved_tags) for r in valid) / n avg_selected = sum(len(r.selected_tags) for r in valid) / n avg_gt = sum(len(r.ground_truth_tags) for r in valid) / n avg_t1 = sum(r.stage1_time for r in valid) / n avg_t2 = sum(r.stage2_time for r in valid) / n avg_t3 = sum(r.stage3_time for r in valid) / n print() print("=" * 60) print(f"EVALUATION SUMMARY ({n} samples, {len(errored)} errors)") print("=" * 60) print() print("Stage 2 - Retrieval:") print(f" Avg recall@300: {avg_retrieval_recall:.4f}") print(f" Avg candidates: {avg_retrieved:.1f}") print() print("Stage 3 - Selection (final output):") print(f" Avg precision: {avg_sel_precision:.4f}") print(f" Avg recall: {avg_sel_recall:.4f}") print(f" Avg F1: {avg_sel_f1:.4f}") print(f" Avg selected tags: {avg_selected:.1f}") print(f" Avg ground-truth tags:{avg_gt:.1f}") print() print("Timing (avg per sample):") print(f" Stage 1 (rewrite): {avg_t1:.2f}s") print(f" Stage 2 (retrieval): {avg_t2:.2f}s") print(f" Stage 3 (selection): {avg_t3:.2f}s") print(f" Total: {avg_t1 + avg_t2 + avg_t3:.2f}s") print() # Show worst and best F1 samples by_f1 = sorted(valid, key=lambda r: r.selection_f1) print("Lowest F1 samples:") for r in by_f1[:3]: print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}") missed = r.ground_truth_tags - r.selected_tags extra = r.selected_tags - r.ground_truth_tags if missed: print(f" missed: {sorted(missed)[:10]}") if extra: print(f" extra: {sorted(extra)[:10]}") print() print("Highest F1 samples:") for r in by_f1[-3:]: print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}") if errored: print() print(f"Errors ({len(errored)}):") for r in errored[:5]: print(f" id={r.sample_id}: {r.error}") print("=" * 60) def main(argv=None) -> int: ap = argparse.ArgumentParser(description="End-to-end pipeline evaluation") ap.add_argument("--n", type=int, default=20, help="Number of samples to evaluate") ap.add_argument("--caption-field", default="caption_cogvlm", choices=["caption_cogvlm", "caption_llm_0", "caption_llm_1", "caption_llm_2", "caption_llm_3", "caption_llm_4", "caption_llm_5", "caption_llm_6", "caption_llm_7"], help="Which caption field to use as input") ap.add_argument("--skip-rewrite", action="store_true", help="Skip Stage 1 LLM rewrite; split caption directly into phrases") ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags") ap.add_argument("--mode", default="chunked_map_union", choices=["single_shot", "chunked_map_union"]) ap.add_argument("--chunk-size", type=int, default=60) ap.add_argument("--per-phrase-k", type=int, default=2) ap.add_argument("--temperature", type=float, default=0.0) ap.add_argument("--max-tokens", type=int, default=512) ap.add_argument("--verbose", "-v", action="store_true", help="Show per-call Stage 3 logs") ap.add_argument("--output", "-o", type=str, default=None, help="Save detailed results as JSONL to this path") args = ap.parse_args(list(argv) if argv is not None else None) results = run_eval( n_samples=args.n, caption_field=args.caption_field, skip_rewrite=args.skip_rewrite, allow_nsfw=args.allow_nsfw, mode=args.mode, chunk_size=args.chunk_size, per_phrase_k=args.per_phrase_k, temperature=args.temperature, max_tokens=args.max_tokens, verbose=args.verbose, ) print_summary(results) # Optionally save detailed results if args.output: out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8") as f: for r in results: row = { "sample_id": r.sample_id, "caption": r.caption, "ground_truth_tags": sorted(r.ground_truth_tags), "rewrite_phrases": r.rewrite_phrases, "retrieved_tags": sorted(r.retrieved_tags), "selected_tags": sorted(r.selected_tags), "retrieval_recall": round(r.retrieval_recall, 4), "selection_precision": round(r.selection_precision, 4), "selection_recall": round(r.selection_recall, 4), "selection_f1": round(r.selection_f1, 4), "stage1_time": round(r.stage1_time, 3), "stage2_time": round(r.stage2_time, 3), "stage3_time": round(r.stage3_time, 3), "error": r.error, } f.write(json.dumps(row, ensure_ascii=False) + "\n") print(f"\nDetailed results saved to: {out_path}") return 0 if __name__ == "__main__": sys.exit(main())