Claude commited on
Commit
349b999
·
1 Parent(s): df66964

Add diagnostic eval metrics, why-distribution tracking, and generic character filter

Browse files

- select.py: return_metadata option exposes per-tag 'why' rationale from LLM
- select.py: Route generic character-category tags (fan_character, viewer, etc.)
to general pipeline instead of entity pipeline to reduce false positives
- eval_pipeline.py: New metrics — retrieval precision, selection-given-retrieval,
over-selection ratio, and why distribution breakdown in summary output
- eval_pipeline.py: All new metrics saved to JSONL output for analysis

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (2) hide show
  1. psq_rag/llm/select.py +29 -2
  2. scripts/eval_pipeline.py +53 -1
psq_rag/llm/select.py CHANGED
@@ -23,6 +23,19 @@ from rapidfuzz import fuzz
23
  from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
24
  from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
28
 
@@ -334,11 +347,18 @@ def _split_candidates_by_type(
334
  unknown_count = 0
335
  copyright_count = 0
336
 
 
 
337
  for idx, cand in enumerate(candidates):
338
  type_name = get_tag_type_name(cand.tag)
339
 
340
  if type_name == "character":
341
- entity_with_idx.append((idx, cand))
 
 
 
 
 
342
  elif type_name == "copyright":
343
  # Filter out copyright/series tags - too broad for image generation
344
  copyright_count += 1
@@ -355,6 +375,7 @@ def _split_candidates_by_type(
355
  f"general={len(general_with_idx)} "
356
  f"entity={len(entity_with_idx)} "
357
  f"copyright_filtered={copyright_count} "
 
358
  f"unknown_type={unknown_count}"
359
  )
360
 
@@ -462,7 +483,8 @@ def llm_select_indices(
462
  per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
463
  temperature: float = 0.0,
464
  max_tokens: int = 512,
465
- ) -> List[int]:
 
466
  """Return indices into the ORIGINAL candidates list (legacy interface).
467
 
468
  This implementation uses LangChain ONLY.
@@ -704,8 +726,13 @@ def llm_select_indices(
704
 
705
  # Map back to original indices
706
  out_idx: List[int] = []
 
707
  for t in ordered_tags:
708
  if t in tag_to_first_index:
709
  out_idx.append(tag_to_first_index[t])
 
 
 
 
710
 
711
  return out_idx
 
23
  from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
24
  from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
25
 
26
+ # Character-typed tags that are generic categories, not actual named characters.
27
+ # These leak through the alias filter because they match common words in captions.
28
+ # They are excluded from the entity pipeline and instead routed to general selection.
29
+ _GENERIC_CHARACTER_TAGS = frozenset({
30
+ "fan_character",
31
+ "background_character",
32
+ "unnamed_character",
33
+ "unknown_character",
34
+ "anonymous_character",
35
+ "viewer",
36
+ "original_character",
37
+ })
38
+
39
 
40
  WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
41
 
 
347
  unknown_count = 0
348
  copyright_count = 0
349
 
350
+ generic_char_count = 0
351
+
352
  for idx, cand in enumerate(candidates):
353
  type_name = get_tag_type_name(cand.tag)
354
 
355
  if type_name == "character":
356
+ if cand.tag in _GENERIC_CHARACTER_TAGS:
357
+ # Route generic character-category tags to general selection
358
+ general_with_idx.append((idx, cand))
359
+ generic_char_count += 1
360
+ else:
361
+ entity_with_idx.append((idx, cand))
362
  elif type_name == "copyright":
363
  # Filter out copyright/series tags - too broad for image generation
364
  copyright_count += 1
 
375
  f"general={len(general_with_idx)} "
376
  f"entity={len(entity_with_idx)} "
377
  f"copyright_filtered={copyright_count} "
378
+ f"generic_char_to_general={generic_char_count} "
379
  f"unknown_type={unknown_count}"
380
  )
381
 
 
483
  per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
484
  temperature: float = 0.0,
485
  max_tokens: int = 512,
486
+ return_metadata: bool = False,
487
+ ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
488
  """Return indices into the ORIGINAL candidates list (legacy interface).
489
 
490
  This implementation uses LangChain ONLY.
 
726
 
727
  # Map back to original indices
728
  out_idx: List[int] = []
729
+ tag_why: Dict[str, str] = {}
730
  for t in ordered_tags:
731
  if t in tag_to_first_index:
732
  out_idx.append(tag_to_first_index[t])
733
+ tag_why[t] = best[t][1] # why string
734
+
735
+ if return_metadata:
736
+ return out_idx, tag_why
737
 
738
  return out_idx
scripts/eval_pipeline.py CHANGED
@@ -127,6 +127,12 @@ class SampleResult:
127
  general_precision: float = 0.0
128
  general_recall: float = 0.0
129
  general_f1: float = 0.0
 
 
 
 
 
 
130
  # Timing
131
  stage1_time: float = 0.0
132
  stage2_time: float = 0.0
@@ -233,7 +239,7 @@ def _process_one_sample(
233
 
234
  # --- Stage 3: LLM Selection ---
235
  t0 = time.time()
236
- picked_indices = llm_select_indices(
237
  query_text=caption,
238
  candidates=candidates,
239
  max_pick=0,
@@ -243,17 +249,34 @@ def _process_one_sample(
243
  per_phrase_k=per_phrase_k,
244
  temperature=temperature,
245
  max_tokens=max_tokens,
 
246
  )
247
  result.stage3_time = time.time() - t0
248
 
249
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
250
 
 
 
 
 
 
 
251
  # Overall selection metrics
252
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
253
  result.selection_precision = p
254
  result.selection_recall = r
255
  result.selection_f1 = f1
256
 
 
 
 
 
 
 
 
 
 
 
257
  # Split ground-truth and selected tags by type
258
  gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
259
  sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
@@ -456,6 +479,11 @@ def print_summary(results: List[SampleResult]) -> None:
456
  print("Stage 2 - Retrieval:")
457
  print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
458
  print(f" Avg candidates: {avg_retrieved:.1f}")
 
 
 
 
 
459
  print()
460
  print("Stage 3 - Selection (ALL tags):")
461
  print(f" Avg precision: {avg_sel_precision:.4f}")
@@ -463,6 +491,25 @@ def print_summary(results: List[SampleResult]) -> None:
463
  print(f" Avg F1: {avg_sel_f1:.4f}")
464
  print(f" Avg selected tags: {avg_selected:.1f}")
465
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # --- Character tag breakdown ---
468
  # Only include samples that actually have character tags in ground truth
@@ -678,6 +725,11 @@ def main(argv=None) -> int:
678
  "general_precision": round(r.general_precision, 4),
679
  "general_recall": round(r.general_recall, 4),
680
  "general_f1": round(r.general_f1, 4),
 
 
 
 
 
681
  # Timing
682
  "stage1_time": round(r.stage1_time, 3),
683
  "stage2_time": round(r.stage2_time, 3),
 
127
  general_precision: float = 0.0
128
  general_recall: float = 0.0
129
  general_f1: float = 0.0
130
+ # New diagnostic metrics
131
+ retrieval_precision: float = 0.0 # |retrieved ∩ gt| / |retrieved|
132
+ selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
133
+ over_selection_ratio: float = 0.0 # |selected| / |gt|
134
+ # Why distribution (from Stage 3 LLM)
135
+ why_counts: Dict[str, int] = field(default_factory=dict)
136
  # Timing
137
  stage1_time: float = 0.0
138
  stage2_time: float = 0.0
 
239
 
240
  # --- Stage 3: LLM Selection ---
241
  t0 = time.time()
242
+ picked_indices, tag_why = llm_select_indices(
243
  query_text=caption,
244
  candidates=candidates,
245
  max_pick=0,
 
249
  per_phrase_k=per_phrase_k,
250
  temperature=temperature,
251
  max_tokens=max_tokens,
252
+ return_metadata=True,
253
  )
254
  result.stage3_time = time.time() - t0
255
 
256
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
257
 
258
+ # Why distribution
259
+ why_counts: Dict[str, int] = {}
260
+ for w in tag_why.values():
261
+ why_counts[w] = why_counts.get(w, 0) + 1
262
+ result.why_counts = why_counts
263
+
264
  # Overall selection metrics
265
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
266
  result.selection_precision = p
267
  result.selection_recall = r
268
  result.selection_f1 = f1
269
 
270
+ # New diagnostic metrics
271
+ retrieved_and_gt = result.retrieved_tags & gt_tags
272
+ selected_and_gt = result.selected_tags & gt_tags
273
+ if result.retrieved_tags:
274
+ result.retrieval_precision = len(retrieved_and_gt) / len(result.retrieved_tags)
275
+ if retrieved_and_gt:
276
+ result.selection_given_retrieval = len(selected_and_gt) / len(retrieved_and_gt)
277
+ if gt_tags:
278
+ result.over_selection_ratio = len(result.selected_tags) / len(gt_tags)
279
+
280
  # Split ground-truth and selected tags by type
281
  gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
282
  sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
 
479
  print("Stage 2 - Retrieval:")
480
  print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
481
  print(f" Avg candidates: {avg_retrieved:.1f}")
482
+ avg_retrieval_precision = _safe_avg([r.retrieval_precision for r in valid])
483
+ avg_sel_given_ret = _safe_avg([r.selection_given_retrieval for r in valid
484
+ if (r.retrieved_tags & r.ground_truth_tags)])
485
+ avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
486
+
487
  print()
488
  print("Stage 3 - Selection (ALL tags):")
489
  print(f" Avg precision: {avg_sel_precision:.4f}")
 
491
  print(f" Avg F1: {avg_sel_f1:.4f}")
492
  print(f" Avg selected tags: {avg_selected:.1f}")
493
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
494
+ print()
495
+ print("Diagnostic Metrics:")
496
+ print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
497
+ print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
498
+ print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
499
+
500
+ # Why distribution across all samples
501
+ total_why: Dict[str, int] = {}
502
+ for r in valid:
503
+ for w, cnt in r.why_counts.items():
504
+ total_why[w] = total_why.get(w, 0) + cnt
505
+ if total_why:
506
+ total_selections = sum(total_why.values())
507
+ print()
508
+ print("Why Distribution (Stage 3 LLM rationale):")
509
+ for w in ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]:
510
+ cnt = total_why.get(w, 0)
511
+ pct = 100 * cnt / total_selections if total_selections else 0
512
+ print(f" {w:20s} {cnt:4d} ({pct:5.1f}%)")
513
 
514
  # --- Character tag breakdown ---
515
  # Only include samples that actually have character tags in ground truth
 
725
  "general_precision": round(r.general_precision, 4),
726
  "general_recall": round(r.general_recall, 4),
727
  "general_f1": round(r.general_f1, 4),
728
+ # Diagnostic metrics
729
+ "retrieval_precision": round(r.retrieval_precision, 4),
730
+ "selection_given_retrieval": round(r.selection_given_retrieval, 4),
731
+ "over_selection_ratio": round(r.over_selection_ratio, 2),
732
+ "why_counts": r.why_counts,
733
  # Timing
734
  "stage1_time": round(r.stage1_time, 3),
735
  "stage2_time": round(r.stage2_time, 3),