File size: 39,089 Bytes
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
684cf99
c6be992
 
 
 
 
 
 
 
 
 
 
349b999
 
 
 
 
 
 
 
 
 
 
 
 
c6be992
 
 
09a248d
 
 
 
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
 
c6be992
 
 
 
349b999
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349b999
4968635
349b999
c6be992
 
09a248d
 
 
4968635
09a248d
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a248d
 
 
 
 
 
 
 
 
c6be992
 
 
 
 
 
 
 
 
 
349b999
c6be992
 
 
349b999
 
 
 
c6be992
 
a16e111
 
 
684cf99
a16e111
684cf99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4f71fe
 
684cf99
f4f71fe
684cf99
f4f71fe
 
684cf99
f4f71fe
 
684cf99
 
 
 
 
 
 
 
 
 
 
 
 
 
a16e111
684cf99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a16e111
46fe384
 
684cf99
 
 
f4f71fe
 
a16e111
684cf99
46fe384
a16e111
46fe384
684cf99
 
 
a16e111
46fe384
 
 
a16e111
 
46fe384
 
a16e111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684cf99
 
 
 
 
 
 
 
 
 
 
a16e111
 
 
 
 
684cf99
a16e111
 
684cf99
a16e111
684cf99
 
 
a16e111
684cf99
a16e111
 
684cf99
a16e111
684cf99
 
 
a16e111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684cf99
 
a16e111
 
 
 
 
 
 
 
 
 
 
 
 
 
684cf99
a16e111
 
 
 
684cf99
a16e111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
# psq_rag/llm/select.py
# Stage 3: Closed-Set Selection (LangChain-only implementation)
#
# This module intentionally uses LangChain for:
# - prompt templating (including {N})
# - LLM call orchestration
# - JSON parsing
#
# There is NO fallback path. If LangChain dependencies are missing, this module
# should fail loudly so you install them.

import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, SecretStr
from rapidfuzz import fuzz

from psq_rag.retrieval.psq_retrieval import Candidate  # Candidate(tag, score_*, count, sources)
from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases

# Character-typed tags that are generic categories, not actual named characters.
# These leak through the alias filter because they match common words in captions.
# They are excluded from the entity pipeline and instead routed to general selection.
_GENERIC_CHARACTER_TAGS = frozenset({
    "fan_character",
    "background_character",
    "unnamed_character",
    "unknown_character",
    "anonymous_character",
    "viewer",
    "original_character",
})


WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]

# Ordinal rank: lower = more confident.  Used for threshold filtering.
WHY_RANK: Dict[str, int] = {
    "explicit": 0,
    "strong_implied": 1,
    "weak_implied": 2,
    "style_or_meta": 3,
    "other": 4,
}

# Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
WHY_TO_SCORE: Dict[str, float] = {
    "explicit": 0.90,
    "strong_implied": 0.70,
    "weak_implied": 0.45,
    "style_or_meta": 0.35,
    "other": 0.25,
}


# IMPORTANT ABOUT TEMPLATING:
# - This string is rendered by LangChain's f-string template engine.
# - Literal JSON braces must be escaped as {{ and }}.
# - {N} is a real template variable and MUST be provided.
SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.

Select the tags that correspond to content that would be visible or depicted in the described image.

The list contains only valid tags; many of them are irrelevant to the image.

Return JSON ONLY matching this schema:

