Claude
Improve topless/bottomless definitions to prevent confusion
f4f71fe
Raw
History Blame
39.1 kB
# psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
"fan_character",
"background_character",
"unnamed_character",
"unknown_character",
"anonymous_character",
"viewer",
"original_character",
})
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
# Ordinal rank: lower = more confident. Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
"explicit": 0,
"strong_implied": 1,
"weak_implied": 2,
"style_or_meta": 3,
"other": 4,
}
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
"explicit": 0.90,
"strong_implied": 0.70,
"weak_implied": 0.45,
"style_or_meta": 0.35,
"other": 0.25,
}
# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.
Select the tags that correspond to content that would be visible or depicted in the described image.
The list contains only valid tags; many of them are irrelevant to the image.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
...
]
}}
Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").
Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""
ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.
These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"explicit\"}},
...
]
}}
Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""
USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}
CANDIDATES (choose by index only):
{candidate_lines}
Select up to {per_call_budget} indices. Output fewer if uncertain.
"""
@dataclass(frozen=True)
class Selected:
i: int
tag: str # canonical tag (underscore form)
why: str
score: float
WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
class Stage3SelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the candidate list.")
why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")
class Stage3SelectionResponse(BaseModel):
selections: List[Stage3SelectionItem] = Field(default_factory=list)
def _build_response_format() -> Dict[str, Any]:
# Strict JSON Schema structured output.
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
"why": {"type": "string", "enum": WHY_ENUM},
},
"required": ["i", "why"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "stage3_selection",
"strict": True,
"schema": schema,
},
}
def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY is not set.\n"
"Set it in your environment before running Stage 3."
)
api_key = SecretStr(cast(str, api_key))
model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
headers: Dict[str, str] = {}
if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
headers["HTTP-Referer"] = referer
if title := os.getenv("OPENROUTER_X_TITLE"):
headers["X-Title"] = title
# OpenRouter OpenAI-compatible endpoint.
return ChatOpenAI(
model=model,
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_tokens,
default_headers=headers,
# Provider-specific request body fields (OpenAI-compatible).
# Response Healing plugin reduces malformed-JSON failures (syntax only).
extra_body={
"response_format": response_format,
"plugins": [{"id": "response-healing"}],
},
)
def _phrase_key_for_candidate(c: Candidate) -> str:
# Deterministic "primary phrase" for grouping.
if c.sources:
return sorted(c.sources)[0]
return ""
def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
"""Round-robin interleave by primary source phrase.
NOTE: counts are used only for ordering; they are NOT shown to the LLM.
"""
groups: Dict[str, List[Candidate]] = {}
for c in cands:
k = _phrase_key_for_candidate(c)
groups.setdefault(k, []).append(c)
for k in groups:
groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)
keys = sorted(groups.keys())
out: List[Candidate] = []
idx = 0
while True:
progressed = False
for k in keys:
if idx < len(groups[k]):
out.append(groups[k][idx])
progressed = True
if not progressed:
break
idx += 1
return out
def _display_tag(tag: str) -> str:
# Display tags with spaces for the LLM, but keep canonical underscores internally.
return tag.replace("_", " ")
def _format_candidates_local(
cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
lines: List[str] = []
idx_to_tag: Dict[int, str] = {}
idx_to_candidate: Dict[int, Candidate] = {}
for j, c in enumerate(cands, start=1):
idx_to_tag[j] = c.tag
idx_to_candidate[j] = c
lines.append(f"{j}. {_display_tag(c.tag)}")
return "\n".join(lines), idx_to_tag, idx_to_candidate
def _phrases_in_call(cands: Sequence[Candidate]) -> int:
s = set()
for c in cands:
for src in c.sources:
s.add(src)
return len(s)
def _parse_validate_map(
parsed: Any,
idx_to_tag: Dict[int, str],
per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
diag = {
"parse_ok": isinstance(parsed, dict),
"invalid_items": 0,
"oob_indices": 0,
"dupe_indices": 0,
"kept": 0,
}
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
diag["parse_ok"] = isinstance(parsed, dict)
if not isinstance(parsed, dict):
return [], diag
selections = parsed.get("selections", [])
if not isinstance(selections, list):
diag["parse_ok"] = False
return [], diag
out: List[Selected] = []
seen_i = set()
for item in selections:
if len(out) >= per_call_budget:
break
if not isinstance(item, dict):
diag["invalid_items"] += 1
continue
i = item.get("i")
why = item.get("why")
if isinstance(i, bool) or not isinstance(i, int):
diag["invalid_items"] += 1
continue
if i in seen_i:
diag["dupe_indices"] += 1
continue
if i not in idx_to_tag:
diag["oob_indices"] += 1
continue
if not isinstance(why, str) or why not in WHY_ENUM:
diag["invalid_items"] += 1
continue
seen_i.add(i)
tag = idx_to_tag[i]
out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))
diag["kept"] = len(out)
return out, diag
def _split_candidates_by_type(
candidates: List[Candidate],
log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
"""Split candidates into general vs entity (character only) lists.
Returns:
(general_list, entity_list) where each item is (original_index, candidate)
Tag types:
- General: 0 (general), 1 (artist), 5 (species), 7 (meta)
- Entity: 4 (character) only
- Filtered: 3 (copyright) - too broad for image generation
"""
general_with_idx: List[Tuple[int, Candidate]] = []
entity_with_idx: List[Tuple[int, Candidate]] = []
unknown_count = 0
copyright_count = 0
generic_char_count = 0
for idx, cand in enumerate(candidates):
type_name = get_tag_type_name(cand.tag)
if type_name == "character":
if cand.tag in _GENERIC_CHARACTER_TAGS:
# Route generic character-category tags to general selection
general_with_idx.append((idx, cand))
generic_char_count += 1
else:
entity_with_idx.append((idx, cand))
elif type_name == "copyright":
# Filter out copyright/series tags - too broad for image generation
copyright_count += 1
elif type_name in ("general", "artist", "species", "meta"):
general_with_idx.append((idx, cand))
else:
# Unknown or None - treat as general by default
general_with_idx.append((idx, cand))
unknown_count += 1
if log:
log(
f"Stage3 split: "
f"general={len(general_with_idx)} "
f"entity={len(entity_with_idx)} "
f"copyright_filtered={copyright_count} "
f"generic_char_to_general={generic_char_count} "
f"unknown_type={unknown_count}"
)
return general_with_idx, entity_with_idx
# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")
def _normalize_for_matching(text: str) -> str:
"""Lowercase, replace underscores with spaces, strip series suffixes."""
text = text.lower().strip()
text = _SERIES_SUFFIX_RE.sub("", text)
text = text.replace("_", " ")
return text
def _query_words(query: str) -> Set[str]:
"""Extract individual words from the user query for matching."""
return set(_normalize_for_matching(query).split())
def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
fuzzy_threshold: int = 85) -> bool:
"""Check if an alias matches the user query.
Matching logic:
1. Exact substring: alias appears as a substring of the query
2. Word subset: all words in the alias appear in the query words
3. Fuzzy: alias is close to a word in the query (handles typos)
"""
# Exact substring match
if alias_norm in query_norm:
return True
alias_words = alias_norm.split()
if not alias_words:
return False
# Word subset match: all alias words must appear in query
if all(w in query_words for w in alias_words):
return True
# For single-word aliases, try fuzzy matching against each query word
if len(alias_words) == 1:
for qw in query_words:
if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
return True
# For multi-word aliases, try fuzzy partial ratio against whole query
if len(alias_words) > 1:
if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
return True
return False
def _character_matches_via_aliases(
tag: str,
query: str,
tag2aliases: Dict[str, List[str]],
query_words: Set[str],
query_norm: str,
fuzzy_threshold: int = 85,
) -> bool:
"""Check if a character tag matches the user query via its aliases.
For a character tag to match:
- The tag name itself (normalized) must match, OR
- At least one of its registered aliases must match.
Empty aliases list means no known aliases; still check the tag name itself.
"""
# Check the tag name itself
tag_norm = _normalize_for_matching(tag)
if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
return True
# Check all registered aliases
aliases = tag2aliases.get(tag, [])
for alias in aliases:
alias_norm = _normalize_for_matching(alias)
if not alias_norm:
continue
if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
return True
return False
def llm_select_indices(
query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION
candidates: Union[
Sequence[Candidate],
Sequence[str],
Sequence[Tuple[str, float]],
],
max_pick: int, # legacy param; applied after union + ordering (optional)
log,
retries: int = 2,
*,
mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union"
chunk_size: int = 60,
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
temperature: float = 0.0,
max_tokens: int = 512,
return_metadata: bool = False,
min_why: Optional[str] = "strong_implied",
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
"""Return indices into the ORIGINAL candidates list (legacy interface).
min_why: if set, only keep tags whose 'why' is at or above this confidence
level. E.g. min_why="explicit" keeps only explicit matches;
min_why="strong_implied" keeps explicit + strong_implied.
Default: "strong_implied".
This implementation uses LangChain ONLY.
NOTE: query_text is treated as the image description (original prompt).
"""
image_description = query_text
# Normalize candidates:
# - preferred: List[Candidate]
# - legacy: List[(tag, sim)] (count/sources unavailable)
norm: List[Candidate] = []
tag_to_first_index: Dict[str, int] = {}
branch = "empty"
cand0_type = type(candidates[0]).__name__ if candidates else "none"
if candidates and isinstance(candidates[0], Candidate):
branch = "candidate"
typed_candidates = cast(Sequence[Candidate], candidates)
for idx, c in enumerate(typed_candidates):
if c.tag not in tag_to_first_index:
tag_to_first_index[c.tag] = idx
norm.append(c)
elif candidates and isinstance(candidates[0], str):
branch = "string"
typed_candidates = cast(Sequence[str], candidates)
for idx, tag in enumerate(typed_candidates):
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=0.0,
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
else:
if candidates:
branch = "tuple"
typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
for idx, row in enumerate(typed_candidates):
if not isinstance(row, (list, tuple)) or len(row) < 2:
raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
tag, sim = row[0], row[1]
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=float(sim),
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
if log:
if norm:
log(
"Stage3 input: "
f"type0={cand0_type} "
f"branch={branch} "
f"norm0_score={norm[0].score_combined!r} "
f"norm0_sources_empty={not bool(norm[0].sources)}"
)
else:
log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")
if mode not in ("single_shot", "chunked_map_union"):
raise ValueError(f"Invalid mode: {mode}")
response_format = _build_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)
# Global union: tag -> best (score, why)
best: Dict[str, Tuple[float, str]] = {}
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
# Create chain with the provided system template
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
ordered = _interleave_round_robin(call_cands)
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
N_local = len(idx_to_tag)
phrases = _phrases_in_call(call_cands)
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
summary_logged = False
if log:
log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
if phrases > 0:
distinct_phrases = sorted({src for c in call_cands for src in c.sources})
log(
f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
f"phrases={', '.join(distinct_phrases)}"
)
# Invoke LangChain chain (templating fills {N} and other vars)
for att in range(retries + 1):
try:
if log:
log(
f"Stage3 {label}: "
f"model={model_name} "
f"N={N_local} "
f"phrases={phrases} "
f"per_call_budget={per_call_budget} "
f"response_healing=on"
)
parsed = chain.invoke(
{
"N": N_local,
"image_description": image_description,
"candidate_lines": candidate_lines,
"per_call_budget": per_call_budget,
}
)
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
if log:
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
if not summary_logged and (selected or att == retries):
log(
f"Stage3 {label}: summary "
f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
)
summary_logged = True
if selected:
lines = [
f"Stage3 {label} selections:",
*[
(
f' - i={s.i} tag="{s.tag}" '
f"why={s.why} score={s.score:.2f} "
f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
)
for s in selected
],
]
log("\n".join(lines))
else:
log(f"Stage3 {label} selections: (none)")
if selected:
for s in selected:
prev = best.get(s.tag)
if prev is None or s.score > prev[0]:
best[s.tag] = (s.score, s.why)
return
except Exception as e:
if log:
log(f"Stage3 {label}: attempt {att+1} error: {e}")
if log:
log(f"Stage3 {label}: gave up after {retries+1} attempts")
# Split candidates by type (general vs entity)
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
# Extract just the candidates for LLM calls
general_cands = [cand for _, cand in general_with_idx]
entity_cands = [cand for _, cand in entity_with_idx]
# Process general candidates (attributes, actions, species, etc.)
if general_cands:
if mode == "single_shot":
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
else:
for start in range(0, len(general_cands), chunk_size):
run_call(
general_cands[start:start + chunk_size],
f"general_chunk_{start//chunk_size}",
SELECT_SYSTEM_TEMPLATE
)
# Process entity candidates (characters only) with alias-based pre-filtering
if entity_cands:
tag2aliases = get_tag2aliases()
qwords = _query_words(image_description)
qnorm = _normalize_for_matching(image_description)
filtered_entity_cands: List[Candidate] = []
filtered_out: List[str] = []
for cand in entity_cands:
if _character_matches_via_aliases(
cand.tag, image_description, tag2aliases, qwords, qnorm
):
filtered_entity_cands.append(cand)
else:
filtered_out.append(cand.tag)
if log:
log(
f"Stage3 entity alias filter: "
f"before={len(entity_cands)} "
f"after={len(filtered_entity_cands)} "
f"removed={len(filtered_out)}"
)
if filtered_out:
log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")
if filtered_entity_cands:
if mode == "single_shot":
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
else:
for start in range(0, len(filtered_entity_cands), chunk_size):
run_call(
filtered_entity_cands[start:start + chunk_size],
f"entity_chunk_{start//chunk_size}",
ENTITY_SYSTEM_TEMPLATE
)
# Apply why threshold: drop tags below the minimum confidence level.
if min_why is not None:
max_rank = WHY_RANK.get(min_why, 4)
before = len(best)
best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
if log:
log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
f"before={before} after={len(best)} dropped={before - len(best)}")
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
# Legacy cap: apply AFTER union + ordering.
if isinstance(max_pick, int) and max_pick > 0:
ordered_tags = ordered_tags[:max_pick]
# Map back to original indices
out_idx: List[int] = []
tag_why: Dict[str, str] = {}
for t in ordered_tags:
if t in tag_to_first_index:
out_idx.append(tag_to_first_index[t])
tag_why[t] = best[t][1] # why string
if return_metadata:
return out_idx, tag_why
return out_idx
# ---------------------------------------------------------------------------
# Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… )
# ---------------------------------------------------------------------------
# Group-based approach: tags are organized into semantic groups loaded from
# tag_groups.json / tag_wiki_defs.json where possible, with curated fallback
# definitions for tags whose wiki entries are only thumbnail references.
#
# Each group specifies a constraint mode:
# "exclusive" = pick exactly one (e.g. character count)
# "multi" = pick all that apply (e.g. body type, gender)
import json as _json
@dataclass
class StructuralGroup:
"""One category of structural tags to probe."""
name: str
constraint: str # "exclusive" or "multi"
tags: List[Tuple[str, str]] # (tag, definition) pairs
def _load_structural_groups() -> List[StructuralGroup]:
"""Build structural groups from curated config + data files.
Uses tag_groups.json for membership and tag_wiki_defs.json for definitions
where text definitions exist; falls back to curated definitions otherwise.
"""
data_dir = Path(__file__).resolve().parents[2] / "data"
# Load wiki definitions (may not exist yet)
wiki_defs: Dict[str, str] = {}
wiki_path = data_dir / "tag_wiki_defs.json"
if wiki_path.is_file():
with wiki_path.open("r", encoding="utf-8") as f:
wiki_defs = _json.load(f)
def _def(tag: str, fallback: str) -> str:
"""Get wiki definition if it's real text, otherwise use fallback."""
d = wiki_defs.get(tag, "")
# Skip thumbnail-only definitions
if not d or d.startswith("thumb ") or len(d) < 15:
return fallback
return d[:200] # cap length for prompt
groups: List[StructuralGroup] = []
# ── Group A: Character Count (exclusive) ──
groups.append(StructuralGroup(
name="character_count",
constraint="exclusive",
tags=[
("zero_pictured", _def("zero_pictured",
"No characters or living beings appear in the image")),
("solo", _def("solo",
"Exactly one character appears in the image")),
("duo", _def("duo",
"Exactly two characters appear in the image")),
("trio", _def("trio",
"Exactly three characters appear in the image")),
("group", _def("group",
"Four or more characters appear in the image")),
],
))
# ── Group B: Body Type (multi — per character) ──
# Key distinction the LLM must learn:
# anthro = ANIMAL with human body shape (upright, hands)
# humanoid = HUMAN or near-human (elf, dwarf) with NO animal features
# feral = normal animal shape, on all fours
groups.append(StructuralGroup(
name="body_type",
constraint="multi",
tags=[
("anthro", _def("anthro",
"An animal character with a human-like body: walks upright on two legs, "
"has arms and hands. Examples: a wolf-person, a fox standing up. "
"Still has animal features like fur, tail, muzzle")),
("feral", _def("feral",
"A regular animal in its natural body shape. Walks on all fours (or "
"flies/swims naturally). NOT standing upright, NOT humanized")),
("humanoid", _def("humanoid",
"A human or human-like character with NO animal features. Includes "
"humans, elves, dwarves, and fantasy races that look human. "
"Does NOT include animal-people — those are anthro")),
("taur", _def("taur",
"A centaur-like body: human or anthro upper body attached to a "
"four-legged animal lower body")),
],
))
# ── Group C: Gender (multi — per character) ──
groups.append(StructuralGroup(
name="gender",
constraint="multi",
tags=[
("male", _def("male",
"A character described as male, a boy, or with he/him pronouns")),
("female", _def("female",
"A character described as female, a girl, or with she/her pronouns")),
("ambiguous_gender", _def("ambiguous_gender",
"A character whose gender is not stated or cannot be determined")),
("intersex", _def("intersex",
"A character explicitly described as intersex or hermaphrodite")),
],
))
# ── Group D: Clothing State (multi) ──
groups.append(StructuralGroup(
name="clothing_state",
constraint="multi",
tags=[
("clothed", _def("clothed",
"Wearing clothes on BOTH chest/torso AND legs/waist. "
"Examples: shirt and pants, dress, full outfit")),
("nude", _def("nude",
"Wearing NO clothes at all. Completely naked, no shirt and no pants")),
("topless", _def("topless",
"NO shirt/top (bare chest), BUT wearing pants/bottoms. "
"Upper body exposed, lower body covered")),
("bottomless", _def("bottomless",
"Wearing shirt/top on chest, BUT NO pants/bottoms. "
"Upper body covered, lower body exposed")),
],
))
# ── Group E: Common Visual Elements (multi) ──
groups.append(StructuralGroup(
name="visual_elements",
constraint="multi",
tags=[
("looking_at_viewer", _def("looking_at_viewer",
"A character is looking directly at the camera or viewer")),
("text", _def("text",
"The image contains visible writing, words, or lettering")),
],
))
return groups
def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]:
"""Build numbered statement list from structural groups.
Returns (formatted_text, flat_list_of_(tag, definition)_pairs).
The flat list maps 1-based statement numbers to tags.
"""
lines: List[str] = []
flat: List[Tuple[str, str]] = []
idx = 1
for g in groups:
constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply"
group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---"
lines.append(group_header)
for tag, defn in g.tags:
lines.append(f"{idx}. {defn}")
flat.append((tag, defn))
idx += 1
lines.append("") # blank line between groups
return "\n".join(lines), flat
STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list.
The statements are organized into GROUPS. Each group header tells you how many to pick:
- "pick EXACTLY ONE" = choose the single best match in that group
- "pick ALL that apply" = choose every statement that is true
IMPORTANT RULES:
1. ONLY select a statement if the description directly says it or makes it very obvious.
2. Do NOT guess or assume things the description does not mention.
3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid.
4. If the description never mentions gender, pick "gender cannot be determined".
5. For clothing state: READ CAREFULLY! "topless" = bare chest, wearing pants. "bottomless" = wearing shirt, no pants. If unsure, re-read the description.
6. If clothing is not mentioned, do NOT pick any clothing statement.
Return JSON ONLY:
{{"selections": [{{"i": 1}}, {{"i": 5}}]}}
EXAMPLE:
Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up"
Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}}
Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14)."""
STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.
IMAGE DESCRIPTION:
{image_description}
STATEMENTS (pick by number):
{statement_lines}"""
class StructuralSelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the statement list.")
class StructuralSelectionResponse(BaseModel):
selections: List[StructuralSelectionItem] = Field(default_factory=list)
def _build_structural_response_format() -> Dict[str, Any]:
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
},
"required": ["i"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "structural_selection",
"strict": True,
"schema": schema,
},
}
# Cache the loaded groups so we only read JSON files once per process.
_cached_structural_groups: Optional[List[StructuralGroup]] = None
def _get_structural_groups() -> List[StructuralGroup]:
global _cached_structural_groups
if _cached_structural_groups is None:
_cached_structural_groups = _load_structural_groups()
return _cached_structural_groups
def llm_infer_structural_tags(
query_text: str,
log=None,
*,
temperature: float = 0.0,
max_tokens: int = 512,
retries: int = 2,
) -> List[str]:
"""Infer structural tags via LLM using group-based statement agreement.
Probes multiple semantic groups (character count, body type, gender,
clothing state, visual elements) with definitions loaded from wiki data
where available.
Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]).
"""
if log:
log("Stage3s (structural): inferring structural tags via group-based statement agreement")
groups = _get_structural_groups()
statement_lines, flat_tags = _build_structural_prompt(groups)
N = len(flat_tags)
response_format = _build_structural_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)
prompt = ChatPromptTemplate.from_messages(
[
("system", STRUCTURAL_SYSTEM_TEMPLATE),
("human", STRUCTURAL_USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
if log:
group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups)
log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}")
for att in range(retries + 1):
try:
parsed = chain.invoke({
"N": N,
"image_description": query_text,
"statement_lines": statement_lines,
})
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
chosen_tags: List[str] = []
seen: Set[str] = set()
for item in sels:
idx = item.get("i") if isinstance(item, dict) else None
if not isinstance(idx, int) or idx < 1 or idx > N:
continue
tag = flat_tags[idx - 1][0]
if tag not in seen:
chosen_tags.append(tag)
seen.add(tag)
if log:
tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")
return chosen_tags
except Exception as e:
if log:
log(f"Stage3s: attempt {att+1} error: {e}")
if log:
log(f"Stage3s: gave up after {retries+1} attempts")
return []