# psq_rag/llm/select.py # Stage 3: Closed-Set Selection (LangChain-only implementation) # # This module intentionally uses LangChain for: # - prompt templating (including {N}) # - LLM call orchestration # - JSON parsing # # There is NO fallback path. If LangChain dependencies are missing, this module # should fail loudly so you install them. import os import re import csv from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Mapping from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import PydanticOutputParser from pydantic import BaseModel, Field, SecretStr from rapidfuzz import fuzz from psq_rag.retrieval.psq_retrieval import Candidate, _norm_tag_for_lookup # Candidate(tag, score_*, count, sources) from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases, get_tag_counts, get_fasttext_model def _redact_sensitive_error_text(err: Any) -> str: """Redact provider/user identifiers and token-like substrings from error text.""" text = str(err or "") # OpenRouter-style user identifier. text = re.sub(r"\buser_(?!id\b)[A-Za-z0-9]{6,}\b", "user_", text) # Bearer tokens / API-key-like strings if surfaced by an upstream exception. text = re.sub(r"(?i)\b(bearer)\s+[A-Za-z0-9._:-]+\b", r"\1 ", text) text = re.sub(r"\b(sk|or)-[A-Za-z0-9._-]+\b", r"\1-", text) return text def _parse_openrouter_error_facts(err: Any) -> Dict[str, Any]: """Extract coarse error facts from OpenRouter/LangChain exception text.""" text = str(err or "") text_l = text.lower() code: Optional[int] = None m_code = re.search(r"Error code:\s*(\d+)", text, flags=re.IGNORECASE) if not m_code: m_code = re.search(r"['\"]code['\"]\s*:\s*(\d+)", text) if m_code: try: code = int(m_code.group(1)) except Exception: code = None provider_name: Optional[str] = None m_provider = re.search(r"['\"]provider_name['\"]\s*:\s*['\"]([^'\"]+)['\"]", text) if m_provider: provider_name = m_provider.group(1).strip() or None is_upstream_throttle = ( (code == 429) or ("rate-limited upstream" in text_l) or ("too many requests" in text_l) ) is_credit_or_quota = ( (code == 402) or ("insufficient credits" in text_l) or ("payment required" in text_l) or ("quota exceeded" in text_l) ) is_transient_server = (code in {500, 502, 503, 504}) or ("timeout" in text_l) retryable = bool(is_upstream_throttle or is_transient_server) kind = "other" if is_upstream_throttle: kind = "upstream_throttle" elif is_credit_or_quota: kind = "credit_or_quota" elif is_transient_server: kind = "transient_server" return { "code": code, "provider_name": provider_name, "kind": kind, "retryable": retryable, } # Character-typed tags that are generic categories, not actual named characters. # These leak through the alias filter because they match common words in captions. # They are excluded from the entity pipeline and instead routed to general selection. _GENERIC_CHARACTER_TAGS = frozenset({ "fan_character", "background_character", "unnamed_character", "unknown_character", "anonymous_character", "viewer", "original_character", }) # IMPORTANT ABOUT TEMPLATING: # - This string is rendered by LangChain's f-string template engine. # - Literal JSON braces must be escaped as {{ and }}. # - {N} is a real template variable and MUST be provided. SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags. Select tags ONLY when they are explicitly stated in the image description text. Do NOT select tags based on implication, plausibility, style assumptions, or world knowledge. If a tag is not directly supported by explicit wording in the description, do not select it. Return JSON ONLY matching this schema: {{ \"selections\": [ {{\"i\": }}, ... ] }} Rules: - Choose ONLY from indices 1..{N}. - Do NOT output tag text. - Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\". - Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\"). """ def _get_select_system_template() -> str: """Return Stage 3 selection prompt text.""" return SELECT_SYSTEM_TEMPLATE ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags. These character tags have already been pre-filtered to only include characters whose names (or known aliases) appear in the image description. Your job is to confirm which of these pre-filtered candidates are the correct match for the character mentioned by the user. Return JSON ONLY matching this schema: {{ \"selections\": [ {{\"i\": }}, ... ] }} Rules for character selection: - Choose ONLY from indices 1..{N}. - Do NOT output tag text. - Select the tag that best represents the character as described. - If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"), select that specific variant tag. - If the user described only the base character (e.g. just \"pikachu\"), select only the base/default tag, NOT costume or variant tags. - When uncertain between variants, prefer the simplest/most general tag. """ USER_TEMPLATE = """IMAGE DESCRIPTION: {image_description} CANDIDATES (choose by index only): {candidate_lines} Select only indices for tags that are directly and explicitly stated by the image description text. If uncertain, select nothing for that tag. It is valid to return an empty selections list. Select up to {per_call_budget} indices. """ PROBE_USER_TEMPLATE_PRECISION = """IMAGE DESCRIPTION: {image_description} CANDIDATES (choose by index only): {candidate_lines} Select up to {per_call_budget} indices of tags that describe the image's visible contents. Prefer precision over recall: if uncertain, skip. For species-family tags (for example canid, felid, bird, bear), only select when the description clearly names that family/species (dog/wolf/fox/cat/tiger/bird/bear/etc.) or gives unmistakable species-specific evidence; do not infer from generic words like animal, mammal, or anthro. """ PROBE_USER_TEMPLATE_LEXICAL = """IMAGE DESCRIPTION: {image_description} CANDIDATES (choose by index only): {candidate_lines} Select up to {per_call_budget} indices. Select a tag only when the description text explicitly states it or directly describes it. Precision first: false positives are worse than misses. Return an empty selections list when unsure. """ def _get_probe_user_template() -> str: variant = (os.environ.get("PSQ_PROBE_PROMPT_VARIANT", "baseline") or "baseline").strip().lower() if variant == "precision": return PROBE_USER_TEMPLATE_PRECISION if variant == "lexical": return PROBE_USER_TEMPLATE_LEXICAL return USER_TEMPLATE @dataclass(frozen=True) class Selected: i: int tag: str # canonical tag (underscore form) class Stage3SelectionItem(BaseModel): i: int = Field(..., description="1-based index into the candidate list.") class Stage3SelectionResponse(BaseModel): selections: List[Stage3SelectionItem] = Field(default_factory=list) def _build_response_format() -> Dict[str, Any]: # Strict JSON Schema structured output. schema = { "type": "object", "properties": { "selections": { "type": "array", "items": { "type": "object", "properties": { "i": {"type": "integer"}, }, "required": ["i"], "additionalProperties": False, }, } }, "required": ["selections"], "additionalProperties": False, } return { "type": "json_schema", "json_schema": { "name": "stage3_selection", "strict": True, "schema": schema, }, } def _get_llm( *, temperature: float, max_tokens: int, response_format: Dict[str, Any], model: Optional[str] = None, ) -> ChatOpenAI: api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: raise RuntimeError( "OPENROUTER_API_KEY is not set.\n" "Set it in your environment before running Stage 3." ) api_key = SecretStr(cast(str, api_key)) default_model = "mistralai/mistral-small-24b-instruct-2501" resolved_model = (model or os.getenv("OPENROUTER_MODEL", default_model) or default_model).strip() or default_model headers: Dict[str, str] = {} if referer := os.getenv("OPENROUTER_HTTP_REFERER"): headers["HTTP-Referer"] = referer if title := os.getenv("OPENROUTER_X_TITLE"): headers["X-Title"] = title # OpenRouter OpenAI-compatible endpoint. return ChatOpenAI( model=resolved_model, base_url="https://openrouter.ai/api/v1", api_key=api_key, temperature=temperature, max_completion_tokens=max_tokens, default_headers=headers, # Provider-specific request body fields (OpenAI-compatible). # Response Healing plugin reduces malformed-JSON failures (syntax only). extra_body={ "response_format": response_format, "plugins": [{"id": "response-healing"}], }, ) def _resolve_fallback_model( *, primary_model: str, explicit_fallback: Optional[str], env_key: str, ) -> Optional[str]: fallback_default = "meta-llama/llama-3.1-8b-instruct" raw = ( (explicit_fallback or "").strip() or (os.environ.get(env_key, "") or "").strip() or (os.environ.get("PSQ_OPENROUTER_FALLBACK_MODEL", "") or "").strip() or fallback_default ) if not raw: return None if raw.lower() in {"none", "off", "false", "0"}: return None if raw == primary_model: return None return raw def _phrase_key_for_candidate(c: Candidate) -> str: # Deterministic "primary phrase" for grouping. if c.sources: return sorted(c.sources)[0] return "" def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]: """Round-robin interleave by primary source phrase. NOTE: counts are used only for ordering; they are NOT shown to the LLM. """ groups: Dict[str, List[Candidate]] = {} for c in cands: k = _phrase_key_for_candidate(c) groups.setdefault(k, []).append(c) for k in groups: groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True) keys = sorted(groups.keys()) out: List[Candidate] = [] idx = 0 while True: progressed = False for k in keys: if idx < len(groups[k]): out.append(groups[k][idx]) progressed = True if not progressed: break idx += 1 return out def _build_chunks(cands: Sequence[Candidate], chunk_size: int) -> List[List[Candidate]]: if chunk_size <= 0: raise ValueError(f"chunk_size must be > 0, got {chunk_size}") ordered = _interleave_round_robin(cands) return [ordered[i:i + chunk_size] for i in range(0, len(ordered), chunk_size)] def _display_tag(tag: str) -> str: # Display tags with spaces for the LLM, but keep canonical underscores internally. return tag.replace("_", " ") def _format_candidates_local( cands: Sequence[Candidate], candidate_display: Optional[Mapping[str, str]] = None, ) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]: lines: List[str] = [] idx_to_tag: Dict[int, str] = {} idx_to_candidate: Dict[int, Candidate] = {} for j, c in enumerate(cands, start=1): idx_to_tag[j] = c.tag idx_to_candidate[j] = c display = candidate_display.get(c.tag) if candidate_display else None if not display: display = _display_tag(c.tag) lines.append(f"{j}. {display}") return "\n".join(lines), idx_to_tag, idx_to_candidate def _phrases_in_call(cands: Sequence[Candidate]) -> int: s = set() for c in cands: for src in c.sources: s.add(src) return len(s) def _parse_validate_map( parsed: Any, idx_to_tag: Dict[int, str], per_call_budget: int, ) -> Tuple[List[Selected], Dict[str, Any]]: diag = { "parse_ok": isinstance(parsed, dict), "invalid_items": 0, "oob_indices": 0, "dupe_indices": 0, "kept": 0, } if isinstance(parsed, BaseModel): parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() diag["parse_ok"] = isinstance(parsed, dict) if not isinstance(parsed, dict): return [], diag selections = parsed.get("selections", []) if not isinstance(selections, list): diag["parse_ok"] = False return [], diag out: List[Selected] = [] seen_i = set() for item in selections: if len(out) >= per_call_budget: break if not isinstance(item, dict): diag["invalid_items"] += 1 continue i = item.get("i") if isinstance(i, bool) or not isinstance(i, int): diag["invalid_items"] += 1 continue if i in seen_i: diag["dupe_indices"] += 1 continue if i not in idx_to_tag: diag["oob_indices"] += 1 continue extra_keys = set(item.keys()) - {"i"} if extra_keys: diag["invalid_items"] += 1 continue seen_i.add(i) tag = idx_to_tag[i] out.append(Selected(i=i, tag=tag)) diag["kept"] = len(out) return out, diag def _split_candidates_by_type( candidates: List[Candidate], log, ) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]: """Split candidates into general vs entity (character only) lists. Returns: (general_list, entity_list) where each item is (original_index, candidate) Tag types: - General: 0 (general), 1 (artist), 5 (species), 7 (meta) - Entity: 4 (character) only - Filtered: 3 (copyright) - too broad for image generation """ general_with_idx: List[Tuple[int, Candidate]] = [] entity_with_idx: List[Tuple[int, Candidate]] = [] unknown_count = 0 copyright_count = 0 generic_char_count = 0 for idx, cand in enumerate(candidates): type_name = get_tag_type_name(cand.tag) if type_name == "character": if cand.tag in _GENERIC_CHARACTER_TAGS: # Route generic character-category tags to general selection general_with_idx.append((idx, cand)) generic_char_count += 1 else: entity_with_idx.append((idx, cand)) elif type_name == "copyright": # Filter out copyright/series tags - too broad for image generation copyright_count += 1 elif type_name in ("general", "artist", "species", "meta"): general_with_idx.append((idx, cand)) else: # Unknown or None - treat as general by default general_with_idx.append((idx, cand)) unknown_count += 1 if log: log( f"Stage3 split: " f"general={len(general_with_idx)} " f"entity={len(entity_with_idx)} " f"copyright_filtered={copyright_count} " f"generic_char_to_general={generic_char_count} " f"unknown_type={unknown_count}" ) return general_with_idx, entity_with_idx # Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character) _SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$") def _normalize_for_matching(text: str) -> str: """Lowercase, replace underscores with spaces, strip series suffixes.""" text = text.lower().strip() text = _SERIES_SUFFIX_RE.sub("", text) text = text.replace("_", " ") return text def _query_words(query: str) -> Set[str]: """Extract individual words from the user query for matching.""" return set(_normalize_for_matching(query).split()) def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str, fuzzy_threshold: int = 85) -> bool: """Check if an alias matches the user query. Matching logic: 1. Exact substring: alias appears as a substring of the query 2. Word subset: all words in the alias appear in the query words 3. Fuzzy: alias is close to a word in the query (handles typos) """ # Exact substring match if alias_norm in query_norm: return True alias_words = alias_norm.split() if not alias_words: return False # Word subset match: all alias words must appear in query if all(w in query_words for w in alias_words): return True # For single-word aliases, try fuzzy matching against each query word if len(alias_words) == 1: for qw in query_words: if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold: return True # For multi-word aliases, try fuzzy partial ratio against whole query if len(alias_words) > 1: if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold: return True return False def _character_matches_via_aliases( tag: str, query: str, tag2aliases: Dict[str, List[str]], query_words: Set[str], query_norm: str, fuzzy_threshold: int = 85, ) -> bool: """Check if a character tag matches the user query via its aliases. For a character tag to match: - The tag name itself (normalized) must match, OR - At least one of its registered aliases must match. Empty aliases list means no known aliases; still check the tag name itself. """ # Check the tag name itself tag_norm = _normalize_for_matching(tag) if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold): return True # Check all registered aliases aliases = tag2aliases.get(tag, []) for alias in aliases: alias_norm = _normalize_for_matching(alias) if not alias_norm: continue if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold): return True return False def llm_select_indices( query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION candidates: Union[ Sequence[Candidate], Sequence[str], Sequence[Tuple[str, float]], ], max_pick: int, # legacy param; applied after union + ordering (optional) log, retries: int = 2, *, mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union" chunk_size: int = 60, per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call temperature: float = 0.0, max_tokens: int = 512, return_metadata: bool = False, return_diagnostics: bool = False, min_why: Optional[str] = None, candidate_display: Optional[Mapping[str, str]] = None, user_template: Optional[str] = None, model_override: Optional[str] = None, fallback_model_override: Optional[str] = None, ) -> Union[ List[int], Tuple[List[int], Dict[str, str]], Tuple[List[int], Dict[str, str], Dict[str, Any]], ]: """Return indices into the ORIGINAL candidates list (legacy interface). min_why: legacy compatibility argument; ignored in explicit-only mode. This implementation uses LangChain ONLY. NOTE: query_text is treated as the image description (original prompt). """ image_description = query_text # Normalize candidates: # - preferred: List[Candidate] # - legacy: List[(tag, sim)] (count/sources unavailable) norm: List[Candidate] = [] tag_to_first_index: Dict[str, int] = {} branch = "empty" cand0_type = type(candidates[0]).__name__ if candidates else "none" if candidates and isinstance(candidates[0], Candidate): branch = "candidate" typed_candidates = cast(Sequence[Candidate], candidates) for idx, c in enumerate(typed_candidates): if c.tag not in tag_to_first_index: tag_to_first_index[c.tag] = idx norm.append(c) elif candidates and isinstance(candidates[0], str): branch = "string" typed_candidates = cast(Sequence[str], candidates) for idx, tag in enumerate(typed_candidates): if tag not in tag_to_first_index: tag_to_first_index[tag] = idx norm.append( Candidate( tag=tag, score_combined=0.0, score_fasttext=None, score_context=None, count=None, sources=[], ) ) else: if candidates: branch = "tuple" typed_candidates = cast(Sequence[Tuple[str, float]], candidates) for idx, row in enumerate(typed_candidates): if not isinstance(row, (list, tuple)) or len(row) < 2: raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.") tag, sim = row[0], row[1] if tag not in tag_to_first_index: tag_to_first_index[tag] = idx norm.append( Candidate( tag=tag, score_combined=float(sim), score_fasttext=None, score_context=None, count=None, sources=[], ) ) if log: if norm: log( "Stage3 input: " f"type0={cand0_type} " f"branch={branch} " f"norm0_score={norm[0].score_combined!r} " f"norm0_sources_empty={not bool(norm[0].sources)}" ) else: log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)") if mode not in ("single_shot", "chunked_map_union"): raise ValueError(f"Invalid mode: {mode}") response_format = _build_response_format() default_model = "mistralai/mistral-small-24b-instruct-2501" model_name = (model_override or os.getenv("OPENROUTER_MODEL", default_model) or default_model).strip() or default_model fallback_model_name = _resolve_fallback_model( primary_model=model_name, explicit_fallback=fallback_model_override, env_key="PSQ_SELECT_OPENROUTER_FALLBACK_MODEL", ) parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse) select_system_template = _get_select_system_template() human_template = user_template or USER_TEMPLATE # Global union of selected tags across calls. best_tags: Set[str] = set() diagnostics: Dict[str, Any] = { "mode": mode, "chunk_strategy": "interleave", "chunk_passes": 1, "chunk_shuffle_within_call": False, "calls_total": 0, "calls_with_selection": 0, "calls_exhausted_retries": 0, "attempts_total": 0, "attempt_errors": 0, "attempt_parse_fail": 0, "attempt_parse_ok": 0, "invalid_items_total": 0, "oob_indices_total": 0, "dupe_indices_total": 0, "kept_total": 0, "attempts_by_n_local": {}, "upstream_throttle_errors": 0, "credit_or_quota_errors": 0, "errors_by_code": {}, "errors_by_provider": {}, "fallback_activations": 0, } def _record_attempt_for_n(n_local: int, *, parse_ok: bool, error: bool) -> None: by_n = diagnostics["attempts_by_n_local"] key = str(n_local) if key not in by_n: by_n[key] = { "attempts": 0, "parse_ok": 0, "parse_fail": 0, "errors": 0, } by_n[key]["attempts"] += 1 if error: by_n[key]["errors"] += 1 elif parse_ok: by_n[key]["parse_ok"] += 1 else: by_n[key]["parse_fail"] += 1 def _record_error_facts(facts: Dict[str, Any]) -> None: code = facts.get("code") provider_name = (facts.get("provider_name") or "").strip() or "unknown" kind = facts.get("kind") if kind == "upstream_throttle": diagnostics["upstream_throttle_errors"] += 1 if kind == "credit_or_quota": diagnostics["credit_or_quota_errors"] += 1 if code is not None: code_key = str(code) diagnostics["errors_by_code"][code_key] = int(diagnostics["errors_by_code"].get(code_key, 0)) + 1 diagnostics["errors_by_provider"][provider_name] = int( diagnostics["errors_by_provider"].get(provider_name, 0) ) + 1 def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None: # Create prompt once; model may fail over on upstream throttling. prompt = ChatPromptTemplate.from_messages( [ ("system", system_template), ("human", human_template), ], template_format="f-string", ) ordered = _interleave_round_robin(call_cands) if mode == "single_shot" else list(call_cands) candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local( ordered, candidate_display=candidate_display, ) N_local = len(idx_to_tag) diagnostics["calls_total"] += 1 phrases = _phrases_in_call(call_cands) per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k summary_logged = False if log: log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}") if phrases > 0: distinct_phrases = sorted({src for c in call_cands for src in c.sources}) log( f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} " f"phrases={', '.join(distinct_phrases)}" ) # Do not retry the same model repeatedly. On upstream 429, try fallback once. model_attempts: List[Tuple[str, str]] = [("primary", model_name)] if fallback_model_name: model_attempts.append(("fallback", fallback_model_name)) # Invoke LangChain chain (templating fills {N} and other vars) for role, active_model in model_attempts: try: diagnostics["attempts_total"] += 1 llm = _get_llm( temperature=temperature, max_tokens=max_tokens, response_format=response_format, model=active_model, ) chain = prompt | llm | parser if role == "fallback": diagnostics["fallback_activations"] += 1 if log: log( "!!! Stage3 fallback model activated: " f"reason=upstream_throttle primary={model_name} fallback={active_model}" ) if log: log( f"Stage3 {label}: " f"model={active_model} " f"N={N_local} " f"phrases={phrases} " f"per_call_budget={per_call_budget} " f"response_healing=on" ) parsed = chain.invoke( { "N": N_local, "image_description": image_description, "candidate_lines": candidate_lines, "per_call_budget": per_call_budget, } ) selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget) diagnostics["invalid_items_total"] += int(diag.get("invalid_items", 0)) diagnostics["oob_indices_total"] += int(diag.get("oob_indices", 0)) diagnostics["dupe_indices_total"] += int(diag.get("dupe_indices", 0)) diagnostics["kept_total"] += int(diag.get("kept", 0)) if bool(diag.get("parse_ok", False)): diagnostics["attempt_parse_ok"] += 1 _record_attempt_for_n(N_local, parse_ok=True, error=False) else: diagnostics["attempt_parse_fail"] += 1 _record_attempt_for_n(N_local, parse_ok=False, error=False) if log: log(f"Stage3 {label}: attempt model={active_model} diag={diag}") if not summary_logged and selected: log( f"Stage3 {label}: summary " f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}" ) summary_logged = True if selected: lines = [ f"Stage3 {label} selections:", *[ ( f' - i={s.i} tag="{s.tag}" ' f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}" ) for s in selected ], ] log("\n".join(lines)) else: log(f"Stage3 {label} selections: (none)") if selected: diagnostics["calls_with_selection"] += 1 for s in selected: best_tags.add(s.tag) return except Exception as e: diagnostics["attempt_errors"] += 1 _record_attempt_for_n(N_local, parse_ok=False, error=True) facts = _parse_openrouter_error_facts(e) _record_error_facts(facts) if log: log(f"Stage3 {label}: attempt model={active_model} error: {_redact_sensitive_error_text(e)}") if facts.get("kind") == "upstream_throttle": p = facts.get("provider_name") or "unknown" c = facts.get("code") log(f"Stage3 {label}: upstream throttle detected (provider={p}, code={c})") elif facts.get("kind") == "credit_or_quota": c = facts.get("code") log(f"Stage3 {label}: credit/quota error detected (code={c})") should_fallback = ( role == "primary" and fallback_model_name is not None and facts.get("kind") == "upstream_throttle" ) if should_fallback: continue break if log: log(f"Stage3 {label}: gave up after model attempts") diagnostics["calls_exhausted_retries"] += 1 # Split candidates by type (general vs entity) general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log) # Extract just the candidates for LLM calls general_cands = [cand for _, cand in general_with_idx] entity_cands = [cand for _, cand in entity_with_idx] # Process general candidates (attributes, actions, species, etc.) if general_cands: if mode == "single_shot": run_call(general_cands, "general_single_shot", select_system_template) else: base_chunks = _build_chunks(general_cands, chunk_size) for chunk_idx, chunk in enumerate(base_chunks): run_call(chunk, f"general_chunk_{chunk_idx}", select_system_template) # Process entity candidates (characters only) with alias-based pre-filtering if entity_cands: tag2aliases = get_tag2aliases() qwords = _query_words(image_description) qnorm = _normalize_for_matching(image_description) filtered_entity_cands: List[Candidate] = [] filtered_out: List[str] = [] for cand in entity_cands: if _character_matches_via_aliases( cand.tag, image_description, tag2aliases, qwords, qnorm ): filtered_entity_cands.append(cand) else: filtered_out.append(cand.tag) if log: log( f"Stage3 entity alias filter: " f"before={len(entity_cands)} " f"after={len(filtered_entity_cands)} " f"removed={len(filtered_out)}" ) if filtered_out: log(f"Stage3 entity alias filter removed: {filtered_out[:20]}") if filtered_entity_cands: if mode == "single_shot": run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE) else: base_chunks = _build_chunks(filtered_entity_cands, chunk_size) for chunk_idx, chunk in enumerate(base_chunks): run_call(chunk, f"entity_chunk_{chunk_idx}", ENTITY_SYSTEM_TEMPLATE) if min_why is not None and log: log("Stage3: min_why is ignored in explicit-only no-why mode") # Deterministic ordering: count desc (count not shown to LLM), then tag. count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm} ordered_tags = sorted(best_tags, key=lambda t: (count_by_tag.get(t, -1), t), reverse=True) # Legacy cap: apply AFTER union + ordering. if isinstance(max_pick, int) and max_pick > 0: ordered_tags = ordered_tags[:max_pick] # Map back to original indices out_idx: List[int] = [] tag_why: Dict[str, str] = {} for t in ordered_tags: if t in tag_to_first_index: out_idx.append(tag_to_first_index[t]) # Why labels removed in explicit-only no-why mode. if diagnostics["attempts_total"] > 0: diagnostics["attempt_failure_rate"] = ( diagnostics["attempt_parse_fail"] + diagnostics["attempt_errors"] ) / diagnostics["attempts_total"] else: diagnostics["attempt_failure_rate"] = 0.0 if diagnostics["calls_total"] > 0: diagnostics["call_exhaustion_rate"] = ( diagnostics["calls_exhausted_retries"] / diagnostics["calls_total"] ) else: diagnostics["call_exhaustion_rate"] = 0.0 if log and (diagnostics["upstream_throttle_errors"] or diagnostics["credit_or_quota_errors"]): log( "Stage3 diagnostics: " f"upstream_throttle_errors={diagnostics['upstream_throttle_errors']} " f"credit_or_quota_errors={diagnostics['credit_or_quota_errors']} " f"fallback_activations={diagnostics['fallback_activations']} " f"errors_by_code={diagnostics['errors_by_code']} " f"errors_by_provider={diagnostics['errors_by_provider']}" ) if return_metadata: if return_diagnostics: return out_idx, tag_why, diagnostics return out_idx, tag_why return out_idx # --------------------------------------------------------------------------- # Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… ) # --------------------------------------------------------------------------- # Group-based approach: tags are organized into semantic groups loaded from # tag_groups.json / tag_wiki_defs.json where possible, with curated fallback # definitions for tags whose wiki entries are only thumbnail references. # # Each group specifies a constraint mode: # "exclusive" = pick exactly one (e.g. character count) # "multi" = pick all that apply (e.g. body type, gender) import json as _json @dataclass class StructuralGroup: """One category of structural tags to probe.""" name: str constraint: str # "exclusive" or "multi" tags: List[Tuple[str, str]] # (tag, definition) pairs def _load_structural_groups_from_csv() -> List[StructuralGroup]: """Load structural groups from data/structural_tag_definitions.csv.""" data_dir = Path(__file__).resolve().parents[2] / "data" csv_path = data_dir / "structural_tag_definitions.csv" if not csv_path.is_file(): return [] groups_by_name: Dict[str, List[Tuple[str, str]]] = {} constraints_by_name: Dict[str, str] = {} with csv_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: enabled = (row.get("enabled") or "1").strip().lower() if enabled in {"0", "false", "no"}: continue group_name = (row.get("group_name") or "").strip() constraint = (row.get("constraint") or "multi").strip().lower() tag = (row.get("tag") or "").strip() definition = " ".join((row.get("definition") or "").split()) if not group_name or not tag or not definition: continue if constraint not in {"exclusive", "multi"}: constraint = "multi" if group_name not in groups_by_name: groups_by_name[group_name] = [] constraints_by_name[group_name] = constraint groups_by_name[group_name].append((tag, definition)) out: List[StructuralGroup] = [] for group_name, tags in groups_by_name.items(): if not tags: continue out.append( StructuralGroup( name=group_name, constraint=constraints_by_name.get(group_name, "multi"), tags=tags, ) ) return out def _load_structural_groups() -> List[StructuralGroup]: """Build structural groups from local config file with legacy fallback. Preferred source: data/structural_tag_definitions.csv Fallback: tag_wiki_defs.json + curated hardcoded defaults """ csv_groups = _load_structural_groups_from_csv() if csv_groups: return csv_groups data_dir = Path(__file__).resolve().parents[2] / "data" # Load wiki definitions (may not exist yet) wiki_defs: Dict[str, str] = {} wiki_path = data_dir / "tag_wiki_defs.json" if wiki_path.is_file(): with wiki_path.open("r", encoding="utf-8") as f: wiki_defs = _json.load(f) def _def(tag: str, fallback: str) -> str: """Get wiki definition if it's real text, otherwise use fallback.""" d = wiki_defs.get(tag, "") # Skip thumbnail-only definitions if not d or d.startswith("thumb ") or len(d) < 15: return fallback return d[:200] # cap length for prompt groups: List[StructuralGroup] = [] # ── Group A: Character Count (exclusive) ── groups.append(StructuralGroup( name="character_count", constraint="exclusive", tags=[ ("zero_pictured", _def("zero_pictured", "No characters or living beings appear in the image")), ("solo", _def("solo", "Exactly one character appears in the image")), ("duo", _def("duo", "Exactly two characters appear in the image")), ("trio", _def("trio", "Exactly three characters appear in the image")), ("group", _def("group", "Four or more characters appear in the image")), ], )) # ── Group B: Body Type (multi — per character) ── # Key distinction the LLM must learn: # anthro = ANIMAL with human body shape (upright, hands) # humanoid = HUMAN or near-human (elf, dwarf) with NO animal features # feral = normal animal shape, on all fours groups.append(StructuralGroup( name="body_type", constraint="multi", tags=[ ("anthro", _def("anthro", "An animal character with a human-like body: walks upright on two legs, " "has arms and hands. Examples: a wolf-person, a fox standing up. " "Still has animal features like fur, tail, muzzle")), ("feral", _def("feral", "A regular animal in its natural body shape. Walks on all fours (or " "flies/swims naturally). NOT standing upright, NOT humanized")), ("humanoid", _def("humanoid", "A human or human-like character with NO animal features. Includes " "humans, elves, dwarves, and fantasy races that look human. " "Does NOT include animal-people — those are anthro")), ("taur", _def("taur", "A centaur-like body: human or anthro upper body attached to a " "four-legged animal lower body")), ], )) # ── Group C: Gender (multi — per character) ── groups.append(StructuralGroup( name="gender", constraint="multi", tags=[ ("male", _def("male", "A character described as male, a boy, or with he/him pronouns")), ("female", _def("female", "A character described as female, a girl, or with she/her pronouns")), ("ambiguous_gender", _def("ambiguous_gender", "A character whose gender is not stated or cannot be determined")), ("intersex", _def("intersex", "A character explicitly described as intersex or hermaphrodite")), ], )) # ── Group D: Clothing State (multi) ── groups.append(StructuralGroup( name="clothing_state", constraint="multi", tags=[ ("clothed", _def("clothed", "Wearing clothes on BOTH chest/torso AND legs/waist. " "Examples: shirt and pants, dress, full outfit")), ("nude", _def("nude", "Wearing NO clothes at all. Completely naked, no shirt and no pants")), ("topless", _def("topless", "NO shirt/top (bare chest), BUT wearing pants/bottoms. " "Upper body exposed, lower body covered")), ("bottomless", _def("bottomless", "Wearing shirt/top on chest, BUT NO pants/bottoms. " "Upper body covered, lower body exposed")), ], )) # ── Group E: Common Visual Elements (multi) ── groups.append(StructuralGroup( name="visual_elements", constraint="multi", tags=[ ("looking_at_viewer", _def("looking_at_viewer", "A character is looking directly at the camera or viewer")), ("text", _def("text", "The image contains visible writing, words, or lettering")), ], )) return groups def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]: """Build numbered statement list from structural groups. Returns (formatted_text, flat_list_of_(tag, definition)_pairs). The flat list maps 1-based statement numbers to tags. """ lines: List[str] = [] flat: List[Tuple[str, str]] = [] idx = 1 for g in groups: constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply" group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---" lines.append(group_header) for tag, defn in g.tags: lines.append(f"{idx}. {defn}") flat.append((tag, defn)) idx += 1 lines.append("") # blank line between groups return "\n".join(lines), flat STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list. The statements are organized into GROUPS. Each group header tells you how many to pick: - "pick EXACTLY ONE" = choose the single best match in that group - "pick ALL that apply" = choose every statement that is true IMPORTANT RULES: 1. ONLY select a statement if the description directly says it or makes it very obvious. 2. Do NOT guess or assume things the description does not mention. 3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid. 4. For gender: only select male/female/intersex when there is explicit textual evidence (such as gender words or pronouns). Do not infer gender from species, body shape, clothing, or style. If no reliable gender cue is present, do not select male/female/intersex; use ambiguous_gender instead. 5. For clothing state: READ CAREFULLY! "topless" = bare chest, wearing pants. "bottomless" = wearing shirt, no pants. If unsure, re-read the description. 6. If clothing is not mentioned, do NOT pick any clothing statement. Return JSON ONLY: {{"selections": [{{"i": 1}}, {{"i": 5}}]}} EXAMPLE: Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up" Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}} Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14).""" STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true. IMAGE DESCRIPTION: {image_description} STATEMENTS (pick by number): {statement_lines}""" class StructuralSelectionItem(BaseModel): i: int = Field(..., description="1-based index into the statement list.") class StructuralSelectionResponse(BaseModel): selections: List[StructuralSelectionItem] = Field(default_factory=list) def _build_structural_response_format() -> Dict[str, Any]: schema = { "type": "object", "properties": { "selections": { "type": "array", "items": { "type": "object", "properties": { "i": {"type": "integer"}, }, "required": ["i"], "additionalProperties": False, }, } }, "required": ["selections"], "additionalProperties": False, } return { "type": "json_schema", "json_schema": { "name": "structural_selection", "strict": True, "schema": schema, }, } # Cache the loaded groups so we only read JSON files once per process. _cached_structural_groups: Optional[List[StructuralGroup]] = None def _get_structural_groups() -> List[StructuralGroup]: global _cached_structural_groups if _cached_structural_groups is None: _cached_structural_groups = _load_structural_groups() return _cached_structural_groups def _postprocess_structural_tags(tags: Sequence[str], log=None) -> List[str]: """Apply deterministic structural-tag mapping rules. The character-count prompt distinguishes `trio` (exactly 3) from `group` (4+), but runtime tagging convention expects `group` whenever 3+ characters are present. Therefore `trio` implies `group`. """ out: List[str] = [] seen: Set[str] = set() for tag in tags: if tag and tag not in seen: out.append(tag) seen.add(tag) if "trio" in seen and "group" not in seen: out.append("group") seen.add("group") if log: log("Stage3s: postprocess added group because trio implies group") return out def llm_infer_structural_tags( query_text: str, log=None, *, temperature: float = 0.0, max_tokens: int = 512, retries: int = 2, model_override: Optional[str] = None, fallback_model_override: Optional[str] = None, ) -> List[str]: """Infer structural tags via LLM using group-based statement agreement. Probes multiple semantic groups (character count, body type, gender, clothing state, visual elements) with definitions loaded from wiki data where available. Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]). """ if log: log("Stage3s (structural): inferring structural tags via group-based statement agreement") groups = _get_structural_groups() statement_lines, flat_tags = _build_structural_prompt(groups) N = len(flat_tags) response_format = _build_structural_response_format() struct_model_override = ( (model_override or "").strip() or (os.environ.get("PSQ_STRUCT_OPENROUTER_MODEL", "") or "").strip() or None ) default_model = "mistralai/mistral-small-24b-instruct-2501" model_name = ( struct_model_override or os.getenv("OPENROUTER_MODEL", default_model) or default_model ).strip() or default_model fallback_model_name = _resolve_fallback_model( primary_model=model_name, explicit_fallback=fallback_model_override, env_key="PSQ_STRUCT_OPENROUTER_FALLBACK_MODEL", ) parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse) prompt = ChatPromptTemplate.from_messages( [ ("system", STRUCTURAL_SYSTEM_TEMPLATE), ("human", STRUCTURAL_USER_TEMPLATE), ], template_format="f-string", ) if log: group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups) log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}") throttle_errors = 0 credit_or_quota_errors = 0 errors_by_code: Dict[str, int] = {} errors_by_provider: Dict[str, int] = {} model_attempts: List[Tuple[str, str]] = [("primary", model_name)] if fallback_model_name: model_attempts.append(("fallback", fallback_model_name)) for role, active_model in model_attempts: try: llm = _get_llm( temperature=temperature, max_tokens=max_tokens, response_format=response_format, model=active_model, ) chain = prompt | llm | parser if role == "fallback" and log: log( "!!! Stage3s fallback model activated: " f"reason=upstream_throttle primary={model_name} fallback={active_model}" ) parsed = chain.invoke({ "N": N, "image_description": query_text, "statement_lines": statement_lines, }) if isinstance(parsed, BaseModel): parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() sels = parsed.get("selections", []) if isinstance(parsed, dict) else [] chosen_tags: List[str] = [] seen: Set[str] = set() for item in sels: idx = item.get("i") if isinstance(item, dict) else None if not isinstance(idx, int) or idx < 1 or idx > N: continue tag = flat_tags[idx - 1][0] if tag not in seen: chosen_tags.append(tag) seen.add(tag) chosen_tags = _postprocess_structural_tags(chosen_tags, log=log) if log: tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)" log(f"Stage3s: attempt model={active_model} selected {len(chosen_tags)} tags: {tag_str}") return chosen_tags except Exception as e: facts = _parse_openrouter_error_facts(e) code = facts.get("code") provider_name = (facts.get("provider_name") or "").strip() or "unknown" if code is not None: code_key = str(code) errors_by_code[code_key] = int(errors_by_code.get(code_key, 0)) + 1 errors_by_provider[provider_name] = int(errors_by_provider.get(provider_name, 0)) + 1 if facts.get("kind") == "upstream_throttle": throttle_errors += 1 elif facts.get("kind") == "credit_or_quota": credit_or_quota_errors += 1 if log: log(f"Stage3s: attempt model={active_model} error: {_redact_sensitive_error_text(e)}") if facts.get("kind") == "upstream_throttle": log(f"Stage3s: upstream throttle detected (provider={provider_name}, code={code})") elif facts.get("kind") == "credit_or_quota": log(f"Stage3s: credit/quota error detected (code={code})") should_fallback = ( role == "primary" and fallback_model_name is not None and facts.get("kind") == "upstream_throttle" ) if should_fallback: continue break if log: if throttle_errors or credit_or_quota_errors: log( "Stage3s diagnostics: " f"upstream_throttle_errors={throttle_errors} " f"credit_or_quota_errors={credit_or_quota_errors} " f"errors_by_code={errors_by_code} " f"errors_by_provider={errors_by_provider}" ) log(f"Stage3s: gave up after model attempts") return [] # --------------------------------------------------------------------------- # Stage 3p: Simplified high-precision probe tags # --------------------------------------------------------------------------- _cached_runtime_probe_tags: Optional[List[str]] = None _cached_runtime_probe_rows: Optional[List[Dict[str, str]]] = None _cached_probe_definition_rows: Optional[List[Dict[str, str]]] = None _cached_runtime_probe_rows_path: Optional[str] = None _cached_probe_definition_rows_path: Optional[str] = None _SPECIES_ANCHOR_RULES: Dict[str, Set[str]] = { "canid": {"canine", "wolf", "fox"}, "felid": {"feline", "tiger", "lion"}, } def _norm_text_for_probe(text: str) -> str: return " ".join((text or "").lower().replace("_", " ").split()) def _contains_any_token(text_norm: str, tokens: Sequence[str]) -> bool: for tok in tokens: t = _norm_text_for_probe(tok) if not t: continue if re.search(rf"(? List[str]: """Optionally derive family tags from simpler anchor tags (no extra LLM calls). Controlled by env var: PSQ_PROBE_SPECIES_PROXY_MODE = off | anchors | anchors_lexical - off: no mapping - anchors: add families from selected anchors only - anchors_lexical: also require lexical cue in query text """ mode = (os.environ.get("PSQ_PROBE_SPECIES_PROXY_MODE", "off") or "off").strip().lower() if mode in {"", "0", "false", "off", "none"}: return selected out: List[str] = list(selected) out_set: Set[str] = set(out) query_norm = _norm_text_for_probe(query_text) def _add(tag: str, reason: str) -> None: if tag in out_set: return out.append(tag) out_set.add(tag) if log: log(f"Stage3p: species-anchor mapped {tag} ({reason})") # canid / felid from common anchors for family, anchors in _SPECIES_ANCHOR_RULES.items(): hits = sorted([a for a in anchors if a in out_set]) if not hits: continue if mode == "anchors_lexical": lexical = _contains_any_token(query_norm, [family, *anchors]) if not lexical: continue _add(family, f"anchors={hits}") # Bird is handled more conservatively to avoid broad false positives. # Require both beak+feathers, or one of them plus explicit bird wording. has_beak = "beak" in out_set has_feathers = "feathers" in out_set has_bird_word = _contains_any_token(query_norm, ["bird", "avian"]) bird_ok = (has_beak and has_feathers) or ((has_beak or has_feathers) and has_bird_word) if bird_ok: if mode != "anchors_lexical" or has_bird_word or (has_beak and has_feathers): _add("bird", "beak/feathers evidence") return out def _load_runtime_probe_rows(log=None) -> List[Dict[str, str]]: global _cached_runtime_probe_rows, _cached_runtime_probe_rows_path repo_dir = Path(__file__).resolve().parents[2] data_dir = Path(__file__).resolve().parents[2] / "data" env_path = (os.environ.get("PSQ_PROBE_TAGS_CSV", "") or "").strip() if env_path: csv_path = Path(env_path) if not csv_path.is_absolute() and not csv_path.is_file(): csv_path = repo_dir / env_path if not csv_path.is_absolute() and not csv_path.is_file(): csv_path = data_dir / env_path else: csv_path = data_dir / "simplified_probe_tags.csv" if not csv_path.is_file(): # Legacy fallback for older checkouts. csv_path = data_dir / "analysis" / "simplified_probe_tags.csv" csv_path_key = str(csv_path.resolve()) if csv_path.exists() else str(csv_path) if _cached_runtime_probe_rows is not None and _cached_runtime_probe_rows_path == csv_path_key: return _cached_runtime_probe_rows rows: List[Dict[str, str]] = [] if not csv_path.is_file(): if log: log(f"Stage3p: probe CSV not found at {csv_path}; skipping probe step") _cached_runtime_probe_rows_path = csv_path_key _cached_runtime_probe_rows = rows return rows try: with csv_path.open("r", encoding="utf-8", newline="") as f: rows = list(csv.DictReader(f)) except Exception as e: if log: log(f"Stage3p: failed reading probe CSV: {_redact_sensitive_error_text(e)}") rows = [] _cached_runtime_probe_rows = rows _cached_runtime_probe_rows_path = csv_path_key return rows def _load_probe_definition_rows(log=None) -> List[Dict[str, str]]: global _cached_probe_definition_rows, _cached_probe_definition_rows_path repo_dir = Path(__file__).resolve().parents[2] data_dir = Path(__file__).resolve().parents[2] / "data" env_path = (os.environ.get("PSQ_PROBE_DEFINITIONS_CSV", "") or "").strip() if env_path: csv_path = Path(env_path) if not csv_path.is_absolute() and not csv_path.is_file(): csv_path = repo_dir / env_path if not csv_path.is_absolute() and not csv_path.is_file(): csv_path = data_dir / env_path else: csv_path = data_dir / "probe_tag_definitions.csv" csv_path_key = str(csv_path.resolve()) if csv_path.exists() else str(csv_path) if _cached_probe_definition_rows is not None and _cached_probe_definition_rows_path == csv_path_key: return _cached_probe_definition_rows rows: List[Dict[str, str]] = [] if not csv_path.is_file(): if log: log(f"Stage3p: probe definition CSV not found at {csv_path}; using bare tag labels") _cached_probe_definition_rows_path = csv_path_key _cached_probe_definition_rows = rows return rows try: with csv_path.open("r", encoding="utf-8", newline="") as f: rows = list(csv.DictReader(f)) except Exception as e: if log: log(f"Stage3p: failed reading probe definition CSV: {_redact_sensitive_error_text(e)}") rows = [] _cached_probe_definition_rows = rows _cached_probe_definition_rows_path = csv_path_key return rows def _load_runtime_probe_tags(log=None) -> List[str]: """Load runtime probe tags from analysis output. Preference order: 1) selected_final=1 (reliability-gated list) 2) selected_initial=1 (fallback if reliability file not built) """ global _cached_runtime_probe_tags if _cached_runtime_probe_tags is not None: return _cached_runtime_probe_tags rows = _load_runtime_probe_rows(log=log) tags: List[str] = [] def _is_on(v: str) -> bool: return (v or "").strip() in {"1", "true", "True"} final = [r.get("tag", "").strip() for r in rows if _is_on(r.get("selected_final", ""))] initial = [r.get("tag", "").strip() for r in rows if _is_on(r.get("selected_initial", ""))] tags = [t for t in (final if final else initial) if t] _cached_runtime_probe_tags = tags if log and tags: log(f"Stage3p: loaded {len(tags)} probe tags") return tags def _clean_glossary_text(text: str) -> str: t = " ".join((text or "").replace("\n", " ").replace("\r", " ").split()) if len(t) > 160: t = t[:157].rstrip() + "..." return t def _build_probe_candidate_display(probe_tags: Sequence[str], log=None) -> Dict[str, str]: rows = _load_probe_definition_rows(log=log) rows_by_tag = {r.get("tag", "").strip(): r for r in rows} display: Dict[str, str] = {} for tag in probe_tags: base = _display_tag(tag) row = rows_by_tag.get(tag, {}) # `notes` is intentionally ignored by runtime; only `definition` affects prompt text. gloss = _clean_glossary_text(row.get("definition", "")) display[tag] = f"{base} - {gloss}" if gloss else base return display def _probe_tokenize(text: str) -> List[str]: return [tok for tok in re.findall(r"[a-z0-9]+", _norm_text_for_probe(text)) if len(tok) >= 2] def _query_mentions_phrase(query_norm: str, phrase: str) -> bool: p = _norm_text_for_probe(phrase) if not p: return False # Word-boundary match for plain terms; fallback to substring for symbol tags (e.g. "<3"). if re.fullmatch(r"[a-z0-9 ]+", p): return re.search(rf"(? List[str]: """Order probe tags by likely relevance for this query (lexical + FastText). No TF-IDF is used here. This is intentionally lightweight and per-request. """ if not probe_tags: return [] enabled = os.environ.get("PSQ_PROBE_DYNAMIC_ORDER", "1").strip().lower() not in {"0", "false", "no"} if not enabled: return list(probe_tags) query_norm = _norm_text_for_probe(query_text) query_tokens = _probe_tokenize(query_text)[:64] try: ft_model = get_fasttext_model() except Exception as e: if log: log(f"Stage3p: dynamic ordering disabled (fasttext unavailable: {_redact_sensitive_error_text(e)})") return list(probe_tags) tag2aliases = get_tag2aliases() scored: List[Tuple[int, float, int, str]] = [] for i, tag in enumerate(probe_tags): primary = _display_tag(tag) term_pool = [primary] for alias in tag2aliases.get(tag, []): norm_alias = _norm_text_for_probe(alias) if norm_alias and len(norm_alias) <= 48: term_pool.append(norm_alias) # Deduplicate while preserving order. term_pool = list(dict.fromkeys(term_pool))[:12] lexical_hit = 0 for term in term_pool: if _query_mentions_phrase(query_norm, term): lexical_hit = 1 break # FastText relevance score: best token/phrase similarity to any term variant. ft_best = -1.0 for term in term_pool: try: s_phrase = float(ft_model.similarity(query_norm, term)) except Exception: s_phrase = -1.0 s_token = -1.0 for tok in query_tokens: try: s = float(ft_model.similarity(tok, term)) except Exception: s = -1.0 if s > s_token: s_token = s ft_term = max(s_phrase, s_token) if ft_term > ft_best: ft_best = ft_term scored.append((lexical_hit, ft_best, -i, tag)) scored.sort(reverse=True) ordered = [tag for _, _, _, tag in scored] if log: preview = ", ".join(f"{t}" for t in ordered[: min(8, len(ordered))]) log(f"Stage3p: dynamic order top={preview}") return ordered def _probe_bundle_map(log=None) -> Dict[str, str]: rows = _load_runtime_probe_rows(log=log) out: Dict[str, str] = {} for r in rows: t = (r.get("tag") or "").strip() if not t: continue b = (r.get("bundle") or "").strip() out[t] = b if b else "__uncategorized__" return out def _split_probe_tags_by_bundle( probe_tags: Sequence[str], n_splits: int, log=None, ) -> List[List[str]]: if n_splits <= 1: return [list(probe_tags)] bundle_by_tag = _probe_bundle_map(log=log) groups: Dict[str, List[str]] = {} for t in probe_tags: b = bundle_by_tag.get(t, f"__{t}__") groups.setdefault(b, []).append(t) # Greedy pack bundles into balanced chunks by total tag count. bundle_keys = sorted(groups.keys()) chunks: List[List[str]] = [[] for _ in range(n_splits)] sizes: List[int] = [0 for _ in range(n_splits)] for b in bundle_keys: j = min(range(n_splits), key=lambda idx: (sizes[idx], idx)) chunks[j].extend(groups[b]) sizes[j] += len(groups[b]) out = [c for c in chunks if c] if log: parts: List[str] = [] for i, c in enumerate(out, start=1): bundles = {bundle_by_tag.get(t, f"__{t}__") for t in c} parts.append(f"chunk{i}:tags={len(c)} bundles={len(bundles)}") log(f"Stage3p: split probe into {len(out)} calls ({', '.join(parts)})") return out def llm_infer_probe_tags( query_text: str, log=None, *, temperature: float = 0.0, max_tokens: int = 512, retries: int = 2, min_why: Optional[str] = None, model_override: Optional[str] = None, fallback_model_override: Optional[str] = None, ) -> List[str]: """Infer probe tags from a fixed reliability-gated tag list.""" probe_tags = _load_runtime_probe_tags(log=log) if not probe_tags: return [] try: min_tag_count = max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100"))) except Exception: min_tag_count = 100 if min_tag_count > 0: tag_counts = get_tag_counts() probe_tags_filtered = [ t for t in probe_tags if int(tag_counts.get(_norm_tag_for_lookup(t), 0) or 0) >= min_tag_count ] kept_set = set(probe_tags_filtered) removed = [t for t in probe_tags if t not in kept_set] probe_tags = probe_tags_filtered if log and removed: log( f"Stage3p: filtered {len(removed)} probe tags below count {min_tag_count}: " f"{', '.join(removed)}" ) if not probe_tags: if log: log("Stage3p: no probe tags remain after min-count filtering") return [] probe_tags = _rank_probe_tags_for_query(probe_tags, query_text=query_text, log=log) probe_model_override = ( (model_override or "").strip() or (os.environ.get("PSQ_PROBE_OPENROUTER_MODEL", "") or "").strip() or None ) probe_default_model = "mistralai/mistral-small-24b-instruct-2501" probe_effective_model = ( probe_model_override or os.getenv("OPENROUTER_MODEL", probe_default_model) or probe_default_model ).strip() or probe_default_model probe_fallback_model_override = _resolve_fallback_model( primary_model=probe_effective_model, explicit_fallback=fallback_model_override, env_key="PSQ_PROBE_OPENROUTER_FALLBACK_MODEL", ) if log: log(f"Stage3p: probing {len(probe_tags)} tags") probe_variant = (os.environ.get("PSQ_PROBE_PROMPT_VARIANT", "baseline") or "baseline").strip() log(f"Stage3p: prompt_variant={probe_variant}") log(f"Stage3p: model={probe_effective_model}") candidate_display = _build_probe_candidate_display(probe_tags, log=log) try: split_calls = max(1, int((os.environ.get("PSQ_PROBE_SPLIT_CALLS", "1") or "1").strip())) except Exception: split_calls = 1 try: # Default probe cap is 2. # Rationale (caption-evident n=10 evals, 2026-03-12): # - max1 reduced pollution but hurt recall too much # - max3 increased pollution and reduced overall F1 # - max2 provided the best observed precision/recall tradeoff and highest overall F1 max_pick_override = int((os.environ.get("PSQ_PROBE_MAX_PICK_OVERRIDE", "2") or "2").strip()) except Exception: max_pick_override = 0 if max_pick_override < 0: max_pick_override = 0 def _call_chunks( chunks: Sequence[Sequence[str]], *, label: str, ) -> List[str]: out_tags: List[str] = [] out_seen: Set[str] = set() for chunk_idx, chunk_tags_seq in enumerate(chunks, start=1): chunk_tags = list(chunk_tags_seq) if not chunk_tags: continue if len(chunks) > 1 and log: log(f"Stage3p {label}: call {chunk_idx}/{len(chunks)} with {len(chunk_tags)} probe tags") per_call_budget = len(chunk_tags) if max_pick_override > 0: per_call_budget = min(per_call_budget, max_pick_override) if per_call_budget <= 0: if log: log(f"Stage3p {label}: skipping call with budget=0") continue out = llm_select_indices( query_text=query_text, candidates=chunk_tags, max_pick=per_call_budget, log=log, retries=retries, mode="single_shot", chunk_size=max(1, len(chunk_tags)), per_phrase_k=max(1, per_call_budget), temperature=temperature, max_tokens=max_tokens, return_metadata=False, return_diagnostics=False, min_why=None, candidate_display=candidate_display, user_template=_get_probe_user_template(), model_override=probe_model_override, fallback_model_override=probe_fallback_model_override, ) for i in out: if 0 <= i < len(chunk_tags): t = chunk_tags[i] if t not in out_seen: out_seen.add(t) out_tags.append(t) return out_tags probe_chunks = _split_probe_tags_by_bundle(probe_tags, split_calls, log=log) selected: List[str] = _call_chunks(probe_chunks, label="all") selected = _apply_species_anchor_mapping(selected, query_text=query_text, log=log) if log: shown = ", ".join(selected) if selected else "(none)" log(f"Stage3p: selected {len(selected)} probe tags: {shown}") return selected