{{
  \"selections\": [
    {{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
    ...
  ]
}}

Rules:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
- Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").

Define \"why\" as:
- explicit: directly stated in the image description
- strong_implied: very likely given the description, even if not literally stated
- weak_implied: plausible but not strongly supported by the description
- style_or_meta: stylistic or presentation-related tags only if clearly indicated
- other: fallback category; use sparingly
"""


ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.

These character tags have already been pre-filtered to only include characters whose names
(or known aliases) appear in the image description. Your job is to confirm which of these
pre-filtered candidates are the correct match for the character mentioned by the user.

Return JSON ONLY matching this schema:

{{
  \"selections\": [
    {{\"i\": <int>, \"why\": \"explicit\"}},
    ...
  ]
}}

Rules for character selection:
- Choose ONLY from indices 1..{N}.
- Do NOT output tag text.
- Always use \"why\": \"explicit\" for all selections.
- Select the tag that best represents the character as described.
- If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
  select that specific variant tag.
- If the user described only the base character (e.g. just \"pikachu\"), select only the
  base/default tag, NOT costume or variant tags.
- When uncertain between variants, prefer the simplest/most general tag.
"""


USER_TEMPLATE = """IMAGE DESCRIPTION:
{image_description}

CANDIDATES (choose by index only):
{candidate_lines}

Select up to {per_call_budget} indices. Output fewer if uncertain.
"""


@dataclass(frozen=True)
class Selected:
    i: int
    tag: str  # canonical tag (underscore form)
    why: str
    score: float


WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]


class Stage3SelectionItem(BaseModel):
    i: int = Field(..., description="1-based index into the candidate list.")
    why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")


class Stage3SelectionResponse(BaseModel):
    selections: List[Stage3SelectionItem] = Field(default_factory=list)


def _build_response_format() -> Dict[str, Any]:
    # Strict JSON Schema structured output.
    schema = {
        "type": "object",
        "properties": {
            "selections": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "i": {"type": "integer"},
                        "why": {"type": "string", "enum": WHY_ENUM},
                    },
                    "required": ["i", "why"],
                    "additionalProperties": False,
                },
            }
        },
        "required": ["selections"],
        "additionalProperties": False,
    }

    return {
        "type": "json_schema",
        "json_schema": {
            "name": "stage3_selection",
            "strict": True,
            "schema": schema,
        },
    }


def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
    api_key = os.getenv("OPENROUTER_API_KEY")
    if not api_key:
        raise RuntimeError(
            "OPENROUTER_API_KEY is not set.\n"
            "Set it in your environment before running Stage 3."
        )
    api_key = SecretStr(cast(str, api_key))

    model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
    headers: Dict[str, str] = {}
    if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
        headers["HTTP-Referer"] = referer
    if title := os.getenv("OPENROUTER_X_TITLE"):
        headers["X-Title"] = title

    # OpenRouter OpenAI-compatible endpoint.
    return ChatOpenAI(
        model=model,
        base_url="https://openrouter.ai/api/v1",
        api_key=api_key,
        temperature=temperature,
        max_completion_tokens=max_tokens,
        default_headers=headers,
        # Provider-specific request body fields (OpenAI-compatible).
        # Response Healing plugin reduces malformed-JSON failures (syntax only).
        extra_body={
            "response_format": response_format,
            "plugins": [{"id": "response-healing"}],
        },
    )


def _phrase_key_for_candidate(c: Candidate) -> str:
    # Deterministic "primary phrase" for grouping.
    if c.sources:
        return sorted(c.sources)[0]
    return ""


def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
    """Round-robin interleave by primary source phrase.

    NOTE: counts are used only for ordering; they are NOT shown to the LLM.
    """
    groups: Dict[str, List[Candidate]] = {}
    for c in cands:
        k = _phrase_key_for_candidate(c)
        groups.setdefault(k, []).append(c)

    for k in groups:
        groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)

    keys = sorted(groups.keys())

    out: List[Candidate] = []
    idx = 0
    while True:
        progressed = False
        for k in keys:
            if idx < len(groups[k]):
                out.append(groups[k][idx])
                progressed = True
        if not progressed:
            break
        idx += 1

    return out


def _display_tag(tag: str) -> str:
    # Display tags with spaces for the LLM, but keep canonical underscores internally.
    return tag.replace("_", " ")


