Food Desert commited on
Commit
3c18372
·
1 Parent(s): 73f56cf

Simplify Stage3 chunking to interleave-only and add eval diagnostics

Browse files
Files changed (2) hide show
  1. psq_rag/llm/select.py +88 -14
  2. scripts/eval_pipeline.py +129 -71
psq_rag/llm/select.py CHANGED
@@ -253,6 +253,13 @@ def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
253
  return out
254
 
255
 
 
 
 
 
 
 
 
256
  def _display_tag(tag: str) -> str:
257
  # Display tags with spaces for the LLM, but keep canonical underscores internally.
258
  return tag.replace("_", " ")
@@ -494,8 +501,13 @@ def llm_select_indices(
494
  temperature: float = 0.0,
495
  max_tokens: int = 512,
496
  return_metadata: bool = False,
 
497
  min_why: Optional[str] = "strong_implied",
498
- ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
 
 
 
 
499
  """Return indices into the ORIGINAL candidates list (legacy interface).
500
 
501
  min_why: if set, only keep tags whose 'why' is at or above this confidence
@@ -586,6 +598,42 @@ def llm_select_indices(
586
 
587
  # Global union: tag -> best (score, why)
588
  best: Dict[str, Tuple[float, str]] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
 
590
  def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
591
  # Create chain with the provided system template
@@ -598,9 +646,10 @@ def llm_select_indices(
598
  )
599
  chain = prompt | llm | parser
600
 
601
- ordered = _interleave_round_robin(call_cands)
602
  candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
603
  N_local = len(idx_to_tag)
 
604
 
605
  phrases = _phrases_in_call(call_cands)
606
  per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
@@ -618,6 +667,7 @@ def llm_select_indices(
618
  # Invoke LangChain chain (templating fills {N} and other vars)
619
  for att in range(retries + 1):
620
  try:
 
621
  if log:
622
  log(
623
  f"Stage3 {label}: "
@@ -637,6 +687,16 @@ def llm_select_indices(
637
  }
638
  )
639
  selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
 
 
 
 
 
 
 
 
 
 
640
  if log:
641
  log(f"Stage3 {label}: attempt {att+1} diag={diag}")
642
  if not summary_logged and (selected or att == retries):
@@ -662,6 +722,7 @@ def llm_select_indices(
662
  log(f"Stage3 {label} selections: (none)")
663
 
664
  if selected:
 
665
  for s in selected:
666
  prev = best.get(s.tag)
667
  if prev is None or s.score > prev[0]:
@@ -669,11 +730,14 @@ def llm_select_indices(
669
  return
670
 
671
  except Exception as e:
 
 
672
  if log:
673
  log(f"Stage3 {label}: attempt {att+1} error: {e}")
674
 
675
  if log:
676
  log(f"Stage3 {label}: gave up after {retries+1} attempts")
 
677
 
678
  # Split candidates by type (general vs entity)
679
  general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
@@ -687,12 +751,9 @@ def llm_select_indices(
687
  if mode == "single_shot":
688
  run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
689
  else:
690
- for start in range(0, len(general_cands), chunk_size):
691
- run_call(
692
- general_cands[start:start + chunk_size],
693
- f"general_chunk_{start//chunk_size}",
694
- SELECT_SYSTEM_TEMPLATE
695
- )
696
 
697
  # Process entity candidates (characters only) with alias-based pre-filtering
698
  if entity_cands:
@@ -725,12 +786,9 @@ def llm_select_indices(
725
  if mode == "single_shot":
726
  run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
727
  else:
728
- for start in range(0, len(filtered_entity_cands), chunk_size):
729
- run_call(
730
- filtered_entity_cands[start:start + chunk_size],
731
- f"entity_chunk_{start//chunk_size}",
732
- ENTITY_SYSTEM_TEMPLATE
733
- )
734
 
735
  # Apply why threshold: drop tags below the minimum confidence level.
736
  if min_why is not None:
@@ -757,7 +815,23 @@ def llm_select_indices(
757
  out_idx.append(tag_to_first_index[t])
758
  tag_why[t] = best[t][1] # why string
759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  if return_metadata:
 
 
761
  return out_idx, tag_why
762
 
763
  return out_idx
 
253
  return out
254
 
255
 
256
+ def _build_chunks(cands: Sequence[Candidate], chunk_size: int) -> List[List[Candidate]]:
257
+ if chunk_size <= 0:
258
+ raise ValueError(f"chunk_size must be > 0, got {chunk_size}")
259
+ ordered = _interleave_round_robin(cands)
260
+ return [ordered[i:i + chunk_size] for i in range(0, len(ordered), chunk_size)]
261
+
262
+
263
  def _display_tag(tag: str) -> str:
264
  # Display tags with spaces for the LLM, but keep canonical underscores internally.
265
  return tag.replace("_", " ")
 
501
  temperature: float = 0.0,
502
  max_tokens: int = 512,
503
  return_metadata: bool = False,
504
+ return_diagnostics: bool = False,
505
  min_why: Optional[str] = "strong_implied",
506
+ ) -> Union[
507
+ List[int],
508
+ Tuple[List[int], Dict[str, str]],
509
+ Tuple[List[int], Dict[str, str], Dict[str, Any]],
510
+ ]:
511
  """Return indices into the ORIGINAL candidates list (legacy interface).
512
 
513
  min_why: if set, only keep tags whose 'why' is at or above this confidence
 
598
 
599
  # Global union: tag -> best (score, why)
600
  best: Dict[str, Tuple[float, str]] = {}
601
+ diagnostics: Dict[str, Any] = {
602
+ "mode": mode,
603
+ "chunk_strategy": "interleave",
604
+ "chunk_passes": 1,
605
+ "chunk_shuffle_within_call": False,
606
+ "calls_total": 0,
607
+ "calls_with_selection": 0,
608
+ "calls_exhausted_retries": 0,
609
+ "attempts_total": 0,
610
+ "attempt_errors": 0,
611
+ "attempt_parse_fail": 0,
612
+ "attempt_parse_ok": 0,
613
+ "invalid_items_total": 0,
614
+ "oob_indices_total": 0,
615
+ "dupe_indices_total": 0,
616
+ "kept_total": 0,
617
+ "attempts_by_n_local": {},
618
+ }
619
+
620
+ def _record_attempt_for_n(n_local: int, *, parse_ok: bool, error: bool) -> None:
621
+ by_n = diagnostics["attempts_by_n_local"]
622
+ key = str(n_local)
623
+ if key not in by_n:
624
+ by_n[key] = {
625
+ "attempts": 0,
626
+ "parse_ok": 0,
627
+ "parse_fail": 0,
628
+ "errors": 0,
629
+ }
630
+ by_n[key]["attempts"] += 1
631
+ if error:
632
+ by_n[key]["errors"] += 1
633
+ elif parse_ok:
634
+ by_n[key]["parse_ok"] += 1
635
+ else:
636
+ by_n[key]["parse_fail"] += 1
637
 
638
  def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
639
  # Create chain with the provided system template
 
646
  )
