Claude
Rewrite structural inference prompt for better Llama 3.1 8B performance
46fe384
Raw
History Blame
33.2 kB
# psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
"fan_character",
"background_character",
"unnamed_character",
"unknown_character",
"anonymous_character",
"viewer",
"original_character",
})
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
# Ordinal rank: lower = more confident. Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
"explicit": 0,
"strong_implied": 1,
"weak_implied": 2,
"style_or_meta": 3,
"other": 4,
}
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
"explicit": 0.90,
"strong_implied": 0.70,
"weak_implied": 0.45,
"style_or_meta": 0.35,
"other": 0.25,
}
# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.
Select the tags that correspond to content that would be visible or depicted in the described image.
The list contains only valid tags; many of them are irrelevant to the image.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
...
]
}}
Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").
Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""
ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.
These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"explicit\"}},
...
]
}}
Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""
USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}
CANDIDATES (choose by index only):
{candidate_lines}
Select up to {per_call_budget} indices. Output fewer if uncertain.
"""
@dataclass(frozen=True)
class Selected:
i: int
tag: str # canonical tag (underscore form)
why: str
score: float
WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
class Stage3SelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the candidate list.")
why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")
class Stage3SelectionResponse(BaseModel):
selections: List[Stage3SelectionItem] = Field(default_factory=list)
def _build_response_format() -> Dict[str, Any]:
# Strict JSON Schema structured output.
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
"why": {"type": "string", "enum": WHY_ENUM},
},
"required": ["i", "why"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "stage3_selection",
"strict": True,
"schema": schema,
},
}
def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY is not set.\n"
"Set it in your environment before running Stage 3."
)
api_key = SecretStr(cast(str, api_key))
model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
headers: Dict[str, str] = {}
if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
headers["HTTP-Referer"] = referer
if title := os.getenv("OPENROUTER_X_TITLE"):
headers["X-Title"] = title
# OpenRouter OpenAI-compatible endpoint.
return ChatOpenAI(
model=model,
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_tokens,
default_headers=headers,
# Provider-specific request body fields (OpenAI-compatible).
# Response Healing plugin reduces malformed-JSON failures (syntax only).
extra_body={
"response_format": response_format,
"plugins": [{"id": "response-healing"}],
},
)
def _phrase_key_for_candidate(c: Candidate) -> str:
# Deterministic "primary phrase" for grouping.
if c.sources:
return sorted(c.sources)[0]
return ""
def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
"""Round-robin interleave by primary source phrase.
NOTE: counts are used only for ordering; they are NOT shown to the LLM.
"""
groups: Dict[str, List[Candidate]] = {}
for c in cands:
k = _phrase_key_for_candidate(c)
groups.setdefault(k, []).append(c)
for k in groups:
groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)
keys = sorted(groups.keys())
out: List[Candidate] = []
idx = 0
while True:
progressed = False
for k in keys:
if idx < len(groups[k]):
out.append(groups[k][idx])
progressed = True
if not progressed:
break
idx += 1
return out
def _display_tag(tag: str) -> str:
# Display tags with spaces for the LLM, but keep canonical underscores internally.
return tag.replace("_", " ")
def _format_candidates_local(
cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
lines: List[str] = []
idx_to_tag: Dict[int, str] = {}
idx_to_candidate: Dict[int, Candidate] = {}
for j, c in enumerate(cands, start=1):
idx_to_tag[j] = c.tag
idx_to_candidate[j] = c
lines.append(f"{j}. {_display_tag(c.tag)}")
return "\n".join(lines), idx_to_tag, idx_to_candidate
def _phrases_in_call(cands: Sequence[Candidate]) -> int:
s = set()
for c in cands:
for src in c.sources:
s.add(src)
return len(s)
def _parse_validate_map(
parsed: Any,
idx_to_tag: Dict[int, str],
per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
diag = {
"parse_ok": isinstance(parsed, dict),
"invalid_items": 0,
"oob_indices": 0,
"dupe_indices": 0,
"kept": 0,
}
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
diag["parse_ok"] = isinstance(parsed, dict)
if not isinstance(parsed, dict):
return [], diag
selections = parsed.get("selections", [])
if not isinstance(selections, list):
diag["parse_ok"] = False
return [], diag
out: List[Selected] = []
seen_i = set()
for item in selections:
if len(out) >= per_call_budget:
break
if not isinstance(item, dict):
diag["invalid_items"] += 1
continue
i = item.get("i")
why = item.get("why")
if isinstance(i, bool) or not isinstance(i, int):
diag["invalid_items"] += 1
continue
if i in seen_i:
diag["dupe_indices"] += 1
continue
if i not in idx_to_tag:
diag["oob_indices"] += 1
continue
if not isinstance(why, str) or why not in WHY_ENUM:
diag["invalid_items"] += 1
continue
seen_i.add(i)
tag = idx_to_tag[i]
out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))
diag["kept"] = len(out)
return out, diag
def _split_candidates_by_type(
candidates: List[Candidate],
log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
"""Split candidates into general vs entity (character only) lists.
Returns:
(general_list, entity_list) where each item is (original_index, candidate)
Tag types:
- General: 0 (general), 1 (artist), 5 (species), 7 (meta)
- Entity: 4 (character) only
- Filtered: 3 (copyright) - too broad for image generation
"""
general_with_idx: List[Tuple[int, Candidate]] = []
entity_with_idx: List[Tuple[int, Candidate]] = []
unknown_count = 0
copyright_count = 0
generic_char_count = 0
for idx, cand in enumerate(candidates):
type_name = get_tag_type_name(cand.tag)
if type_name == "character":
if cand.tag in _GENERIC_CHARACTER_TAGS:
# Route generic character-category tags to general selection
general_with_idx.append((idx, cand))
generic_char_count += 1
else:
entity_with_idx.append((idx, cand))
elif type_name == "copyright":
# Filter out copyright/series tags - too broad for image generation
copyright_count += 1
elif type_name in ("general", "artist", "species", "meta"):
general_with_idx.append((idx, cand))
else:
# Unknown or None - treat as general by default
general_with_idx.append((idx, cand))
unknown_count += 1
if log:
log(
f"Stage3 split: "
f"general={len(general_with_idx)} "
f"entity={len(entity_with_idx)} "
f"copyright_filtered={copyright_count} "
f"generic_char_to_general={generic_char_count} "
f"unknown_type={unknown_count}"
)
return general_with_idx, entity_with_idx
# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")
def _normalize_for_matching(text: str) -> str:
"""Lowercase, replace underscores with spaces, strip series suffixes."""
text = text.lower().strip()
text = _SERIES_SUFFIX_RE.sub("", text)
text = text.replace("_", " ")
return text
def _query_words(query: str) -> Set[str]:
"""Extract individual words from the user query for matching."""
return set(_normalize_for_matching(query).split())
def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
fuzzy_threshold: int = 85) -> bool:
"""Check if an alias matches the user query.
Matching logic:
1. Exact substring: alias appears as a substring of the query
2. Word subset: all words in the alias appear in the query words
3. Fuzzy: alias is close to a word in the query (handles typos)
"""
# Exact substring match
if alias_norm in query_norm:
return True
alias_words = alias_norm.split()
if not alias_words:
return False
# Word subset match: all alias words must appear in query
if all(w in query_words for w in alias_words):
return True
# For single-word aliases, try fuzzy matching against each query word
if len(alias_words) == 1:
for qw in query_words:
if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
return True
# For multi-word aliases, try fuzzy partial ratio against whole query
if len(alias_words) > 1:
if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
return True
return False
def _character_matches_via_aliases(
tag: str,
query: str,
tag2aliases: Dict[str, List[str]],
query_words: Set[str],
query_norm: str,
fuzzy_threshold: int = 85,
) -> bool:
"""Check if a character tag matches the user query via its aliases.
For a character tag to match:
- The tag name itself (normalized) must match, OR
- At least one of its registered aliases must match.
Empty aliases list means no known aliases; still check the tag name itself.
"""
# Check the tag name itself
tag_norm = _normalize_for_matching(tag)
if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
return True
# Check all registered aliases
aliases = tag2aliases.get(tag, [])
for alias in aliases:
alias_norm = _normalize_for_matching(alias)
if not alias_norm:
continue
if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
return True
return False
def llm_select_indices(
query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION
candidates: Union[
Sequence[Candidate],
Sequence[str],
Sequence[Tuple[str, float]],
],
max_pick: int, # legacy param; applied after union + ordering (optional)
log,
retries: int = 2,
*,
mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union"
chunk_size: int = 60,
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
temperature: float = 0.0,
max_tokens: int = 512,
return_metadata: bool = False,
min_why: Optional[str] = "strong_implied",
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
"""Return indices into the ORIGINAL candidates list (legacy interface).
min_why: if set, only keep tags whose 'why' is at or above this confidence
level. E.g. min_why="explicit" keeps only explicit matches;
min_why="strong_implied" keeps explicit + strong_implied.
Default: "strong_implied".
This implementation uses LangChain ONLY.
NOTE: query_text is treated as the image description (original prompt).
"""
image_description = query_text
# Normalize candidates:
# - preferred: List[Candidate]
# - legacy: List[(tag, sim)] (count/sources unavailable)
norm: List[Candidate] = []
tag_to_first_index: Dict[str, int] = {}
branch = "empty"
cand0_type = type(candidates[0]).__name__ if candidates else "none"
if candidates and isinstance(candidates[0], Candidate):
branch = "candidate"
typed_candidates = cast(Sequence[Candidate], candidates)
for idx, c in enumerate(typed_candidates):
if c.tag not in tag_to_first_index:
tag_to_first_index[c.tag] = idx
norm.append(c)
elif candidates and isinstance(candidates[0], str):
branch = "string"
typed_candidates = cast(Sequence[str], candidates)
for idx, tag in enumerate(typed_candidates):
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=0.0,
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
else:
if candidates:
branch = "tuple"
typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
for idx, row in enumerate(typed_candidates):
if not isinstance(row, (list, tuple)) or len(row) < 2:
raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
tag, sim = row[0], row[1]
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=float(sim),
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
if log:
if norm:
log(
"Stage3 input: "
f"type0={cand0_type} "
f"branch={branch} "
f"norm0_score={norm[0].score_combined!r} "
f"norm0_sources_empty={not bool(norm[0].sources)}"
)
else:
log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")
if mode not in ("single_shot", "chunked_map_union"):
raise ValueError(f"Invalid mode: {mode}")
response_format = _build_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)
# Global union: tag -> best (score, why)
best: Dict[str, Tuple[float, str]] = {}
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
# Create chain with the provided system template
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
ordered = _interleave_round_robin(call_cands)
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
N_local = len(idx_to_tag)
phrases = _phrases_in_call(call_cands)
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
summary_logged = False
if log:
log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
if phrases > 0:
distinct_phrases = sorted({src for c in call_cands for src in c.sources})
log(
f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
f"phrases={', '.join(distinct_phrases)}"
)
# Invoke LangChain chain (templating fills {N} and other vars)
for att in range(retries + 1):
try:
if log:
log(
f"Stage3 {label}: "
f"model={model_name} "
f"N={N_local} "
f"phrases={phrases} "
f"per_call_budget={per_call_budget} "
f"response_healing=on"
)
parsed = chain.invoke(
{
"N": N_local,
"image_description": image_description,
"candidate_lines": candidate_lines,
"per_call_budget": per_call_budget,
}
)
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
if log:
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
if not summary_logged and (selected or att == retries):
log(
f"Stage3 {label}: summary "
f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
)
summary_logged = True
if selected:
lines = [
f"Stage3 {label} selections:",
*[
(
f' - i={s.i} tag="{s.tag}" '
f"why={s.why} score={s.score:.2f} "
f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
)
for s in selected
],
]
log("\n".join(lines))
else:
log(f"Stage3 {label} selections: (none)")
if selected:
for s in selected:
prev = best.get(s.tag)
if prev is None or s.score > prev[0]:
best[s.tag] = (s.score, s.why)
return
except Exception as e:
if log:
log(f"Stage3 {label}: attempt {att+1} error: {e}")
if log:
log(f"Stage3 {label}: gave up after {retries+1} attempts")
# Split candidates by type (general vs entity)
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
# Extract just the candidates for LLM calls
general_cands = [cand for _, cand in general_with_idx]
entity_cands = [cand for _, cand in entity_with_idx]
# Process general candidates (attributes, actions, species, etc.)
if general_cands:
if mode == "single_shot":
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
else:
for start in range(0, len(general_cands), chunk_size):
run_call(
general_cands[start:start + chunk_size],
f"general_chunk_{start//chunk_size}",
SELECT_SYSTEM_TEMPLATE
)
# Process entity candidates (characters only) with alias-based pre-filtering
if entity_cands:
tag2aliases = get_tag2aliases()
qwords = _query_words(image_description)
qnorm = _normalize_for_matching(image_description)
filtered_entity_cands: List[Candidate] = []
filtered_out: List[str] = []
for cand in entity_cands:
if _character_matches_via_aliases(
cand.tag, image_description, tag2aliases, qwords, qnorm
):
filtered_entity_cands.append(cand)
else:
filtered_out.append(cand.tag)
if log:
log(
f"Stage3 entity alias filter: "
f"before={len(entity_cands)} "
f"after={len(filtered_entity_cands)} "
f"removed={len(filtered_out)}"
)
if filtered_out:
log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")
if filtered_entity_cands:
if mode == "single_shot":
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
else:
for start in range(0, len(filtered_entity_cands), chunk_size):
run_call(
filtered_entity_cands[start:start + chunk_size],
f"entity_chunk_{start//chunk_size}",
ENTITY_SYSTEM_TEMPLATE
)
# Apply why threshold: drop tags below the minimum confidence level.
if min_why is not None:
max_rank = WHY_RANK.get(min_why, 4)
before = len(best)
best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
if log:
log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
f"before={before} after={len(best)} dropped={before - len(best)}")
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
# Legacy cap: apply AFTER union + ordering.
if isinstance(max_pick, int) and max_pick > 0:
ordered_tags = ordered_tags[:max_pick]
# Map back to original indices
out_idx: List[int] = []
tag_why: Dict[str, str] = {}
for t in ordered_tags:
if t in tag_to_first_index:
out_idx.append(tag_to_first_index[t])
tag_why[t] = best[t][1] # why string
if return_metadata:
return out_idx, tag_why
return out_idx
# ---------------------------------------------------------------------------
# Stage 3s: Structural tag inference (solo/duo/male/female/anthro/biped …)
# ---------------------------------------------------------------------------
# Each statement maps to exactly one tag. The LLM picks statement numbers.
_STRUCTURAL_STATEMENTS: List[Tuple[str, str]] = [
# Character count — exactly one should be picked
("No characters or living beings appear in the image", "zero_pictured"),
("There is exactly one character in the image", "solo"),
("There are exactly two characters in the image", "duo"),
("There are exactly three characters in the image", "trio"),
("There are four or more characters in the image", "group"),
# Body plan — pick all that apply across characters
("A character is a normal animal walking on all fours, not humanized", "feral"),
("A character is an animal with a human-like body (standing upright on two legs, with hands)", "anthro"),
("A character is a human or looks fully human", "humanoid"),
# Gender — pick all that apply across characters
("A male character is shown", "male"),
("A female character is shown", "female"),
("A character's gender cannot be determined from the description", "ambiguous_gender"),
("An intersex or hermaphrodite character is shown", "intersex"),
]
STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions. You will read a description of an image, then select which numbered statements are true about it.
IMPORTANT RULES:
1. ONLY select a statement if the description directly says it or makes it very obvious.
2. Do NOT guess or assume anything the description does not say.
3. Select exactly ONE statement from the character count group (statements about how many characters there are).
4. Select ALL statements that apply from the body type and gender groups.
5. If the description does not mention gender at all, select the "gender cannot be determined" statement.
Return JSON matching this exact format — nothing else:
{{"selections": [{{"i": 1}}, {{"i": 5}}]}}
where each "i" is a statement number from 1 to {N}.
EXAMPLE:
Description: "A muscular male wolf standing in a forest, giving a thumbs up"
Statements: 1. No characters 2. Exactly one character 3. Exactly two 4. Exactly three 5. Four or more 6. Normal animal on all fours 7. Animal with human-like body 8. Human 9. Male shown 10. Female shown 11. Gender unknown 12. Intersex shown
Correct answer: {{"selections": [{{"i": 2}}, {{"i": 7}}, {{"i": 9}}]}}
Reasoning: One character (2), wolf standing upright with hands giving thumbs up = animal with human body (7), described as male (9)."""
STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.
IMAGE DESCRIPTION:
{image_description}
STATEMENTS (pick by number):
{statement_lines}"""
class StructuralSelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the statement list.")
class StructuralSelectionResponse(BaseModel):
selections: List[StructuralSelectionItem] = Field(default_factory=list)
def _build_structural_response_format() -> Dict[str, Any]:
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
},
"required": ["i"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "structural_selection",
"strict": True,
"schema": schema,
},
}
def llm_infer_structural_tags(
query_text: str,
log=None,
*,
temperature: float = 0.0,
max_tokens: int = 256,
retries: int = 2,
) -> List[str]:
"""Infer structural tags (solo/duo/male/female/anthro/biped/…) via LLM.
Instead of retrieving these from a candidate list, we ask the LLM to agree
with natural-language statements about the image. This handles tags that
are almost never stated in captions but are visually/structurally obvious.
Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "biped"]).
"""
if log:
log("Stage3s (structural): inferring structural tags via statement agreement")
statements = _STRUCTURAL_STATEMENTS
lines = [f"{j}. {stmt}" for j, (stmt, _tag) in enumerate(statements, 1)]
statement_lines = "\n".join(lines)
N = len(statements)
response_format = _build_structural_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)
prompt = ChatPromptTemplate.from_messages(
[
("system", STRUCTURAL_SYSTEM_TEMPLATE),
("human", STRUCTURAL_USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
if log:
log(f"Stage3s: model={model_name} statements={N}")
for att in range(retries + 1):
try:
parsed = chain.invoke({
"N": N,
"image_description": query_text,
"statement_lines": statement_lines,
})
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
chosen_tags: List[str] = []
seen = set()
for item in sels:
idx = item.get("i") if isinstance(item, dict) else None
if not isinstance(idx, int) or idx < 1 or idx > N:
continue
tag = statements[idx - 1][1]
if tag not in seen:
chosen_tags.append(tag)
seen.add(tag)
if log:
tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")
return chosen_tags
except Exception as e:
if log:
log(f"Stage3s: attempt {att+1} error: {e}")
if log:
log(f"Stage3s: gave up after {retries+1} attempts")
return []