Spaces:
Running
Running
| # psq_rag/llm/select.py | |
| # Stage 3: Closed-Set Selection (LangChain-only implementation) | |
| # | |
| # This module intentionally uses LangChain for: | |
| # - prompt templating (including {N}) | |
| # - LLM call orchestration | |
| # - JSON parsing | |
| # | |
| # There is NO fallback path. If LangChain dependencies are missing, this module | |
| # should fail loudly so you install them. | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from pydantic import BaseModel, Field, SecretStr | |
| from rapidfuzz import fuzz | |
| from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources) | |
| from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases | |
| # Character-typed tags that are generic categories, not actual named characters. | |
| # These leak through the alias filter because they match common words in captions. | |
| # They are excluded from the entity pipeline and instead routed to general selection. | |
| _GENERIC_CHARACTER_TAGS = frozenset({ | |
| "fan_character", | |
| "background_character", | |
| "unnamed_character", | |
| "unknown_character", | |
| "anonymous_character", | |
| "viewer", | |
| "original_character", | |
| }) | |
| WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] | |
| # Ordinal rank: lower = more confident. Used for threshold filtering. | |
| WHY_RANK: Dict[str, int] = { | |
| "explicit": 0, | |
| "strong_implied": 1, | |
| "weak_implied": 2, | |
| "style_or_meta": 3, | |
| "other": 4, | |
| } | |
| # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug. | |
| WHY_TO_SCORE: Dict[str, float] = { | |
| "explicit": 0.90, | |
| "strong_implied": 0.70, | |
| "weak_implied": 0.45, | |
| "style_or_meta": 0.35, | |
| "other": 0.25, | |
| } | |
| # IMPORTANT ABOUT TEMPLATING: | |
| # - This string is rendered by LangChain's f-string template engine. | |
| # - Literal JSON braces must be escaped as {{ and }}. | |
| # - {N} is a real template variable and MUST be provided. | |
| SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags. | |
| Select the tags that correspond to content that would be visible or depicted in the described image. | |
| The list contains only valid tags; many of them are irrelevant to the image. | |
| Return JSON ONLY matching this schema: | |
| {{ | |
| \"selections\": [ | |
| {{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}}, | |
| ... | |
| ] | |
| }} | |
| Rules: | |
| - Choose ONLY from indices 1..{N}. | |
| - Do NOT output tag text. | |
| - Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\". | |
| - Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\"). | |
| Define \"why\" as: | |
| - explicit: directly stated in the image description | |
| - strong_implied: very likely given the description, even if not literally stated | |
| - weak_implied: plausible but not strongly supported by the description | |
| - style_or_meta: stylistic or presentation-related tags only if clearly indicated | |
| - other: fallback category; use sparingly | |
| """ | |
| ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags. | |
| These character tags have already been pre-filtered to only include characters whose names | |
| (or known aliases) appear in the image description. Your job is to confirm which of these | |
| pre-filtered candidates are the correct match for the character mentioned by the user. | |
| Return JSON ONLY matching this schema: | |
| {{ | |
| \"selections\": [ | |
| {{\"i\": <int>, \"why\": \"explicit\"}}, | |
| ... | |
| ] | |
| }} | |
| Rules for character selection: | |
| - Choose ONLY from indices 1..{N}. | |
| - Do NOT output tag text. | |
| - Always use \"why\": \"explicit\" for all selections. | |
| - Select the tag that best represents the character as described. | |
| - If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"), | |
| select that specific variant tag. | |
| - If the user described only the base character (e.g. just \"pikachu\"), select only the | |
| base/default tag, NOT costume or variant tags. | |
| - When uncertain between variants, prefer the simplest/most general tag. | |
| """ | |
| USER_TEMPLATE = """IMAGE DESCRIPTION: | |
| {image_description} | |
| CANDIDATES (choose by index only): | |
| {candidate_lines} | |
| Select up to {per_call_budget} indices. Output fewer if uncertain. | |
| """ | |
| class Selected: | |
| i: int | |
| tag: str # canonical tag (underscore form) | |
| why: str | |
| score: float | |
| WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] | |
| class Stage3SelectionItem(BaseModel): | |
| i: int = Field(..., description="1-based index into the candidate list.") | |
| why: WhyLiteral = Field(..., description="Rationale code from the allowed set.") | |
| class Stage3SelectionResponse(BaseModel): | |
| selections: List[Stage3SelectionItem] = Field(default_factory=list) | |
| def _build_response_format() -> Dict[str, Any]: | |
| # Strict JSON Schema structured output. | |
| schema = { | |
| "type": "object", | |
| "properties": { | |
| "selections": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "i": {"type": "integer"}, | |
| "why": {"type": "string", "enum": WHY_ENUM}, | |
| }, | |
| "required": ["i", "why"], | |
| "additionalProperties": False, | |
| }, | |
| } | |
| }, | |
| "required": ["selections"], | |
| "additionalProperties": False, | |
| } | |
| return { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "stage3_selection", | |
| "strict": True, | |
| "schema": schema, | |
| }, | |
| } | |
| def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI: | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "OPENROUTER_API_KEY is not set.\n" | |
| "Set it in your environment before running Stage 3." | |
| ) | |
| api_key = SecretStr(cast(str, api_key)) | |
| model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") | |
| headers: Dict[str, str] = {} | |
| if referer := os.getenv("OPENROUTER_HTTP_REFERER"): | |
| headers["HTTP-Referer"] = referer | |
| if title := os.getenv("OPENROUTER_X_TITLE"): | |
| headers["X-Title"] = title | |
| # OpenRouter OpenAI-compatible endpoint. | |
| return ChatOpenAI( | |
| model=model, | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key, | |
| temperature=temperature, | |
| max_completion_tokens=max_tokens, | |
| default_headers=headers, | |
| # Provider-specific request body fields (OpenAI-compatible). | |
| # Response Healing plugin reduces malformed-JSON failures (syntax only). | |
| extra_body={ | |
| "response_format": response_format, | |
| "plugins": [{"id": "response-healing"}], | |
| }, | |
| ) | |
| def _phrase_key_for_candidate(c: Candidate) -> str: | |
| # Deterministic "primary phrase" for grouping. | |
| if c.sources: | |
| return sorted(c.sources)[0] | |
| return "" | |
| def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]: | |
| """Round-robin interleave by primary source phrase. | |
| NOTE: counts are used only for ordering; they are NOT shown to the LLM. | |
| """ | |
| groups: Dict[str, List[Candidate]] = {} | |
| for c in cands: | |
| k = _phrase_key_for_candidate(c) | |
| groups.setdefault(k, []).append(c) | |
| for k in groups: | |
| groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True) | |
| keys = sorted(groups.keys()) | |
| out: List[Candidate] = [] | |
| idx = 0 | |
| while True: | |
| progressed = False | |
| for k in keys: | |
| if idx < len(groups[k]): | |
| out.append(groups[k][idx]) | |
| progressed = True | |
| if not progressed: | |
| break | |
| idx += 1 | |
| return out | |
| def _display_tag(tag: str) -> str: | |
| # Display tags with spaces for the LLM, but keep canonical underscores internally. | |
| return tag.replace("_", " ") | |
| def _format_candidates_local( | |
| cands: Sequence[Candidate], | |
| ) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]: | |
| lines: List[str] = [] | |
| idx_to_tag: Dict[int, str] = {} | |
| idx_to_candidate: Dict[int, Candidate] = {} | |
| for j, c in enumerate(cands, start=1): | |
| idx_to_tag[j] = c.tag | |
| idx_to_candidate[j] = c | |
| lines.append(f"{j}. {_display_tag(c.tag)}") | |
| return "\n".join(lines), idx_to_tag, idx_to_candidate | |
| def _phrases_in_call(cands: Sequence[Candidate]) -> int: | |
| s = set() | |
| for c in cands: | |
| for src in c.sources: | |
| s.add(src) | |
| return len(s) | |
| def _parse_validate_map( | |
| parsed: Any, | |
| idx_to_tag: Dict[int, str], | |
| per_call_budget: int, | |
| ) -> Tuple[List[Selected], Dict[str, Any]]: | |
| diag = { | |
| "parse_ok": isinstance(parsed, dict), | |
| "invalid_items": 0, | |
| "oob_indices": 0, | |
| "dupe_indices": 0, | |
| "kept": 0, | |
| } | |
| if isinstance(parsed, BaseModel): | |
| parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() | |
| diag["parse_ok"] = isinstance(parsed, dict) | |
| if not isinstance(parsed, dict): | |
| return [], diag | |
| selections = parsed.get("selections", []) | |
| if not isinstance(selections, list): | |
| diag["parse_ok"] = False | |
| return [], diag | |
| out: List[Selected] = [] | |
| seen_i = set() | |
| for item in selections: | |
| if len(out) >= per_call_budget: | |
| break | |
| if not isinstance(item, dict): | |
| diag["invalid_items"] += 1 | |
| continue | |
| i = item.get("i") | |
| why = item.get("why") | |
| if isinstance(i, bool) or not isinstance(i, int): | |
| diag["invalid_items"] += 1 | |
| continue | |
| if i in seen_i: | |
| diag["dupe_indices"] += 1 | |
| continue | |
| if i not in idx_to_tag: | |
| diag["oob_indices"] += 1 | |
| continue | |
| if not isinstance(why, str) or why not in WHY_ENUM: | |
| diag["invalid_items"] += 1 | |
| continue | |
| seen_i.add(i) | |
| tag = idx_to_tag[i] | |
| out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why])) | |
| diag["kept"] = len(out) | |
| return out, diag | |
| def _split_candidates_by_type( | |
| candidates: List[Candidate], | |
| log, | |
| ) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]: | |
| """Split candidates into general vs entity (character only) lists. | |
| Returns: | |
| (general_list, entity_list) where each item is (original_index, candidate) | |
| Tag types: | |
| - General: 0 (general), 1 (artist), 5 (species), 7 (meta) | |
| - Entity: 4 (character) only | |
| - Filtered: 3 (copyright) - too broad for image generation | |
| """ | |
| general_with_idx: List[Tuple[int, Candidate]] = [] | |
| entity_with_idx: List[Tuple[int, Candidate]] = [] | |
| unknown_count = 0 | |
| copyright_count = 0 | |
| generic_char_count = 0 | |
| for idx, cand in enumerate(candidates): | |
| type_name = get_tag_type_name(cand.tag) | |
| if type_name == "character": | |
| if cand.tag in _GENERIC_CHARACTER_TAGS: | |
| # Route generic character-category tags to general selection | |
| general_with_idx.append((idx, cand)) | |
| generic_char_count += 1 | |
| else: | |
| entity_with_idx.append((idx, cand)) | |
| elif type_name == "copyright": | |
| # Filter out copyright/series tags - too broad for image generation | |
| copyright_count += 1 | |
| elif type_name in ("general", "artist", "species", "meta"): | |
| general_with_idx.append((idx, cand)) | |
| else: | |
| # Unknown or None - treat as general by default | |
| general_with_idx.append((idx, cand)) | |
| unknown_count += 1 | |
| if log: | |
| log( | |
| f"Stage3 split: " | |
| f"general={len(general_with_idx)} " | |
| f"entity={len(entity_with_idx)} " | |
| f"copyright_filtered={copyright_count} " | |
| f"generic_char_to_general={generic_char_count} " | |
| f"unknown_type={unknown_count}" | |
| ) | |
| return general_with_idx, entity_with_idx | |
| # Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character) | |
| _SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$") | |
| def _normalize_for_matching(text: str) -> str: | |
| """Lowercase, replace underscores with spaces, strip series suffixes.""" | |
| text = text.lower().strip() | |
| text = _SERIES_SUFFIX_RE.sub("", text) | |
| text = text.replace("_", " ") | |
| return text | |
| def _query_words(query: str) -> Set[str]: | |
| """Extract individual words from the user query for matching.""" | |
| return set(_normalize_for_matching(query).split()) | |
| def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str, | |
| fuzzy_threshold: int = 85) -> bool: | |
| """Check if an alias matches the user query. | |
| Matching logic: | |
| 1. Exact substring: alias appears as a substring of the query | |
| 2. Word subset: all words in the alias appear in the query words | |
| 3. Fuzzy: alias is close to a word in the query (handles typos) | |
| """ | |
| # Exact substring match | |
| if alias_norm in query_norm: | |
| return True | |
| alias_words = alias_norm.split() | |
| if not alias_words: | |
| return False | |
| # Word subset match: all alias words must appear in query | |
| if all(w in query_words for w in alias_words): | |
| return True | |
| # For single-word aliases, try fuzzy matching against each query word | |
| if len(alias_words) == 1: | |
| for qw in query_words: | |
| if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold: | |
| return True | |
| # For multi-word aliases, try fuzzy partial ratio against whole query | |
| if len(alias_words) > 1: | |
| if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold: | |
| return True | |
| return False | |
| def _character_matches_via_aliases( | |
| tag: str, | |
| query: str, | |
| tag2aliases: Dict[str, List[str]], | |
| query_words: Set[str], | |
| query_norm: str, | |
| fuzzy_threshold: int = 85, | |
| ) -> bool: | |
| """Check if a character tag matches the user query via its aliases. | |
| For a character tag to match: | |
| - The tag name itself (normalized) must match, OR | |
| - At least one of its registered aliases must match. | |
| Empty aliases list means no known aliases; still check the tag name itself. | |
| """ | |
| # Check the tag name itself | |
| tag_norm = _normalize_for_matching(tag) | |
| if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold): | |
| return True | |
| # Check all registered aliases | |
| aliases = tag2aliases.get(tag, []) | |
| for alias in aliases: | |
| alias_norm = _normalize_for_matching(alias) | |
| if not alias_norm: | |
| continue | |
| if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold): | |
| return True | |
| return False | |
| def llm_select_indices( | |
| query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION | |
| candidates: Union[ | |
| Sequence[Candidate], | |
| Sequence[str], | |
| Sequence[Tuple[str, float]], | |
| ], | |
| max_pick: int, # legacy param; applied after union + ordering (optional) | |
| log, | |
| retries: int = 2, | |
| *, | |
| mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union" | |
| chunk_size: int = 60, | |
| per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call | |
| temperature: float = 0.0, | |
| max_tokens: int = 512, | |
| return_metadata: bool = False, | |
| min_why: Optional[str] = None, | |
| ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]: | |
| """Return indices into the ORIGINAL candidates list (legacy interface). | |
| min_why: if set, only keep tags whose 'why' is at or above this confidence | |
| level. E.g. min_why="explicit" keeps only explicit matches; | |
| min_why="strong_implied" keeps explicit + strong_implied. | |
| This implementation uses LangChain ONLY. | |
| NOTE: query_text is treated as the image description (original prompt). | |
| """ | |
| image_description = query_text | |
| # Normalize candidates: | |
| # - preferred: List[Candidate] | |
| # - legacy: List[(tag, sim)] (count/sources unavailable) | |
| norm: List[Candidate] = [] | |
| tag_to_first_index: Dict[str, int] = {} | |
| branch = "empty" | |
| cand0_type = type(candidates[0]).__name__ if candidates else "none" | |
| if candidates and isinstance(candidates[0], Candidate): | |
| branch = "candidate" | |
| typed_candidates = cast(Sequence[Candidate], candidates) | |
| for idx, c in enumerate(typed_candidates): | |
| if c.tag not in tag_to_first_index: | |
| tag_to_first_index[c.tag] = idx | |
| norm.append(c) | |
| elif candidates and isinstance(candidates[0], str): | |
| branch = "string" | |
| typed_candidates = cast(Sequence[str], candidates) | |
| for idx, tag in enumerate(typed_candidates): | |
| if tag not in tag_to_first_index: | |
| tag_to_first_index[tag] = idx | |
| norm.append( | |
| Candidate( | |
| tag=tag, | |
| score_combined=0.0, | |
| score_fasttext=None, | |
| score_context=None, | |
| count=None, | |
| sources=[], | |
| ) | |
| ) | |
| else: | |
| if candidates: | |
| branch = "tuple" | |
| typed_candidates = cast(Sequence[Tuple[str, float]], candidates) | |
| for idx, row in enumerate(typed_candidates): | |
| if not isinstance(row, (list, tuple)) or len(row) < 2: | |
| raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.") | |
| tag, sim = row[0], row[1] | |
| if tag not in tag_to_first_index: | |
| tag_to_first_index[tag] = idx | |
| norm.append( | |
| Candidate( | |
| tag=tag, | |
| score_combined=float(sim), | |
| score_fasttext=None, | |
| score_context=None, | |
| count=None, | |
| sources=[], | |
| ) | |
| ) | |
| if log: | |
| if norm: | |
| log( | |
| "Stage3 input: " | |
| f"type0={cand0_type} " | |
| f"branch={branch} " | |
| f"norm0_score={norm[0].score_combined!r} " | |
| f"norm0_sources_empty={not bool(norm[0].sources)}" | |
| ) | |
| else: | |
| log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)") | |
| if mode not in ("single_shot", "chunked_map_union"): | |
| raise ValueError(f"Invalid mode: {mode}") | |
| response_format = _build_response_format() | |
| llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format) | |
| model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") | |
| parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse) | |
| # Global union: tag -> best (score, why) | |
| best: Dict[str, Tuple[float, str]] = {} | |
| def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None: | |
| # Create chain with the provided system template | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_template), | |
| ("human", USER_TEMPLATE), | |
| ], | |
| template_format="f-string", | |
| ) | |
| chain = prompt | llm | parser | |
| ordered = _interleave_round_robin(call_cands) | |
| candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered) | |
| N_local = len(idx_to_tag) | |
| phrases = _phrases_in_call(call_cands) | |
| per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k | |
| summary_logged = False | |
| if log: | |
| log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}") | |
| if phrases > 0: | |
| distinct_phrases = sorted({src for c in call_cands for src in c.sources}) | |
| log( | |
| f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} " | |
| f"phrases={', '.join(distinct_phrases)}" | |
| ) | |
| # Invoke LangChain chain (templating fills {N} and other vars) | |
| for att in range(retries + 1): | |
| try: | |
| if log: | |
| log( | |
| f"Stage3 {label}: " | |
| f"model={model_name} " | |
| f"N={N_local} " | |
| f"phrases={phrases} " | |
| f"per_call_budget={per_call_budget} " | |
| f"response_healing=on" | |
| ) | |
| parsed = chain.invoke( | |
| { | |
| "N": N_local, | |
| "image_description": image_description, | |
| "candidate_lines": candidate_lines, | |
| "per_call_budget": per_call_budget, | |
| } | |
| ) | |
| selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget) | |
| if log: | |
| log(f"Stage3 {label}: attempt {att+1} diag={diag}") | |
| if not summary_logged and (selected or att == retries): | |
| log( | |
| f"Stage3 {label}: summary " | |
| f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}" | |
| ) | |
| summary_logged = True | |
| if selected: | |
| lines = [ | |
| f"Stage3 {label} selections:", | |
| *[ | |
| ( | |
| f' - i={s.i} tag="{s.tag}" ' | |
| f"why={s.why} score={s.score:.2f} " | |
| f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}" | |
| ) | |
| for s in selected | |
| ], | |
| ] | |
| log("\n".join(lines)) | |
| else: | |
| log(f"Stage3 {label} selections: (none)") | |
| if selected: | |
| for s in selected: | |
| prev = best.get(s.tag) | |
| if prev is None or s.score > prev[0]: | |
| best[s.tag] = (s.score, s.why) | |
| return | |
| except Exception as e: | |
| if log: | |
| log(f"Stage3 {label}: attempt {att+1} error: {e}") | |
| if log: | |
| log(f"Stage3 {label}: gave up after {retries+1} attempts") | |
| # Split candidates by type (general vs entity) | |
| general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log) | |
| # Extract just the candidates for LLM calls | |
| general_cands = [cand for _, cand in general_with_idx] | |
| entity_cands = [cand for _, cand in entity_with_idx] | |
| # Process general candidates (attributes, actions, species, etc.) | |
| if general_cands: | |
| if mode == "single_shot": | |
| run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE) | |
| else: | |
| for start in range(0, len(general_cands), chunk_size): | |
| run_call( | |
| general_cands[start:start + chunk_size], | |
| f"general_chunk_{start//chunk_size}", | |
| SELECT_SYSTEM_TEMPLATE | |
| ) | |
| # Process entity candidates (characters only) with alias-based pre-filtering | |
| if entity_cands: | |
| tag2aliases = get_tag2aliases() | |
| qwords = _query_words(image_description) | |
| qnorm = _normalize_for_matching(image_description) | |
| filtered_entity_cands: List[Candidate] = [] | |
| filtered_out: List[str] = [] | |
| for cand in entity_cands: | |
| if _character_matches_via_aliases( | |
| cand.tag, image_description, tag2aliases, qwords, qnorm | |
| ): | |
| filtered_entity_cands.append(cand) | |
| else: | |
| filtered_out.append(cand.tag) | |
| if log: | |
| log( | |
| f"Stage3 entity alias filter: " | |
| f"before={len(entity_cands)} " | |
| f"after={len(filtered_entity_cands)} " | |
| f"removed={len(filtered_out)}" | |
| ) | |
| if filtered_out: | |
| log(f"Stage3 entity alias filter removed: {filtered_out[:20]}") | |
| if filtered_entity_cands: | |
| if mode == "single_shot": | |
| run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE) | |
| else: | |
| for start in range(0, len(filtered_entity_cands), chunk_size): | |
| run_call( | |
| filtered_entity_cands[start:start + chunk_size], | |
| f"entity_chunk_{start//chunk_size}", | |
| ENTITY_SYSTEM_TEMPLATE | |
| ) | |
| # Apply why threshold: drop tags below the minimum confidence level. | |
| if min_why is not None: | |
| max_rank = WHY_RANK.get(min_why, 4) | |
| before = len(best) | |
| best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank} | |
| if log: | |
| log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), " | |
| f"before={before} after={len(best)} dropped={before - len(best)}") | |
| # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM). | |
| count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm} | |
| ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True) | |
| # Legacy cap: apply AFTER union + ordering. | |
| if isinstance(max_pick, int) and max_pick > 0: | |
| ordered_tags = ordered_tags[:max_pick] | |
| # Map back to original indices | |
| out_idx: List[int] = [] | |
| tag_why: Dict[str, str] = {} | |
| for t in ordered_tags: | |
| if t in tag_to_first_index: | |
| out_idx.append(tag_to_first_index[t]) | |
| tag_why[t] = best[t][1] # why string | |
| if return_metadata: | |
| return out_idx, tag_why | |
| return out_idx | |