def _format_candidates_local(
    cands: Sequence[Candidate],
) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
    lines: List[str] = []
    idx_to_tag: Dict[int, str] = {}
    idx_to_candidate: Dict[int, Candidate] = {}
    for j, c in enumerate(cands, start=1):
        idx_to_tag[j] = c.tag
        idx_to_candidate[j] = c
        lines.append(f"{j}. {_display_tag(c.tag)}")
    return "\n".join(lines), idx_to_tag, idx_to_candidate


def _phrases_in_call(cands: Sequence[Candidate]) -> int:
    s = set()
    for c in cands:
        for src in c.sources:
            s.add(src)
    return len(s)


def _parse_validate_map(
    parsed: Any,
    idx_to_tag: Dict[int, str],
    per_call_budget: int,
) -> Tuple[List[Selected], Dict[str, Any]]:
    diag = {
        "parse_ok": isinstance(parsed, dict),
        "invalid_items": 0,
        "oob_indices": 0,
        "dupe_indices": 0,
        "kept": 0,
    }

    if isinstance(parsed, BaseModel):
        parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
        diag["parse_ok"] = isinstance(parsed, dict)

    if not isinstance(parsed, dict):
        return [], diag

    selections = parsed.get("selections", [])
    if not isinstance(selections, list):
        diag["parse_ok"] = False
        return [], diag

    out: List[Selected] = []
    seen_i = set()

    for item in selections:
        if len(out) >= per_call_budget:
            break
        if not isinstance(item, dict):
            diag["invalid_items"] += 1
            continue

        i = item.get("i")
        why = item.get("why")

        if isinstance(i, bool) or not isinstance(i, int):
            diag["invalid_items"] += 1
            continue
        if i in seen_i:
            diag["dupe_indices"] += 1
            continue
        if i not in idx_to_tag:
            diag["oob_indices"] += 1
            continue
        if not isinstance(why, str) or why not in WHY_ENUM:
            diag["invalid_items"] += 1
            continue
        seen_i.add(i)
        tag = idx_to_tag[i]
        out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))

    diag["kept"] = len(out)
    return out, diag


def _split_candidates_by_type(
    candidates: List[Candidate],
    log,
) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
    """Split candidates into general vs entity (character only) lists.

    Returns:
        (general_list, entity_list) where each item is (original_index, candidate)

    Tag types:
        - General: 0 (general), 1 (artist), 5 (species), 7 (meta)
        - Entity: 4 (character) only
        - Filtered: 3 (copyright) - too broad for image generation
    """
    general_with_idx: List[Tuple[int, Candidate]] = []
    entity_with_idx: List[Tuple[int, Candidate]] = []

    unknown_count = 0
    copyright_count = 0

    generic_char_count = 0

    for idx, cand in enumerate(candidates):
        type_name = get_tag_type_name(cand.tag)

        if type_name == "character":
            if cand.tag in _GENERIC_CHARACTER_TAGS:
                # Route generic character-category tags to general selection
                general_with_idx.append((idx, cand))
                generic_char_count += 1
            else:
                entity_with_idx.append((idx, cand))
        elif type_name == "copyright":
            # Filter out copyright/series tags - too broad for image generation
            copyright_count += 1
        elif type_name in ("general", "artist", "species", "meta"):
            general_with_idx.append((idx, cand))
        else:
            # Unknown or None - treat as general by default
            general_with_idx.append((idx, cand))
            unknown_count += 1

    if log:
        log(
            f"Stage3 split: "
            f"general={len(general_with_idx)} "
            f"entity={len(entity_with_idx)} "
            f"copyright_filtered={copyright_count} "
            f"generic_char_to_general={generic_char_count} "
            f"unknown_type={unknown_count}"
        )

    return general_with_idx, entity_with_idx


# Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
_SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")


def _normalize_for_matching(text: str) -> str:
    """Lowercase, replace underscores with spaces, strip series suffixes."""
    text = text.lower().strip()
    text = _SERIES_SUFFIX_RE.sub("", text)
    text = text.replace("_", " ")
    return text


