Spaces:
Running
Running
Claude commited on
Commit ·
09a248d
1
Parent(s): 962e2b4
Add --min-why threshold to filter Stage 3 selections by confidence level
Browse filesselect.py: WHY_RANK ordinal mapping, min_why parameter filters the best{}
dict before output. E.g. min_why="explicit" keeps only explicitly matched
tags, min_why="strong_implied" keeps explicit + strong_implied.
eval_pipeline.py: --min-why CLI arg threaded through to llm_select_indices.
https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG
- psq_rag/llm/select.py +23 -0
- scripts/eval_pipeline.py +9 -1
psq_rag/llm/select.py
CHANGED
|
@@ -39,6 +39,15 @@ _GENERIC_CHARACTER_TAGS = frozenset({
|
|
| 39 |
|
| 40 |
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
|
| 43 |
WHY_TO_SCORE: Dict[str, float] = {
|
| 44 |
"explicit": 0.90,
|
|
@@ -484,9 +493,14 @@ def llm_select_indices(
|
|
| 484 |
temperature: float = 0.0,
|
| 485 |
max_tokens: int = 512,
|
| 486 |
return_metadata: bool = False,
|
|
|
|
| 487 |
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
|
| 488 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
This implementation uses LangChain ONLY.
|
| 491 |
|
| 492 |
NOTE: query_text is treated as the image description (original prompt).
|
|
@@ -716,6 +730,15 @@ def llm_select_indices(
|
|
| 716 |
ENTITY_SYSTEM_TEMPLATE
|
| 717 |
)
|
| 718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
|
| 720 |
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
|
| 721 |
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
|
|
|
|
| 39 |
|
| 40 |
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
|
| 41 |
|
| 42 |
+
# Ordinal rank: lower = more confident. Used for threshold filtering.
|
| 43 |
+
WHY_RANK: Dict[str, int] = {
|
| 44 |
+
"explicit": 0,
|
| 45 |
+
"strong_implied": 1,
|
| 46 |
+
"weak_implied": 2,
|
| 47 |
+
"style_or_meta": 3,
|
| 48 |
+
"other": 4,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
|
| 52 |
WHY_TO_SCORE: Dict[str, float] = {
|
| 53 |
"explicit": 0.90,
|
|
|
|
| 493 |
temperature: float = 0.0,
|
| 494 |
max_tokens: int = 512,
|
| 495 |
return_metadata: bool = False,
|
| 496 |
+
min_why: Optional[str] = None,
|
| 497 |
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
|
| 498 |
"""Return indices into the ORIGINAL candidates list (legacy interface).
|
| 499 |
|
| 500 |
+
min_why: if set, only keep tags whose 'why' is at or above this confidence
|
| 501 |
+
level. E.g. min_why="explicit" keeps only explicit matches;
|
| 502 |
+
min_why="strong_implied" keeps explicit + strong_implied.
|
| 503 |
+
|
| 504 |
This implementation uses LangChain ONLY.
|
| 505 |
|
| 506 |
NOTE: query_text is treated as the image description (original prompt).
|
|
|
|
| 730 |
ENTITY_SYSTEM_TEMPLATE
|
| 731 |
)
|
| 732 |
|
| 733 |
+
# Apply why threshold: drop tags below the minimum confidence level.
|
| 734 |
+
if min_why is not None:
|
| 735 |
+
max_rank = WHY_RANK.get(min_why, 4)
|
| 736 |
+
before = len(best)
|
| 737 |
+
best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
|
| 738 |
+
if log:
|
| 739 |
+
log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
|
| 740 |
+
f"before={before} after={len(best)} dropped={before - len(best)}")
|
| 741 |
+
|
| 742 |
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
|
| 743 |
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
|
| 744 |
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
|
scripts/eval_pipeline.py
CHANGED
|
@@ -170,6 +170,7 @@ def _process_one_sample(
|
|
| 170 |
max_tokens: int,
|
| 171 |
verbose: bool,
|
| 172 |
print_lock: threading.Lock,
|
|
|
|
| 173 |
) -> SampleResult:
|
| 174 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 175 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
|
@@ -250,6 +251,7 @@ def _process_one_sample(
|
|
| 250 |
temperature=temperature,
|
| 251 |
max_tokens=max_tokens,
|
| 252 |
return_metadata=True,
|
|
|
|
| 253 |
)
|
| 254 |
result.stage3_time = time.time() - t0
|
| 255 |
|
|
@@ -351,6 +353,7 @@ def run_eval(
|
|
| 351 |
shuffle: bool = True,
|
| 352 |
seed: int = 42,
|
| 353 |
workers: int = 1,
|
|
|
|
| 354 |
) -> List[SampleResult]:
|
| 355 |
|
| 356 |
# Load eval samples
|
|
@@ -400,7 +403,7 @@ def run_eval(
|
|
| 400 |
sample, i, total,
|
| 401 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 402 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 403 |
-
print_lock,
|
| 404 |
)
|
| 405 |
results.append(result)
|
| 406 |
else:
|
|
@@ -647,6 +650,9 @@ def main(argv=None) -> int:
|
|
| 647 |
help="Random seed for shuffle (default: 42)")
|
| 648 |
ap.add_argument("--workers", "-w", type=int, default=4,
|
| 649 |
help="Number of parallel workers (default: 4, use 1 for sequential)")
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 652 |
|
|
@@ -664,6 +670,7 @@ def main(argv=None) -> int:
|
|
| 664 |
shuffle=args.shuffle,
|
| 665 |
seed=args.seed,
|
| 666 |
workers=args.workers,
|
|
|
|
| 667 |
)
|
| 668 |
|
| 669 |
print_summary(results)
|
|
@@ -694,6 +701,7 @@ def main(argv=None) -> int:
|
|
| 694 |
"shuffle": args.shuffle,
|
| 695 |
"seed": args.seed,
|
| 696 |
"workers": args.workers,
|
|
|
|
| 697 |
"n_errors": sum(1 for r in results if r.error),
|
| 698 |
}
|
| 699 |
|
|
|
|
| 170 |
max_tokens: int,
|
| 171 |
verbose: bool,
|
| 172 |
print_lock: threading.Lock,
|
| 173 |
+
min_why: Optional[str] = None,
|
| 174 |
) -> SampleResult:
|
| 175 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 176 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
|
|
|
| 251 |
temperature=temperature,
|
| 252 |
max_tokens=max_tokens,
|
| 253 |
return_metadata=True,
|
| 254 |
+
min_why=min_why,
|
| 255 |
)
|
| 256 |
result.stage3_time = time.time() - t0
|
| 257 |
|
|
|
|
| 353 |
shuffle: bool = True,
|
| 354 |
seed: int = 42,
|
| 355 |
workers: int = 1,
|
| 356 |
+
min_why: Optional[str] = None,
|
| 357 |
) -> List[SampleResult]:
|
| 358 |
|
| 359 |
# Load eval samples
|
|
|
|
| 403 |
sample, i, total,
|
| 404 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 405 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 406 |
+
print_lock, min_why,
|
| 407 |
)
|
| 408 |
results.append(result)
|
| 409 |
else:
|
|
|
|
| 650 |
help="Random seed for shuffle (default: 42)")
|
| 651 |
ap.add_argument("--workers", "-w", type=int, default=4,
|
| 652 |
help="Number of parallel workers (default: 4, use 1 for sequential)")
|
| 653 |
+
ap.add_argument("--min-why", default=None,
|
| 654 |
+
choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"],
|
| 655 |
+
help="Minimum 'why' confidence to keep (e.g. 'explicit' keeps only explicit matches)")
|
| 656 |
|
| 657 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 658 |
|
|
|
|
| 670 |
shuffle=args.shuffle,
|
| 671 |
seed=args.seed,
|
| 672 |
workers=args.workers,
|
| 673 |
+
min_why=args.min_why,
|
| 674 |
)
|
| 675 |
|
| 676 |
print_summary(results)
|
|
|
|
| 701 |
"shuffle": args.shuffle,
|
| 702 |
"seed": args.seed,
|
| 703 |
"workers": args.workers,
|
| 704 |
+
"min_why": args.min_why,
|
| 705 |
"n_errors": sum(1 for r in results if r.error),
|
| 706 |
}
|
| 707 |
|