Spaces:
Running
Running
File size: 39,089 Bytes
c6be992 684cf99 c6be992 349b999 c6be992 09a248d c6be992 349b999 c6be992 349b999 c6be992 349b999 c6be992 349b999 4968635 349b999 c6be992 09a248d 4968635 09a248d c6be992 09a248d c6be992 349b999 c6be992 349b999 c6be992 a16e111 684cf99 a16e111 684cf99 f4f71fe 684cf99 f4f71fe 684cf99 f4f71fe 684cf99 f4f71fe 684cf99 a16e111 684cf99 a16e111 46fe384 684cf99 f4f71fe a16e111 684cf99 46fe384 a16e111 46fe384 684cf99 a16e111 46fe384 a16e111 46fe384 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 684cf99 a16e111 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 | # psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz
from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
"fan_character",
"background_character",
"unnamed_character",
"unknown_character",
"anonymous_character",
"viewer",
"original_character",
})
WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
# Ordinal rank: lower = more confident. Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
"explicit": 0,
"strong_implied": 1,
"weak_implied": 2,
"style_or_meta": 3,
"other": 4,
}
# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
"explicit": 0.90,
"strong_implied": 0.70,
"weak_implied": 0.45,
"style_or_meta": 0.35,
"other": 0.25,
}
# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.
Select the tags that correspond to content that would be visible or depicted in the described image.
The list contains only valid tags; many of them are irrelevant to the image.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
...
]
}}
Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").
Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""
ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.
These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.
Return JSON ONLY matching this schema:
{{
\"selections\": [
{{\"i\": <int>, \"why\": \"explicit\"}},
...
]
}}
Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""
USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}
CANDIDATES (choose by index only):
{candidate_lines}
Select up to {per_call_budget} indices. Output fewer if uncertain.
"""
@dataclass(frozen=True)
class Selected:
i: int
tag: str # canonical tag (underscore form)
why: str
score: float
WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
class Stage3SelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the candidate list.")
why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")
class Stage3SelectionResponse(BaseModel):
selections: List[Stage3SelectionItem] = Field(default_factory=list)
def _build_response_format() -> Dict[str, Any]:
# Strict JSON Schema structured output.
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
"why": {"type": "string", "enum": WHY_ENUM},
},
"required": ["i", "why"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "stage3_selection",
"strict": True,
"schema": schema,
},
}
def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY is not set.\n"
"Set it in your environment before running Stage 3."
)
api_key = SecretStr(cast(str, api_key))
model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
headers: Dict[str, str] = {}
if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
headers["HTTP-Referer"] = referer
if title := os.getenv("OPENROUTER_X_TITLE"):
headers["X-Title"] = title
# OpenRouter OpenAI-compatible endpoint.
return ChatOpenAI(
model=model,
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_tokens,
default_headers=headers,
# Provider-specific request body fields (OpenAI-compatible).
# Response Healing plugin reduces malformed-JSON failures (syntax only).
extra_body={
"response_format": response_format,
"plugins": [{"id": "response-healing"}],
},
)
def _phrase_key_for_candidate(c: Candidate) -> str:
# Deterministic "primary phrase" for grouping.
if c.sources:
return sorted(c.sources)[0]
return ""
def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
"""Round-robin interleave by primary source phrase.
NOTE: counts are used only for ordering; they are NOT shown to the LLM.
"""
groups: Dict[str, List[Candidate]] = {}
for c in cands:
k = _phrase_key_for_candidate(c)
groups.setdefault(k, []).append(c)
for k in groups:
groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)
keys = sorted(groups.keys())
out: List[Candidate] = []
idx = 0
while True:
progressed = False
for k in keys:
if idx < len(groups[k]):
out.append(groups[k][idx])
progressed = True
if not progressed:
break
idx += 1
return out
def _display_tag(tag: str) -> str:
# Display tags with spaces for the LLM, but keep canonical underscores internally.
return tag.replace("_", " ")
def _format_candidates_local(
cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
lines: List[str] = []
idx_to_tag: Dict[int, str] = {}
idx_to_candidate: Dict[int, Candidate] = {}
for j, c in enumerate(cands, start=1):
idx_to_tag[j] = c.tag
idx_to_candidate[j] = c
lines.append(f"{j}. {_display_tag(c.tag)}")
return "\n".join(lines), idx_to_tag, idx_to_candidate
def _phrases_in_call(cands: Sequence[Candidate]) -> int:
s = set()
for c in cands:
for src in c.sources:
s.add(src)
return len(s)
def _parse_validate_map(
parsed: Any,
idx_to_tag: Dict[int, str],
per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
diag = {
"parse_ok": isinstance(parsed, dict),
"invalid_items": 0,
"oob_indices": 0,
"dupe_indices": 0,
"kept": 0,
}
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
diag["parse_ok"] = isinstance(parsed, dict)
if not isinstance(parsed, dict):
return [], diag
selections = parsed.get("selections", [])
if not isinstance(selections, list):
diag["parse_ok"] = False
return [], diag
out: List[Selected] = []
seen_i = set()
for item in selections:
if len(out) >= per_call_budget:
break
if not isinstance(item, dict):
diag["invalid_items"] += 1
continue
i = item.get("i")
why = item.get("why")
if isinstance(i, bool) or not isinstance(i, int):
diag["invalid_items"] += 1
continue
if i in seen_i:
diag["dupe_indices"] += 1
continue
if i not in idx_to_tag:
diag["oob_indices"] += 1
continue
if not isinstance(why, str) or why not in WHY_ENUM:
diag["invalid_items"] += 1
continue
seen_i.add(i)
tag = idx_to_tag[i]
out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))
diag["kept"] = len(out)
return out, diag
def _split_candidates_by_type(
candidates: List[Candidate],
log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
"""Split candidates into general vs entity (character only) lists.
Returns:
(general_list, entity_list) where each item is (original_index, candidate)
Tag types:
- General: 0 (general), 1 (artist), 5 (species), 7 (meta)
- Entity: 4 (character) only
- Filtered: 3 (copyright) - too broad for image generation
"""
general_with_idx: List[Tuple[int, Candidate]] = []
entity_with_idx: List[Tuple[int, Candidate]] = []
unknown_count = 0
copyright_count = 0
generic_char_count = 0
for idx, cand in enumerate(candidates):
type_name = get_tag_type_name(cand.tag)
if type_name == "character":
if cand.tag in _GENERIC_CHARACTER_TAGS:
# Route generic character-category tags to general selection
general_with_idx.append((idx, cand))
generic_char_count += 1
else:
entity_with_idx.append((idx, cand))
elif type_name == "copyright":
# Filter out copyright/series tags - too broad for image generation
copyright_count += 1
elif type_name in ("general", "artist", "species", "meta"):
general_with_idx.append((idx, cand))
else:
# Unknown or None - treat as general by default
general_with_idx.append((idx, cand))
unknown_count += 1
if log:
log(
f"Stage3 split: "
f"general={len(general_with_idx)} "
f"entity={len(entity_with_idx)} "
f"copyright_filtered={copyright_count} "
f"generic_char_to_general={generic_char_count} "
f"unknown_type={unknown_count}"
)
return general_with_idx, entity_with_idx
# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")
def _normalize_for_matching(text: str) -> str:
"""Lowercase, replace underscores with spaces, strip series suffixes."""
text = text.lower().strip()
text = _SERIES_SUFFIX_RE.sub("", text)
text = text.replace("_", " ")
return text
def _query_words(query: str) -> Set[str]:
"""Extract individual words from the user query for matching."""
return set(_normalize_for_matching(query).split())
def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
fuzzy_threshold: int = 85) -> bool:
"""Check if an alias matches the user query.
Matching logic:
1. Exact substring: alias appears as a substring of the query
2. Word subset: all words in the alias appear in the query words
3. Fuzzy: alias is close to a word in the query (handles typos)
"""
# Exact substring match
if alias_norm in query_norm:
return True
alias_words = alias_norm.split()
if not alias_words:
return False
# Word subset match: all alias words must appear in query
if all(w in query_words for w in alias_words):
return True
# For single-word aliases, try fuzzy matching against each query word
if len(alias_words) == 1:
for qw in query_words:
if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
return True
# For multi-word aliases, try fuzzy partial ratio against whole query
if len(alias_words) > 1:
if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
return True
return False
def _character_matches_via_aliases(
tag: str,
query: str,
tag2aliases: Dict[str, List[str]],
query_words: Set[str],
query_norm: str,
fuzzy_threshold: int = 85,
) -> bool:
"""Check if a character tag matches the user query via its aliases.
For a character tag to match:
- The tag name itself (normalized) must match, OR
- At least one of its registered aliases must match.
Empty aliases list means no known aliases; still check the tag name itself.
"""
# Check the tag name itself
tag_norm = _normalize_for_matching(tag)
if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
return True
# Check all registered aliases
aliases = tag2aliases.get(tag, [])
for alias in aliases:
alias_norm = _normalize_for_matching(alias)
if not alias_norm:
continue
if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
return True
return False
def llm_select_indices(
query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION
candidates: Union[
Sequence[Candidate],
Sequence[str],
Sequence[Tuple[str, float]],
],
max_pick: int, # legacy param; applied after union + ordering (optional)
log,
retries: int = 2,
*,
mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union"
chunk_size: int = 60,
per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
temperature: float = 0.0,
max_tokens: int = 512,
return_metadata: bool = False,
min_why: Optional[str] = "strong_implied",
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
"""Return indices into the ORIGINAL candidates list (legacy interface).
min_why: if set, only keep tags whose 'why' is at or above this confidence
level. E.g. min_why="explicit" keeps only explicit matches;
min_why="strong_implied" keeps explicit + strong_implied.
Default: "strong_implied".
This implementation uses LangChain ONLY.
NOTE: query_text is treated as the image description (original prompt).
"""
image_description = query_text
# Normalize candidates:
# - preferred: List[Candidate]
# - legacy: List[(tag, sim)] (count/sources unavailable)
norm: List[Candidate] = []
tag_to_first_index: Dict[str, int] = {}
branch = "empty"
cand0_type = type(candidates[0]).__name__ if candidates else "none"
if candidates and isinstance(candidates[0], Candidate):
branch = "candidate"
typed_candidates = cast(Sequence[Candidate], candidates)
for idx, c in enumerate(typed_candidates):
if c.tag not in tag_to_first_index:
tag_to_first_index[c.tag] = idx
norm.append(c)
elif candidates and isinstance(candidates[0], str):
branch = "string"
typed_candidates = cast(Sequence[str], candidates)
for idx, tag in enumerate(typed_candidates):
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=0.0,
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
else:
if candidates:
branch = "tuple"
typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
for idx, row in enumerate(typed_candidates):
if not isinstance(row, (list, tuple)) or len(row) < 2:
raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
tag, sim = row[0], row[1]
if tag not in tag_to_first_index:
tag_to_first_index[tag] = idx
norm.append(
Candidate(
tag=tag,
score_combined=float(sim),
score_fasttext=None,
score_context=None,
count=None,
sources=[],
)
)
if log:
if norm:
log(
"Stage3 input: "
f"type0={cand0_type} "
f"branch={branch} "
f"norm0_score={norm[0].score_combined!r} "
f"norm0_sources_empty={not bool(norm[0].sources)}"
)
else:
log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")
if mode not in ("single_shot", "chunked_map_union"):
raise ValueError(f"Invalid mode: {mode}")
response_format = _build_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)
# Global union: tag -> best (score, why)
best: Dict[str, Tuple[float, str]] = {}
def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
# Create chain with the provided system template
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
ordered = _interleave_round_robin(call_cands)
candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
N_local = len(idx_to_tag)
phrases = _phrases_in_call(call_cands)
per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
summary_logged = False
if log:
log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
if phrases > 0:
distinct_phrases = sorted({src for c in call_cands for src in c.sources})
log(
f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
f"phrases={', '.join(distinct_phrases)}"
)
# Invoke LangChain chain (templating fills {N} and other vars)
for att in range(retries + 1):
try:
if log:
log(
f"Stage3 {label}: "
f"model={model_name} "
f"N={N_local} "
f"phrases={phrases} "
f"per_call_budget={per_call_budget} "
f"response_healing=on"
)
parsed = chain.invoke(
{
"N": N_local,
"image_description": image_description,
"candidate_lines": candidate_lines,
"per_call_budget": per_call_budget,
}
)
selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
if log:
log(f"Stage3 {label}: attempt {att+1} diag={diag}")
if not summary_logged and (selected or att == retries):
log(
f"Stage3 {label}: summary "
f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
)
summary_logged = True
if selected:
lines = [
f"Stage3 {label} selections:",
*[
(
f' - i={s.i} tag="{s.tag}" '
f"why={s.why} score={s.score:.2f} "
f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
)
for s in selected
],
]
log("\n".join(lines))
else:
log(f"Stage3 {label} selections: (none)")
if selected:
for s in selected:
prev = best.get(s.tag)
if prev is None or s.score > prev[0]:
best[s.tag] = (s.score, s.why)
return
except Exception as e:
if log:
log(f"Stage3 {label}: attempt {att+1} error: {e}")
if log:
log(f"Stage3 {label}: gave up after {retries+1} attempts")
# Split candidates by type (general vs entity)
general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
# Extract just the candidates for LLM calls
general_cands = [cand for _, cand in general_with_idx]
entity_cands = [cand for _, cand in entity_with_idx]
# Process general candidates (attributes, actions, species, etc.)
if general_cands:
if mode == "single_shot":
run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
else:
for start in range(0, len(general_cands), chunk_size):
run_call(
general_cands[start:start + chunk_size],
f"general_chunk_{start//chunk_size}",
SELECT_SYSTEM_TEMPLATE
)
# Process entity candidates (characters only) with alias-based pre-filtering
if entity_cands:
tag2aliases = get_tag2aliases()
qwords = _query_words(image_description)
qnorm = _normalize_for_matching(image_description)
filtered_entity_cands: List[Candidate] = []
filtered_out: List[str] = []
for cand in entity_cands:
if _character_matches_via_aliases(
cand.tag, image_description, tag2aliases, qwords, qnorm
):
filtered_entity_cands.append(cand)
else:
filtered_out.append(cand.tag)
if log:
log(
f"Stage3 entity alias filter: "
f"before={len(entity_cands)} "
f"after={len(filtered_entity_cands)} "
f"removed={len(filtered_out)}"
)
if filtered_out:
log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")
if filtered_entity_cands:
if mode == "single_shot":
run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
else:
for start in range(0, len(filtered_entity_cands), chunk_size):
run_call(
filtered_entity_cands[start:start + chunk_size],
f"entity_chunk_{start//chunk_size}",
ENTITY_SYSTEM_TEMPLATE
)
# Apply why threshold: drop tags below the minimum confidence level.
if min_why is not None:
max_rank = WHY_RANK.get(min_why, 4)
before = len(best)
best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
if log:
log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
f"before={before} after={len(best)} dropped={before - len(best)}")
# Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
# Legacy cap: apply AFTER union + ordering.
if isinstance(max_pick, int) and max_pick > 0:
ordered_tags = ordered_tags[:max_pick]
# Map back to original indices
out_idx: List[int] = []
tag_why: Dict[str, str] = {}
for t in ordered_tags:
if t in tag_to_first_index:
out_idx.append(tag_to_first_index[t])
tag_why[t] = best[t][1] # why string
if return_metadata:
return out_idx, tag_why
return out_idx
# ---------------------------------------------------------------------------
# Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… )
# ---------------------------------------------------------------------------
# Group-based approach: tags are organized into semantic groups loaded from
# tag_groups.json / tag_wiki_defs.json where possible, with curated fallback
# definitions for tags whose wiki entries are only thumbnail references.
#
# Each group specifies a constraint mode:
# "exclusive" = pick exactly one (e.g. character count)
# "multi" = pick all that apply (e.g. body type, gender)
import json as _json
@dataclass
class StructuralGroup:
"""One category of structural tags to probe."""
name: str
constraint: str # "exclusive" or "multi"
tags: List[Tuple[str, str]] # (tag, definition) pairs
def _load_structural_groups() -> List[StructuralGroup]:
"""Build structural groups from curated config + data files.
Uses tag_groups.json for membership and tag_wiki_defs.json for definitions
where text definitions exist; falls back to curated definitions otherwise.
"""
data_dir = Path(__file__).resolve().parents[2] / "data"
# Load wiki definitions (may not exist yet)
wiki_defs: Dict[str, str] = {}
wiki_path = data_dir / "tag_wiki_defs.json"
if wiki_path.is_file():
with wiki_path.open("r", encoding="utf-8") as f:
wiki_defs = _json.load(f)
def _def(tag: str, fallback: str) -> str:
"""Get wiki definition if it's real text, otherwise use fallback."""
d = wiki_defs.get(tag, "")
# Skip thumbnail-only definitions
if not d or d.startswith("thumb ") or len(d) < 15:
return fallback
return d[:200] # cap length for prompt
groups: List[StructuralGroup] = []
# ── Group A: Character Count (exclusive) ──
groups.append(StructuralGroup(
name="character_count",
constraint="exclusive",
tags=[
("zero_pictured", _def("zero_pictured",
"No characters or living beings appear in the image")),
("solo", _def("solo",
"Exactly one character appears in the image")),
("duo", _def("duo",
"Exactly two characters appear in the image")),
("trio", _def("trio",
"Exactly three characters appear in the image")),
("group", _def("group",
"Four or more characters appear in the image")),
],
))
# ── Group B: Body Type (multi — per character) ──
# Key distinction the LLM must learn:
# anthro = ANIMAL with human body shape (upright, hands)
# humanoid = HUMAN or near-human (elf, dwarf) with NO animal features
# feral = normal animal shape, on all fours
groups.append(StructuralGroup(
name="body_type",
constraint="multi",
tags=[
("anthro", _def("anthro",
"An animal character with a human-like body: walks upright on two legs, "
"has arms and hands. Examples: a wolf-person, a fox standing up. "
"Still has animal features like fur, tail, muzzle")),
("feral", _def("feral",
"A regular animal in its natural body shape. Walks on all fours (or "
"flies/swims naturally). NOT standing upright, NOT humanized")),
("humanoid", _def("humanoid",
"A human or human-like character with NO animal features. Includes "
"humans, elves, dwarves, and fantasy races that look human. "
"Does NOT include animal-people — those are anthro")),
("taur", _def("taur",
"A centaur-like body: human or anthro upper body attached to a "
"four-legged animal lower body")),
],
))
# ── Group C: Gender (multi — per character) ──
groups.append(StructuralGroup(
name="gender",
constraint="multi",
tags=[
("male", _def("male",
"A character described as male, a boy, or with he/him pronouns")),
("female", _def("female",
"A character described as female, a girl, or with she/her pronouns")),
("ambiguous_gender", _def("ambiguous_gender",
"A character whose gender is not stated or cannot be determined")),
("intersex", _def("intersex",
"A character explicitly described as intersex or hermaphrodite")),
],
))
# ── Group D: Clothing State (multi) ──
groups.append(StructuralGroup(
name="clothing_state",
constraint="multi",
tags=[
("clothed", _def("clothed",
"Wearing clothes on BOTH chest/torso AND legs/waist. "
"Examples: shirt and pants, dress, full outfit")),
("nude", _def("nude",
"Wearing NO clothes at all. Completely naked, no shirt and no pants")),
("topless", _def("topless",
"NO shirt/top (bare chest), BUT wearing pants/bottoms. "
"Upper body exposed, lower body covered")),
("bottomless", _def("bottomless",
"Wearing shirt/top on chest, BUT NO pants/bottoms. "
"Upper body covered, lower body exposed")),
],
))
# ── Group E: Common Visual Elements (multi) ──
groups.append(StructuralGroup(
name="visual_elements",
constraint="multi",
tags=[
("looking_at_viewer", _def("looking_at_viewer",
"A character is looking directly at the camera or viewer")),
("text", _def("text",
"The image contains visible writing, words, or lettering")),
],
))
return groups
def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]:
"""Build numbered statement list from structural groups.
Returns (formatted_text, flat_list_of_(tag, definition)_pairs).
The flat list maps 1-based statement numbers to tags.
"""
lines: List[str] = []
flat: List[Tuple[str, str]] = []
idx = 1
for g in groups:
constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply"
group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---"
lines.append(group_header)
for tag, defn in g.tags:
lines.append(f"{idx}. {defn}")
flat.append((tag, defn))
idx += 1
lines.append("") # blank line between groups
return "\n".join(lines), flat
STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list.
The statements are organized into GROUPS. Each group header tells you how many to pick:
- "pick EXACTLY ONE" = choose the single best match in that group
- "pick ALL that apply" = choose every statement that is true
IMPORTANT RULES:
1. ONLY select a statement if the description directly says it or makes it very obvious.
2. Do NOT guess or assume things the description does not mention.
3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid.
4. If the description never mentions gender, pick "gender cannot be determined".
5. For clothing state: READ CAREFULLY! "topless" = bare chest, wearing pants. "bottomless" = wearing shirt, no pants. If unsure, re-read the description.
6. If clothing is not mentioned, do NOT pick any clothing statement.
Return JSON ONLY:
{{"selections": [{{"i": 1}}, {{"i": 5}}]}}
EXAMPLE:
Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up"
Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}}
Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14)."""
STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.
IMAGE DESCRIPTION:
{image_description}
STATEMENTS (pick by number):
{statement_lines}"""
class StructuralSelectionItem(BaseModel):
i: int = Field(..., description="1-based index into the statement list.")
class StructuralSelectionResponse(BaseModel):
selections: List[StructuralSelectionItem] = Field(default_factory=list)
def _build_structural_response_format() -> Dict[str, Any]:
schema = {
"type": "object",
"properties": {
"selections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"i": {"type": "integer"},
},
"required": ["i"],
"additionalProperties": False,
},
}
},
"required": ["selections"],
"additionalProperties": False,
}
return {
"type": "json_schema",
"json_schema": {
"name": "structural_selection",
"strict": True,
"schema": schema,
},
}
# Cache the loaded groups so we only read JSON files once per process.
_cached_structural_groups: Optional[List[StructuralGroup]] = None
def _get_structural_groups() -> List[StructuralGroup]:
global _cached_structural_groups
if _cached_structural_groups is None:
_cached_structural_groups = _load_structural_groups()
return _cached_structural_groups
def llm_infer_structural_tags(
query_text: str,
log=None,
*,
temperature: float = 0.0,
max_tokens: int = 512,
retries: int = 2,
) -> List[str]:
"""Infer structural tags via LLM using group-based statement agreement.
Probes multiple semantic groups (character count, body type, gender,
clothing state, visual elements) with definitions loaded from wiki data
where available.
Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]).
"""
if log:
log("Stage3s (structural): inferring structural tags via group-based statement agreement")
groups = _get_structural_groups()
statement_lines, flat_tags = _build_structural_prompt(groups)
N = len(flat_tags)
response_format = _build_structural_response_format()
llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
response_format=response_format)
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)
prompt = ChatPromptTemplate.from_messages(
[
("system", STRUCTURAL_SYSTEM_TEMPLATE),
("human", STRUCTURAL_USER_TEMPLATE),
],
template_format="f-string",
)
chain = prompt | llm | parser
if log:
group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups)
log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}")
for att in range(retries + 1):
try:
parsed = chain.invoke({
"N": N,
"image_description": query_text,
"statement_lines": statement_lines,
})
if isinstance(parsed, BaseModel):
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
chosen_tags: List[str] = []
seen: Set[str] = set()
for item in sels:
idx = item.get("i") if isinstance(item, dict) else None
if not isinstance(idx, int) or idx < 1 or idx > N:
continue
tag = flat_tags[idx - 1][0]
if tag not in seen:
chosen_tags.append(tag)
seen.add(tag)
if log:
tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")
return chosen_tags
except Exception as e:
if log:
log(f"Stage3s: attempt {att+1} error: {e}")
if log:
log(f"Stage3s: gave up after {retries+1} attempts")
return []
|