File size: 26,457 Bytes
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
 
 
 
 
 
 
 
 
 
 
 
 
c6be992
 
 
09a248d
 
 
 
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
 
c6be992
 
 
 
349b999
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
09a248d
349b999
c6be992
 
09a248d
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a248d
 
 
 
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
349b999
c6be992
 
 
349b999
 
 
 
c6be992
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
# psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.

import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz

from psq_rag.retrieval.psq_retrieval import Candidate  # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases

# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
    "fan_character",
    "background_character",
    "unnamed_character",
    "unknown_character",
    "anonymous_character",
    "viewer",
    "original_character",
})


WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]

# Ordinal rank: lower = more confident.  Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
    "explicit": 0,
    "strong_implied": 1,
    "weak_implied": 2,
    "style_or_meta": 3,
    "other": 4,
}

# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
    "explicit": 0.90,
    "strong_implied": 0.70,
    "weak_implied": 0.45,
    "style_or_meta": 0.35,
    "other": 0.25,
}


# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.

Select the tags that correspond to content that would be visible or depicted in the described image.

The list contains only valid tags; many of them are irrelevant to the image.

Return JSON ONLY matching this schema:

{{
  \"selections\": [
    {{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
    ...
  ]
}}

Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").

Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""


ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.

These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.

Return JSON ONLY matching this schema:

{{
  \"selections\": [
    {{\"i\": <int>, \"why\": \"explicit\"}},
    ...
  ]
}}

Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
  select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
  base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""


USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}

CANDIDATES (choose by index only):
{candidate_lines}

Select up to {per_call_budget} indices. Output fewer if uncertain.
"""


@dataclass(frozen=True)
class Selected:
    i: int
    tag: str  # canonical tag (underscore form)
    why: str
    score: float


WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]


class Stage3SelectionItem(BaseModel):
    i: int = Field(..., description="1-based index into the candidate list.")
    why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")


class Stage3SelectionResponse(BaseModel):
    selections: List[Stage3SelectionItem] = Field(default_factory=list)


def _build_response_format() -> Dict[str, Any]:
    # Strict JSON Schema structured output.
    schema = {
        "type": "object",
        "properties": {
            "selections": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "i": {"type": "integer"},
                        "why": {"type": "string", "enum": WHY_ENUM},
                    },
                    "required": ["i", "why"],
                    "additionalProperties": False,
                },
            }
        },
        "required": ["selections"],
        "additionalProperties": False,
    }

    return {
        "type": "json_schema",
        "json_schema": {
            "name": "stage3_selection",
            "strict": True,
            "schema": schema,
        },
    }


def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
    api_key = os.getenv("OPENROUTER_API_KEY")
    if not api_key:
        raise RuntimeError(
            "OPENROUTER_API_KEY is not set.\n"
            "Set it in your environment before running Stage 3."
        )
    api_key = SecretStr(cast(str, api_key))

    model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
    headers: Dict[str, str] = {}
    if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
        headers["HTTP-Referer"] = referer
    if title := os.getenv("OPENROUTER_X_TITLE"):
        headers["X-Title"] = title

    # OpenRouter OpenAI-compatible endpoint.
    return ChatOpenAI(
        model=model,
        base_url="https://openrouter.ai/api/v1",
        api_key=api_key,
        temperature=temperature,
        max_completion_tokens=max_tokens,
        default_headers=headers,
        # Provider-specific request body fields (OpenAI-compatible).
        # Response Healing plugin reduces malformed-JSON failures (syntax only).
        extra_body={
            "response_format": response_format,
            "plugins": [{"id": "response-healing"}],
        },
    )


def _phrase_key_for_candidate(c: Candidate) -> str:
    # Deterministic "primary phrase" for grouping.
    if c.sources:
        return sorted(c.sources)[0]
    return ""


def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
    """Round-robin interleave by primary source phrase.

    NOTE: counts are used only for ordering; they are NOT shown to the LLM.
    """
    groups: Dict[str, List[Candidate]] = {}
    for c in cands:
        k = _phrase_key_for_candidate(c)
        groups.setdefault(k, []).append(c)

    for k in groups:
        groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)

    keys = sorted(groups.keys())

    out: List[Candidate] = []
    idx = 0
    while True:
        progressed = False
        for k in keys:
            if idx < len(groups[k]):
                out.append(groups[k][idx])
                progressed = True
        if not progressed:
            break
        idx += 1

    return out


def _display_tag(tag: str) -> str:
    # Display tags with spaces for the LLM, but keep canonical underscores internally.
    return tag.replace("_", " ")


def _format_candidates_local(
    cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
    lines: List[str] = []
    idx_to_tag: Dict[int, str] = {}
    idx_to_candidate: Dict[int, Candidate] = {}
    for j, c in enumerate(cands, start=1):
        idx_to_tag[j] = c.tag
        idx_to_candidate[j] = c
        lines.append(f"{j}. {_display_tag(c.tag)}")
    return "\n".join(lines), idx_to_tag, idx_to_candidate


def _phrases_in_call(cands: Sequence[Candidate]) -> int:
    s = set()
    for c in cands:
        for src in c.sources:
            s.add(src)
    return len(s)


def _parse_validate_map(
    parsed: Any,
    idx_to_tag: Dict[int, str],
    per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
    diag = {
        "parse_ok": isinstance(parsed, dict),
        "invalid_items": 0,
        "oob_indices": 0,
        "dupe_indices": 0,
        "kept": 0,
    }

    if isinstance(parsed, BaseModel):
        parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
        diag["parse_ok"] = isinstance(parsed, dict)

    if not isinstance(parsed, dict):
        return [], diag

    selections = parsed.get("selections", [])
    if not isinstance(selections, list):
        diag["parse_ok"] = False
        return [], diag

    out: List[Selected] = []
    seen_i = set()

    for item in selections:
        if len(out) >= per_call_budget:
            break
        if not isinstance(item, dict):
            diag["invalid_items"] += 1
            continue

        i = item.get("i")
        why = item.get("why")

        if isinstance(i, bool) or not isinstance(i, int):
            diag["invalid_items"] += 1
            continue
        if i in seen_i:
            diag["dupe_indices"] += 1
            continue
        if i not in idx_to_tag:
            diag["oob_indices"] += 1
            continue
        if not isinstance(why, str) or why not in WHY_ENUM:
            diag["invalid_items"] += 1
            continue
        seen_i.add(i)
        tag = idx_to_tag[i]
        out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))

    diag["kept"] = len(out)
    return out, diag


def _split_candidates_by_type(
    candidates: List[Candidate],
    log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
    """Split candidates into general vs entity (character only) lists.

    Returns:
        (general_list, entity_list) where each item is (original_index, candidate)

    Tag types:
        - General: 0 (general), 1 (artist), 5 (species), 7 (meta)
        - Entity: 4 (character) only
        - Filtered: 3 (copyright) - too broad for image generation
    """
    general_with_idx: List[Tuple[int, Candidate]] = []
    entity_with_idx: List[Tuple[int, Candidate]] = []

    unknown_count = 0
    copyright_count = 0

    generic_char_count = 0

    for idx, cand in enumerate(candidates):
        type_name = get_tag_type_name(cand.tag)

        if type_name == "character":
            if cand.tag in _GENERIC_CHARACTER_TAGS:
                # Route generic character-category tags to general selection
                general_with_idx.append((idx, cand))
                generic_char_count += 1
            else:
                entity_with_idx.append((idx, cand))
        elif type_name == "copyright":
            # Filter out copyright/series tags - too broad for image generation
            copyright_count += 1
        elif type_name in ("general", "artist", "species", "meta"):
            general_with_idx.append((idx, cand))
        else:
            # Unknown or None - treat as general by default
            general_with_idx.append((idx, cand))
            unknown_count += 1

    if log:
        log(
            f"Stage3 split: "
            f"general={len(general_with_idx)} "
            f"entity={len(entity_with_idx)} "
            f"copyright_filtered={copyright_count} "
            f"generic_char_to_general={generic_char_count} "
            f"unknown_type={unknown_count}"
        )

    return general_with_idx, entity_with_idx


# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")


def _normalize_for_matching(text: str) -> str:
    """Lowercase, replace underscores with spaces, strip series suffixes."""
    text = text.lower().strip()
    text = _SERIES_SUFFIX_RE.sub("", text)
    text = text.replace("_", " ")
    return text


def _query_words(query: str) -> Set[str]:
    """Extract individual words from the user query for matching."""
    return set(_normalize_for_matching(query).split())


def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
                         fuzzy_threshold: int = 85) -> bool:
    """Check if an alias matches the user query.

    Matching logic:
    1. Exact substring: alias appears as a substring of the query
    2. Word subset: all words in the alias appear in the query words
    3. Fuzzy: alias is close to a word in the query (handles typos)
    """
    # Exact substring match
    if alias_norm in query_norm:
        return True

    alias_words = alias_norm.split()
    if not alias_words:
        return False

    # Word subset match: all alias words must appear in query
    if all(w in query_words for w in alias_words):
        return True

    # For single-word aliases, try fuzzy matching against each query word
    if len(alias_words) == 1:
        for qw in query_words:
            if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
                return True

    # For multi-word aliases, try fuzzy partial ratio against whole query
    if len(alias_words) > 1:
        if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
            return True

    return False


def _character_matches_via_aliases(
    tag: str,
    query: str,
    tag2aliases: Dict[str, List[str]],
    query_words: Set[str],
    query_norm: str,
    fuzzy_threshold: int = 85,
) -> bool:
    """Check if a character tag matches the user query via its aliases.

    For a character tag to match:
    - The tag name itself (normalized) must match, OR
    - At least one of its registered aliases must match.

    Empty aliases list means no known aliases; still check the tag name itself.
    """
    # Check the tag name itself
    tag_norm = _normalize_for_matching(tag)
    if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
        return True

    # Check all registered aliases
    aliases = tag2aliases.get(tag, [])
    for alias in aliases:
        alias_norm = _normalize_for_matching(alias)
        if not alias_norm:
            continue
        if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
            return True

    return False


def llm_select_indices(
    query_text: str,                 # kept for compatibility; treated as IMAGE DESCRIPTION
    candidates: Union[
        Sequence[Candidate],
        Sequence[str],
        Sequence[Tuple[str, float]],
    ],
    max_pick: int,                         # legacy param; applied after union + ordering (optional)
    log,
    retries: int = 2,
    *,
    mode: str = "chunked_map_union",       # "single_shot" or "chunked_map_union"
    chunk_size: int = 60,
    per_phrase_k: int = 2,                 # per-call budget = per_phrase_k * phrases_in_call
    temperature: float = 0.0,
    max_tokens: int = 512,
    return_metadata: bool = False,
    min_why: Optional[str] = None,
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
    """Return indices into the ORIGINAL candidates list (legacy interface).

    min_why: if set, only keep tags whose 'why' is at or above this confidence
             level.  E.g. min_why="explicit" keeps only explicit matches;
             min_why="strong_implied" keeps explicit + strong_implied.

    This implementation uses LangChain ONLY.

    NOTE: query_text is treated as the image description (original prompt).
    """

    image_description = query_text

    # Normalize candidates:
    # - preferred: List[Candidate]
    # - legacy: List[(tag, sim)] (count/sources unavailable)
    norm: List[Candidate] = []
    tag_to_first_index: Dict[str, int] = {}

    branch = "empty"
    cand0_type = type(candidates[0]).__name__ if candidates else "none"

    if candidates and isinstance(candidates[0], Candidate):
        branch = "candidate"
        typed_candidates = cast(Sequence[Candidate], candidates)
        for idx, c in enumerate(typed_candidates):
            if c.tag not in tag_to_first_index:
                tag_to_first_index[c.tag] = idx
                norm.append(c)
    elif candidates and isinstance(candidates[0], str):
        branch = "string"
        typed_candidates = cast(Sequence[str], candidates)
        for idx, tag in enumerate(typed_candidates):
            if tag not in tag_to_first_index:
                tag_to_first_index[tag] = idx
                norm.append(
                    Candidate(
                        tag=tag,
                        score_combined=0.0,
                        score_fasttext=None,
                        score_context=None,
                        count=None,
                        sources=[],
                    )
                )
    else:
        if candidates:
            branch = "tuple"
        typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
        for idx, row in enumerate(typed_candidates):
            if not isinstance(row, (list, tuple)) or len(row) < 2:
                raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
            tag, sim = row[0], row[1]
            if tag not in tag_to_first_index:
                tag_to_first_index[tag] = idx
                norm.append(
                    Candidate(
                        tag=tag,
                        score_combined=float(sim),
                        score_fasttext=None,
                        score_context=None,
                        count=None,
                        sources=[],
                    )
                )

    if log:
        if norm:
            log(
                "Stage3 input: "
                f"type0={cand0_type} "
                f"branch={branch} "
                f"norm0_score={norm[0].score_combined!r} "
                f"norm0_sources_empty={not bool(norm[0].sources)}"
            )
        else:
            log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")

    if mode not in ("single_shot", "chunked_map_union"):
        raise ValueError(f"Invalid mode: {mode}")

    response_format = _build_response_format()
    llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
    model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")

    parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)

    # Global union: tag -> best (score, why)
    best: Dict[str, Tuple[float, str]] = {}

    def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
        # Create chain with the provided system template
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_template),
                ("human", USER_TEMPLATE),
            ],
            template_format="f-string",
        )
        chain = prompt | llm | parser

        ordered = _interleave_round_robin(call_cands)
        candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
        N_local = len(idx_to_tag)

        phrases = _phrases_in_call(call_cands)
        per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
        summary_logged = False

        if log:
            log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
            if phrases > 0:
                distinct_phrases = sorted({src for c in call_cands for src in c.sources})
                log(
                    f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
                    f"phrases={', '.join(distinct_phrases)}"
                )

        # Invoke LangChain chain (templating fills {N} and other vars)
        for att in range(retries + 1):
            try:
                if log:
                    log(
                        f"Stage3 {label}: "
                        f"model={model_name} "
                        f"N={N_local} "
                        f"phrases={phrases} "
                        f"per_call_budget={per_call_budget} "
                        f"response_healing=on"
                    )

                parsed = chain.invoke(
                    {
                        "N": N_local,
                        "image_description": image_description,
                        "candidate_lines": candidate_lines,
                        "per_call_budget": per_call_budget,
                    }
                )
                selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
                if log:
                    log(f"Stage3 {label}: attempt {att+1} diag={diag}")
                    if not summary_logged and (selected or att == retries):
                        log(
                            f"Stage3 {label}: summary "
                            f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
                        )
                        summary_logged = True
                    if selected:
                        lines = [
                            f"Stage3 {label} selections:",
                            *[
                                (
                                    f'  - i={s.i} tag="{s.tag}" '
                                    f"why={s.why} score={s.score:.2f} "
                                    f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
                                )
                                for s in selected
                            ],
                        ]
                        log("\n".join(lines))
                    else:
                        log(f"Stage3 {label} selections: (none)")

                if selected:
                    for s in selected:
                        prev = best.get(s.tag)
                        if prev is None or s.score > prev[0]:
                            best[s.tag] = (s.score, s.why)
                    return

            except Exception as e:
                if log:
                    log(f"Stage3 {label}: attempt {att+1} error: {e}")

        if log:
            log(f"Stage3 {label}: gave up after {retries+1} attempts")

    # Split candidates by type (general vs entity)
    general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)

    # Extract just the candidates for LLM calls
    general_cands = [cand for _, cand in general_with_idx]
    entity_cands = [cand for _, cand in entity_with_idx]

    # Process general candidates (attributes, actions, species, etc.)
    if general_cands:
        if mode == "single_shot":
            run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
        else:
            for start in range(0, len(general_cands), chunk_size):
                run_call(
                    general_cands[start:start + chunk_size],
                    f"general_chunk_{start//chunk_size}",
                    SELECT_SYSTEM_TEMPLATE
                )

    # Process entity candidates (characters only) with alias-based pre-filtering
    if entity_cands:
        tag2aliases = get_tag2aliases()
        qwords = _query_words(image_description)
        qnorm = _normalize_for_matching(image_description)

        filtered_entity_cands: List[Candidate] = []
        filtered_out: List[str] = []

        for cand in entity_cands:
            if _character_matches_via_aliases(
                cand.tag, image_description, tag2aliases, qwords, qnorm
            ):
                filtered_entity_cands.append(cand)
            else:
                filtered_out.append(cand.tag)

        if log:
            log(
                f"Stage3 entity alias filter: "
                f"before={len(entity_cands)} "
                f"after={len(filtered_entity_cands)} "
                f"removed={len(filtered_out)}"
            )
            if filtered_out:
                log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")

        if filtered_entity_cands:
            if mode == "single_shot":
                run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
            else:
                for start in range(0, len(filtered_entity_cands), chunk_size):
                    run_call(
                        filtered_entity_cands[start:start + chunk_size],
                        f"entity_chunk_{start//chunk_size}",
                        ENTITY_SYSTEM_TEMPLATE
                    )

    # Apply why threshold: drop tags below the minimum confidence level.
    if min_why is not None:
        max_rank = WHY_RANK.get(min_why, 4)
        before = len(best)
        best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
        if log:
            log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
                f"before={before} after={len(best)} dropped={before - len(best)}")

    # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
    count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
    ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)

    # Legacy cap: apply AFTER union + ordering.
    if isinstance(max_pick, int) and max_pick > 0:
        ordered_tags = ordered_tags[:max_pick]

    # Map back to original indices
    out_idx: List[int] = []
    tag_why: Dict[str, str] = {}
    for t in ordered_tags:
        if t in tag_to_first_index:
            out_idx.append(tag_to_first_index[t])
            tag_why[t] = best[t][1]  # why string

    if return_metadata:
        return out_idx, tag_why

    return out_idx