Spaces:
Running
Running
Food Desert commited on
Commit ·
3c18372
1
Parent(s): 73f56cf
Simplify Stage3 chunking to interleave-only and add eval diagnostics
Browse files- psq_rag/llm/select.py +88 -14
- scripts/eval_pipeline.py +129 -71
psq_rag/llm/select.py
CHANGED
|
@@ -253,6 +253,13 @@ def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
|
|
| 253 |
return out
|
| 254 |
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def _display_tag(tag: str) -> str:
|
| 257 |
# Display tags with spaces for the LLM, but keep canonical underscores internally.
|
| 258 |
return tag.replace("_", " ")
|
|
@@ -494,8 +501,13 @@ def llm_select_indices(
|
|
| 494 |
temperature: float = 0.0,
|
| 495 |
max_tokens: int = 512,
|
| 496 |
return_metadata: bool = False,
|
|
|
|
| 497 |
min_why: Optional[str] = "strong_implied",
|
| 498 |
-
) -> Union[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 500 |
|
| 501 |
min_why: if set, only keep tags whose 'why' is at or above this confidence
|
|
@@ -586,6 +598,42 @@ def llm_select_indices(
|
|
| 586 |
|
| 587 |
# Global union: tag -> best (score, why)
|
| 588 |
best: Dict[str, Tuple[float, str]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
|
| 590 |
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
|
| 591 |
# Create chain with the provided system template
|
|
@@ -598,9 +646,10 @@ def llm_select_indices(
|
|
| 598 |
)
|
| 599 |
chain = prompt | llm | parser
|
| 600 |
|
| 601 |
-
ordered = _interleave_round_robin(call_cands)
|
| 602 |
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
|
| 603 |
N_local = len(idx_to_tag)
|
|
|
|
| 604 |
|
| 605 |
phrases = _phrases_in_call(call_cands)
|
| 606 |
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
|
|
@@ -618,6 +667,7 @@ def llm_select_indices(
|
|
| 618 |
# Invoke LangChain chain (templating fills {N} and other vars)
|
| 619 |
for att in range(retries + 1):
|
| 620 |
try:
|
|
|
|
| 621 |
if log:
|
| 622 |
log(
|
| 623 |
f"Stage3 {label}: "
|
|
@@ -637,6 +687,16 @@ def llm_select_indices(
|
|
| 637 |
}
|
| 638 |
)
|
| 639 |
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
if log:
|
| 641 |
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
|
| 642 |
if not summary_logged and (selected or att == retries):
|
|
@@ -662,6 +722,7 @@ def llm_select_indices(
|
|
| 662 |
log(f"Stage3 {label} selections: (none)")
|
| 663 |
|
| 664 |
if selected:
|
|
|
|
| 665 |
for s in selected:
|
| 666 |
prev = best.get(s.tag)
|
| 667 |
if prev is None or s.score > prev[0]:
|
|
@@ -669,11 +730,14 @@ def llm_select_indices(
|
|
| 669 |
return
|
| 670 |
|
| 671 |
except Exception as e:
|
|
|
|
|
|
|
| 672 |
if log:
|
| 673 |
log(f"Stage3 {label}: attempt {att+1} error: {e}")
|
| 674 |
|
| 675 |
if log:
|
| 676 |
log(f"Stage3 {label}: gave up after {retries+1} attempts")
|
|
|
|
| 677 |
|
| 678 |
# Split candidates by type (general vs entity)
|
| 679 |
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
|
|
@@ -687,12 +751,9 @@ def llm_select_indices(
|
|
| 687 |
if mode == "single_shot":
|
| 688 |
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
|
| 689 |
else:
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
f"general_chunk_{start//chunk_size}",
|
| 694 |
-
SELECT_SYSTEM_TEMPLATE
|
| 695 |
-
)
|
| 696 |
|
| 697 |
# Process entity candidates (characters only) with alias-based pre-filtering
|
| 698 |
if entity_cands:
|
|
@@ -725,12 +786,9 @@ def llm_select_indices(
|
|
| 725 |
if mode == "single_shot":
|
| 726 |
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
|
| 727 |
else:
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
f"entity_chunk_{start//chunk_size}",
|
| 732 |
-
ENTITY_SYSTEM_TEMPLATE
|
| 733 |
-
)
|
| 734 |
|
| 735 |
# Apply why threshold: drop tags below the minimum confidence level.
|
| 736 |
if min_why is not None:
|
|
@@ -757,7 +815,23 @@ def llm_select_indices(
|
|
| 757 |
out_idx.append(tag_to_first_index[t])
|
| 758 |
tag_why[t] = best[t][1] # why string
|
| 759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
if return_metadata:
|
|
|
|
|
|
|
| 761 |
return out_idx, tag_why
|
| 762 |
|
| 763 |
return out_idx
|
|
|
|
| 253 |
return out
|
| 254 |
|
| 255 |
|
| 256 |
+
def _build_chunks(cands: Sequence[Candidate], chunk_size: int) -> List[List[Candidate]]:
|
| 257 |
+
if chunk_size <= 0:
|
| 258 |
+
raise ValueError(f"chunk_size must be > 0, got {chunk_size}")
|
| 259 |
+
ordered = _interleave_round_robin(cands)
|
| 260 |
+
return [ordered[i:i + chunk_size] for i in range(0, len(ordered), chunk_size)]
|
| 261 |
+
|
| 262 |
+
|
| 263 |
def _display_tag(tag: str) -> str:
|
| 264 |
# Display tags with spaces for the LLM, but keep canonical underscores internally.
|
| 265 |
return tag.replace("_", " ")
|
|
|
|
| 501 |
temperature: float = 0.0,
|
| 502 |
max_tokens: int = 512,
|
| 503 |
return_metadata: bool = False,
|
| 504 |
+
return_diagnostics: bool = False,
|
| 505 |
min_why: Optional[str] = "strong_implied",
|
| 506 |
+
) -> Union[
|
| 507 |
+
List[int],
|
| 508 |
+
Tuple[List[int], Dict[str, str]],
|
| 509 |
+
Tuple[List[int], Dict[str, str], Dict[str, Any]],
|
| 510 |
+
]:
|
| 511 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 512 |
|
| 513 |
min_why: if set, only keep tags whose 'why' is at or above this confidence
|
|
|
|
| 598 |
|
| 599 |
# Global union: tag -> best (score, why)
|
| 600 |
best: Dict[str, Tuple[float, str]] = {}
|
| 601 |
+
diagnostics: Dict[str, Any] = {
|
| 602 |
+
"mode": mode,
|
| 603 |
+
"chunk_strategy": "interleave",
|
| 604 |
+
"chunk_passes": 1,
|
| 605 |
+
"chunk_shuffle_within_call": False,
|
| 606 |
+
"calls_total": 0,
|
| 607 |
+
"calls_with_selection": 0,
|
| 608 |
+
"calls_exhausted_retries": 0,
|
| 609 |
+
"attempts_total": 0,
|
| 610 |
+
"attempt_errors": 0,
|
| 611 |
+
"attempt_parse_fail": 0,
|
| 612 |
+
"attempt_parse_ok": 0,
|
| 613 |
+
"invalid_items_total": 0,
|
| 614 |
+
"oob_indices_total": 0,
|
| 615 |
+
"dupe_indices_total": 0,
|
| 616 |
+
"kept_total": 0,
|
| 617 |
+
"attempts_by_n_local": {},
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
def _record_attempt_for_n(n_local: int, *, parse_ok: bool, error: bool) -> None:
|
| 621 |
+
by_n = diagnostics["attempts_by_n_local"]
|
| 622 |
+
key = str(n_local)
|
| 623 |
+
if key not in by_n:
|
| 624 |
+
by_n[key] = {
|
| 625 |
+
"attempts": 0,
|
| 626 |
+
"parse_ok": 0,
|
| 627 |
+
"parse_fail": 0,
|
| 628 |
+
"errors": 0,
|
| 629 |
+
}
|
| 630 |
+
by_n[key]["attempts"] += 1
|
| 631 |
+
if error:
|
| 632 |
+
by_n[key]["errors"] += 1
|
| 633 |
+
elif parse_ok:
|
| 634 |
+
by_n[key]["parse_ok"] += 1
|
| 635 |
+
else:
|
| 636 |
+
by_n[key]["parse_fail"] += 1
|
| 637 |
|
| 638 |
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
|
| 639 |
# Create chain with the provided system template
|
|
|
|
| 646 |
)
|
| 647 |
chain = prompt | llm | parser
|
| 648 |
|
| 649 |
+
ordered = _interleave_round_robin(call_cands) if mode == "single_shot" else list(call_cands)
|
| 650 |
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
|
| 651 |
N_local = len(idx_to_tag)
|
| 652 |
+
diagnostics["calls_total"] += 1
|
| 653 |
|
| 654 |
phrases = _phrases_in_call(call_cands)
|
| 655 |
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
|
|
|
|
| 667 |
# Invoke LangChain chain (templating fills {N} and other vars)
|
| 668 |
for att in range(retries + 1):
|
| 669 |
try:
|
| 670 |
+
diagnostics["attempts_total"] += 1
|
| 671 |
if log:
|
| 672 |
log(
|
| 673 |
f"Stage3 {label}: "
|
|
|
|
| 687 |
}
|
| 688 |
)
|
| 689 |
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
|
| 690 |
+
diagnostics["invalid_items_total"] += int(diag.get("invalid_items", 0))
|
| 691 |
+
diagnostics["oob_indices_total"] += int(diag.get("oob_indices", 0))
|
| 692 |
+
diagnostics["dupe_indices_total"] += int(diag.get("dupe_indices", 0))
|
| 693 |
+
diagnostics["kept_total"] += int(diag.get("kept", 0))
|
| 694 |
+
if bool(diag.get("parse_ok", False)):
|
| 695 |
+
diagnostics["attempt_parse_ok"] += 1
|
| 696 |
+
_record_attempt_for_n(N_local, parse_ok=True, error=False)
|
| 697 |
+
else:
|
| 698 |
+
diagnostics["attempt_parse_fail"] += 1
|
| 699 |
+
_record_attempt_for_n(N_local, parse_ok=False, error=False)
|
| 700 |
if log:
|
| 701 |
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
|
| 702 |
if not summary_logged and (selected or att == retries):
|
|
|
|
| 722 |
log(f"Stage3 {label} selections: (none)")
|
| 723 |
|
| 724 |
if selected:
|
| 725 |
+
diagnostics["calls_with_selection"] += 1
|
| 726 |
for s in selected:
|
| 727 |
prev = best.get(s.tag)
|
| 728 |
if prev is None or s.score > prev[0]:
|
|
|
|
| 730 |
return
|
| 731 |
|
| 732 |
except Exception as e:
|
| 733 |
+
diagnostics["attempt_errors"] += 1
|
| 734 |
+
_record_attempt_for_n(N_local, parse_ok=False, error=True)
|
| 735 |
if log:
|
| 736 |
log(f"Stage3 {label}: attempt {att+1} error: {e}")
|
| 737 |
|
| 738 |
if log:
|
| 739 |
log(f"Stage3 {label}: gave up after {retries+1} attempts")
|
| 740 |
+
diagnostics["calls_exhausted_retries"] += 1
|
| 741 |
|
| 742 |
# Split candidates by type (general vs entity)
|
| 743 |
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
|
|
|
|
| 751 |
if mode == "single_shot":
|
| 752 |
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
|
| 753 |
else:
|
| 754 |
+
base_chunks = _build_chunks(general_cands, chunk_size)
|
| 755 |
+
for chunk_idx, chunk in enumerate(base_chunks):
|
| 756 |
+
run_call(chunk, f"general_chunk_{chunk_idx}", SELECT_SYSTEM_TEMPLATE)
|
|
|
|
|
|
|
|
|
|
| 757 |
|
| 758 |
# Process entity candidates (characters only) with alias-based pre-filtering
|
| 759 |
if entity_cands:
|
|
|
|
| 786 |
if mode == "single_shot":
|
| 787 |
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
|
| 788 |
else:
|
| 789 |
+
base_chunks = _build_chunks(filtered_entity_cands, chunk_size)
|
| 790 |
+
for chunk_idx, chunk in enumerate(base_chunks):
|
| 791 |
+
run_call(chunk, f"entity_chunk_{chunk_idx}", ENTITY_SYSTEM_TEMPLATE)
|
|
|
|
|
|
|
|
|
|
| 792 |
|
| 793 |
# Apply why threshold: drop tags below the minimum confidence level.
|
| 794 |
if min_why is not None:
|
|
|
|
| 815 |
out_idx.append(tag_to_first_index[t])
|
| 816 |
tag_why[t] = best[t][1] # why string
|
| 817 |
|
| 818 |
+
if diagnostics["attempts_total"] > 0:
|
| 819 |
+
diagnostics["attempt_failure_rate"] = (
|
| 820 |
+
diagnostics["attempt_parse_fail"] + diagnostics["attempt_errors"]
|
| 821 |
+
) / diagnostics["attempts_total"]
|
| 822 |
+
else:
|
| 823 |
+
diagnostics["attempt_failure_rate"] = 0.0
|
| 824 |
+
|
| 825 |
+
if diagnostics["calls_total"] > 0:
|
| 826 |
+
diagnostics["call_exhaustion_rate"] = (
|
| 827 |
+
diagnostics["calls_exhausted_retries"] / diagnostics["calls_total"]
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
diagnostics["call_exhaustion_rate"] = 0.0
|
| 831 |
+
|
| 832 |
if return_metadata:
|
| 833 |
+
if return_diagnostics:
|
| 834 |
+
return out_idx, tag_why, diagnostics
|
| 835 |
return out_idx, tag_why
|
| 836 |
|
| 837 |
return out_idx
|
scripts/eval_pipeline.py
CHANGED
|
@@ -162,7 +162,8 @@ class SampleResult:
|
|
| 162 |
selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
|
| 163 |
over_selection_ratio: float = 0.0 # |selected| / |gt|
|
| 164 |
# Why distribution (from Stage 3 LLM)
|
| 165 |
-
why_counts: Dict[str, int] = field(default_factory=dict)
|
|
|
|
| 166 |
# Tag implications
|
| 167 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
| 168 |
# Structural inference tags (solo/duo/male/female/anthro/biped etc.)
|
|
@@ -216,12 +217,12 @@ def _process_one_sample(
|
|
| 216 |
per_phrase_final_k: int,
|
| 217 |
temperature: float,
|
| 218 |
max_tokens: int,
|
| 219 |
-
verbose: bool,
|
| 220 |
-
print_lock: threading.Lock,
|
| 221 |
-
min_why: Optional[str] = None,
|
| 222 |
-
expand_implications: bool = False,
|
| 223 |
-
infer_structural: bool = False,
|
| 224 |
-
) -> SampleResult:
|
| 225 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 226 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 227 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
|
@@ -300,20 +301,22 @@ def _process_one_sample(
|
|
| 300 |
|
| 301 |
# --- Stage 3: LLM Selection ---
|
| 302 |
t0 = time.time()
|
| 303 |
-
picked_indices, tag_why = llm_select_indices(
|
| 304 |
-
query_text=caption,
|
| 305 |
-
candidates=candidates,
|
| 306 |
-
max_pick=0,
|
| 307 |
-
log=log,
|
| 308 |
-
mode=mode,
|
| 309 |
-
chunk_size=chunk_size,
|
| 310 |
-
per_phrase_k=per_phrase_k,
|
| 311 |
-
temperature=temperature,
|
| 312 |
-
max_tokens=max_tokens,
|
| 313 |
-
return_metadata=True,
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
| 317 |
|
| 318 |
result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
|
| 319 |
result.stage3_selected_tags = set(result.selected_tags)
|
|
@@ -497,10 +500,11 @@ def run_eval(
|
|
| 497 |
max_tokens: int = 512,
|
| 498 |
verbose: bool = False,
|
| 499 |
shuffle: bool = True,
|
| 500 |
-
seed: int = 42,
|
| 501 |
-
workers: int = 1,
|
| 502 |
-
min_why: Optional[str] = "strong_implied",
|
| 503 |
-
|
|
|
|
| 504 |
infer_structural: bool = False,
|
| 505 |
) -> List[SampleResult]:
|
| 506 |
expand_gt = expand_implications
|
|
@@ -508,20 +512,26 @@ def run_eval(
|
|
| 508 |
from psq_rag.retrieval.state import expand_tags_via_implications as _expand_gt_tags
|
| 509 |
|
| 510 |
# Load eval samples — prefer expanded file, fall back to raw
|
| 511 |
-
|
| 512 |
-
if not
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
caption = row.get(caption_field, "")
|
| 526 |
if not caption or not caption.strip():
|
| 527 |
continue
|
|
@@ -543,19 +553,20 @@ def run_eval(
|
|
| 543 |
"caption": caption.strip(),
|
| 544 |
"gt_tags": gt_tags,
|
| 545 |
})
|
| 546 |
-
if using_expanded:
|
| 547 |
-
print("Using implication-expanded ground truth")
|
| 548 |
-
|
| 549 |
if shuffle:
|
| 550 |
rng = random.Random(seed)
|
| 551 |
rng.shuffle(all_samples)
|
| 552 |
|
| 553 |
samples = all_samples[:n_samples]
|
| 554 |
|
| 555 |
-
print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
|
| 556 |
-
print(f"
|
| 557 |
-
print(f"
|
| 558 |
-
print()
|
|
|
|
| 559 |
|
| 560 |
# Pre-warm shared retrieval assets before spawning threads
|
| 561 |
_prewarm_retrieval_assets()
|
|
@@ -572,10 +583,11 @@ def run_eval(
|
|
| 572 |
sample, i, total,
|
| 573 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 574 |
per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
|
| 575 |
-
print_lock, min_why,
|
|
|
|
| 576 |
infer_structural,
|
| 577 |
)
|
| 578 |
-
results.append(result)
|
| 579 |
else:
|
| 580 |
# Parallel mode
|
| 581 |
print(f"Processing {total} samples with {workers} parallel workers...")
|
|
@@ -584,12 +596,13 @@ def run_eval(
|
|
| 584 |
results_by_index: Dict[int, SampleResult] = {}
|
| 585 |
with ThreadPoolExecutor(max_workers=workers) as executor:
|
| 586 |
futures = {
|
| 587 |
-
executor.submit(
|
| 588 |
_process_one_sample,
|
| 589 |
sample, i, total,
|
| 590 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 591 |
per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
|
| 592 |
-
print_lock, min_why,
|
|
|
|
| 593 |
infer_structural,
|
| 594 |
): i
|
| 595 |
for i, sample in enumerate(samples)
|
|
@@ -688,13 +701,52 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 688 |
print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
|
| 689 |
|
| 690 |
print()
|
| 691 |
-
print("Diagnostic Metrics:")
|
| 692 |
-
print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
|
| 693 |
-
print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
|
| 694 |
-
print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
for r in valid:
|
| 699 |
for w, cnt in r.why_counts.items():
|
| 700 |
total_why[w] = total_why.get(w, 0) + cnt
|
|
@@ -830,8 +882,8 @@ def main(argv=None) -> int:
|
|
| 830 |
ap.add_argument("--skip-rewrite", action="store_true",
|
| 831 |
help="Skip Stage 1 LLM rewrite; split caption directly into phrases")
|
| 832 |
ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags")
|
| 833 |
-
ap.add_argument("--mode", default="chunked_map_union",
|
| 834 |
-
choices=["single_shot", "chunked_map_union"])
|
| 835 |
ap.add_argument("--chunk-size", type=int, default=60)
|
| 836 |
ap.add_argument("--per-phrase-k", type=int, default=2)
|
| 837 |
ap.add_argument("--per-phrase-final-k", type=int, default=10,
|
|
@@ -847,8 +899,10 @@ def main(argv=None) -> int:
|
|
| 847 |
help="Use samples in file order (first N)")
|
| 848 |
ap.add_argument("--seed", type=int, default=42,
|
| 849 |
help="Random seed for shuffle (default: 42)")
|
| 850 |
-
ap.add_argument("--workers", "-w", type=int, default=4,
|
| 851 |
-
help="Number of parallel workers (default: 4, use 1 for sequential)")
|
|
|
|
|
|
|
| 852 |
ap.add_argument("--min-why", default="strong_implied",
|
| 853 |
choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other", "none"],
|
| 854 |
help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
|
|
@@ -875,12 +929,13 @@ def main(argv=None) -> int:
|
|
| 875 |
max_tokens=args.max_tokens,
|
| 876 |
verbose=args.verbose,
|
| 877 |
shuffle=args.shuffle,
|
| 878 |
-
seed=args.seed,
|
| 879 |
-
workers=args.workers,
|
| 880 |
-
min_why=min_why_val,
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
|
|
|
| 884 |
|
| 885 |
print_summary(results)
|
| 886 |
|
|
@@ -907,9 +962,10 @@ def main(argv=None) -> int:
|
|
| 907 |
"n_samples": len(results),
|
| 908 |
"caption_field": args.caption_field,
|
| 909 |
"skip_rewrite": args.skip_rewrite,
|
| 910 |
-
"allow_nsfw": args.allow_nsfw,
|
| 911 |
"mode": args.mode,
|
| 912 |
"chunk_size": args.chunk_size,
|
|
|
|
| 913 |
"per_phrase_k": args.per_phrase_k,
|
| 914 |
"per_phrase_final_k": args.per_phrase_final_k,
|
| 915 |
"temperature": args.temperature,
|
|
@@ -953,7 +1009,8 @@ def main(argv=None) -> int:
|
|
| 953 |
"ret_P": round(r.retrieval_precision, 4),
|
| 954 |
"sel_given_ret": round(r.selection_given_retrieval, 4),
|
| 955 |
"over_sel": round(r.over_selection_ratio, 2),
|
| 956 |
-
"why": r.why_counts,
|
|
|
|
| 957 |
# Character metrics (compact)
|
| 958 |
"n_gt_char": len(r.gt_character_tags),
|
| 959 |
"n_sel_char": len(r.selected_character_tags),
|
|
@@ -1005,8 +1062,9 @@ def main(argv=None) -> int:
|
|
| 1005 |
"implied_tags": sorted(r.implied_tags),
|
| 1006 |
"structural_tags": r.structural_tags,
|
| 1007 |
"categorized_suggestions": r.categorized_suggestions,
|
| 1008 |
-
"why_counts": r.why_counts,
|
| 1009 |
-
"
|
|
|
|
| 1010 |
"gt_character_tags": sorted(r.gt_character_tags),
|
| 1011 |
"selected_character_tags": sorted(r.selected_character_tags),
|
| 1012 |
"gt_general_tags": sorted(r.gt_general_tags),
|
|
|
|
| 162 |
selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
|
| 163 |
over_selection_ratio: float = 0.0 # |selected| / |gt|
|
| 164 |
# Why distribution (from Stage 3 LLM)
|
| 165 |
+
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 166 |
+
stage3_diag: Dict[str, Any] = field(default_factory=dict)
|
| 167 |
# Tag implications
|
| 168 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
| 169 |
# Structural inference tags (solo/duo/male/female/anthro/biped etc.)
|
|
|
|
| 217 |
per_phrase_final_k: int,
|
| 218 |
temperature: float,
|
| 219 |
max_tokens: int,
|
| 220 |
+
verbose: bool,
|
| 221 |
+
print_lock: threading.Lock,
|
| 222 |
+
min_why: Optional[str] = None,
|
| 223 |
+
expand_implications: bool = False,
|
| 224 |
+
infer_structural: bool = False,
|
| 225 |
+
) -> SampleResult:
|
| 226 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 227 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 228 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
|
|
|
| 301 |
|
| 302 |
# --- Stage 3: LLM Selection ---
|
| 303 |
t0 = time.time()
|
| 304 |
+
picked_indices, tag_why, stage3_diag = llm_select_indices(
|
| 305 |
+
query_text=caption,
|
| 306 |
+
candidates=candidates,
|
| 307 |
+
max_pick=0,
|
| 308 |
+
log=log,
|
| 309 |
+
mode=mode,
|
| 310 |
+
chunk_size=chunk_size,
|
| 311 |
+
per_phrase_k=per_phrase_k,
|
| 312 |
+
temperature=temperature,
|
| 313 |
+
max_tokens=max_tokens,
|
| 314 |
+
return_metadata=True,
|
| 315 |
+
return_diagnostics=True,
|
| 316 |
+
min_why=min_why,
|
| 317 |
+
)
|
| 318 |
+
result.stage3_time = time.time() - t0
|
| 319 |
+
result.stage3_diag = stage3_diag or {}
|
| 320 |
|
| 321 |
result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
|
| 322 |
result.stage3_selected_tags = set(result.selected_tags)
|
|
|
|
| 500 |
max_tokens: int = 512,
|
| 501 |
verbose: bool = False,
|
| 502 |
shuffle: bool = True,
|
| 503 |
+
seed: int = 42,
|
| 504 |
+
workers: int = 1,
|
| 505 |
+
min_why: Optional[str] = "strong_implied",
|
| 506 |
+
eval_path: Optional[str] = None,
|
| 507 |
+
expand_implications: bool = False,
|
| 508 |
infer_structural: bool = False,
|
| 509 |
) -> List[SampleResult]:
|
| 510 |
expand_gt = expand_implications
|
|
|
|
| 512 |
from psq_rag.retrieval.state import expand_tags_via_implications as _expand_gt_tags
|
| 513 |
|
| 514 |
# Load eval samples — prefer expanded file, fall back to raw
|
| 515 |
+
eval_path_obj = Path(eval_path) if eval_path else EVAL_DATA_PATH
|
| 516 |
+
if not eval_path_obj.is_absolute():
|
| 517 |
+
eval_path_obj = (_REPO_ROOT / eval_path_obj).resolve()
|
| 518 |
+
|
| 519 |
+
if not eval_path_obj.is_file() and eval_path is None:
|
| 520 |
+
eval_path_obj = EVAL_DATA_PATH_RAW
|
| 521 |
+
if not eval_path_obj.is_file():
|
| 522 |
+
print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
|
| 523 |
+
sys.exit(1)
|
| 524 |
+
print(f"WARNING: Expanded eval data not found, falling back to raw: {eval_path_obj}")
|
| 525 |
+
print(" Run: python scripts/preprocess_eval_data.py")
|
| 526 |
+
elif not eval_path_obj.is_file():
|
| 527 |
+
print(f"ERROR: Eval data not found: {eval_path_obj}")
|
| 528 |
+
sys.exit(1)
|
| 529 |
+
|
| 530 |
+
all_samples = []
|
| 531 |
+
using_expanded = False
|
| 532 |
+
with eval_path_obj.open("r", encoding="utf-8") as f:
|
| 533 |
+
for line in f:
|
| 534 |
+
row = json.loads(line)
|
| 535 |
caption = row.get(caption_field, "")
|
| 536 |
if not caption or not caption.strip():
|
| 537 |
continue
|
|
|
|
| 553 |
"caption": caption.strip(),
|
| 554 |
"gt_tags": gt_tags,
|
| 555 |
})
|
| 556 |
+
if using_expanded:
|
| 557 |
+
print("Using implication-expanded ground truth")
|
| 558 |
+
|
| 559 |
if shuffle:
|
| 560 |
rng = random.Random(seed)
|
| 561 |
rng.shuffle(all_samples)
|
| 562 |
|
| 563 |
samples = all_samples[:n_samples]
|
| 564 |
|
| 565 |
+
print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
|
| 566 |
+
print(f"eval_path={eval_path_obj}")
|
| 567 |
+
print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
|
| 568 |
+
print(f"workers={workers}")
|
| 569 |
+
print()
|
| 570 |
|
| 571 |
# Pre-warm shared retrieval assets before spawning threads
|
| 572 |
_prewarm_retrieval_assets()
|
|
|
|
| 583 |
sample, i, total,
|
| 584 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 585 |
per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
|
| 586 |
+
print_lock, min_why,
|
| 587 |
+
expand_implications,
|
| 588 |
infer_structural,
|
| 589 |
)
|
| 590 |
+
results.append(result)
|
| 591 |
else:
|
| 592 |
# Parallel mode
|
| 593 |
print(f"Processing {total} samples with {workers} parallel workers...")
|
|
|
|
| 596 |
results_by_index: Dict[int, SampleResult] = {}
|
| 597 |
with ThreadPoolExecutor(max_workers=workers) as executor:
|
| 598 |
futures = {
|
| 599 |
+
executor.submit(
|
| 600 |
_process_one_sample,
|
| 601 |
sample, i, total,
|
| 602 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 603 |
per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
|
| 604 |
+
print_lock, min_why,
|
| 605 |
+
expand_implications,
|
| 606 |
infer_structural,
|
| 607 |
): i
|
| 608 |
for i, sample in enumerate(samples)
|
|
|
|
| 701 |
print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
|
| 702 |
|
| 703 |
print()
|
| 704 |
+
print("Diagnostic Metrics:")
|
| 705 |
+
print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
|
| 706 |
+
print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
|
| 707 |
+
print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
|
| 708 |
+
|
| 709 |
+
stage3_diag_rows = [r.stage3_diag for r in valid if r.stage3_diag]
|
| 710 |
+
if stage3_diag_rows:
|
| 711 |
+
calls_total = sum(int(d.get("calls_total", 0)) for d in stage3_diag_rows)
|
| 712 |
+
calls_exhausted = sum(int(d.get("calls_exhausted_retries", 0)) for d in stage3_diag_rows)
|
| 713 |
+
attempts_total = sum(int(d.get("attempts_total", 0)) for d in stage3_diag_rows)
|
| 714 |
+
attempts_parse_fail = sum(int(d.get("attempt_parse_fail", 0)) for d in stage3_diag_rows)
|
| 715 |
+
attempts_errors = sum(int(d.get("attempt_errors", 0)) for d in stage3_diag_rows)
|
| 716 |
+
|
| 717 |
+
print()
|
| 718 |
+
print("Stage 3 Structured Output Reliability:")
|
| 719 |
+
print(f" Calls total: {calls_total}")
|
| 720 |
+
print(f" Calls exhausted: {calls_exhausted} ({(100 * calls_exhausted / calls_total) if calls_total else 0:.1f}%)")
|
| 721 |
+
print(f" Attempts total: {attempts_total}")
|
| 722 |
+
print(f" Parse/schema failures:{attempts_parse_fail} ({(100 * attempts_parse_fail / attempts_total) if attempts_total else 0:.1f}%)")
|
| 723 |
+
print(f" Call errors/exc: {attempts_errors} ({(100 * attempts_errors / attempts_total) if attempts_total else 0:.1f}%)")
|
| 724 |
+
|
| 725 |
+
by_n_agg: Dict[int, Dict[str, int]] = {}
|
| 726 |
+
for d in stage3_diag_rows:
|
| 727 |
+
for n_str, n_stats in d.get("attempts_by_n_local", {}).items():
|
| 728 |
+
try:
|
| 729 |
+
n_local = int(n_str)
|
| 730 |
+
except Exception:
|
| 731 |
+
continue
|
| 732 |
+
cur = by_n_agg.setdefault(n_local, {"attempts": 0, "parse_fail": 0, "errors": 0})
|
| 733 |
+
cur["attempts"] += int(n_stats.get("attempts", 0))
|
| 734 |
+
cur["parse_fail"] += int(n_stats.get("parse_fail", 0))
|
| 735 |
+
cur["errors"] += int(n_stats.get("errors", 0))
|
| 736 |
+
|
| 737 |
+
if by_n_agg:
|
| 738 |
+
print(" Failure by call size (N_local):")
|
| 739 |
+
for n_local in sorted(by_n_agg.keys()):
|
| 740 |
+
s = by_n_agg[n_local]
|
| 741 |
+
fail = s["parse_fail"] + s["errors"]
|
| 742 |
+
rate = (100 * fail / s["attempts"]) if s["attempts"] else 0.0
|
| 743 |
+
print(
|
| 744 |
+
f" N={n_local:3d} attempts={s['attempts']:4d} "
|
| 745 |
+
f"fail={fail:4d} ({rate:5.1f}%)"
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Why distribution across all samples
|
| 749 |
+
total_why: Dict[str, int] = {}
|
| 750 |
for r in valid:
|
| 751 |
for w, cnt in r.why_counts.items():
|
| 752 |
total_why[w] = total_why.get(w, 0) + cnt
|
|
|
|
| 882 |
ap.add_argument("--skip-rewrite", action="store_true",
|
| 883 |
help="Skip Stage 1 LLM rewrite; split caption directly into phrases")
|
| 884 |
ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags")
|
| 885 |
+
ap.add_argument("--mode", default="chunked_map_union",
|
| 886 |
+
choices=["single_shot", "chunked_map_union"])
|
| 887 |
ap.add_argument("--chunk-size", type=int, default=60)
|
| 888 |
ap.add_argument("--per-phrase-k", type=int, default=2)
|
| 889 |
ap.add_argument("--per-phrase-final-k", type=int, default=10,
|
|
|
|
| 899 |
help="Use samples in file order (first N)")
|
| 900 |
ap.add_argument("--seed", type=int, default=42,
|
| 901 |
help="Random seed for shuffle (default: 42)")
|
| 902 |
+
ap.add_argument("--workers", "-w", type=int, default=4,
|
| 903 |
+
help="Number of parallel workers (default: 4, use 1 for sequential)")
|
| 904 |
+
ap.add_argument("--eval-path", type=str, default=None,
|
| 905 |
+
help="Optional path to eval JSONL (defaults to expanded 1000-sample set).")
|
| 906 |
ap.add_argument("--min-why", default="strong_implied",
|
| 907 |
choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other", "none"],
|
| 908 |
help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
|
|
|
|
| 929 |
max_tokens=args.max_tokens,
|
| 930 |
verbose=args.verbose,
|
| 931 |
shuffle=args.shuffle,
|
| 932 |
+
seed=args.seed,
|
| 933 |
+
workers=args.workers,
|
| 934 |
+
min_why=min_why_val,
|
| 935 |
+
eval_path=args.eval_path,
|
| 936 |
+
expand_implications=args.expand_implications,
|
| 937 |
+
infer_structural=args.infer_structural,
|
| 938 |
+
)
|
| 939 |
|
| 940 |
print_summary(results)
|
| 941 |
|
|
|
|
| 962 |
"n_samples": len(results),
|
| 963 |
"caption_field": args.caption_field,
|
| 964 |
"skip_rewrite": args.skip_rewrite,
|
| 965 |
+
"allow_nsfw": args.allow_nsfw,
|
| 966 |
"mode": args.mode,
|
| 967 |
"chunk_size": args.chunk_size,
|
| 968 |
+
"eval_path": args.eval_path,
|
| 969 |
"per_phrase_k": args.per_phrase_k,
|
| 970 |
"per_phrase_final_k": args.per_phrase_final_k,
|
| 971 |
"temperature": args.temperature,
|
|
|
|
| 1009 |
"ret_P": round(r.retrieval_precision, 4),
|
| 1010 |
"sel_given_ret": round(r.selection_given_retrieval, 4),
|
| 1011 |
"over_sel": round(r.over_selection_ratio, 2),
|
| 1012 |
+
"why": r.why_counts,
|
| 1013 |
+
"stage3_diag": r.stage3_diag,
|
| 1014 |
# Character metrics (compact)
|
| 1015 |
"n_gt_char": len(r.gt_character_tags),
|
| 1016 |
"n_sel_char": len(r.selected_character_tags),
|
|
|
|
| 1062 |
"implied_tags": sorted(r.implied_tags),
|
| 1063 |
"structural_tags": r.structural_tags,
|
| 1064 |
"categorized_suggestions": r.categorized_suggestions,
|
| 1065 |
+
"why_counts": r.why_counts,
|
| 1066 |
+
"stage3_diag": r.stage3_diag,
|
| 1067 |
+
"tag_evidence": r.tag_evidence,
|
| 1068 |
"gt_character_tags": sorted(r.gt_character_tags),
|
| 1069 |
"selected_character_tags": sorted(r.selected_character_tags),
|
| 1070 |
"gt_general_tags": sorted(r.gt_general_tags),
|