Spaces:
Running
Running
Claude commited on
Commit ·
349b999
1
Parent(s): df66964
Add diagnostic eval metrics, why-distribution tracking, and generic character filter
Browse files- select.py: return_metadata option exposes per-tag 'why' rationale from LLM
- select.py: Route generic character-category tags (fan_character, viewer, etc.)
to general pipeline instead of entity pipeline to reduce false positives
- eval_pipeline.py: New metrics — retrieval precision, selection-given-retrieval,
over-selection ratio, and why distribution breakdown in summary output
- eval_pipeline.py: All new metrics saved to JSONL output for analysis
https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG
- psq_rag/llm/select.py +29 -2
- scripts/eval_pipeline.py +53 -1
psq_rag/llm/select.py
CHANGED
|
@@ -23,6 +23,19 @@ from rapidfuzz import fuzz
|
|
| 23 |
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
|
| 24 |
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
|
| 28 |
|
|
@@ -334,11 +347,18 @@ def _split_candidates_by_type(
|
|
| 334 |
unknown_count = 0
|
| 335 |
copyright_count = 0
|
| 336 |
|
|
|
|
|
|
|
| 337 |
for idx, cand in enumerate(candidates):
|
| 338 |
type_name = get_tag_type_name(cand.tag)
|
| 339 |
|
| 340 |
if type_name == "character":
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
elif type_name == "copyright":
|
| 343 |
# Filter out copyright/series tags - too broad for image generation
|
| 344 |
copyright_count += 1
|
|
@@ -355,6 +375,7 @@ def _split_candidates_by_type(
|
|
| 355 |
f"general={len(general_with_idx)} "
|
| 356 |
f"entity={len(entity_with_idx)} "
|
| 357 |
f"copyright_filtered={copyright_count} "
|
|
|
|
| 358 |
f"unknown_type={unknown_count}"
|
| 359 |
)
|
| 360 |
|
|
@@ -462,7 +483,8 @@ def llm_select_indices(
|
|
| 462 |
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
|
| 463 |
temperature: float = 0.0,
|
| 464 |
max_tokens: int = 512,
|
| 465 |
-
|
|
|
|
| 466 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 467 |
|
| 468 |
This implementation uses LangChain ONLY.
|
|
@@ -704,8 +726,13 @@ def llm_select_indices(
|
|
| 704 |
|
| 705 |
# Map back to original indices
|
| 706 |
out_idx: List[int] = []
|
|
|
|
| 707 |
for t in ordered_tags:
|
| 708 |
if t in tag_to_first_index:
|
| 709 |
out_idx.append(tag_to_first_index[t])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
return out_idx
|
|
|
|
| 23 |
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
|
| 24 |
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
|
| 25 |
|
| 26 |
+
# Character-typed tags that are generic categories, not actual named characters.
|
| 27 |
+
# These leak through the alias filter because they match common words in captions.
|
| 28 |
+
# They are excluded from the entity pipeline and instead routed to general selection.
|
| 29 |
+
_GENERIC_CHARACTER_TAGS = frozenset({
|
| 30 |
+
"fan_character",
|
| 31 |
+
"background_character",
|
| 32 |
+
"unnamed_character",
|
| 33 |
+
"unknown_character",
|
| 34 |
+
"anonymous_character",
|
| 35 |
+
"viewer",
|
| 36 |
+
"original_character",
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
|
| 40 |
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
|
| 41 |
|
|
|
|
| 347 |
unknown_count = 0
|
| 348 |
copyright_count = 0
|
| 349 |
|
| 350 |
+
generic_char_count = 0
|
| 351 |
+
|
| 352 |
for idx, cand in enumerate(candidates):
|
| 353 |
type_name = get_tag_type_name(cand.tag)
|
| 354 |
|
| 355 |
if type_name == "character":
|
| 356 |
+
if cand.tag in _GENERIC_CHARACTER_TAGS:
|
| 357 |
+
# Route generic character-category tags to general selection
|
| 358 |
+
general_with_idx.append((idx, cand))
|
| 359 |
+
generic_char_count += 1
|
| 360 |
+
else:
|
| 361 |
+
entity_with_idx.append((idx, cand))
|
| 362 |
elif type_name == "copyright":
|
| 363 |
# Filter out copyright/series tags - too broad for image generation
|
| 364 |
copyright_count += 1
|
|
|
|
| 375 |
f"general={len(general_with_idx)} "
|
| 376 |
f"entity={len(entity_with_idx)} "
|
| 377 |
f"copyright_filtered={copyright_count} "
|
| 378 |
+
f"generic_char_to_general={generic_char_count} "
|
| 379 |
f"unknown_type={unknown_count}"
|
| 380 |
)
|
| 381 |
|
|
|
|
| 483 |
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
|
| 484 |
temperature: float = 0.0,
|
| 485 |
max_tokens: int = 512,
|
| 486 |
+
return_metadata: bool = False,
|
| 487 |
+
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
|
| 488 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 489 |
|
| 490 |
This implementation uses LangChain ONLY.
|
|
|
|
| 726 |
|
| 727 |
# Map back to original indices
|
| 728 |
out_idx: List[int] = []
|
| 729 |
+
tag_why: Dict[str, str] = {}
|
| 730 |
for t in ordered_tags:
|
| 731 |
if t in tag_to_first_index:
|
| 732 |
out_idx.append(tag_to_first_index[t])
|
| 733 |
+
tag_why[t] = best[t][1] # why string
|
| 734 |
+
|
| 735 |
+
if return_metadata:
|
| 736 |
+
return out_idx, tag_why
|
| 737 |
|
| 738 |
return out_idx
|
scripts/eval_pipeline.py
CHANGED
|
@@ -127,6 +127,12 @@ class SampleResult:
|
|
| 127 |
general_precision: float = 0.0
|
| 128 |
general_recall: float = 0.0
|
| 129 |
general_f1: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# Timing
|
| 131 |
stage1_time: float = 0.0
|
| 132 |
stage2_time: float = 0.0
|
|
@@ -233,7 +239,7 @@ def _process_one_sample(
|
|
| 233 |
|
| 234 |
# --- Stage 3: LLM Selection ---
|
| 235 |
t0 = time.time()
|
| 236 |
-
picked_indices = llm_select_indices(
|
| 237 |
query_text=caption,
|
| 238 |
candidates=candidates,
|
| 239 |
max_pick=0,
|
|
@@ -243,17 +249,34 @@ def _process_one_sample(
|
|
| 243 |
per_phrase_k=per_phrase_k,
|
| 244 |
temperature=temperature,
|
| 245 |
max_tokens=max_tokens,
|
|
|
|
| 246 |
)
|
| 247 |
result.stage3_time = time.time() - t0
|
| 248 |
|
| 249 |
result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# Overall selection metrics
|
| 252 |
p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
|
| 253 |
result.selection_precision = p
|
| 254 |
result.selection_recall = r
|
| 255 |
result.selection_f1 = f1
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
# Split ground-truth and selected tags by type
|
| 258 |
gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
|
| 259 |
sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
|
|
@@ -456,6 +479,11 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 456 |
print("Stage 2 - Retrieval:")
|
| 457 |
print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
|
| 458 |
print(f" Avg candidates: {avg_retrieved:.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
print()
|
| 460 |
print("Stage 3 - Selection (ALL tags):")
|
| 461 |
print(f" Avg precision: {avg_sel_precision:.4f}")
|
|
@@ -463,6 +491,25 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 463 |
print(f" Avg F1: {avg_sel_f1:.4f}")
|
| 464 |
print(f" Avg selected tags: {avg_selected:.1f}")
|
| 465 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
# --- Character tag breakdown ---
|
| 468 |
# Only include samples that actually have character tags in ground truth
|
|
@@ -678,6 +725,11 @@ def main(argv=None) -> int:
|
|
| 678 |
"general_precision": round(r.general_precision, 4),
|
| 679 |
"general_recall": round(r.general_recall, 4),
|
| 680 |
"general_f1": round(r.general_f1, 4),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
# Timing
|
| 682 |
"stage1_time": round(r.stage1_time, 3),
|
| 683 |
"stage2_time": round(r.stage2_time, 3),
|
|
|
|
| 127 |
general_precision: float = 0.0
|
| 128 |
general_recall: float = 0.0
|
| 129 |
general_f1: float = 0.0
|
| 130 |
+
# New diagnostic metrics
|
| 131 |
+
retrieval_precision: float = 0.0 # |retrieved ∩ gt| / |retrieved|
|
| 132 |
+
selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
|
| 133 |
+
over_selection_ratio: float = 0.0 # |selected| / |gt|
|
| 134 |
+
# Why distribution (from Stage 3 LLM)
|
| 135 |
+
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 136 |
# Timing
|
| 137 |
stage1_time: float = 0.0
|
| 138 |
stage2_time: float = 0.0
|
|
|
|
| 239 |
|
| 240 |
# --- Stage 3: LLM Selection ---
|
| 241 |
t0 = time.time()
|
| 242 |
+
picked_indices, tag_why = llm_select_indices(
|
| 243 |
query_text=caption,
|
| 244 |
candidates=candidates,
|
| 245 |
max_pick=0,
|
|
|
|
| 249 |
per_phrase_k=per_phrase_k,
|
| 250 |
temperature=temperature,
|
| 251 |
max_tokens=max_tokens,
|
| 252 |
+
return_metadata=True,
|
| 253 |
)
|
| 254 |
result.stage3_time = time.time() - t0
|
| 255 |
|
| 256 |
result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
|
| 257 |
|
| 258 |
+
# Why distribution
|
| 259 |
+
why_counts: Dict[str, int] = {}
|
| 260 |
+
for w in tag_why.values():
|
| 261 |
+
why_counts[w] = why_counts.get(w, 0) + 1
|
| 262 |
+
result.why_counts = why_counts
|
| 263 |
+
|
| 264 |
# Overall selection metrics
|
| 265 |
p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
|
| 266 |
result.selection_precision = p
|
| 267 |
result.selection_recall = r
|
| 268 |
result.selection_f1 = f1
|
| 269 |
|
| 270 |
+
# New diagnostic metrics
|
| 271 |
+
retrieved_and_gt = result.retrieved_tags & gt_tags
|
| 272 |
+
selected_and_gt = result.selected_tags & gt_tags
|
| 273 |
+
if result.retrieved_tags:
|
| 274 |
+
result.retrieval_precision = len(retrieved_and_gt) / len(result.retrieved_tags)
|
| 275 |
+
if retrieved_and_gt:
|
| 276 |
+
result.selection_given_retrieval = len(selected_and_gt) / len(retrieved_and_gt)
|
| 277 |
+
if gt_tags:
|
| 278 |
+
result.over_selection_ratio = len(result.selected_tags) / len(gt_tags)
|
| 279 |
+
|
| 280 |
# Split ground-truth and selected tags by type
|
| 281 |
gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
|
| 282 |
sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
|
|
|
|
| 479 |
print("Stage 2 - Retrieval:")
|
| 480 |
print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
|
| 481 |
print(f" Avg candidates: {avg_retrieved:.1f}")
|
| 482 |
+
avg_retrieval_precision = _safe_avg([r.retrieval_precision for r in valid])
|
| 483 |
+
avg_sel_given_ret = _safe_avg([r.selection_given_retrieval for r in valid
|
| 484 |
+
if (r.retrieved_tags & r.ground_truth_tags)])
|
| 485 |
+
avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
|
| 486 |
+
|
| 487 |
print()
|
| 488 |
print("Stage 3 - Selection (ALL tags):")
|
| 489 |
print(f" Avg precision: {avg_sel_precision:.4f}")
|
|
|
|
| 491 |
print(f" Avg F1: {avg_sel_f1:.4f}")
|
| 492 |
print(f" Avg selected tags: {avg_selected:.1f}")
|
| 493 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
| 494 |
+
print()
|
| 495 |
+
print("Diagnostic Metrics:")
|
| 496 |
+
print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
|
| 497 |
+
print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
|
| 498 |
+
print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
|
| 499 |
+
|
| 500 |
+
# Why distribution across all samples
|
| 501 |
+
total_why: Dict[str, int] = {}
|
| 502 |
+
for r in valid:
|
| 503 |
+
for w, cnt in r.why_counts.items():
|
| 504 |
+
total_why[w] = total_why.get(w, 0) + cnt
|
| 505 |
+
if total_why:
|
| 506 |
+
total_selections = sum(total_why.values())
|
| 507 |
+
print()
|
| 508 |
+
print("Why Distribution (Stage 3 LLM rationale):")
|
| 509 |
+
for w in ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]:
|
| 510 |
+
cnt = total_why.get(w, 0)
|
| 511 |
+
pct = 100 * cnt / total_selections if total_selections else 0
|
| 512 |
+
print(f" {w:20s} {cnt:4d} ({pct:5.1f}%)")
|
| 513 |
|
| 514 |
# --- Character tag breakdown ---
|
| 515 |
# Only include samples that actually have character tags in ground truth
|
|
|
|
| 725 |
"general_precision": round(r.general_precision, 4),
|
| 726 |
"general_recall": round(r.general_recall, 4),
|
| 727 |
"general_f1": round(r.general_f1, 4),
|
| 728 |
+
# Diagnostic metrics
|
| 729 |
+
"retrieval_precision": round(r.retrieval_precision, 4),
|
| 730 |
+
"selection_given_retrieval": round(r.selection_given_retrieval, 4),
|
| 731 |
+
"over_selection_ratio": round(r.over_selection_ratio, 2),
|
| 732 |
+
"why_counts": r.why_counts,
|
| 733 |
# Timing
|
| 734 |
"stage1_time": round(r.stage1_time, 3),
|
| 735 |
"stage2_time": round(r.stage2_time, 3),
|