Claude commited on
Commit
133d74c
·
1 Parent(s): 6909d06

Improve eval harness: shuffle samples, always write results

Browse files

- Samples are now randomly shuffled with --seed (default 42) for
reproducible but varied evaluation across runs
- Results JSONL always saved to data/eval_results/ with auto-generated
timestamp filename (or custom path with -o)
- First line of output is run metadata (settings, timestamp, error count)
- Default caption field is caption_cogvlm (vision model, not tag-derived)
- Added --no-shuffle flag for sequential sample order
- Added data/eval_results/ to .gitignore

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (2) hide show
  1. .gitignore +1 -0
  2. scripts/eval_pipeline.py +81 -34
.gitignore CHANGED
@@ -10,3 +10,4 @@ tf_idf_files_420.joblib
10
  e621FastTextModel010Replacement_small.bin
11
  tfidf_hnsw_artists.bin
12
  tfidf_hnsw_tags.bin
 
 
10
  e621FastTextModel010Replacement_small.bin
11
  tfidf_hnsw_artists.bin
12
  tfidf_hnsw_tags.bin
13
+ data/eval_results/
scripts/eval_pipeline.py CHANGED
@@ -10,14 +10,20 @@ Metrics computed:
10
  selected tags match the ground truth
11
 
12
  Usage:
13
- # Full end-to-end (Stage 1 + 2 + 3):
14
  python scripts/eval_pipeline.py --n 20
15
 
16
- # Skip Stage 1 LLM rewrite, use ground-truth tags as retrieval input:
 
 
 
17
  python scripts/eval_pipeline.py --n 20 --skip-rewrite
18
 
19
- # Use a specific caption field:
20
- python scripts/eval_pipeline.py --n 20 --caption-field caption_cogvlm
 
 
 
21
 
22
  Requires:
23
  - OPENROUTER_API_KEY env var (for Stage 1 rewrite and Stage 3 selection)
@@ -30,9 +36,11 @@ from __future__ import annotations
30
  import argparse
31
  import json
32
  import os
 
33
  import sys
34
  import time
35
  from dataclasses import dataclass, field
 
36
  from pathlib import Path
37
  from typing import Any, Dict, List, Optional, Set, Tuple
38
 
@@ -110,6 +118,8 @@ def run_eval(
110
  temperature: float = 0.0,
111
  max_tokens: int = 512,
112
  verbose: bool = False,
 
 
113
  ) -> List[SampleResult]:
114
 
115
  from psq_rag.llm.rewrite import llm_rewrite_prompt