def _query_words(query: str) -> Set[str]:
    """Extract individual words from the user query for matching."""
    return set(_normalize_for_matching(query).split())


def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
                         fuzzy_threshold: int = 85) -> bool:
    """Check if an alias matches the user query.

    Matching logic:
    1. Exact substring: alias appears as a substring of the query
    2. Word subset: all words in the alias appear in the query words
    3. Fuzzy: alias is close to a word in the query (handles typos)
    """
    # Exact substring match
    if alias_norm in query_norm:
        return True

    alias_words = alias_norm.split()
    if not alias_words:
        return False

    # Word subset match: all alias words must appear in query
    if all(w in query_words for w in alias_words):
        return True

    # For single-word aliases, try fuzzy matching against each query word
    if len(alias_words) == 1:
        for qw in query_words:
            if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
                return True

    # For multi-word aliases, try fuzzy partial ratio against whole query
    if len(alias_words) > 1:
        if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
            return True

    return False


def _character_matches_via_aliases(
    tag: str,
    query: str,
    tag2aliases: Dict[str, List[str]],
    query_words: Set[str],
    query_norm: str,
    fuzzy_threshold: int = 85,
) -> bool:
    """Check if a character tag matches the user query via its aliases.

    For a character tag to match:
    - The tag name itself (normalized) must match, OR
    - At least one of its registered aliases must match.

    Empty aliases list means no known aliases; still check the tag name itself.
    """
    # Check the tag name itself
    tag_norm = _normalize_for_matching(tag)
    if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
        return True

    # Check all registered aliases
    aliases = tag2aliases.get(tag, [])
    for alias in aliases:
        alias_norm = _normalize_for_matching(alias)
        if not alias_norm:
            continue
        if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
            return True

    return False


