Spaces:
Running
Running
Food Desert commited on
Commit ·
dacabb8
1
Parent(s): 827e786
Tune probe selection budget and optional sequential species pass
Browse files- psq_rag/llm/select.py +111 -29
psq_rag/llm/select.py
CHANGED
|
@@ -1563,10 +1563,27 @@ def _split_probe_tags_by_bundle(
|
|
| 1563 |
return out
|
| 1564 |
|
| 1565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
def llm_infer_probe_tags(
|
| 1567 |
query_text: str,
|
| 1568 |
log=None,
|
| 1569 |
-
*,
|
| 1570 |
temperature: float = 0.0,
|
| 1571 |
max_tokens: int = 512,
|
| 1572 |
retries: int = 2,
|
|
@@ -1615,36 +1632,101 @@ def llm_infer_probe_tags(
|
|
| 1615 |
except Exception:
|
| 1616 |
split_calls = 1
|
| 1617 |
|
| 1618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1619 |
selected: List[str] = []
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
-
|
| 1633 |
-
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
|
| 1638 |
-
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
|
| 1642 |
-
|
| 1643 |
-
if 0 <= i < len(chunk_tags):
|
| 1644 |
-
t = chunk_tags[i]
|
| 1645 |
-
if t not in seen:
|
| 1646 |
-
seen.add(t)
|
| 1647 |
selected.append(t)
|
|
|
|
|
|
|
|
|
|
| 1648 |
|
| 1649 |
selected = _apply_species_anchor_mapping(selected, query_text=query_text, log=log)
|
| 1650 |
|
|
|
|
| 1563 |
return out
|
| 1564 |
|
| 1565 |
|
| 1566 |
+
def _species_cap_from_count_tags(selected_tags: Sequence[str]) -> Optional[int]:
|
| 1567 |
+
"""Derive species max-picks cap from count-like probe tags."""
|
| 1568 |
+
s = set(selected_tags)
|
| 1569 |
+
if "zero_pictured" in s:
|
| 1570 |
+
return 0
|
| 1571 |
+
if "solo" in s:
|
| 1572 |
+
return 1
|
| 1573 |
+
if "duo" in s:
|
| 1574 |
+
return 2
|
| 1575 |
+
if "trio" in s:
|
| 1576 |
+
return 3
|
| 1577 |
+
if "group" in s:
|
| 1578 |
+
# Conservative finite cap for open-ended "group".
|
| 1579 |
+
return 4
|
| 1580 |
+
return None
|
| 1581 |
+
|
| 1582 |
+
|
| 1583 |
def llm_infer_probe_tags(
|
| 1584 |
query_text: str,
|
| 1585 |
log=None,
|
| 1586 |
+
*,
|
| 1587 |
temperature: float = 0.0,
|
| 1588 |
max_tokens: int = 512,
|
| 1589 |
retries: int = 2,
|
|
|
|
| 1632 |
except Exception:
|
| 1633 |
split_calls = 1
|
| 1634 |
|
| 1635 |
+
try:
|
| 1636 |
+
# Default probe cap is 2.
|
| 1637 |
+
# Rationale (caption-evident n=10 evals, 2026-03-12):
|
| 1638 |
+
# - max1 reduced pollution but hurt recall too much
|
| 1639 |
+
# - max3 increased pollution and reduced overall F1
|
| 1640 |
+
# - max2 provided the best observed precision/recall tradeoff and highest overall F1
|
| 1641 |
+
max_pick_override = int((os.environ.get("PSQ_PROBE_MAX_PICK_OVERRIDE", "2") or "2").strip())
|
| 1642 |
+
except Exception:
|
| 1643 |
+
max_pick_override = 0
|
| 1644 |
+
if max_pick_override < 0:
|
| 1645 |
+
max_pick_override = 0
|
| 1646 |
+
|
| 1647 |
+
sequential_species = (os.environ.get("PSQ_PROBE_SEQUENTIAL_SPECIES", "0") or "0").strip().lower() in {
|
| 1648 |
+
"1", "true", "yes", "on"
|
| 1649 |
+
}
|
| 1650 |
+
|
| 1651 |
+
def _call_chunks(
|
| 1652 |
+
chunks: Sequence[Sequence[str]],
|
| 1653 |
+
*,
|
| 1654 |
+
label: str,
|
| 1655 |
+
max_pick_cap: Optional[int] = None,
|
| 1656 |
+
) -> List[str]:
|
| 1657 |
+
out_tags: List[str] = []
|
| 1658 |
+
out_seen: Set[str] = set()
|
| 1659 |
+
for chunk_idx, chunk_tags_seq in enumerate(chunks, start=1):
|
| 1660 |
+
chunk_tags = list(chunk_tags_seq)
|
| 1661 |
+
if not chunk_tags:
|
| 1662 |
+
continue
|
| 1663 |
+
if len(chunks) > 1 and log:
|
| 1664 |
+
log(f"Stage3p {label}: call {chunk_idx}/{len(chunks)} with {len(chunk_tags)} probe tags")
|
| 1665 |
+
|
| 1666 |
+
per_call_budget = len(chunk_tags)
|
| 1667 |
+
if max_pick_override > 0:
|
| 1668 |
+
per_call_budget = min(per_call_budget, max_pick_override)
|
| 1669 |
+
if max_pick_cap is not None:
|
| 1670 |
+
per_call_budget = min(per_call_budget, max(0, int(max_pick_cap)))
|
| 1671 |
+
if per_call_budget <= 0:
|
| 1672 |
+
if log:
|
| 1673 |
+
log(f"Stage3p {label}: skipping call with budget=0")
|
| 1674 |
+
continue
|
| 1675 |
+
|
| 1676 |
+
out = llm_select_indices(
|
| 1677 |
+
query_text=query_text,
|
| 1678 |
+
candidates=chunk_tags,
|
| 1679 |
+
max_pick=per_call_budget,
|
| 1680 |
+
log=log,
|
| 1681 |
+
retries=retries,
|
| 1682 |
+
mode="single_shot",
|
| 1683 |
+
chunk_size=max(1, len(chunk_tags)),
|
| 1684 |
+
per_phrase_k=max(1, per_call_budget),
|
| 1685 |
+
temperature=temperature,
|
| 1686 |
+
max_tokens=max_tokens,
|
| 1687 |
+
return_metadata=False,
|
| 1688 |
+
return_diagnostics=False,
|
| 1689 |
+
min_why=None,
|
| 1690 |
+
candidate_display=candidate_display,
|
| 1691 |
+
user_template=_get_probe_user_template(),
|
| 1692 |
+
model_override=probe_model_override,
|
| 1693 |
+
)
|
| 1694 |
+
for i in out:
|
| 1695 |
+
if 0 <= i < len(chunk_tags):
|
| 1696 |
+
t = chunk_tags[i]
|
| 1697 |
+
if t not in out_seen:
|
| 1698 |
+
out_seen.add(t)
|
| 1699 |
+
out_tags.append(t)
|
| 1700 |
+
return out_tags
|
| 1701 |
+
|
| 1702 |
selected: List[str] = []
|
| 1703 |
+
|
| 1704 |
+
if sequential_species:
|
| 1705 |
+
bundle_by_tag = _probe_bundle_map(log=log)
|
| 1706 |
+
species_tags = [t for t in probe_tags if bundle_by_tag.get(t, "") == "species_taxonomy"]
|
| 1707 |
+
non_species_tags = [t for t in probe_tags if t not in set(species_tags)]
|
| 1708 |
+
if log:
|
| 1709 |
+
log(
|
| 1710 |
+
"Stage3p: sequential species mode on "
|
| 1711 |
+
f"(non_species={len(non_species_tags)}, species={len(species_tags)})"
|
| 1712 |
+
)
|
| 1713 |
+
|
| 1714 |
+
non_species_chunks = _split_probe_tags_by_bundle(non_species_tags, split_calls, log=log) if non_species_tags else []
|
| 1715 |
+
selected_non_species = _call_chunks(non_species_chunks, label="non_species")
|
| 1716 |
+
selected.extend(selected_non_species)
|
| 1717 |
+
|
| 1718 |
+
if species_tags:
|
| 1719 |
+
species_cap = _species_cap_from_count_tags(selected_non_species)
|
| 1720 |
+
if log:
|
| 1721 |
+
log(f"Stage3p: species cap from count tags = {species_cap!r}")
|
| 1722 |
+
species_chunks = [species_tags]
|
| 1723 |
+
selected_species = _call_chunks(species_chunks, label="species", max_pick_cap=species_cap)
|
| 1724 |
+
for t in selected_species:
|
| 1725 |
+
if t not in selected:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1726 |
selected.append(t)
|
| 1727 |
+
else:
|
| 1728 |
+
probe_chunks = _split_probe_tags_by_bundle(probe_tags, split_calls, log=log)
|
| 1729 |
+
selected = _call_chunks(probe_chunks, label="all")
|
| 1730 |
|
| 1731 |
selected = _apply_species_anchor_mapping(selected, query_text=query_text, log=log)
|
| 1732 |
|