# psq_rag/llm/select.py # Stage 3: Closed-Set Selection (LangChain-only implementation) # # This module intentionally uses LangChain for: # - prompt templating (including {N}) # - LLM call orchestration # - JSON parsing # # There is NO fallback path. If LangChain dependencies are missing, this module # should fail loudly so you install them. import os import re from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import PydanticOutputParser from pydantic import BaseModel, Field, SecretStr from rapidfuzz import fuzz from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources) from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases # Character-typed tags that are generic categories, not actual named characters. # These leak through the alias filter because they match common words in captions. # They are excluded from the entity pipeline and instead routed to general selection. _GENERIC_CHARACTER_TAGS = frozenset({ "fan_character", "background_character", "unnamed_character", "unknown_character", "anonymous_character", "viewer", "original_character", }) WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] # Ordinal rank: lower = more confident. Used for threshold filtering. WHY_RANK: Dict[str, int] = { "explicit": 0, "strong_implied": 1, "weak_implied": 2, "style_or_meta": 3, "other": 4, } # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug. WHY_TO_SCORE: Dict[str, float] = { "explicit": 0.90, "strong_implied": 0.70, "weak_implied": 0.45, "style_or_meta": 0.35, "other": 0.25, } # IMPORTANT ABOUT TEMPLATING: # - This string is rendered by LangChain's f-string template engine. # - Literal JSON braces must be escaped as {{ and }}. # - {N} is a real template variable and MUST be provided. SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags. Select the tags that correspond to content that would be visible or depicted in the described image. The list contains only valid tags; many of them are irrelevant to the image. Return JSON ONLY matching this schema: {{ \"selections\": [ {{\"i\": , \"why\": \"\"}}, ... ] }} Rules: - Choose ONLY from indices 1..{N}. - Do NOT output tag text. - Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\". - Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\"). Define \"why\" as: - explicit: directly stated in the image description - strong_implied: very likely given the description, even if not literally stated - weak_implied: plausible but not strongly supported by the description - style_or_meta: stylistic or presentation-related tags only if clearly indicated - other: fallback category; use sparingly """ ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags. These character tags have already been pre-filtered to only include characters whose names (or known aliases) appear in the image description. Your job is to confirm which of these pre-filtered candidates are the correct match for the character mentioned by the user. Return JSON ONLY matching this schema: {{ \"selections\": [ {{\"i\": , \"why\": \"explicit\"}}, ... ] }} Rules for character selection: - Choose ONLY from indices 1..{N}. - Do NOT output tag text. - Always use \"why\": \"explicit\" for all selections. - Select the tag that best represents the character as described. - If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"), select that specific variant tag. - If the user described only the base character (e.g. just \"pikachu\"), select only the base/default tag, NOT costume or variant tags. - When uncertain between variants, prefer the simplest/most general tag. """ USER_TEMPLATE = """IMAGE DESCRIPTION: {image_description} CANDIDATES (choose by index only): {candidate_lines} Select up to {per_call_budget} indices. Output fewer if uncertain. """ @dataclass(frozen=True) class Selected: i: int tag: str # canonical tag (underscore form) why: str score: float WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] class Stage3SelectionItem(BaseModel): i: int = Field(..., description="1-based index into the candidate list.") why: WhyLiteral = Field(..., description="Rationale code from the allowed set.") class Stage3SelectionResponse(BaseModel): selections: List[Stage3SelectionItem] = Field(default_factory=list) def _build_response_format() -> Dict[str, Any]: # Strict JSON Schema structured output. schema = { "type": "object", "properties": { "selections": { "type": "array", "items": { "type": "object", "properties": { "i": {"type": "integer"}, "why": {"type": "string", "enum": WHY_ENUM}, }, "required": ["i", "why"], "additionalProperties": False, }, } }, "required": ["selections"], "additionalProperties": False, } return { "type": "json_schema", "json_schema": { "name": "stage3_selection", "strict": True, "schema": schema, }, } def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI: api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: raise RuntimeError( "OPENROUTER_API_KEY is not set.\n" "Set it in your environment before running Stage 3." ) api_key = SecretStr(cast(str, api_key)) model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") headers: Dict[str, str] = {} if referer := os.getenv("OPENROUTER_HTTP_REFERER"): headers["HTTP-Referer"] = referer if title := os.getenv("OPENROUTER_X_TITLE"): headers["X-Title"] = title # OpenRouter OpenAI-compatible endpoint. return ChatOpenAI( model=model, base_url="https://openrouter.ai/api/v1", api_key=api_key, temperature=temperature, max_completion_tokens=max_tokens, default_headers=headers, # Provider-specific request body fields (OpenAI-compatible). # Response Healing plugin reduces malformed-JSON failures (syntax only). extra_body={ "response_format": response_format, "plugins": [{"id": "response-healing"}], }, ) def _phrase_key_for_candidate(c: Candidate) -> str: # Deterministic "primary phrase" for grouping. if c.sources: return sorted(c.sources)[0] return "" def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]: """Round-robin interleave by primary source phrase. NOTE: counts are used only for ordering; they are NOT shown to the LLM. """ groups: Dict[str, List[Candidate]] = {} for c in cands: k = _phrase_key_for_candidate(c) groups.setdefault(k, []).append(c) for k in groups: groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True) keys = sorted(groups.keys()) out: List[Candidate] = [] idx = 0 while True: progressed = False for k in keys: if idx < len(groups[k]): out.append(groups[k][idx]) progressed = True if not progressed: break idx += 1 return out def _display_tag(tag: str) -> str: # Display tags with spaces for the LLM, but keep canonical underscores internally. return tag.replace("_", " ") def _format_candidates_local( cands: Sequence[Candidate], ) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]: lines: List[str] = [] idx_to_tag: Dict[int, str] = {} idx_to_candidate: Dict[int, Candidate] = {} for j, c in enumerate(cands, start=1): idx_to_tag[j] = c.tag idx_to_candidate[j] = c lines.append(f"{j}. {_display_tag(c.tag)}") return "\n".join(lines), idx_to_tag, idx_to_candidate def _phrases_in_call(cands: Sequence[Candidate]) -> int: s = set() for c in cands: for src in c.sources: s.add(src) return len(s) def _parse_validate_map( parsed: Any, idx_to_tag: Dict[int, str], per_call_budget: int, ) -> Tuple[List[Selected], Dict[str, Any]]: diag = { "parse_ok": isinstance(parsed, dict), "invalid_items": 0, "oob_indices": 0, "dupe_indices": 0, "kept": 0, } if isinstance(parsed, BaseModel): parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() diag["parse_ok"] = isinstance(parsed, dict) if not isinstance(parsed, dict): return [], diag selections = parsed.get("selections", []) if not isinstance(selections, list): diag["parse_ok"] = False return [], diag out: List[Selected] = [] seen_i = set() for item in selections: if len(out) >= per_call_budget: break if not isinstance(item, dict): diag["invalid_items"] += 1 continue i = item.get("i") why = item.get("why") if isinstance(i, bool) or not isinstance(i, int): diag["invalid_items"] += 1 continue if i in seen_i: diag["dupe_indices"] += 1 continue if i not in idx_to_tag: diag["oob_indices"] += 1 continue if not isinstance(why, str) or why not in WHY_ENUM: diag["invalid_items"] += 1 continue seen_i.add(i) tag = idx_to_tag[i] out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why])) diag["kept"] = len(out) return out, diag def _split_candidates_by_type( candidates: List[Candidate], log, ) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]: """Split candidates into general vs entity (character only) lists. Returns: (general_list, entity_list) where each item is (original_index, candidate) Tag types: - General: 0 (general), 1 (artist), 5 (species), 7 (meta) - Entity: 4 (character) only - Filtered: 3 (copyright) - too broad for image generation """ general_with_idx: List[Tuple[int, Candidate]] = [] entity_with_idx: List[Tuple[int, Candidate]] = [] unknown_count = 0 copyright_count = 0 generic_char_count = 0 for idx, cand in enumerate(candidates): type_name = get_tag_type_name(cand.tag) if type_name == "character": if cand.tag in _GENERIC_CHARACTER_TAGS: # Route generic character-category tags to general selection general_with_idx.append((idx, cand)) generic_char_count += 1 else: entity_with_idx.append((idx, cand)) elif type_name == "copyright": # Filter out copyright/series tags - too broad for image generation copyright_count += 1 elif type_name in ("general", "artist", "species", "meta"): general_with_idx.append((idx, cand)) else: # Unknown or None - treat as general by default general_with_idx.append((idx, cand)) unknown_count += 1 if log: log( f"Stage3 split: " f"general={len(general_with_idx)} " f"entity={len(entity_with_idx)} " f"copyright_filtered={copyright_count} " f"generic_char_to_general={generic_char_count} " f"unknown_type={unknown_count}" ) return general_with_idx, entity_with_idx # Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character) _SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$") def _normalize_for_matching(text: str) -> str: """Lowercase, replace underscores with spaces, strip series suffixes.""" text = text.lower().strip() text = _SERIES_SUFFIX_RE.sub("", text) text = text.replace("_", " ") return text def _query_words(query: str) -> Set[str]: """Extract individual words from the user query for matching.""" return set(_normalize_for_matching(query).split()) def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str, fuzzy_threshold: int = 85) -> bool: """Check if an alias matches the user query. Matching logic: 1. Exact substring: alias appears as a substring of the query 2. Word subset: all words in the alias appear in the query words 3. Fuzzy: alias is close to a word in the query (handles typos) """ # Exact substring match if alias_norm in query_norm: return True alias_words = alias_norm.split() if not alias_words: return False # Word subset match: all alias words must appear in query if all(w in query_words for w in alias_words): return True # For single-word aliases, try fuzzy matching against each query word if len(alias_words) == 1: for qw in query_words: if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold: return True # For multi-word aliases, try fuzzy partial ratio against whole query if len(alias_words) > 1: if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold: return True return False def _character_matches_via_aliases( tag: str, query: str, tag2aliases: Dict[str, List[str]], query_words: Set[str], query_norm: str, fuzzy_threshold: int = 85, ) -> bool: """Check if a character tag matches the user query via its aliases. For a character tag to match: - The tag name itself (normalized) must match, OR - At least one of its registered aliases must match. Empty aliases list means no known aliases; still check the tag name itself. """ # Check the tag name itself tag_norm = _normalize_for_matching(tag) if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold): return True # Check all registered aliases aliases = tag2aliases.get(tag, []) for alias in aliases: alias_norm = _normalize_for_matching(alias) if not alias_norm: continue if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold): return True return False def llm_select_indices( query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION candidates: Union[ Sequence[Candidate], Sequence[str], Sequence[Tuple[str, float]], ], max_pick: int, # legacy param; applied after union + ordering (optional) log, retries: int = 2, *, mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union" chunk_size: int = 60, per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call temperature: float = 0.0, max_tokens: int = 512, return_metadata: bool = False, min_why: Optional[str] = None, ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]: """Return indices into the ORIGINAL candidates list (legacy interface). min_why: if set, only keep tags whose 'why' is at or above this confidence level. E.g. min_why="explicit" keeps only explicit matches; min_why="strong_implied" keeps explicit + strong_implied. This implementation uses LangChain ONLY. NOTE: query_text is treated as the image description (original prompt). """ image_description = query_text # Normalize candidates: # - preferred: List[Candidate] # - legacy: List[(tag, sim)] (count/sources unavailable) norm: List[Candidate] = [] tag_to_first_index: Dict[str, int] = {} branch = "empty" cand0_type = type(candidates[0]).__name__ if candidates else "none" if candidates and isinstance(candidates[0], Candidate): branch = "candidate" typed_candidates = cast(Sequence[Candidate], candidates) for idx, c in enumerate(typed_candidates): if c.tag not in tag_to_first_index: tag_to_first_index[c.tag] = idx norm.append(c) elif candidates and isinstance(candidates[0], str): branch = "string" typed_candidates = cast(Sequence[str], candidates) for idx, tag in enumerate(typed_candidates): if tag not in tag_to_first_index: tag_to_first_index[tag] = idx norm.append( Candidate( tag=tag, score_combined=0.0, score_fasttext=None, score_context=None, count=None, sources=[], ) ) else: if candidates: branch = "tuple" typed_candidates = cast(Sequence[Tuple[str, float]], candidates) for idx, row in enumerate(typed_candidates): if not isinstance(row, (list, tuple)) or len(row) < 2: raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.") tag, sim = row[0], row[1] if tag not in tag_to_first_index: tag_to_first_index[tag] = idx norm.append( Candidate( tag=tag, score_combined=float(sim), score_fasttext=None, score_context=None, count=None, sources=[], ) ) if log: if norm: log( "Stage3 input: " f"type0={cand0_type} " f"branch={branch} " f"norm0_score={norm[0].score_combined!r} " f"norm0_sources_empty={not bool(norm[0].sources)}" ) else: log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)") if mode not in ("single_shot", "chunked_map_union"): raise ValueError(f"Invalid mode: {mode}") response_format = _build_response_format() llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format) model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse) # Global union: tag -> best (score, why) best: Dict[str, Tuple[float, str]] = {} def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None: # Create chain with the provided system template prompt = ChatPromptTemplate.from_messages( [ ("system", system_template), ("human", USER_TEMPLATE), ], template_format="f-string", ) chain = prompt | llm | parser ordered = _interleave_round_robin(call_cands) candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered) N_local = len(idx_to_tag) phrases = _phrases_in_call(call_cands) per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k summary_logged = False if log: log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}") if phrases > 0: distinct_phrases = sorted({src for c in call_cands for src in c.sources}) log( f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} " f"phrases={', '.join(distinct_phrases)}" ) # Invoke LangChain chain (templating fills {N} and other vars) for att in range(retries + 1): try: if log: log( f"Stage3 {label}: " f"model={model_name} " f"N={N_local} " f"phrases={phrases} " f"per_call_budget={per_call_budget} " f"response_healing=on" ) parsed = chain.invoke( { "N": N_local, "image_description": image_description, "candidate_lines": candidate_lines, "per_call_budget": per_call_budget, } ) selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget) if log: log(f"Stage3 {label}: attempt {att+1} diag={diag}") if not summary_logged and (selected or att == retries): log( f"Stage3 {label}: summary " f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}" ) summary_logged = True if selected: lines = [ f"Stage3 {label} selections:", *[ ( f' - i={s.i} tag="{s.tag}" ' f"why={s.why} score={s.score:.2f} " f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}" ) for s in selected ], ] log("\n".join(lines)) else: log(f"Stage3 {label} selections: (none)") if selected: for s in selected: prev = best.get(s.tag) if prev is None or s.score > prev[0]: best[s.tag] = (s.score, s.why) return except Exception as e: if log: log(f"Stage3 {label}: attempt {att+1} error: {e}") if log: log(f"Stage3 {label}: gave up after {retries+1} attempts") # Split candidates by type (general vs entity) general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log) # Extract just the candidates for LLM calls general_cands = [cand for _, cand in general_with_idx] entity_cands = [cand for _, cand in entity_with_idx] # Process general candidates (attributes, actions, species, etc.) if general_cands: if mode == "single_shot": run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE) else: for start in range(0, len(general_cands), chunk_size): run_call( general_cands[start:start + chunk_size], f"general_chunk_{start//chunk_size}", SELECT_SYSTEM_TEMPLATE ) # Process entity candidates (characters only) with alias-based pre-filtering if entity_cands: tag2aliases = get_tag2aliases() qwords = _query_words(image_description) qnorm = _normalize_for_matching(image_description) filtered_entity_cands: List[Candidate] = [] filtered_out: List[str] = [] for cand in entity_cands: if _character_matches_via_aliases( cand.tag, image_description, tag2aliases, qwords, qnorm ): filtered_entity_cands.append(cand) else: filtered_out.append(cand.tag) if log: log( f"Stage3 entity alias filter: " f"before={len(entity_cands)} " f"after={len(filtered_entity_cands)} " f"removed={len(filtered_out)}" ) if filtered_out: log(f"Stage3 entity alias filter removed: {filtered_out[:20]}") if filtered_entity_cands: if mode == "single_shot": run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE) else: for start in range(0, len(filtered_entity_cands), chunk_size): run_call( filtered_entity_cands[start:start + chunk_size], f"entity_chunk_{start//chunk_size}", ENTITY_SYSTEM_TEMPLATE ) # Apply why threshold: drop tags below the minimum confidence level. if min_why is not None: max_rank = WHY_RANK.get(min_why, 4) before = len(best) best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank} if log: log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), " f"before={before} after={len(best)} dropped={before - len(best)}") # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM). count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm} ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True) # Legacy cap: apply AFTER union + ordering. if isinstance(max_pick, int) and max_pick > 0: ordered_tags = ordered_tags[:max_pick] # Map back to original indices out_idx: List[int] = [] tag_why: Dict[str, str] = {} for t in ordered_tags: if t in tag_to_first_index: out_idx.append(tag_to_first_index[t]) tag_why[t] = best[t][1] # why string if return_metadata: return out_idx, tag_why return out_idx