def llm_select_indices(
    query_text: str,                 # kept for compatibility; treated as IMAGE DESCRIPTION
    candidates: Union[
        Sequence[Candidate],
        Sequence[str],
        Sequence[Tuple[str, float]],
    ],
    max_pick: int,                         # legacy param; applied after union + ordering (optional)
    log,
    retries: int = 2,
    *,
    mode: str = "chunked_map_union",       # "single_shot" or "chunked_map_union"
    chunk_size: int = 60,
    per_phrase_k: int = 2,                 # per-call budget = per_phrase_k * phrases_in_call
    temperature: float = 0.0,
    max_tokens: int = 512,
    return_metadata: bool = False,
    min_why: Optional[str] = "strong_implied",
) -> Union[List[int], Tuple[List[int], Dict[str, str]]]:
    """Return indices into the ORIGINAL candidates list (legacy interface).

    min_why: if set, only keep tags whose 'why' is at or above this confidence
             level.  E.g. min_why="explicit" keeps only explicit matches;
             min_why="strong_implied" keeps explicit + strong_implied.
             Default: "strong_implied".

    This implementation uses LangChain ONLY.

    NOTE: query_text is treated as the image description (original prompt).
    """

    image_description = query_text

    # Normalize candidates:
    # - preferred: List[Candidate]
    # - legacy: List[(tag, sim)] (count/sources unavailable)
    norm: List[Candidate] = []
    tag_to_first_index: Dict[str, int] = {}

    branch = "empty"
    cand0_type = type(candidates[0]).__name__ if candidates else "none"

    if candidates and isinstance(candidates[0], Candidate):
        branch = "candidate"
        typed_candidates = cast(Sequence[Candidate], candidates)
        for idx, c in enumerate(typed_candidates):
            if c.tag not in tag_to_first_index:
                tag_to_first_index[c.tag] = idx
                norm.append(c)
    elif candidates and isinstance(candidates[0], str):
        branch = "string"
        typed_candidates = cast(Sequence[str], candidates)
        for idx, tag in enumerate(typed_candidates):
            if tag not in tag_to_first_index:
                tag_to_first_index[tag] = idx
                norm.append(
                    Candidate(
                        tag=tag,
                        score_combined=0.0,
                        score_fasttext=None,
                        score_context=None,
                        count=None,
                        sources=[],
                    )
                )
    else:
        if candidates:
            branch = "tuple"
        typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
        for idx, row in enumerate(typed_candidates):
            if not isinstance(row, (list, tuple)) or len(row) < 2:
                raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
            tag, sim = row[0], row[1]
            if tag not in tag_to_first_index:
                tag_to_first_index[tag] = idx
                norm.append(
                    Candidate(
                        tag=tag,
                        score_combined=float(sim),
                        score_fasttext=None,
                        score_context=None,
                        count=None,
                        sources=[],
                    )
                )

    if log:
        if norm:
            log(
                "Stage3 input: "
                f"type0={cand0_type} "
                f"branch={branch} "
                f"norm0_score={norm[0].score_combined!r} "
                f"norm0_sources_empty={not bool(norm[0].sources)}"
            )
        else:
            log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")

    if mode not in ("single_shot", "chunked_map_union"):
        raise ValueError(f"Invalid mode: {mode}")

    response_format = _build_response_format()
    llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
    model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")

    parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)

    # Global union: tag -> best (score, why)
    best: Dict[str, Tuple[float, str]] = {}

    def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
        # Create chain with the provided system template
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_template),
                ("human", USER_TEMPLATE),
            ],
            template_format="f-string",
        )
        chain = prompt | llm | parser

        ordered = _interleave_round_robin(call_cands)
        candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
        N_local = len(idx_to_tag)

        phrases = _phrases_in_call(call_cands)
        per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
        summary_logged = False

        if log:
            log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
            if phrases > 0:
                distinct_phrases = sorted({src for c in call_cands for src in c.sources})
                log(
                    f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
                    f"phrases={', '.join(distinct_phrases)}"
                )

        # Invoke LangChain chain (templating fills {N} and other vars)
        for att in range(retries + 1):
            try:
                if log:
                    log(
                        f"Stage3 {label}: "
                        f"model={model_name} "
                        f"N={N_local} "
                        f"phrases={phrases} "
                        f"per_call_budget={per_call_budget} "
                        f"response_healing=on"
                    )

                parsed = chain.invoke(
                    {
                        "N": N_local,
                        "image_description": image_description,
                        "candidate_lines": candidate_lines,
                        "per_call_budget": per_call_budget,
                    }
                )
                selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
                if log:
                    log(f"Stage3 {label}: attempt {att+1} diag={diag}")
                    if not summary_logged and (selected or att == retries):
                        log(
                            f"Stage3 {label}: summary "
                            f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
                        )
                        summary_logged = True
                    if selected:
                        lines = [
                            f"Stage3 {label} selections:",
                            *[
                                (
                                    f'  - i={s.i} tag="{s.tag}" '
                                    f"why={s.why} score={s.score:.2f} "
                                    f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
                                )
                                for s in selected
                            ],
                        ]
                        log("\n".join(lines))
                    else:
                        log(f"Stage3 {label} selections: (none)")

                if selected:
                    for s in selected:
                        prev = best.get(s.tag)
                        if prev is None or s.score > prev[0]:
                            best[s.tag] = (s.score, s.why)
                    return

            except Exception as e:
                if log:
                    log(f"Stage3 {label}: attempt {att+1} error: {e}")

        if log:
            log(f"Stage3 {label}: gave up after {retries+1} attempts")

    # Split candidates by type (general vs entity)
    general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)

    # Extract just the candidates for LLM calls
    general_cands = [cand for _, cand in general_with_idx]
    entity_cands = [cand for _, cand in entity_with_idx]

    # Process general candidates (attributes, actions, species, etc.)
    if general_cands:
        if mode == "single_shot":
            run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
        else:
            for start in range(0, len(general_cands), chunk_size):
                run_call(
                    general_cands[start:start + chunk_size],
                    f"general_chunk_{start//chunk_size}",
                    SELECT_SYSTEM_TEMPLATE
                )

    # Process entity candidates (characters only) with alias-based pre-filtering
    if entity_cands:
        tag2aliases = get_tag2aliases()
        qwords = _query_words(image_description)
        qnorm = _normalize_for_matching(image_description)

        filtered_entity_cands: List[Candidate] = []
        filtered_out: List[str] = []

        for cand in entity_cands:
            if _character_matches_via_aliases(
                cand.tag, image_description, tag2aliases, qwords, qnorm
            ):
                filtered_entity_cands.append(cand)
            else:
                filtered_out.append(cand.tag)

        if log:
            log(
                f"Stage3 entity alias filter: "
                f"before={len(entity_cands)} "
                f"after={len(filtered_entity_cands)} "
                f"removed={len(filtered_out)}"
            )
            if filtered_out:
                log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")

        if filtered_entity_cands:
            if mode == "single_shot":
                run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
            else:
                for start in range(0, len(filtered_entity_cands), chunk_size):
                    run_call(
                        filtered_entity_cands[start:start + chunk_size],
                        f"entity_chunk_{start//chunk_size}",
                        ENTITY_SYSTEM_TEMPLATE
                    )

    # Apply why threshold: drop tags below the minimum confidence level.
    if min_why is not None:
        max_rank = WHY_RANK.get(min_why, 4)
        before = len(best)
        best = {t: v for t, v in best.items() if WHY_RANK.get(v[1], 4) <= max_rank}
        if log:
            log(f"Stage3 why filter: min_why={min_why} (rank<={max_rank}), "
                f"before={before} after={len(best)} dropped={before - len(best)}")

    # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
    count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
    ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)

    # Legacy cap: apply AFTER union + ordering.
    if isinstance(max_pick, int) and max_pick > 0:
        ordered_tags = ordered_tags[:max_pick]

    # Map back to original indices
    out_idx: List[int] = []
    tag_why: Dict[str, str] = {}
    for t in ordered_tags:
        if t in tag_to_first_index:
            out_idx.append(tag_to_first_index[t])
            tag_why[t] = best[t][1]  # why string

    if return_metadata:
        return out_idx, tag_why

    return out_idx