@@ -125,11 +135,9 @@ def run_eval(
125
  print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
126
  sys.exit(1)
127
 
128
- samples = []
129
  with EVAL_DATA_PATH.open("r", encoding="utf-8") as f:
130
  for line in f:
131
- if len(samples) >= n_samples:
132
- break
133
  row = json.loads(line)
134
  caption = row.get(caption_field, "")
135
  if not caption or not caption.strip():
@@ -137,14 +145,20 @@ def run_eval(
137
  gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
138
  if not gt_tags:
139
  continue
140
- samples.append({
141
- "id": row.get("id", row.get("row_id", len(samples))),
142
  "caption": caption.strip(),
143
  "gt_tags": gt_tags,
144
  })
145
 
146
- print(f"Loaded {len(samples)} samples (caption_field={caption_field})")
147
- print(f"skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
 
 
 
 
 
 
148
  print()
149
 
150
  results: List[SampleResult] = []
@@ -340,7 +354,13 @@ def main(argv=None) -> int:
340
  ap.add_argument("--max-tokens", type=int, default=512)
341
  ap.add_argument("--verbose", "-v", action="store_true", help="Show per-call Stage 3 logs")
342
  ap.add_argument("--output", "-o", type=str, default=None,
343
- help="Save detailed results as JSONL to this path")
 
 
 
 
 
 
344
 
345
  args = ap.parse_args(list(argv) if argv is not None else None)
346
 
@@ -355,34 +375,61 @@ def main(argv=None) -> int:
355
  temperature=args.temperature,
356
  max_tokens=args.max_tokens,
357
  verbose=args.verbose,
 
 
358
  )
359
 
360
  print_summary(results)
361
 
362
- # Optionally save detailed results
363
  if args.output:
364
  out_path = Path(args.output)
365
- out_path.parent.mkdir(parents=True, exist_ok=True)
366
- with out_path.open("w", encoding="utf-8") as f:
367
- for r in results:
368
- row = {
369
- "sample_id": r.sample_id,
370
- "caption": r.caption,
371
- "ground_truth_tags": sorted(r.ground_truth_tags),
372
- "rewrite_phrases": r.rewrite_phrases,
373
- "retrieved_tags": sorted(r.retrieved_tags),
374
- "selected_tags": sorted(r.selected_tags),
375
- "retrieval_recall": round(r.retrieval_recall, 4),
376
- "selection_precision": round(r.selection_precision, 4),
377
- "selection_recall": round(r.selection_recall, 4),
378
- "selection_f1": round(r.selection_f1, 4),
379
- "stage1_time": round(r.stage1_time, 3),
380
- "stage2_time": round(r.stage2_time, 3),
381
- "stage3_time": round(r.stage3_time, 3),
382
- "error": r.error,
383
- }
384
- f.write(json.dumps(row, ensure_ascii=False) + "\n")
385
- print(f"\nDetailed results saved to: {out_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  return 0
388
 
 
10
  selected tags match the ground truth
11
 
12
  Usage:
13
+ # Full end-to-end (Stage 1 + 2 + 3), 20 random samples:
14
  python scripts/eval_pipeline.py --n 20
15
 
16
+ # Reproducible run with specific seed:
17
+ python scripts/eval_pipeline.py --n 50 --seed 123
18
+
19
+ # Skip Stage 1 LLM rewrite (cheaper, tests Stage 2+3 only):
20
  python scripts/eval_pipeline.py --n 20 --skip-rewrite
21
 
22
+ # First N samples in file order (no shuffle):
23
+ python scripts/eval_pipeline.py --n 20 --no-shuffle
24
+
25
+ Results are always saved as JSONL to data/eval_results/ (auto-named by timestamp)
26
+ or to a custom path with -o.
27
 
28
  Requires:
29
  - OPENROUTER_API_KEY env var (for Stage 1 rewrite and Stage 3 selection)
 
36
  import argparse
37
  import json
38
  import os
39
+ import random
40
  import sys
41
  import time
42
  from dataclasses import dataclass, field
43
+ from datetime import datetime
44
  from pathlib import Path
45
  from typing import Any, Dict, List, Optional, Set, Tuple
46
 
 
118
  temperature: float = 0.0,
119
  max_tokens: int = 512,
120
  verbose: bool = False,
121
+ shuffle: bool = True,
122
+ seed: int = 42,
123
  ) -> List[SampleResult]:
124
 
125
  from psq_rag.llm.rewrite import llm_rewrite_prompt
 
135
  print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
136
  sys.exit(1)
137
 
138
+ all_samples = []
139
  with EVAL_DATA_PATH.open("r", encoding="utf-8") as f:
140
  for line in f:
 
 
141
  row = json.loads(line)
142
  caption = row.get(caption_field, "")
143
  if not caption or not caption.strip():
 
145
  gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
146
  if not gt_tags:
147
  continue
148
+ all_samples.append({
149
+ "id": row.get("id", row.get("row_id", len(all_samples))),
150
  "caption": caption.strip(),
151
  "gt_tags": gt_tags,
152
  })
153
 
154
+ if shuffle:
155
+ rng = random.Random(seed)
156
+ rng.shuffle(all_samples)
157
+
158
+ samples = all_samples[:n_samples]
159
+
160
+ print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
161
+ print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
162
  print()
163
 
164
  results: List[SampleResult] = []
 
354
  ap.add_argument("--max-tokens", type=int, default=512)
355
  ap.add_argument("--verbose", "-v", action="store_true", help="Show per-call Stage 3 logs")
356
  ap.add_argument("--output", "-o", type=str, default=None,
357
+ help="Save detailed results as JSONL (default: auto-generated in data/eval_results/)")
358
+ ap.add_argument("--shuffle", action="store_true", default=True,
359
+ help="Randomly shuffle samples before selecting (default: True)")
360
+ ap.add_argument("--no-shuffle", dest="shuffle", action="store_false",
361
+ help="Use samples in file order (first N)")
362
+ ap.add_argument("--seed", type=int, default=42,
363
+ help="Random seed for shuffle (default: 42)")
364
 
365
  args = ap.parse_args(list(argv) if argv is not None else None)
366
 
 
375
  temperature=args.temperature,
376
  max_tokens=args.max_tokens,
377
  verbose=args.verbose,
378
+ shuffle=args.shuffle,
379
+ seed=args.seed,
380
  )
381
 
382
  print_summary(results)
383
 
384
+ # Always save detailed results
385
  if args.output:
386
  out_path = Path(args.output)
387
+ else:
388
+ results_dir = _REPO_ROOT / "data" / "eval_results"
389
+ results_dir.mkdir(parents=True, exist_ok=True)
390
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
391
+ out_path = results_dir / f"eval_{args.caption_field}_n{args.n}_seed{args.seed}_{timestamp}.jsonl"
392
+
393
+ out_path.parent.mkdir(parents=True, exist_ok=True)
394
+
395
+ # Write run metadata as first line
396
+ meta = {
397
+ "_meta": True,
398
+ "timestamp": datetime.now().isoformat(),
399
+ "n_samples": len(results),
400
+ "caption_field": args.caption_field,
401
+ "skip_rewrite": args.skip_rewrite,
402
+ "allow_nsfw": args.allow_nsfw,
403
+ "mode": args.mode,
404
+ "chunk_size": args.chunk_size,
405
+ "per_phrase_k": args.per_phrase_k,
406
+ "temperature": args.temperature,
407
+ "shuffle": args.shuffle,
408
+ "seed": args.seed,
409
+ "n_errors": sum(1 for r in results if r.error),
410
+ }
411
+
412
+ with out_path.open("w", encoding="utf-8") as f:
413
+ f.write(json.dumps(meta, ensure_ascii=False) + "\n")
414
+ for r in results:
415
+ row = {
416
+ "sample_id": r.sample_id,
417
+ "caption": r.caption,
418
+ "ground_truth_tags": sorted(r.ground_truth_tags),
419
+ "rewrite_phrases": r.rewrite_phrases,
420
+ "retrieved_tags": sorted(r.retrieved_tags),
421
+ "selected_tags": sorted(r.selected_tags),
422
+ "retrieval_recall": round(r.retrieval_recall, 4),
423
+ "selection_precision": round(r.selection_precision, 4),
424
+ "selection_recall": round(r.selection_recall, 4),
425
+ "selection_f1": round(r.selection_f1, 4),
426
+ "stage1_time": round(r.stage1_time, 3),
427
+ "stage2_time": round(r.stage2_time, 3),
428
+ "stage3_time": round(r.stage3_time, 3),
429
+ "error": r.error,
430
+ }
431
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
432
+ print(f"\nDetailed results saved to: {out_path}")
433
 
434
  return 0
435