647
  chain = prompt | llm | parser
648
 
649
+ ordered = _interleave_round_robin(call_cands) if mode == "single_shot" else list(call_cands)
650
  candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
651
  N_local = len(idx_to_tag)
652
+ diagnostics["calls_total"] += 1
653
 
654
  phrases = _phrases_in_call(call_cands)
655
  per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
 
667
  # Invoke LangChain chain (templating fills {N} and other vars)
668
  for att in range(retries + 1):
669
  try:
670
+ diagnostics["attempts_total"] += 1
671
  if log:
672
  log(
673
  f"Stage3 {label}: "
 
687
  }
688
  )
689
  selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
690
+ diagnostics["invalid_items_total"] += int(diag.get("invalid_items", 0))
691
+ diagnostics["oob_indices_total"] += int(diag.get("oob_indices", 0))
692
+ diagnostics["dupe_indices_total"] += int(diag.get("dupe_indices", 0))
693
+ diagnostics["kept_total"] += int(diag.get("kept", 0))
694
+ if bool(diag.get("parse_ok", False)):
695
+ diagnostics["attempt_parse_ok"] += 1
696
+ _record_attempt_for_n(N_local, parse_ok=True, error=False)
697
+ else:
698
+ diagnostics["attempt_parse_fail"] += 1
699
+ _record_attempt_for_n(N_local, parse_ok=False, error=False)
700
  if log:
701
  log(f"Stage3 {label}: attempt {att+1} diag={diag}")
702
  if not summary_logged and (selected or att == retries):
 
722
  log(f"Stage3 {label} selections: (none)")
723
 
724
  if selected:
725
+ diagnostics["calls_with_selection"] += 1
726
  for s in selected:
727
  prev = best.get(s.tag)
728
  if prev is None or s.score > prev[0]:
 
730
  return
731
 
732
  except Exception as e:
733
+ diagnostics["attempt_errors"] += 1
734
+ _record_attempt_for_n(N_local, parse_ok=False, error=True)
735
  if log:
736
  log(f"Stage3 {label}: attempt {att+1} error: {e}")
737
 
738
  if log:
739
  log(f"Stage3 {label}: gave up after {retries+1} attempts")
740
+ diagnostics["calls_exhausted_retries"] += 1
741
 
742
  # Split candidates by type (general vs entity)
743
  general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
 
751
  if mode == "single_shot":
752
  run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
753
  else:
754
+ base_chunks = _build_chunks(general_cands, chunk_size)
755
+ for chunk_idx, chunk in enumerate(base_chunks):
756
+ run_call(chunk, f"general_chunk_{chunk_idx}", SELECT_SYSTEM_TEMPLATE)
 
 
 
757
 
758
  # Process entity candidates (characters only) with alias-based pre-filtering
759
  if entity_cands:
 
786
  if mode == "single_shot":
787
  run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
788
  else:
789
+ base_chunks = _build_chunks(filtered_entity_cands, chunk_size)
790
+ for chunk_idx, chunk in enumerate(base_chunks):
791
+ run_call(chunk, f"entity_chunk_{chunk_idx}", ENTITY_SYSTEM_TEMPLATE)
 
 
 
792
 
793
  # Apply why threshold: drop tags below the minimum confidence level.
794
  if min_why is not None:
 
815
  out_idx.append(tag_to_first_index[t])
816
  tag_why[t] = best[t][1] # why string
817
 
818
+ if diagnostics["attempts_total"] > 0:
819
+ diagnostics["attempt_failure_rate"] = (
820
+ diagnostics["attempt_parse_fail"] + diagnostics["attempt_errors"]
821
+ ) / diagnostics["attempts_total"]
822
+ else:
823
+ diagnostics["attempt_failure_rate"] = 0.0
824
+
825
+ if diagnostics["calls_total"] > 0:
826
+ diagnostics["call_exhaustion_rate"] = (
827
+ diagnostics["calls_exhausted_retries"] / diagnostics["calls_total"]
828
+ )
829
+ else:
830
+ diagnostics["call_exhaustion_rate"] = 0.0
831
+
832
  if return_metadata:
833
+ if return_diagnostics:
834
+ return out_idx, tag_why, diagnostics
835
  return out_idx, tag_why
836
 
837
  return out_idx
scripts/eval_pipeline.py CHANGED
@@ -162,7 +162,8 @@ class SampleResult:
162
  selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
163
  over_selection_ratio: float = 0.0 # |selected| / |gt|
164
  # Why distribution (from Stage 3 LLM)