# ---------------------------------------------------------------------------
# Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… )
# ---------------------------------------------------------------------------
# Group-based approach: tags are organized into semantic groups loaded from
# tag_groups.json / tag_wiki_defs.json where possible, with curated fallback
# definitions for tags whose wiki entries are only thumbnail references.
#
# Each group specifies a constraint mode:
#   "exclusive" = pick exactly one (e.g. character count)
#   "multi"     = pick all that apply (e.g. body type, gender)

import json as _json

@dataclass
class StructuralGroup:
    """One category of structural tags to probe."""
    name: str
    constraint: str  # "exclusive" or "multi"
    tags: List[Tuple[str, str]]  # (tag, definition) pairs

def _load_structural_groups() -> List[StructuralGroup]:
    """Build structural groups from curated config + data files.

    Uses tag_groups.json for membership and tag_wiki_defs.json for definitions
    where text definitions exist; falls back to curated definitions otherwise.
    """
    data_dir = Path(__file__).resolve().parents[2] / "data"

    # Load wiki definitions (may not exist yet)
    wiki_defs: Dict[str, str] = {}
    wiki_path = data_dir / "tag_wiki_defs.json"
    if wiki_path.is_file():
        with wiki_path.open("r", encoding="utf-8") as f:
            wiki_defs = _json.load(f)

    def _def(tag: str, fallback: str) -> str:
        """Get wiki definition if it's real text, otherwise use fallback."""
        d = wiki_defs.get(tag, "")
        # Skip thumbnail-only definitions
        if not d or d.startswith("thumb ") or len(d) < 15:
            return fallback
        return d[:200]  # cap length for prompt

    groups: List[StructuralGroup] = []

    # ── Group A: Character Count (exclusive) ──
    groups.append(StructuralGroup(
        name="character_count",
        constraint="exclusive",
        tags=[
            ("zero_pictured", _def("zero_pictured",
                "No characters or living beings appear in the image")),
            ("solo", _def("solo",
                "Exactly one character appears in the image")),
            ("duo", _def("duo",
                "Exactly two characters appear in the image")),
            ("trio", _def("trio",
                "Exactly three characters appear in the image")),
            ("group", _def("group",
                "Four or more characters appear in the image")),
        ],
    ))

    # ── Group B: Body Type (multi — per character) ──
    # Key distinction the LLM must learn:
    #   anthro = ANIMAL with human body shape (upright, hands)
    #   humanoid = HUMAN or near-human (elf, dwarf) with NO animal features
    #   feral = normal animal shape, on all fours
    groups.append(StructuralGroup(
        name="body_type",
        constraint="multi",
        tags=[
            ("anthro", _def("anthro",
                "An animal character with a human-like body: walks upright on two legs, "
                "has arms and hands. Examples: a wolf-person, a fox standing up. "
                "Still has animal features like fur, tail, muzzle")),
            ("feral", _def("feral",
                "A regular animal in its natural body shape. Walks on all fours (or "
                "flies/swims naturally). NOT standing upright, NOT humanized")),
            ("humanoid", _def("humanoid",
                "A human or human-like character with NO animal features. Includes "
                "humans, elves, dwarves, and fantasy races that look human. "
                "Does NOT include animal-people — those are anthro")),
            ("taur", _def("taur",
                "A centaur-like body: human or anthro upper body attached to a "
                "four-legged animal lower body")),
        ],
    ))

    # ── Group C: Gender (multi — per character) ──
    groups.append(StructuralGroup(
        name="gender",
        constraint="multi",
        tags=[
            ("male", _def("male",
                "A character described as male, a boy, or with he/him pronouns")),
            ("female", _def("female",
                "A character described as female, a girl, or with she/her pronouns")),
            ("ambiguous_gender", _def("ambiguous_gender",
                "A character whose gender is not stated or cannot be determined")),
            ("intersex", _def("intersex",
                "A character explicitly described as intersex or hermaphrodite")),
        ],
    ))

    # ── Group D: Clothing State (multi) ──
    groups.append(StructuralGroup(
        name="clothing_state",
        constraint="multi",
        tags=[
            ("clothed", _def("clothed",
                "Wearing clothes on BOTH chest/torso AND legs/waist. "
                "Examples: shirt and pants, dress, full outfit")),
            ("nude", _def("nude",
                "Wearing NO clothes at all. Completely naked, no shirt and no pants")),
            ("topless", _def("topless",
                "NO shirt/top (bare chest), BUT wearing pants/bottoms. "
                "Upper body exposed, lower body covered")),
            ("bottomless", _def("bottomless",
                "Wearing shirt/top on chest, BUT NO pants/bottoms. "
                "Upper body covered, lower body exposed")),
        ],
    ))

    # ── Group E: Common Visual Elements (multi) ──
    groups.append(StructuralGroup(
        name="visual_elements",
        constraint="multi",
        tags=[
            ("looking_at_viewer", _def("looking_at_viewer",
                "A character is looking directly at the camera or viewer")),
            ("text", _def("text",
                "The image contains visible writing, words, or lettering")),
        ],
    ))

    return groups


