Spaces:
Running
Running
| # psq_rag/llm/select.py | |
| # Stage 3: Closed-Set Selection (LangChain-only implementation) | |
| # | |
| # This module intentionally uses LangChain for: | |
| # - prompt templating (including {N}) | |
| # - LLM call orchestration | |
| # - JSON parsing | |
| # | |
| # There is NO fallback path. If LangChain dependencies are missing, this module | |
| # should fail loudly so you install them. | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from pydantic import BaseModel, Field, SecretStr | |
| from rapidfuzz import fuzz | |
| from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources) | |
| from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases | |
| # Character-typed tags that are generic categories, not actual named characters. | |
| # These leak through the alias filter because they match common words in captions. | |
| # They are excluded from the entity pipeline and instead routed to general selection. | |
| _GENERIC_CHARACTER_TAGS = frozenset({ | |
| "fan_character", | |
| "background_character", | |
| "unnamed_character", | |
| "unknown_character", | |
| "anonymous_character", | |
| "viewer", | |
| "original_character", | |
| }) | |
| WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] | |
| # Ordinal rank: lower = more confident. Used for threshold filtering. | |
| WHY_RANK: Dict[str, int] = { | |
| "explicit": 0, | |
| "strong_implied": 1, | |
| "weak_implied": 2, | |
| "style_or_meta": 3, | |
| "other": 4, | |
| } | |
| # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug. | |
| WHY_TO_SCORE: Dict[str, float] = { | |
| "explicit": 0.90, | |
| "strong_implied": 0.70, | |
| "weak_implied": 0.45, | |
| "style_or_meta": 0.35, | |
| "other": 0.25, | |
| } | |
| # IMPORTANT ABOUT TEMPLATING: | |
| # - This string is rendered by LangChain's f-string template engine. | |
| # - Literal JSON braces must be escaped as {{ and }}. | |
| # - {N} is a real template variable and MUST be provided. | |
| SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags. | |
| Select the tags that correspond to content that would be visible or depicted in the described image. | |
| The list contains only valid tags; many of them are irrelevant to the image. | |
| Return JSON ONLY matching this schema: | |
| {{ | |
| \"selections\": [ | |
| {{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}}, | |
| ... | |
| ] | |
| }} | |
| Rules: | |
| - Choose ONLY from indices 1..{N}. | |
| - Do NOT output tag text. | |
| - Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\". | |
| - Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\"). | |
| Define \"why\" as: | |
| - explicit: directly stated in the image description | |
| - strong_implied: very likely given the description, even if not literally stated | |
| - weak_implied: plausible but not strongly supported by the description | |
| - style_or_meta: stylistic or presentation-related tags only if clearly indicated | |
| - other: fallback category; use sparingly | |
| """ | |
| ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags. | |
| These character tags have already been pre-filtered to only include characters whose names | |
| (or known aliases) appear in the image description. Your job is to confirm which of these | |
| pre-filtered candidates are the correct match for the character mentioned by the user. | |
| Return JSON ONLY matching this schema: | |
| {{ | |
| \"selections\": [ | |
| {{\"i\": <int>, \"why\": \"explicit\"}}, | |
| ... | |
| ] | |
| }} | |
| Rules for character selection: | |
| - Choose ONLY from indices 1..{N}. | |
| - Do NOT output tag text. | |
| - Always use \"why\": \"explicit\" for all selections. | |
| - Select the tag that best represents the character as described. | |
| - If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"), | |
| select that specific variant tag. | |
| - If the user described only the base character (e.g. just \"pikachu\"), select only the | |
| base/default tag, NOT costume or variant tags. | |
| - When uncertain between variants, prefer the simplest/most general tag. | |
| """ | |
| USER_TEMPLATE = """IMAGE DESCRIPTION: | |
| {image_description} | |
| CANDIDATES (choose by index only): | |
| {candidate_lines} | |
| Select up to {per_call_budget} indices. Output fewer if uncertain. | |
| """ | |
| class Selected: | |
| i: int | |
| tag: str # canonical tag (underscore form) | |
| why: str | |
| score: float | |
| WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"] | |
| class Stage3SelectionItem(BaseModel): | |
| i: int = Field(..., description="1-based index into the candidate list.") | |
| why: WhyLiteral = Field(..., description="Rationale code from the allowed set.") | |
| class Stage3SelectionResponse(BaseModel): | |
| selections: List[Stage3SelectionItem] = Field(default_factory=list) | |
| def _build_response_format() -> Dict[str, Any]: | |
| # Strict JSON Schema structured output. | |
| schema = { | |
| "type": "object", | |
| "properties": { | |
| "selections": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "i": {"type": "integer"}, | |
| "why": {"type": "string", "enum": WHY_ENUM}, | |
| }, | |
| "required": ["i", "why"], | |
| "additionalProperties": False, | |
| }, | |
| } | |
| }, | |
| "required": ["selections"], | |
| "additionalProperties": False, | |
| } | |
| return { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "stage3_selection", | |
| "strict": True, | |
| "schema": schema, | |
| }, | |
| } | |
| def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI: | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "OPENROUTER_API_KEY is not set.\n" | |
| "Set it in your environment before running Stage 3." | |
| ) | |
| api_key = SecretStr(cast(str, api_key)) | |
| model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") | |
| headers: Dict[str, str] = {} | |
| if referer := os.getenv("OPENROUTER_HTTP_REFERER"): | |
| headers["HTTP-Referer"] = referer | |
| if title := os.getenv("OPENROUTER_X_TITLE"): | |
| headers["X-Title"] = title | |
| # OpenRouter OpenAI-compatible endpoint. | |
| return ChatOpenAI( | |
| model=model, | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key, | |
| temperature=temperature, | |
| max_completion_tokens=max_tokens, | |
| default_headers=headers, | |
| # Provider-specific request body fields (OpenAI-compatible). | |
| # Response Healing plugin reduces malformed-JSON failures (syntax only). | |
| extra_body={ | |
| "response_format": response_format, | |
| "plugins": [{"id": "response-healing"}], | |
| }, | |
| ) | |
| def _phrase_key_for_candidate(c: Candidate) -> str: | |
| # Deterministic "primary phrase" for grouping. | |
| if c.sources: | |
| return sorted(c.sources)[0] | |
| return "" | |
| def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]: | |
| """Round-robin interleave by primary source phrase. | |
| NOTE: counts are used only for ordering; they are NOT shown to the LLM. | |
| """ | |
| groups: Dict[str, List[Candidate]] = {} | |
| for c in cands: | |
| k = _phrase_key_for_candidate(c) | |
| groups.setdefault(k, []).append(c) | |
| for k in groups: | |
| groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True) | |
| keys = sorted(groups.keys()) | |
| out: List[Candidate] = [] | |
| idx = 0 | |
| while True: | |
| progressed = False | |
| for k in keys: | |
| if idx < len(groups[k]): | |
| out.append(groups[k][idx]) | |
| progressed = True | |
| if not progressed: | |
| break | |
| idx += 1 | |
| return out | |
| def _display_tag(tag: str) -> str: | |
| # Display tags with spaces for the LLM, but keep canonical underscores internally. | |
| return tag.replace("_", " ") | |
| def _format_candidates_local( | |
| cands: Sequence[Candidate], | |
| ) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]: | |
| lines: List[str] = [] | |
| idx_to_tag: Dict[int, str] = {} | |
| idx_to_candidate: Dict[int, Candidate] = {} | |
| for j, c in enumerate(cands, start=1): | |
| idx_to_tag[j] = c.tag | |
| idx_to_candidate[j] = c | |
| lines.append(f"{j}. {_display_tag(c.tag)}") | |
| return "\n".join(lines), idx_to_tag, idx_to_candidate | |
| def _phrases_in_call(cands: Sequence[Candidate]) -> int: | |
| s = set() | |
| for c in cands: | |
| for src in c.sources: | |
| s.add(src) | |
| return len(s) | |
| def _parse_validate_map( | |
| parsed: Any, | |
| idx_to_tag: Dict[int, str], | |
| per_call_budget: int, | |
| ) -> Tuple[List[Selected], Dict[str, Any]]: | |
| diag = { | |
| "parse_ok": isinstance(parsed, dict), | |
| "invalid_items": 0, | |
| "oob_indices": 0, | |
| "dupe_indices": 0, | |
| "kept": 0, | |
| } | |
| if isinstance(parsed, BaseModel): | |
| parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() | |
| diag["parse_ok"] = isinstance(parsed, dict) | |
| if not isinstance(parsed, dict): | |
| return [], diag | |
| selections = parsed.get("selections", []) | |
| if not isinstance(selections, list): | |
| diag["parse_ok"] = False | |
| return [], diag | |
| out: List[Selected] = [] | |
| seen_i = set() | |
| for item in selections: | |
| if len(out) >= per_call_budget: | |
| break | |
| if not isinstance(item, dict): | |
| diag["invalid_items"] += 1 | |
| continue | |
| i = item.get("i") | |
| why = item.get("why") | |
| if isinstance(i, bool) or not isinstance(i, int): | |
| diag["invalid_items"] += 1 | |
| continue | |
| if i in seen_i: | |
| diag["dupe_indices"] += 1 | |
| continue | |
| if i not in idx_to_tag: | |
| diag["oob_indices"] += 1 | |
| continue | |
| if not isinstance(why, str) or why not in WHY_ENUM: | |
| diag["invalid_items"] += 1 | |
| continue | |
| seen_i.add(i) | |
| tag = idx_to_tag[i] | |
| out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why])) | |
| diag["kept"] = len(out) | |
| return out, diag | |
| def _split_candidates_by_type( | |
| candidates: List[Candidate], | |
| log, | |
| ) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]: | |
| """Split candidates into general vs entity (character only) lists. | |
| Returns: | |
| (general_list, entity_list) where each item is (original_index, candidate) | |
| Tag types: | |
| - General: 0 (general), 1 (artist), 5 (species), 7 (meta) | |
| - Entity: 4 (character) only | |
| - Filtered: 3 (copyright) - too broad for image generation | |
| """ | |
| general_with_idx: List[Tuple[int, Candidate]] = [] | |
| entity_with_idx: List[Tuple[int, Candidate]] = [] | |
| unknown_count = 0 | |
| copyright_count = 0 | |
| generic_char_count = 0 | |
| for idx, cand in enumerate(candidates): | |
| type_name = get_tag_type_name(cand.tag) | |
| if type_name == "character": | |
| if cand.tag in _GENERIC_CHARACTER_TAGS: | |
| # Route generic character-category tags to general selection | |
| general_with_idx.append((idx, cand)) | |
| generic_char_count += 1 | |
| else: | |
| entity_with_idx.append((idx, cand)) | |
| elif type_name == "copyright": | |
| # Filter out copyright/series tags - too broad for image generation | |
| copyright_count += 1 | |
| elif type_name in ("general", "artist", "species", "meta"): | |
| general_with_idx.append((idx, cand)) | |
| else: | |
| # Unknown or None - treat as general by default | |
| general_with_idx.append((idx, cand)) | |
| unknown_count += 1 | |
| if log: | |
| log( | |
| f"Stage3 split: " | |
| f"general={len(general_with_idx)} " | |
| f"entity={len(entity_with_idx)} " | |
| f"copyright_filtered={copyright_count} " | |
| f"generic_char_to_general={generic_char_count} " | |
| f"unknown_type={unknown_count}" | |
| ) | |
| return general_with_idx, entity_with_idx | |
| # Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character) | |
| _SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$") | |
| def _normalize_for_matching(text: str) -> str: | |
| """Lowercase, replace underscores with spaces, strip series suffixes.""" | |
| text = text.lower().strip() | |
| text = _SERIES_SUFFIX_RE.sub("", text) | |
| text = text.replace("_", " ") | |
| return text | |
| def _query_words(query: str) -> Set[str]: | |
| """Extract individual words from the user query for matching.""" | |
| return set(_normalize_for_matching(query).split()) | |
| def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str, | |
| fuzzy_threshold: int = 85) -> bool: | |
| """Check if an alias matches the user query. | |
| Matching logic: | |
| 1. Exact substring: alias appears as a substring of the query | |
| 2. Word subset: all words in the alias appear in the query words | |
| 3. Fuzzy: alias is close to a word in the query (handles typos) | |
| """ | |
| # Exact substring match | |
| if alias_norm in query_norm: | |
| return True | |
| alias_words = alias_norm.split() | |
| if not alias_words: | |
| return False | |
| # Word subset match: all alias words must appear in query | |
| if all(w in query_words for w in alias_words): | |
| return True | |
| # For single-word aliases, try fuzzy matching against each query word | |
| if len(alias_words) == 1: | |
| for qw in query_words: | |
| if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold: | |
| return True | |
| # For multi-word aliases, try fuzzy partial ratio against whole query | |
| if len(alias_words) > 1: | |
| if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold: | |
| return True | |
| return False | |
| def _character_matches_via_aliases( | |
| tag: str, | |
| query: str, | |
| tag2aliases: Dict[str, List[str]], | |
| query_words: Set[str], | |
| query_norm: str, | |
| fuzzy_threshold: int = 85, | |
| ) -> bool: | |
| """Check if a character tag matches the user query via its aliases. | |
| For a character tag to match: | |
| - The tag name itself (normalized) must match, OR | |
| - At least one of its registered aliases must match. | |
| Empty aliases list means no known aliases; still check the tag name itself. | |
| """ | |
| # Check the tag name itself | |
| tag_norm = _normalize_for_matching(tag) | |
| if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold): | |
| return True | |
| # Check all registered aliases | |
| aliases = tag2aliases.get(tag, []) | |
| for alias in aliases: | |
| alias_norm = _normalize_for_matching(alias) | |
| if not alias_norm: | |
| continue | |
| if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold): | |
| return True | |
| return False | |
| def llm_select_indices( | |
| query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION | |
| candidates: Union[ | |
| Sequence[Candidate], | |
| Sequence[str], | |
| Sequence[Tuple[str, float]], | |
| ], | |
| max_pick: int, # legacy param; applied after union + ordering (optional) | |
| log, | |
| retries: int = 2, | |
| *, | |
| mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union" | |
| chunk_size: int = 60, | |
| per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call | |
| temperature: float = 0.0, | |
| max_tokens: int = 512, | |
| return_metadata: bool = False, | |
| min_why: Optional[str] = "strong_implied", | |
| ) -> Union[List[int], Tuple[List[int], Dict[str, str]]]: | |
| """Return indices into the ORIGINAL candidates list (legacy interface). | |
| min_why: if set, only keep tags whose 'why' is at or above this confidence | |
| level. E.g. min_why="explicit" keeps only explicit matches; | |
| min_why="strong_implied" keeps explicit + strong_implied. | |
| Default: "strong_implied". | |
| This implementation uses LangChain ONLY. | |
| NOTE: query_text is treated as the image description (original prompt). | |
| """ | |
| image_description = query_text | |
| # Normalize candidates: | |
| # - preferred: List[Candidate] | |
| # - legacy: List[(tag, sim)] (count/sources unavailable) | |
| norm: List[Candidate] = [] | |
| tag_to_first_index: Dict[str, int] = {} | |
| branch = "empty" | |
| cand0_type = type(candidates[0]).__name__ if candidates else "none" | |
| if candidates and isinstance(candidates[0], Candidate): | |
| branch = "candidate" | |
| typed_candidates = cast(Sequence[Candidate], candidates) | |
| for idx, c in enumerate(typed_candidates): | |
| if c.tag not in tag_to_first_index: | |
| tag_to_first_index[c.tag] = idx | |
| norm.append(c) | |
| elif candidates and isinstance(candidates[0], str): | |
| branch = "string" | |
| typed_candidates = cast(Sequence[str], candidates) | |
| for idx, tag in enumerate(typed_candidates): | |
| if tag not in tag_to_first_index: | |
| tag_to_first_index[tag] = idx | |
| norm.append( | |
| Candidate( | |
| tag=tag, | |
| score_combined=0.0, | |
| score_fasttext=None, | |
| score_context=None, | |
| count=None, | |
| sources=[], | |
| ) | |
| ) | |
| else: | |
| if candidates: | |
| branch = "tuple" | |
| typed_candidates = cast(Sequence[Tuple[str, float]], candidates) | |
| for idx, row in enumerate(typed_candidates): | |
| if not isinstance(row, (list, tuple)) or len(row) < 2: | |
| raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.") | |
| tag, sim = row[0], row[1] | |
| if tag not in tag_to_first_index: | |
| tag_to_first_index[tag] = idx | |
| norm.append( | |
| Candidate( | |
| tag=tag, | |
| score_combined=float(sim), | |
| score_fasttext=None, | |
| score_context=None, | |
| count=None, | |
| sources=[], | |
| ) | |
| ) | |
| if log: | |
| if norm: | |
| log( | |
| "Stage3 input: " | |
| f"type0={cand0_type} " | |
| f"branch={branch} " | |
| f"norm0_score={norm[0].score_combined!r} " | |
| f"norm0_sources_empty={not bool(norm[0].sources)}" | |
| ) | |
| else: | |
| log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)") | |
| if mode not in ("single_shot", "chunked_map_union"): | |
| raise ValueError(f"Invalid mode: {mode}") | |
| response_format = _build_response_format() | |
| llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format) | |
| model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") | |
| parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse) | |
| # Global union: tag -> best (score, why) | |
| best: Dict[str, Tuple[float, str]] = {} | |
| def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None: | |
| # Create chain with the provided system template | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_template), | |
| ("human", USER_TEMPLATE), | |
| ], | |
| template_format="f-string", | |
| ) | |
| chain = prompt | llm | parser | |
| ordered = _interleave_round_robin(call_cands) | |
| candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered) | |
| N_local = len(idx_to_tag) | |
| phrases = _phrases_in_call(call_cands) | |
| per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k | |
| summary_logged = False | |
| if log: | |
| log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}") | |
| if phrases > 0: | |
| distinct_phrases = sorted({src for c in call_cands for src in c.sources}) | |
| log( | |
| f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} " | |
| f"phrases={', '.join(distinct_phrases)}" | |
| ) | |
| # Invoke LangChain chain (templating fills {N} and other vars) | |
| for att in range(retries + 1): | |
| try: | |
| if log: | |
| log( | |
| f"Stage3 {label}: " | |
| f"model={model_name} " | |
| f"N={N_local} " | |
| f"phrases={phrases} " | |
| f"per_call_budget={per_call_budget} " | |
| f"response_healing=on" | |
| ) | |
| parsed = chain.invoke( | |
| { | |
| "N": N_local, | |
| "image_description": image_description, | |
| "candidate_lines": candidate_lines, | |
| "per_call_budget": per_call_budget, | |
| } | |
| ) | |
| selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget) | |
| if log: | |
| log(f"Stage3 {label}: attempt {att+1} diag={diag}") | |
| if not summary_logged and (selected or att == retries): | |
| log( | |
| f"Stage3 {label}: summary " | |
| f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}" | |
| ) | |
| summary_logged = True | |
| if selected: | |
| lines = [ | |
| f"Stage3 {label} selections:", | |
| *[ | |
| ( | |
| f' - i={s.i} tag="{s.tag}" ' | |
| f"why={s.why} score={s.score:.2f} " | |
| f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}" | |
| ) | |
| for s in selected | |
| ], | |
| ] | |
| log("\n".join(lines)) | |
| else: | |
| log(f"Stage3 {label} selections: (none)") | |
| if selected: | |
| for s in selected: | |
| prev = best.get(s.tag) | |
| if prev is None or s.score > prev[0]: | |
| best[s.tag] = (s.score, s.why) | |
| return | |
| except Exception as e: | |
| if log: | |
| log(f"Stage3 {label}: attempt {att+1} error: {e}") | |
| if log: | |
| log(f"Stage3 {label}: gave up after {retries+1} attempts") | |
| # Split candidates by type (general vs entity) | |
| general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log) | |
| # Extract just the candidates for LLM calls | |
| general_cands = [cand for _, cand in general_with_idx] | |
| entity_cands = [cand for _, cand in entity_with_idx] | |
| # Process general candidates (attributes, actions, species, etc.) | |
| if general_cands: | |
| if mode == "single_shot": | |
| run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE) | |
| else: | |
| for start in range(0, len(general_cands), chunk_size): | |
| run_call( | |
| general_cands[start:start + chunk_size], | |
| f"general_chunk_{start//chunk_size}", | |
| SELECT_SYSTEM_TEMPLATE | |
| ) | |
| # Process entity candidates (characters only) with alias-based pre-filtering | |
| if entity_cands: | |
| tag2aliases = get_tag2aliases() | |
| qwords = _query_words(image_description) | |
| qnorm = _normalize_for_matching(image_description) | |
| filtered_entity_cands: List[Candidate] = [] | |
| filtered_out: List[str] = [] | |
| for cand in entity_cands: | |
| if _character_matches_via_aliases( | |
| cand.tag, image_description, tag2aliases, qwords, qnorm | |
| ): | |
| filtered_entity_cands.append(cand) | |
| else: | |
| filtered_out.append(cand.tag) | |
| if log: | |
| log( | |
| f"Stage3 entity alias filter: " | |
| f"before={len(entity_cands)} " | |
| f"after={len(filtered_entity_cands)} " | |
| f"removed={len(filtered_out)}" | |
| ) | |
| if filtered_out: | |
| log(f"Stage3 entity alias filter removed: {filtered_out[:20]}") | |
| if filtered_entity_cands: | |
| if mode == "single_shot": | |
| run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE) | |
| else: | |
| for start in range(0, len(filtered_entity_cands), chunk_size): | |
| run_call( | |
| filtered_entity_cands[start:start + chunk_size], | |
| f"entity_chunk_{start//chunk_size}", | |
| ENTITY_SYSTEM_TEMPLATE | |
| ) | |
| # Apply why threshold: drop tags below the minimum confidence level. | |
| if min_why is not None: | |
| max_rank = WHY_RANK.get(min_why, 4) | |
| before = len(best) | |
| best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank} | |
| if log: | |
| log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), " | |
| f"before={before} after={len(best)} dropped={before - len(best)}") | |
| # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM). | |
| count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm} | |
| ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True) | |
| # Legacy cap: apply AFTER union + ordering. | |
| if isinstance(max_pick, int) and max_pick > 0: | |
| ordered_tags = ordered_tags[:max_pick] | |
| # Map back to original indices | |
| out_idx: List[int] = [] | |
| tag_why: Dict[str, str] = {} | |
| for t in ordered_tags: | |
| if t in tag_to_first_index: | |
| out_idx.append(tag_to_first_index[t]) | |
| tag_why[t] = best[t][1] # why string | |
| if return_metadata: | |
| return out_idx, tag_why | |
| return out_idx | |
| # --------------------------------------------------------------------------- | |
| # Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… ) | |
| # --------------------------------------------------------------------------- | |
| # Group-based approach: tags are organized into semantic groups loaded from | |
| # tag_groups.json / tag_wiki_defs.json where possible, with curated fallback | |
| # definitions for tags whose wiki entries are only thumbnail references. | |
| # | |
| # Each group specifies a constraint mode: | |
| # "exclusive" = pick exactly one (e.g. character count) | |
| # "multi" = pick all that apply (e.g. body type, gender) | |
| import json as _json | |
| class StructuralGroup: | |
| """One category of structural tags to probe.""" | |
| name: str | |
| constraint: str # "exclusive" or "multi" | |
| tags: List[Tuple[str, str]] # (tag, definition) pairs | |
| def _load_structural_groups() -> List[StructuralGroup]: | |
| """Build structural groups from curated config + data files. | |
| Uses tag_groups.json for membership and tag_wiki_defs.json for definitions | |
| where text definitions exist; falls back to curated definitions otherwise. | |
| """ | |
| data_dir = Path(__file__).resolve().parents[2] / "data" | |
| # Load wiki definitions (may not exist yet) | |
| wiki_defs: Dict[str, str] = {} | |
| wiki_path = data_dir / "tag_wiki_defs.json" | |
| if wiki_path.is_file(): | |
| with wiki_path.open("r", encoding="utf-8") as f: | |
| wiki_defs = _json.load(f) | |
| def _def(tag: str, fallback: str) -> str: | |
| """Get wiki definition if it's real text, otherwise use fallback.""" | |
| d = wiki_defs.get(tag, "") | |
| # Skip thumbnail-only definitions | |
| if not d or d.startswith("thumb ") or len(d) < 15: | |
| return fallback | |
| return d[:200] # cap length for prompt | |
| groups: List[StructuralGroup] = [] | |
| # ── Group A: Character Count (exclusive) ── | |
| groups.append(StructuralGroup( | |
| name="character_count", | |
| constraint="exclusive", | |
| tags=[ | |
| ("zero_pictured", _def("zero_pictured", | |
| "No characters or living beings appear in the image")), | |
| ("solo", _def("solo", | |
| "Exactly one character appears in the image")), | |
| ("duo", _def("duo", | |
| "Exactly two characters appear in the image")), | |
| ("trio", _def("trio", | |
| "Exactly three characters appear in the image")), | |
| ("group", _def("group", | |
| "Four or more characters appear in the image")), | |
| ], | |
| )) | |
| # ── Group B: Body Type (multi — per character) ── | |
| # Key distinction the LLM must learn: | |
| # anthro = ANIMAL with human body shape (upright, hands) | |
| # humanoid = HUMAN or near-human (elf, dwarf) with NO animal features | |
| # feral = normal animal shape, on all fours | |
| groups.append(StructuralGroup( | |
| name="body_type", | |
| constraint="multi", | |
| tags=[ | |
| ("anthro", _def("anthro", | |
| "An animal character with a human-like body: walks upright on two legs, " | |
| "has arms and hands. Examples: a wolf-person, a fox standing up. " | |
| "Still has animal features like fur, tail, muzzle")), | |
| ("feral", _def("feral", | |
| "A regular animal in its natural body shape. Walks on all fours (or " | |
| "flies/swims naturally). NOT standing upright, NOT humanized")), | |
| ("humanoid", _def("humanoid", | |
| "A human or human-like character with NO animal features. Includes " | |
| "humans, elves, dwarves, and fantasy races that look human. " | |
| "Does NOT include animal-people — those are anthro")), | |
| ("taur", _def("taur", | |
| "A centaur-like body: human or anthro upper body attached to a " | |
| "four-legged animal lower body")), | |
| ], | |
| )) | |
| # ── Group C: Gender (multi — per character) ── | |
| groups.append(StructuralGroup( | |
| name="gender", | |
| constraint="multi", | |
| tags=[ | |
| ("male", _def("male", | |
| "A character described as male, a boy, or with he/him pronouns")), | |
| ("female", _def("female", | |
| "A character described as female, a girl, or with she/her pronouns")), | |
| ("ambiguous_gender", _def("ambiguous_gender", | |
| "A character whose gender is not stated or cannot be determined")), | |
| ("intersex", _def("intersex", | |
| "A character explicitly described as intersex or hermaphrodite")), | |
| ], | |
| )) | |
| # ── Group D: Clothing State (multi) ── | |
| groups.append(StructuralGroup( | |
| name="clothing_state", | |
| constraint="multi", | |
| tags=[ | |
| ("clothed", _def("clothed", | |
| "Wearing clothes on BOTH chest/torso AND legs/waist. " | |
| "Examples: shirt and pants, dress, full outfit")), | |
| ("nude", _def("nude", | |
| "Wearing NO clothes at all. Completely naked, no shirt and no pants")), | |
| ("topless", _def("topless", | |
| "NO shirt/top (bare chest), BUT wearing pants/bottoms. " | |
| "Upper body exposed, lower body covered")), | |
| ("bottomless", _def("bottomless", | |
| "Wearing shirt/top on chest, BUT NO pants/bottoms. " | |
| "Upper body covered, lower body exposed")), | |
| ], | |
| )) | |
| # ── Group E: Common Visual Elements (multi) ── | |
| groups.append(StructuralGroup( | |
| name="visual_elements", | |
| constraint="multi", | |
| tags=[ | |
| ("looking_at_viewer", _def("looking_at_viewer", | |
| "A character is looking directly at the camera or viewer")), | |
| ("text", _def("text", | |
| "The image contains visible writing, words, or lettering")), | |
| ], | |
| )) | |
| return groups | |
| def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]: | |
| """Build numbered statement list from structural groups. | |
| Returns (formatted_text, flat_list_of_(tag, definition)_pairs). | |
| The flat list maps 1-based statement numbers to tags. | |
| """ | |
| lines: List[str] = [] | |
| flat: List[Tuple[str, str]] = [] | |
| idx = 1 | |
| for g in groups: | |
| constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply" | |
| group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---" | |
| lines.append(group_header) | |
| for tag, defn in g.tags: | |
| lines.append(f"{idx}. {defn}") | |
| flat.append((tag, defn)) | |
| idx += 1 | |
| lines.append("") # blank line between groups | |
| return "\n".join(lines), flat | |
| STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list. | |
| The statements are organized into GROUPS. Each group header tells you how many to pick: | |
| - "pick EXACTLY ONE" = choose the single best match in that group | |
| - "pick ALL that apply" = choose every statement that is true | |
| IMPORTANT RULES: | |
| 1. ONLY select a statement if the description directly says it or makes it very obvious. | |
| 2. Do NOT guess or assume things the description does not mention. | |
| 3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid. | |
| 4. If the description never mentions gender, pick "gender cannot be determined". | |
| 5. For clothing state: READ CAREFULLY! "topless" = bare chest, wearing pants. "bottomless" = wearing shirt, no pants. If unsure, re-read the description. | |
| 6. If clothing is not mentioned, do NOT pick any clothing statement. | |
| Return JSON ONLY: | |
| {{"selections": [{{"i": 1}}, {{"i": 5}}]}} | |
| EXAMPLE: | |
| Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up" | |
| Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}} | |
| Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14).""" | |
| STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true. | |
| IMAGE DESCRIPTION: | |
| {image_description} | |
| STATEMENTS (pick by number): | |
| {statement_lines}""" | |
| class StructuralSelectionItem(BaseModel): | |
| i: int = Field(..., description="1-based index into the statement list.") | |
| class StructuralSelectionResponse(BaseModel): | |
| selections: List[StructuralSelectionItem] = Field(default_factory=list) | |
| def _build_structural_response_format() -> Dict[str, Any]: | |
| schema = { | |
| "type": "object", | |
| "properties": { | |
| "selections": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "i": {"type": "integer"}, | |
| }, | |
| "required": ["i"], | |
| "additionalProperties": False, | |
| }, | |
| } | |
| }, | |
| "required": ["selections"], | |
| "additionalProperties": False, | |
| } | |
| return { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "structural_selection", | |
| "strict": True, | |
| "schema": schema, | |
| }, | |
| } | |
| # Cache the loaded groups so we only read JSON files once per process. | |
| _cached_structural_groups: Optional[List[StructuralGroup]] = None | |
| def _get_structural_groups() -> List[StructuralGroup]: | |
| global _cached_structural_groups | |
| if _cached_structural_groups is None: | |
| _cached_structural_groups = _load_structural_groups() | |
| return _cached_structural_groups | |
| def llm_infer_structural_tags( | |
| query_text: str, | |
| log=None, | |
| *, | |
| temperature: float = 0.0, | |
| max_tokens: int = 512, | |
| retries: int = 2, | |
| ) -> List[str]: | |
| """Infer structural tags via LLM using group-based statement agreement. | |
| Probes multiple semantic groups (character count, body type, gender, | |
| clothing state, visual elements) with definitions loaded from wiki data | |
| where available. | |
| Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]). | |
| """ | |
| if log: | |
| log("Stage3s (structural): inferring structural tags via group-based statement agreement") | |
| groups = _get_structural_groups() | |
| statement_lines, flat_tags = _build_structural_prompt(groups) | |
| N = len(flat_tags) | |
| response_format = _build_structural_response_format() | |
| llm = _get_llm(temperature=temperature, max_tokens=max_tokens, | |
| response_format=response_format) | |
| model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct") | |
| parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", STRUCTURAL_SYSTEM_TEMPLATE), | |
| ("human", STRUCTURAL_USER_TEMPLATE), | |
| ], | |
| template_format="f-string", | |
| ) | |
| chain = prompt | llm | parser | |
| if log: | |
| group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups) | |
| log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}") | |
| for att in range(retries + 1): | |
| try: | |
| parsed = chain.invoke({ | |
| "N": N, | |
| "image_description": query_text, | |
| "statement_lines": statement_lines, | |
| }) | |
| if isinstance(parsed, BaseModel): | |
| parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict() | |
| sels = parsed.get("selections", []) if isinstance(parsed, dict) else [] | |
| chosen_tags: List[str] = [] | |
| seen: Set[str] = set() | |
| for item in sels: | |
| idx = item.get("i") if isinstance(item, dict) else None | |
| if not isinstance(idx, int) or idx < 1 or idx > N: | |
| continue | |
| tag = flat_tags[idx - 1][0] | |
| if tag not in seen: | |
| chosen_tags.append(tag) | |
| seen.add(tag) | |
| if log: | |
| tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)" | |
| log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}") | |
| return chosen_tags | |
| except Exception as e: | |
| if log: | |
| log(f"Stage3s: attempt {att+1} error: {e}") | |
| if log: | |
| log(f"Stage3s: gave up after {retries+1} attempts") | |
| return [] | |