165
- why_counts: Dict[str, int] = field(default_factory=dict)
 
166
  # Tag implications
167
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
168
  # Structural inference tags (solo/duo/male/female/anthro/biped etc.)
@@ -216,12 +217,12 @@ def _process_one_sample(
216
  per_phrase_final_k: int,
217
  temperature: float,
218
  max_tokens: int,
219
- verbose: bool,
220
- print_lock: threading.Lock,
221
- min_why: Optional[str] = None,
222
- expand_implications: bool = False,
223
- infer_structural: bool = False,
224
- ) -> SampleResult:
225
  """Process a single eval sample through the full pipeline. Thread-safe."""
226
  from psq_rag.llm.rewrite import llm_rewrite_prompt
227
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
@@ -300,20 +301,22 @@ def _process_one_sample(
300
 
301
  # --- Stage 3: LLM Selection ---
302
  t0 = time.time()
303
- picked_indices, tag_why = llm_select_indices(
304
- query_text=caption,
305
- candidates=candidates,
306
- max_pick=0,
307
- log=log,
308
- mode=mode,
309
- chunk_size=chunk_size,
310
- per_phrase_k=per_phrase_k,
311
- temperature=temperature,
312
- max_tokens=max_tokens,
313
- return_metadata=True,
314
- min_why=min_why,
315
- )
316
- result.stage3_time = time.time() - t0
 
 
317
 
318
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
319
  result.stage3_selected_tags = set(result.selected_tags)
@@ -497,10 +500,11 @@ def run_eval(
497
  max_tokens: int = 512,
498
  verbose: bool = False,
499
  shuffle: bool = True,
500
- seed: int = 42,
501
- workers: int = 1,
502
- min_why: Optional[str] = "strong_implied",
503
- expand_implications: bool = False,
 
504
  infer_structural: bool = False,
505
  ) -> List[SampleResult]:
506
  expand_gt = expand_implications
@@ -508,20 +512,26 @@ def run_eval(
508
  from psq_rag.retrieval.state import expand_tags_via_implications as _expand_gt_tags
509
 
510
  # Load eval samples — prefer expanded file, fall back to raw
511
- eval_path = EVAL_DATA_PATH
512
- if not eval_path.is_file():
513
- eval_path = EVAL_DATA_PATH_RAW
514
- if not eval_path.is_file():
515
- print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
516
- sys.exit(1)
517
- print(f"WARNING: Expanded eval data not found, falling back to raw: {eval_path}")
518
- print(" Run: python scripts/preprocess_eval_data.py")
519
-
520
- all_samples = []
521
- using_expanded = False
522
- with eval_path.open("r", encoding="utf-8") as f:
523
- for line in f:
524
- row = json.loads(line)
 
 
 
 
 
 
525
  caption = row.get(caption_field, "")
526
  if not caption or not caption.strip():
527
  continue
@@ -543,19 +553,20 @@ def run_eval(
543
  "caption": caption.strip(),
544
  "gt_tags": gt_tags,
545
  })
546
- if using_expanded:
547
- print("Using implication-expanded ground truth")
548
-
549
  if shuffle:
550
  rng = random.Random(seed)
551
  rng.shuffle(all_samples)
552
 
553
  samples = all_samples[:n_samples]
554
 
555
- print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
556
- print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
557
- print(f"workers={workers}")
558
- print()
 
559
 
560
  # Pre-warm shared retrieval assets before spawning threads
561
  _prewarm_retrieval_assets()
@@ -572,10 +583,11 @@ def run_eval(
572
  sample, i, total,
573
  skip_rewrite, allow_nsfw, mode, chunk_size,
574
  per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
575
- print_lock, min_why, expand_implications,
 
576
  infer_structural,
577
  )
578
- results.append(result)
579
  else:
580
  # Parallel mode
581
  print(f"Processing {total} samples with {workers} parallel workers...")
@@ -584,12 +596,13 @@ def run_eval(
584
  results_by_index: Dict[int, SampleResult] = {}
585
  with ThreadPoolExecutor(max_workers=workers) as executor:
586
  futures = {
587
- executor.submit(
588
  _process_one_sample,
589
  sample, i, total,
590
  skip_rewrite, allow_nsfw, mode, chunk_size,
591
  per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
592
- print_lock, min_why, expand_implications,
 
593
  infer_structural,
594
  ): i
595
  for i, sample in enumerate(samples)
@@ -688,13 +701,52 @@ def print_summary(results: List[SampleResult]) -> None:
688
  print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
689
 
690
  print()
691
- print("Diagnostic Metrics:")
692
- print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
693
- print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
694
- print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
695
-
696
- # Why distribution across all samples
697
- total_why: Dict[str, int] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  for r in valid:
699
  for w, cnt in r.why_counts.items():
700
  total_why[w] = total_why.get(w, 0) + cnt
@@ -830,8 +882,8 @@ def main(argv=None) -> int:
830
  ap.add_argument("--skip-rewrite", action="store_true",
831
  help="Skip Stage 1 LLM rewrite; split caption directly into phrases")
832
  ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags")
833
- ap.add_argument("--mode", default="chunked_map_union",
834
- choices=["single_shot", "chunked_map_union"])
835
  ap.add_argument("--chunk-size", type=int, default=60)
836
  ap.add_argument("--per-phrase-k", type=int, default=2)
837
  ap.add_argument("--per-phrase-final-k", type=int, default=10,
@@ -847,8 +899,10 @@ def main(argv=None) -> int:
847
  help="Use samples in file order (first N)")
848
  ap.add_argument("--seed", type=int, default=42,
849
  help="Random seed for shuffle (default: 42)")
850
- ap.add_argument("--workers", "-w", type=int, default=4,
851
- help="Number of parallel workers (default: 4, use 1 for sequential)")
 
 
852
  ap.add_argument("--min-why", default="strong_implied",
853
  choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other", "none"],
854
  help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
@@ -875,12 +929,13 @@ def main(argv=None) -> int:
875
  max_tokens=args.max_tokens,
876
  verbose=args.verbose,
877
  shuffle=args.shuffle,
878
- seed=args.seed,
879
- workers=args.workers,
880
- min_why=min_why_val,
881
- expand_implications=args.expand_implications,
882
- infer_structural=args.infer_structural,
883
- )
 
884
 
885
  print_summary(results)
886
 
@@ -907,9 +962,10 @@ def main(argv=None) -> int:
907
  "n_samples": len(results),
908
  "caption_field": args.caption_field,
909
  "skip_rewrite": args.skip_rewrite,
910
- "allow_nsfw": args.allow_nsfw,
911
  "mode": args.mode,
912
  "chunk_size": args.chunk_size,
 
913
  "per_phrase_k": args.per_phrase_k,
914
  "per_phrase_final_k": args.per_phrase_final_k,
915
  "temperature": args.temperature,
@@ -953,7 +1009,8 @@ def main(argv=None) -> int:
953
  "ret_P": round(r.retrieval_precision, 4),
954
  "sel_given_ret": round(r.selection_given_retrieval, 4),
955
  "over_sel": round(r.over_selection_ratio, 2),
956
- "why": r.why_counts,
 
957
  # Character metrics (compact)
958
  "n_gt_char": len(r.gt_character_tags),
959
  "n_sel_char": len(r.selected_character_tags),
@@ -1005,8 +1062,9 @@ def main(argv=None) -> int:
1005
  "implied_tags": sorted(r.implied_tags),
1006
  "structural_tags": r.structural_tags,
1007
  "categorized_suggestions": r.categorized_suggestions,
1008
- "why_counts": r.why_counts,
1009
- "tag_evidence": r.tag_evidence,
 
1010
  "gt_character_tags": sorted(r.gt_character_tags),
1011
  "selected_character_tags": sorted(r.selected_character_tags),
1012
  "gt_general_tags": sorted(r.gt_general_tags),
 
162
  selection_given_retrieval: float = 0.0 # |selected ∩ gt| / |retrieved ∩ gt|
163
  over_selection_ratio: float = 0.0 # |selected| / |gt|
164
  # Why distribution (from Stage 3 LLM)
165
+ why_counts: Dict[str, int] = field(default_factory=dict)
166
+ stage3_diag: Dict[str, Any] = field(default_factory=dict)
167
  # Tag implications
168
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
169
  # Structural inference tags (solo/duo/male/female/anthro/biped etc.)
 
217
  per_phrase_final_k: int,
218
  temperature: float,
219
  max_tokens: int,
220
+ verbose: bool,
221
+ print_lock: threading.Lock,
222
+ min_why: Optional[str] = None,
223
+ expand_implications: bool = False,
224
+ infer_structural: bool = False,
225
+ ) -> SampleResult:
226
  """Process a single eval sample through the full pipeline. Thread-safe."""
227
  from psq_rag.llm.rewrite import llm_rewrite_prompt
228
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
 
301
 
302
  # --- Stage 3: LLM Selection ---
303
  t0 = time.time()
304
+ picked_indices, tag_why, stage3_diag = llm_select_indices(
305
+ query_text=caption,
306
+ candidates=candidates,
307
+ max_pick=0,
308
+ log=log,
309
+ mode=mode,
310
+ chunk_size=chunk_size,
311
+ per_phrase_k=per_phrase_k,
312
+ temperature=temperature,
313
+ max_tokens=max_tokens,
314
+ return_metadata=True,
315
+ return_diagnostics=True,
316
+ min_why=min_why,
317
+ )
318
+ result.stage3_time = time.time() - t0
319
+ result.stage3_diag = stage3_diag or {}
320
 
321
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
322
  result.stage3_selected_tags = set(result.selected_tags)
 
500
  max_tokens: int = 512,
501
  verbose: bool = False,
502
  shuffle: bool = True,
503
+ seed: int = 42,
504
+ workers: int = 1,
505
+ min_why: Optional[str] = "strong_implied",
506
+ eval_path: Optional[str] = None,
507
+ expand_implications: bool = False,
508
  infer_structural: bool = False,
509
  ) -> List[SampleResult]:
510
  expand_gt = expand_implications
 
512
  from psq_rag.retrieval.state import expand_tags_via_implications as _expand_gt_tags
513
 
514
  # Load eval samples — prefer expanded file, fall back to raw
515
+ eval_path_obj = Path(eval_path) if eval_path else EVAL_DATA_PATH
516
+ if not eval_path_obj.is_absolute():
517
+ eval_path_obj = (_REPO_ROOT / eval_path_obj).resolve()
518
+
519
+ if not eval_path_obj.is_file() and eval_path is None:
520
+ eval_path_obj = EVAL_DATA_PATH_RAW
521
+ if not eval_path_obj.is_file():
522
+ print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
523
+ sys.exit(1)
524
+ print(f"WARNING: Expanded eval data not found, falling back to raw: {eval_path_obj}")
525
+ print(" Run: python scripts/preprocess_eval_data.py")
526
+ elif not eval_path_obj.is_file():
527
+ print(f"ERROR: Eval data not found: {eval_path_obj}")
528
+ sys.exit(1)
529
+
530
+ all_samples = []
531
+ using_expanded = False
532
+ with eval_path_obj.open("r", encoding="utf-8") as f:
533
+ for line in f:
534
+ row = json.loads(line)
535
  caption = row.get(caption_field, "")
536
  if not caption or not caption.strip():
537
  continue
 
553
  "caption": caption.strip(),
554
  "gt_tags": gt_tags,
555
  })
