Food Desert commited on
Commit
dacabb8
·
1 Parent(s): 827e786

Tune probe selection budget and optional sequential species pass

Browse files
Files changed (1) hide show
  1. psq_rag/llm/select.py +111 -29
psq_rag/llm/select.py CHANGED
@@ -1563,10 +1563,27 @@ def _split_probe_tags_by_bundle(
1563
  return out
1564
 
1565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1566
  def llm_infer_probe_tags(
1567
  query_text: str,
1568
  log=None,
1569
- *,
1570
  temperature: float = 0.0,
1571
  max_tokens: int = 512,
1572
  retries: int = 2,
@@ -1615,36 +1632,101 @@ def llm_infer_probe_tags(
1615
  except Exception:
1616
  split_calls = 1
1617
 
1618
- probe_chunks = _split_probe_tags_by_bundle(probe_tags, split_calls, log=log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1619
  selected: List[str] = []
1620
- seen: Set[str] = set()
1621
- for chunk_idx, chunk_tags in enumerate(probe_chunks, start=1):
1622
- if split_calls > 1 and log:
1623
- log(f"Stage3p: call {chunk_idx}/{len(probe_chunks)} with {len(chunk_tags)} probe tags")
1624
- out = llm_select_indices(
1625
- query_text=query_text,
1626
- candidates=chunk_tags,
1627
- max_pick=len(chunk_tags),
1628
- log=log,
1629
- retries=retries,
1630
- mode="single_shot",
1631
- chunk_size=max(1, len(chunk_tags)),
1632
- per_phrase_k=max(1, len(chunk_tags)),
1633
- temperature=temperature,
1634
- max_tokens=max_tokens,
1635
- return_metadata=False,
1636
- return_diagnostics=False,
1637
- min_why=None,
1638
- candidate_display=candidate_display,
1639
- user_template=_get_probe_user_template(),
1640
- model_override=probe_model_override,
1641
- )
1642
- for i in out:
1643
- if 0 <= i < len(chunk_tags):
1644
- t = chunk_tags[i]
1645
- if t not in seen:
1646
- seen.add(t)
1647
  selected.append(t)
 
 
 
1648
 
1649
  selected = _apply_species_anchor_mapping(selected, query_text=query_text, log=log)
1650
 
 
1563
  return out
1564
 
1565
 
1566
+ def _species_cap_from_count_tags(selected_tags: Sequence[str]) -> Optional[int]:
1567
+ """Derive species max-picks cap from count-like probe tags."""
1568
+ s = set(selected_tags)
1569
+ if "zero_pictured" in s:
1570
+ return 0
1571
+ if "solo" in s:
1572
+ return 1
1573
+ if "duo" in s:
1574
+ return 2
1575
+ if "trio" in s:
1576
+ return 3
1577
+ if "group" in s:
1578
+ # Conservative finite cap for open-ended "group".
1579
+ return 4
1580
+ return None
1581
+
1582
+
1583
  def llm_infer_probe_tags(
1584
  query_text: str,
1585
  log=None,
1586
+ *,
1587
  temperature: float = 0.0,
1588
  max_tokens: int = 512,
1589
  retries: int = 2,
 
1632
  except Exception:
1633
  split_calls = 1
1634
 
1635
+ try:
1636
+ # Default probe cap is 2.
1637
+ # Rationale (caption-evident n=10 evals, 2026-03-12):
1638
+ # - max1 reduced pollution but hurt recall too much
1639
+ # - max3 increased pollution and reduced overall F1
1640
+ # - max2 provided the best observed precision/recall tradeoff and highest overall F1
1641
+ max_pick_override = int((os.environ.get("PSQ_PROBE_MAX_PICK_OVERRIDE", "2") or "2").strip())
1642
+ except Exception:
1643
+ max_pick_override = 0
1644
+ if max_pick_override < 0:
1645
+ max_pick_override = 0
1646
+
1647
+ sequential_species = (os.environ.get("PSQ_PROBE_SEQUENTIAL_SPECIES", "0") or "0").strip().lower() in {
1648
+ "1", "true", "yes", "on"
1649
+ }
1650
+
1651
+ def _call_chunks(
1652
+ chunks: Sequence[Sequence[str]],
1653
+ *,
1654
+ label: str,
1655
+ max_pick_cap: Optional[int] = None,
1656
+ ) -> List[str]:
1657
+ out_tags: List[str] = []
1658
+ out_seen: Set[str] = set()
1659
+ for chunk_idx, chunk_tags_seq in enumerate(chunks, start=1):
1660
+ chunk_tags = list(chunk_tags_seq)
1661
+ if not chunk_tags:
1662
+ continue
1663
+ if len(chunks) > 1 and log:
1664
+ log(f"Stage3p {label}: call {chunk_idx}/{len(chunks)} with {len(chunk_tags)} probe tags")
1665
+
1666
+ per_call_budget = len(chunk_tags)
1667
+ if max_pick_override > 0:
1668
+ per_call_budget = min(per_call_budget, max_pick_override)
1669
+ if max_pick_cap is not None:
1670
+ per_call_budget = min(per_call_budget, max(0, int(max_pick_cap)))
1671
+ if per_call_budget <= 0:
1672
+ if log:
1673
+ log(f"Stage3p {label}: skipping call with budget=0")
1674
+ continue
1675
+
1676
+ out = llm_select_indices(
1677
+ query_text=query_text,
1678
+ candidates=chunk_tags,
1679
+ max_pick=per_call_budget,
1680
+ log=log,
1681
+ retries=retries,
1682
+ mode="single_shot",
1683
+ chunk_size=max(1, len(chunk_tags)),
1684
+ per_phrase_k=max(1, per_call_budget),
1685
+ temperature=temperature,
1686
+ max_tokens=max_tokens,
1687
+ return_metadata=False,
1688
+ return_diagnostics=False,
1689
+ min_why=None,
1690
+ candidate_display=candidate_display,
1691
+ user_template=_get_probe_user_template(),
1692
+ model_override=probe_model_override,
1693
+ )
1694
+ for i in out:
1695
+ if 0 <= i < len(chunk_tags):
1696
+ t = chunk_tags[i]
1697
+ if t not in out_seen:
1698
+ out_seen.add(t)
1699
+ out_tags.append(t)
1700
+ return out_tags
1701
+
1702
  selected: List[str] = []
1703
+
1704
+ if sequential_species:
1705
+ bundle_by_tag = _probe_bundle_map(log=log)
1706
+ species_tags = [t for t in probe_tags if bundle_by_tag.get(t, "") == "species_taxonomy"]
1707
+ non_species_tags = [t for t in probe_tags if t not in set(species_tags)]
1708
+ if log:
1709
+ log(
1710
+ "Stage3p: sequential species mode on "
1711
+ f"(non_species={len(non_species_tags)}, species={len(species_tags)})"
1712
+ )
1713
+
1714
+ non_species_chunks = _split_probe_tags_by_bundle(non_species_tags, split_calls, log=log) if non_species_tags else []
1715
+ selected_non_species = _call_chunks(non_species_chunks, label="non_species")
1716
+ selected.extend(selected_non_species)
1717
+
1718
+ if species_tags:
1719
+ species_cap = _species_cap_from_count_tags(selected_non_species)
1720
+ if log:
1721
+ log(f"Stage3p: species cap from count tags = {species_cap!r}")
1722
+ species_chunks = [species_tags]
1723
+ selected_species = _call_chunks(species_chunks, label="species", max_pick_cap=species_cap)
1724
+ for t in selected_species:
1725
+ if t not in selected:
 
 
 
 
1726
  selected.append(t)
1727
+ else:
1728
+ probe_chunks = _split_probe_tags_by_bundle(probe_tags, split_calls, log=log)
1729
+ selected = _call_chunks(probe_chunks, label="all")
1730
 
1731
  selected = _apply_species_anchor_mapping(selected, query_text=query_text, log=log)
1732