Claude
Add --min-why threshold to filter Stage 3 selections by confidence level
09a248d
Raw
History Blame
26.5 kB
# psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
"fan_character",
"background_character",
"unnamed_character",
"unknown_character",
"anonymous_character",
"viewer",
"original_character",
})
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
# Ordinal rank: lower = more confident. Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
"explicit": 0,
"strong_implied": 1,
"weak_implied": 2,
"style_or_meta": 3,
"other": 4,
}
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
"explicit": 0.90,
"strong_implied": 0.70,
"weak_implied": 0.45,
"style_or_meta": 0.35,
"other": 0.25,
}
# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.
Select the tags that correspond to content that would be visible or depicted in the described image.
The list contains only valid tags; many of them are irrelevant to the image.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
...
]
}}
Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").
Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""
ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.
These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"explicit\"}},
...
]
}}
Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""
USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}
CANDIDATES (choose by index only):
{candidate_lines}
Select up to {per_call_budget} indices. Output fewer if uncertain.
"""
@dataclass(frozen=True)
class Selected:
i: int
tag: str # canonical tag (underscore form)
why: str
score: float
WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
class Stage3SelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the candidate list.")
why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")
class Stage3SelectionResponse(BaseModel):
selections: List[Stage3SelectionItem] = Field(default_factory=list)
def _build_response_format() -> Dict[str, Any]:
# Strict JSON Schema structured output.
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
"why": {"type": "string", "enum": WHY_ENUM},
},
"required": ["i", "why"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "stage3_selection",
"strict": True,
"schema": schema,
},
}
def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY is not set.\n"
"Set it in your environment before running Stage 3."
)
api_key = SecretStr(cast(str, api_key))
model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
headers: Dict[str, str] = {}
if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
headers["HTTP-Referer"] = referer
if title := os.getenv("OPENROUTER_X_TITLE"):
headers["X-Title"] = title
# OpenRouter OpenAI-compatible endpoint.
return ChatOpenAI(
model=model,
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_tokens,
default_headers=headers,
# Provider-specific request body fields (OpenAI-compatible).
# Response Healing plugin reduces malformed-JSON failures (syntax only).
extra_body={
"response_format": response_format,
"plugins": [{"id": "response-healing"}],
},
)
def _phrase_key_for_candidate(c: Candidate) -> str:
# Deterministic "primary phrase" for grouping.
if c.sources:
return sorted(c.sources)[0]
return ""
def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
"""Round-robin interleave by primary source phrase.
NOTE: counts are used only for ordering; they are NOT shown to the LLM.
"""
groups: Dict[str, List[Candidate]] = {}
for c in cands:
k = _phrase_key_for_candidate(c)
groups.setdefault(k, []).append(c)
for k in groups:
groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)
keys = sorted(groups.keys())
out: List[Candidate] = []
idx = 0
while True:
progressed = False
for k in keys:
if idx < len(groups[k]):
out.append(groups[k][idx])
progressed = True
if not progressed:
break
idx += 1
return out
def _display_tag(tag: str) -> str:
# Display tags with spaces for the LLM, but keep canonical underscores internally.
return tag.replace("_", " ")
def _format_candidates_local(
cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
lines: List[str] = []
idx_to_tag: Dict[int, str] = {}
idx_to_candidate: Dict[int, Candidate] = {}
for j, c in enumerate(cands, start=1):
idx_to_tag[j] = c.tag
idx_to_candidate[j] = c
lines.append(f"{j}. {_display_tag(c.tag)}")
return "\n".join(lines), idx_to_tag, idx_to_candidate
def _phrases_in_call(cands: Sequence[Candidate]) -> int:
s = set()
for c in cands:
for src in c.sources:
s.add(src)
return len(s)
def _parse_validate_map(
parsed: Any,
idx_to_tag: Dict[int, str],
per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
diag = {
"parse_ok": isinstance(parsed, dict),
"invalid_items": 0,
"oob_indices": 0,
"dupe_indices": 0,
"kept": 0,
}
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
diag["parse_ok"] = isinstance(parsed, dict)
if not isinstance(parsed, dict):
return [], diag
selections = parsed.get("selections", [])
if not isinstance(selections, list):
diag["parse_ok"] = False
return [], diag
out: List[Selected] = []
seen_i = set()
for item in selections:
if len(out) >= per_call_budget:
break
if not isinstance(item, dict):
diag["invalid_items"] += 1
continue
i = item.get("i")
why = item.get("why")
if isinstance(i, bool) or not isinstance(i, int):
diag["invalid_items"] += 1
continue
if i in seen_i:
diag["dupe_indices"] += 1
continue
if i not in idx_to_tag:
diag["oob_indices"] += 1
continue
if not isinstance(why, str) or why not in WHY_ENUM:
diag["invalid_items"] += 1
continue
seen_i.add(i)
tag = idx_to_tag[i]
out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))
diag["kept"] = len(out)
return out, diag
def _split_candidates_by_type(
candidates: List[Candidate],
log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
"""Split candidates into general vs entity (character only) lists.
Returns:
(general_list, entity_list) where each item is (original_index, candidate)
Tag types:
- General: 0 (general), 1 (artist), 5 (species), 7 (meta)
- Entity: 4 (character) only
- Filtered: 3 (copyright) - too broad for image generation
"""
general_with_idx: List[Tuple[int, Candidate]] = []
entity_with_idx: List[Tuple[int, Candidate]] = []
unknown_count = 0
copyright_count = 0
generic_char_count = 0
for idx, cand in enumerate(candidates):
type_name = get_tag_type_name(cand.tag)
if type_name == "character":
if cand.tag in _GENERIC_CHARACTER_TAGS:
# Route generic character-category tags to general selection
general_with_idx.append((idx, cand))
generic_char_count += 1
else:
entity_with_idx.append((idx, cand))
elif type_name == "copyright":
# Filter out copyright/series tags - too broad for image generation
copyright_count += 1
elif type_name in ("general", "artist", "species", "meta"):
general_with_idx.append((idx, cand))
else:
# Unknown or None - treat as general by default
general_with_idx.append((idx, cand))
unknown_count += 1
if log:
log(
f"Stage3 split: "
f"general={len(general_with_idx)} "
f"entity={len(entity_with_idx)} "
f"copyright_filtered={copyright_count} "
f"generic_char_to_general={generic_char_count} "
f"unknown_type={unknown_count}"
)
return general_with_idx, entity_with_idx
# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")
def _normalize_for_matching(text: str) -> str:
"""Lowercase, replace underscores with spaces, strip series suffixes."""
text = text.lower().strip()
text = _SERIES_SUFFIX_RE.sub("", text)
text = text.replace("_", " ")
return text
def _query_words(query: str) -> Set[str]:
"""Extract individual words from the user query for matching."""
return set(_normalize_for_matching(query).split())
def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
fuzzy_threshold: int = 85) -> bool:
"""Check if an alias matches the user query.
Matching logic:
1. Exact substring: alias appears as a substring of the query
2. Word subset: all words in the alias appear in the query words
3. Fuzzy: alias is close to a word in the query (handles typos)
"""
# Exact substring match
if alias_norm in query_norm:
return True
alias_words = alias_norm.split()
if not alias_words:
return False
# Word subset match: all alias words must appear in query
if all(w in query_words for w in alias_words):
return True
# For single-word aliases, try fuzzy matching against each query word
if len(alias_words) == 1:
for qw in query_words:
if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
return True
# For multi-word aliases, try fuzzy partial ratio against whole query
if len(alias_words) > 1:
if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
return True
return False
def _character_matches_via_aliases(
tag: str,
query: str,
tag2aliases: Dict[str, List[str]],
query_words: Set[str],
query_norm: str,
fuzzy_threshold: int = 85,
) -> bool:
"""Check if a character tag matches the user query via its aliases.
For a character tag to match:
- The tag name itself (normalized) must match, OR
- At least one of its registered aliases must match.
Empty aliases list means no known aliases; still check the tag name itself.
"""
# Check the tag name itself
tag_norm = _normalize_for_matching(tag)
if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
return True
# Check all registered aliases
aliases = tag2aliases.get(tag, [])
for alias in aliases:
alias_norm = _normalize_for_matching(alias)
if not alias_norm:
continue
if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
return True
return False
def llm_select_indices(
query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION
candidates: Union[
Sequence[Candidate],
Sequence[str],
Sequence[Tuple[str, float]],
],
max_pick: int, # legacy param; applied after union + ordering (optional)
log,
retries: int = 2,
*,
mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union"
chunk_size: int = 60,
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
temperature: float = 0.0,
max_tokens: int = 512,
return_metadata: bool = False,
min_why: Optional[str] = None,
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
"""Return indices into the ORIGINAL candidates list (legacy interface).
min_why: if set, only keep tags whose 'why' is at or above this confidence
level. E.g. min_why="explicit" keeps only explicit matches;
min_why="strong_implied" keeps explicit + strong_implied.
This implementation uses LangChain ONLY.
NOTE: query_text is treated as the image description (original prompt).
"""
image_description = query_text
# Normalize candidates:
# - preferred: List[Candidate]
# - legacy: List[(tag, sim)] (count/sources unavailable)
norm: List[Candidate] = []
tag_to_first_index: Dict[str, int] = {}
branch = "empty"
cand0_type = type(candidates[0]).__name__ if candidates else "none"
if candidates and isinstance(candidates[0], Candidate):
branch = "candidate"
typed_candidates = cast(Sequence[Candidate], candidates)
for idx, c in enumerate(typed_candidates):
if c.tag not in tag_to_first_index:
tag_to_first_index[c.tag] = idx
norm.append(c)
elif candidates and isinstance(candidates[0], str):
branch = "string"
typed_candidates = cast(Sequence[str], candidates)
for idx, tag in enumerate(typed_candidates):
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=0.0,
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
else:
if candidates:
branch = "tuple"
typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
for idx, row in enumerate(typed_candidates):
if not isinstance(row, (list, tuple)) or len(row) < 2:
raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
tag, sim = row[0], row[1]
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=float(sim),
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
if log:
if norm:
log(
"Stage3 input: "
f"type0={cand0_type} "
f"branch={branch} "
f"norm0_score={norm[0].score_combined!r} "
f"norm0_sources_empty={not bool(norm[0].sources)}"
)
else:
log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")
if mode not in ("single_shot", "chunked_map_union"):
raise ValueError(f"Invalid mode: {mode}")
response_format = _build_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)
# Global union: tag -> best (score, why)
best: Dict[str, Tuple[float, str]] = {}
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
# Create chain with the provided system template
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
ordered = _interleave_round_robin(call_cands)
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
N_local = len(idx_to_tag)
phrases = _phrases_in_call(call_cands)
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
summary_logged = False
if log:
log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
if phrases > 0:
distinct_phrases = sorted({src for c in call_cands for src in c.sources})
log(
f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
f"phrases={', '.join(distinct_phrases)}"
)
# Invoke LangChain chain (templating fills {N} and other vars)
for att in range(retries + 1):
try:
if log:
log(
f"Stage3 {label}: "
f"model={model_name} "
f"N={N_local} "
f"phrases={phrases} "
f"per_call_budget={per_call_budget} "
f"response_healing=on"
)
parsed = chain.invoke(
{
"N": N_local,
"image_description": image_description,
"candidate_lines": candidate_lines,
"per_call_budget": per_call_budget,
}
)
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
if log:
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
if not summary_logged and (selected or att == retries):
log(
f"Stage3 {label}: summary "
f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
)
summary_logged = True
if selected:
lines = [
f"Stage3 {label} selections:",
*[
(
f' - i={s.i} tag="{s.tag}" '
f"why={s.why} score={s.score:.2f} "
f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
)
for s in selected
],
]
log("\n".join(lines))
else:
log(f"Stage3 {label} selections: (none)")
if selected:
for s in selected:
prev = best.get(s.tag)
if prev is None or s.score > prev[0]:
best[s.tag] = (s.score, s.why)
return
except Exception as e:
if log:
log(f"Stage3 {label}: attempt {att+1} error: {e}")
if log:
log(f"Stage3 {label}: gave up after {retries+1} attempts")
# Split candidates by type (general vs entity)
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
# Extract just the candidates for LLM calls
general_cands = [cand for _, cand in general_with_idx]
entity_cands = [cand for _, cand in entity_with_idx]
# Process general candidates (attributes, actions, species, etc.)
if general_cands:
if mode == "single_shot":
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
else:
for start in range(0, len(general_cands), chunk_size):
run_call(
general_cands[start:start + chunk_size],
f"general_chunk_{start//chunk_size}",
SELECT_SYSTEM_TEMPLATE
)
# Process entity candidates (characters only) with alias-based pre-filtering
if entity_cands:
tag2aliases = get_tag2aliases()
qwords = _query_words(image_description)
qnorm = _normalize_for_matching(image_description)
filtered_entity_cands: List[Candidate] = []
filtered_out: List[str] = []
for cand in entity_cands:
if _character_matches_via_aliases(
cand.tag, image_description, tag2aliases, qwords, qnorm
):
filtered_entity_cands.append(cand)
else:
filtered_out.append(cand.tag)
if log:
log(
f"Stage3 entity alias filter: "
f"before={len(entity_cands)} "
f"after={len(filtered_entity_cands)} "
f"removed={len(filtered_out)}"
)
if filtered_out:
log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")
if filtered_entity_cands:
if mode == "single_shot":
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
else:
for start in range(0, len(filtered_entity_cands), chunk_size):
run_call(
filtered_entity_cands[start:start + chunk_size],
f"entity_chunk_{start//chunk_size}",
ENTITY_SYSTEM_TEMPLATE
)
# Apply why threshold: drop tags below the minimum confidence level.
if min_why is not None:
max_rank = WHY_RANK.get(min_why, 4)
before = len(best)
best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
if log:
log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
f"before={before} after={len(best)} dropped={before - len(best)}")
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
# Legacy cap: apply AFTER union + ordering.
if isinstance(max_pick, int) and max_pick > 0:
ordered_tags = ordered_tags[:max_pick]
# Map back to original indices
out_idx: List[int] = []
tag_why: Dict[str, str] = {}
for t in ordered_tags:
if t in tag_to_first_index:
out_idx.append(tag_to_first_index[t])
tag_why[t] = best[t][1] # why string
if return_metadata:
return out_idx, tag_why
return out_idx