def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]:
    """Build numbered statement list from structural groups.

    Returns (formatted_text, flat_list_of_(tag, definition)_pairs).
    The flat list maps 1-based statement numbers to tags.
    """
    lines: List[str] = []
    flat: List[Tuple[str, str]] = []
    idx = 1

    for g in groups:
        constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply"
        group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---"
        lines.append(group_header)
        for tag, defn in g.tags:
            lines.append(f"{idx}. {defn}")
            flat.append((tag, defn))
            idx += 1
        lines.append("")  # blank line between groups

    return "\n".join(lines), flat


STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list.

The statements are organized into GROUPS. Each group header tells you how many to pick:
- "pick EXACTLY ONE" = choose the single best match in that group
- "pick ALL that apply" = choose every statement that is true

IMPORTANT RULES:
1. ONLY select a statement if the description directly says it or makes it very obvious.
2. Do NOT guess or assume things the description does not mention.
3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid.
4. If the description never mentions gender, pick "gender cannot be determined".
5. For clothing state: READ CAREFULLY! "topless" = bare chest, wearing pants. "bottomless" = wearing shirt, no pants. If unsure, re-read the description.
6. If clothing is not mentioned, do NOT pick any clothing statement.

Return JSON ONLY:
{{"selections": [{{"i": 1}}, {{"i": 5}}]}}