556
+ if using_expanded:
557
+ print("Using implication-expanded ground truth")
558
+
559
  if shuffle:
560
  rng = random.Random(seed)
561
  rng.shuffle(all_samples)
562
 
563
  samples = all_samples[:n_samples]
564
 
565
+ print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
566
+ print(f"eval_path={eval_path_obj}")
567
+ print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
568
+ print(f"workers={workers}")
569
+ print()
570
 
571
  # Pre-warm shared retrieval assets before spawning threads
572
  _prewarm_retrieval_assets()
 
583
  sample, i, total,
584
  skip_rewrite, allow_nsfw, mode, chunk_size,
585
  per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
586
+ print_lock, min_why,
587
+ expand_implications,
588
  infer_structural,
589
  )
590
+ results.append(result)
591
  else:
592
  # Parallel mode
593
  print(f"Processing {total} samples with {workers} parallel workers...")
 
596
  results_by_index: Dict[int, SampleResult] = {}
597
  with ThreadPoolExecutor(max_workers=workers) as executor:
598
  futures = {
599
+ executor.submit(
600
  _process_one_sample,
601
  sample, i, total,
602
  skip_rewrite, allow_nsfw, mode, chunk_size,
603
  per_phrase_k, per_phrase_final_k, temperature, max_tokens, verbose,
604
+ print_lock, min_why,
605
+ expand_implications,
606
  infer_structural,
607
  ): i
608
  for i, sample in enumerate(samples)
 
701
  print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
702
 
703
  print()
704
+ print("Diagnostic Metrics:")
705
+ print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
706
+ print(f" Sel-given-retrieval: {avg_sel_given_ret:.4f} (of gt tags retrieved, fraction kept by Stage 3)")
707
+ print(f" Over-selection ratio: {avg_over_sel:.2f}x (|selected|/|gt|, ideal ~1.0)")
708
+
709
+ stage3_diag_rows = [r.stage3_diag for r in valid if r.stage3_diag]
710
+ if stage3_diag_rows:
711
+ calls_total = sum(int(d.get("calls_total", 0)) for d in stage3_diag_rows)
712
+ calls_exhausted = sum(int(d.get("calls_exhausted_retries", 0)) for d in stage3_diag_rows)
713
+ attempts_total = sum(int(d.get("attempts_total", 0)) for d in stage3_diag_rows)
714
+ attempts_parse_fail = sum(int(d.get("attempt_parse_fail", 0)) for d in stage3_diag_rows)
715
+ attempts_errors = sum(int(d.get("attempt_errors", 0)) for d in stage3_diag_rows)
716
+
717
+ print()
718
+ print("Stage 3 Structured Output Reliability:")
719
+ print(f" Calls total: {calls_total}")
720
+ print(f" Calls exhausted: {calls_exhausted} ({(100 * calls_exhausted / calls_total) if calls_total else 0:.1f}%)")
721
+ print(f" Attempts total: {attempts_total}")
722
+ print(f" Parse/schema failures:{attempts_parse_fail} ({(100 * attempts_parse_fail / attempts_total) if attempts_total else 0:.1f}%)")
723
+ print(f" Call errors/exc: {attempts_errors} ({(100 * attempts_errors / attempts_total) if attempts_total else 0:.1f}%)")
724
+
725
+ by_n_agg: Dict[int, Dict[str, int]] = {}
726
+ for d in stage3_diag_rows:
727
+ for n_str, n_stats in d.get("attempts_by_n_local", {}).items():
728
+ try:
729
+ n_local = int(n_str)
730
+ except Exception:
731
+ continue
732
+ cur = by_n_agg.setdefault(n_local, {"attempts": 0, "parse_fail": 0, "errors": 0})
733
+ cur["attempts"] += int(n_stats.get("attempts", 0))
734
+ cur["parse_fail"] += int(n_stats.get("parse_fail", 0))
735
+ cur["errors"] += int(n_stats.get("errors", 0))
736
+
737
+ if by_n_agg:
738
+ print(" Failure by call size (N_local):")
739
+ for n_local in sorted(by_n_agg.keys()):
740
+ s = by_n_agg[n_local]
741
+ fail = s["parse_fail"] + s["errors"]
742
+ rate = (100 * fail / s["attempts"]) if s["attempts"] else 0.0
743
+ print(
744
+ f" N={n_local:3d} attempts={s['attempts']:4d} "
745
+ f"fail={fail:4d} ({rate:5.1f}%)"
746
+ )
747
+
748
+ # Why distribution across all samples
749
+ total_why: Dict[str, int] = {}
750
  for r in valid:
751
  for w, cnt in r.why_counts.items():
752
  total_why[w] = total_why.get(w, 0) + cnt
 
882
  ap.add_argument("--skip-rewrite", action="store_true",
883
  help="Skip Stage 1 LLM rewrite; split caption directly into phrases")
884
  ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags")
885
+ ap.add_argument("--mode", default="chunked_map_union",
886
+ choices=["single_shot", "chunked_map_union"])
887
  ap.add_argument("--chunk-size", type=int, default=60)
888
  ap.add_argument("--per-phrase-k", type=int, default=2)
889
  ap.add_argument("--per-phrase-final-k", type=int, default=10,
 
899
  help="Use samples in file order (first N)")
900
  ap.add_argument("--seed", type=int, default=42,
901
  help="Random seed for shuffle (default: 42)")
902
+ ap.add_argument("--workers", "-w", type=int, default=4,
903
+ help="Number of parallel workers (default: 4, use 1 for sequential)")
904
+ ap.add_argument("--eval-path", type=str, default=None,
905
+ help="Optional path to eval JSONL (defaults to expanded 1000-sample set).")
906
  ap.add_argument("--min-why", default="strong_implied",
907
  choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other", "none"],
908
  help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
 
929
  max_tokens=args.max_tokens,
930
  verbose=args.verbose,
931
  shuffle=args.shuffle,
932
+ seed=args.seed,
933
+ workers=args.workers,
934
+ min_why=min_why_val,
935
+ eval_path=args.eval_path,
936
+ expand_implications=args.expand_implications,
937
+ infer_structural=args.infer_structural,
938
+ )
939
 
940
  print_summary(results)
941
 
 
962
  "n_samples": len(results),
963
  "caption_field": args.caption_field,
964
  "skip_rewrite": args.skip_rewrite,
965
+ "allow_nsfw": args.allow_nsfw,
966
  "mode": args.mode,
967
  "chunk_size": args.chunk_size,
968
+ "eval_path": args.eval_path,
969
  "per_phrase_k": args.per_phrase_k,
970
  "per_phrase_final_k": args.per_phrase_final_k,
971
  "temperature": args.temperature,
 
1009
  "ret_P": round(r.retrieval_precision, 4),
1010
  "sel_given_ret": round(r.selection_given_retrieval, 4),
1011
  "over_sel": round(r.over_selection_ratio, 2),
1012
+ "why": r.why_counts,
1013
+ "stage3_diag": r.stage3_diag,
1014
  # Character metrics (compact)
1015
  "n_gt_char": len(r.gt_character_tags),
1016
  "n_sel_char": len(r.selected_character_tags),
 
1062
  "implied_tags": sorted(r.implied_tags),
1063
  "structural_tags": r.structural_tags,
1064
  "categorized_suggestions": r.categorized_suggestions,
1065
+ "why_counts": r.why_counts,
1066
+ "stage3_diag": r.stage3_diag,
1067
+ "tag_evidence": r.tag_evidence,
1068
  "gt_character_tags": sorted(r.gt_character_tags),
1069
  "selected_character_tags": sorted(r.selected_character_tags),
1070
  "gt_general_tags": sorted(r.gt_general_tags),