Claude commited on
Commit
09a248d
·
1 Parent(s): 962e2b4

Add --min-why threshold to filter Stage 3 selections by confidence level

Browse files

select.py: WHY_RANK ordinal mapping, min_why parameter filters the best{}
dict before output. E.g. min_why="explicit" keeps only explicitly matched
tags, min_why="strong_implied" keeps explicit + strong_implied.

eval_pipeline.py: --min-why CLI arg threaded through to llm_select_indices.

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (2) hide show
  1. psq_rag/llm/select.py +23 -0
  2. scripts/eval_pipeline.py +9 -1
psq_rag/llm/select.py CHANGED
@@ -39,6 +39,15 @@ _GENERIC_CHARACTER_TAGS = frozenset({
39
 
40
  WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
41
 
 
 
 
 
 
 
 
 
 
42
  # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
43
  WHY_TO_SCORE: Dict[str, float] = {
44
  "explicit": 0.90,
@@ -484,9 +493,14 @@ def llm_select_indices(
484
  temperature: float = 0.0,
485
  max_tokens: int = 512,
486
  return_metadata: bool = False,
 
487
  ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
488
  """Return indices into the ORIGINAL candidates list (legacy interface).
489
 
 
 
 
 
490
  This implementation uses LangChain ONLY.
491
 
492
  NOTE: query_text is treated as the image description (original prompt).
@@ -716,6 +730,15 @@ def llm_select_indices(
716
  ENTITY_SYSTEM_TEMPLATE
717
  )
718
 
 
 
 
 
 
 
 
 
 
719
  # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
720
  count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
721
  ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
 
39
 
40
  WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
41
 
42
+ # Ordinal rank: lower = more confident. Used for threshold filtering.
43
+ WHY_RANK: Dict[str, int] = {
44
+ "explicit": 0,
45
+ "strong_implied": 1,
46
+ "weak_implied": 2,
47
+ "style_or_meta": 3,
48
+ "other": 4,
49
+ }
50
+
51
  # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
52
  WHY_TO_SCORE: Dict[str, float] = {
53
  "explicit": 0.90,
 
493
  temperature: float = 0.0,
494
  max_tokens: int = 512,
495
  return_metadata: bool = False,
496
+ min_why: Optional[str] = None,
497
  ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
498
  """Return indices into the ORIGINAL candidates list (legacy interface).
499
 
500
+ min_why: if set, only keep tags whose 'why' is at or above this confidence
501
+ level. E.g. min_why="explicit" keeps only explicit matches;
502
+ min_why="strong_implied" keeps explicit + strong_implied.
503
+
504
  This implementation uses LangChain ONLY.
505
 
506
  NOTE: query_text is treated as the image description (original prompt).
 
730
  ENTITY_SYSTEM_TEMPLATE
731
  )
732
 
733
+ # Apply why threshold: drop tags below the minimum confidence level.
734
+ if min_why is not None:
735
+ max_rank = WHY_RANK.get(min_why, 4)
736
+ before = len(best)
737
+ best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
738
+ if log:
739
+ log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
740
+ f"before={before} after={len(best)} dropped={before - len(best)}")
741
+
742
  # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
743
  count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
744
  ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
scripts/eval_pipeline.py CHANGED
@@ -170,6 +170,7 @@ def _process_one_sample(
170
  max_tokens: int,
171
  verbose: bool,
172
  print_lock: threading.Lock,
 
173
  ) -> SampleResult:
174
  """Process a single eval sample through the full pipeline. Thread-safe."""
175
  from psq_rag.llm.rewrite import llm_rewrite_prompt
@@ -250,6 +251,7 @@ def _process_one_sample(
250
  temperature=temperature,
251
  max_tokens=max_tokens,
252
  return_metadata=True,
 
253
  )
254
  result.stage3_time = time.time() - t0
255
 
@@ -351,6 +353,7 @@ def run_eval(
351
  shuffle: bool = True,
352
  seed: int = 42,
353
  workers: int = 1,
 
354
  ) -> List[SampleResult]:
355
 
356
  # Load eval samples
@@ -400,7 +403,7 @@ def run_eval(
400
  sample, i, total,
401
  skip_rewrite, allow_nsfw, mode, chunk_size,
402
  per_phrase_k, temperature, max_tokens, verbose,
403
- print_lock,
404
  )
405
  results.append(result)
406
  else:
@@ -647,6 +650,9 @@ def main(argv=None) -> int:
647
  help="Random seed for shuffle (default: 42)")
648
  ap.add_argument("--workers", "-w", type=int, default=4,
649
  help="Number of parallel workers (default: 4, use 1 for sequential)")
 
 
 
650
 
651
  args = ap.parse_args(list(argv) if argv is not None else None)
652
 
@@ -664,6 +670,7 @@ def main(argv=None) -> int:
664
  shuffle=args.shuffle,
665
  seed=args.seed,
666
  workers=args.workers,
 
667
  )
668
 
669
  print_summary(results)
@@ -694,6 +701,7 @@ def main(argv=None) -> int:
694
  "shuffle": args.shuffle,
695
  "seed": args.seed,
696
  "workers": args.workers,
 
697
  "n_errors": sum(1 for r in results if r.error),
698
  }
699
 
 
170
  max_tokens: int,
171
  verbose: bool,
172
  print_lock: threading.Lock,
173
+ min_why: Optional[str] = None,
174
  ) -> SampleResult:
175
  """Process a single eval sample through the full pipeline. Thread-safe."""
176
  from psq_rag.llm.rewrite import llm_rewrite_prompt
 
251
  temperature=temperature,
252
  max_tokens=max_tokens,
253
  return_metadata=True,
254
+ min_why=min_why,
255
  )
256
  result.stage3_time = time.time() - t0
257
 
 
353
  shuffle: bool = True,
354
  seed: int = 42,
355
  workers: int = 1,
356
+ min_why: Optional[str] = None,
357
  ) -> List[SampleResult]:
358
 
359
  # Load eval samples
 
403
  sample, i, total,
404
  skip_rewrite, allow_nsfw, mode, chunk_size,
405
  per_phrase_k, temperature, max_tokens, verbose,
406
+ print_lock, min_why,
407
  )
408
  results.append(result)
409
  else:
 
650
  help="Random seed for shuffle (default: 42)")
651
  ap.add_argument("--workers", "-w", type=int, default=4,
652
  help="Number of parallel workers (default: 4, use 1 for sequential)")
653
+ ap.add_argument("--min-why", default=None,
654
+ choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"],
655
+ help="Minimum 'why' confidence to keep (e.g. 'explicit' keeps only explicit matches)")
656
 
657
  args = ap.parse_args(list(argv) if argv is not None else None)
658
 
 
670
  shuffle=args.shuffle,
671
  seed=args.seed,
672
  workers=args.workers,
673
+ min_why=args.min_why,
674
  )
675
 
676
  print_summary(results)
 
701
  "shuffle": args.shuffle,
702
  "seed": args.seed,
703
  "workers": args.workers,
704
+ "min_why": args.min_why,
705
  "n_errors": sum(1 for r in results if r.error),
706
  }
707