EXAMPLE:
Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up"
Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}}
Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14)."""

STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.

IMAGE DESCRIPTION:
{image_description}

STATEMENTS (pick by number):
{statement_lines}"""


class StructuralSelectionItem(BaseModel):
    i: int = Field(..., description="1-based index into the statement list.")


class StructuralSelectionResponse(BaseModel):
    selections: List[StructuralSelectionItem] = Field(default_factory=list)


def _build_structural_response_format() -> Dict[str, Any]:
    schema = {
        "type": "object",
        "properties": {
            "selections": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "i": {"type": "integer"},
                    },
                    "required": ["i"],
                    "additionalProperties": False,
                },
            }
        },
        "required": ["selections"],
        "additionalProperties": False,
    }
    return {
        "type": "json_schema",
        "json_schema": {
            "name": "structural_selection",
            "strict": True,
            "schema": schema,
        },
    }


# Cache the loaded groups so we only read JSON files once per process.
_cached_structural_groups: Optional[List[StructuralGroup]] = None


def _get_structural_groups() -> List[StructuralGroup]:
    global _cached_structural_groups
    if _cached_structural_groups is None:
        _cached_structural_groups = _load_structural_groups()
    return _cached_structural_groups


def llm_infer_structural_tags(
    query_text: str,
    log=None,
    *,
    temperature: float = 0.0,
    max_tokens: int = 512,
    retries: int = 2,
) -> List[str]:
    """Infer structural tags via LLM using group-based statement agreement.

    Probes multiple semantic groups (character count, body type, gender,
    clothing state, visual elements) with definitions loaded from wiki data
    where available.

    Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]).
    """
    if log:
        log("Stage3s (structural): inferring structural tags via group-based statement agreement")

    groups = _get_structural_groups()
    statement_lines, flat_tags = _build_structural_prompt(groups)
    N = len(flat_tags)

    response_format = _build_structural_response_format()
    llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
                    response_format=response_format)
    model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")

    parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", STRUCTURAL_SYSTEM_TEMPLATE),
            ("human", STRUCTURAL_USER_TEMPLATE),
        ],
        template_format="f-string",
    )
    chain = prompt | llm | parser

    if log:
        group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups)
        log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}")

    for att in range(retries + 1):
        try:
            parsed = chain.invoke({
                "N": N,
                "image_description": query_text,
                "statement_lines": statement_lines,
            })

            if isinstance(parsed, BaseModel):
                parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()

            sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
            chosen_tags: List[str] = []
            seen: Set[str] = set()
            for item in sels:
                idx = item.get("i") if isinstance(item, dict) else None
                if not isinstance(idx, int) or idx < 1 or idx > N:
                    continue
                tag = flat_tags[idx - 1][0]
                if tag not in seen:
                    chosen_tags.append(tag)
                    seen.add(tag)

            if log:
                tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
                log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")

            return chosen_tags

        except Exception as e:
            if log:
                log(f"Stage3s: attempt {att+1} error: {e}")

    if log:
        log(f"Stage3s: gave up after {retries+1} attempts")
    return []