Claude commited on
Commit
684cf99
·
1 Parent(s): c14936f

Redesign structural inference as group-based system with wiki data

Browse files

- Organize structural tags into semantic groups (character count, body type,
gender, clothing state, visual elements) with explicit constraints
- Load definitions from tag_wiki_defs.json where text exists, fall back to
curated definitions for thumbnail-only wiki entries
- Add clothing state group (clothed/nude/topless/bottomless) and visual
elements group (looking_at_viewer/text) to address top misses
- Improve anthro vs humanoid distinction with clearer definitions and example
- Add taur to body type group
- Fix extract_wiki_data.py: filter "top" navigation artifacts, skip
thumbnail-only definitions, deduplicate group members
- Update analyze_compact_eval.py structural tag set for new groups

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

psq_rag/llm/select.py CHANGED
@@ -12,6 +12,7 @@
12
  import os
13
  import re
14
  from dataclasses import dataclass
 
15
  from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
16
 
17
  from langchain_openai import ChatOpenAI
@@ -763,47 +764,184 @@ def llm_select_indices(
763
 
764
 
765
  # ---------------------------------------------------------------------------
766
- # Stage 3s: Structural tag inference (solo/duo/male/female/anthro/biped …)
767
  # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
- # Each statement maps to exactly one tag. The LLM picks statement numbers.
770
- _STRUCTURAL_STATEMENTS: List[Tuple[str, str]] = [
771
- # Character count — exactly one should be picked
772
- ("No characters or living beings appear in the image", "zero_pictured"),
773
- ("There is exactly one character in the image", "solo"),
774
- ("There are exactly two characters in the image", "duo"),
775
- ("There are exactly three characters in the image", "trio"),
776
- ("There are four or more characters in the image", "group"),
777
- # Body plan — pick all that apply across characters
778
- ("A character is a normal animal walking on all fours, not humanized", "feral"),
779
- ("A character is an animal with a human-like body (standing upright on two legs, with hands)", "anthro"),
780
- ("A character is a human or looks fully human", "humanoid"),
781
- # Gender — pick all that apply across characters
782
- ("A male character is shown", "male"),
783
- ("A female character is shown", "female"),
784
- ("A character's gender cannot be determined from the description", "ambiguous_gender"),
785
- ("An intersex or hermaphrodite character is shown", "intersex"),
786
- ]
787
-
788
- STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions. You will read a description of an image, then select which numbered statements are true about it.
 
 
 
 
 
 
 
 
 
 
 
789
 
790
  IMPORTANT RULES:
791
  1. ONLY select a statement if the description directly says it or makes it very obvious.
792
- 2. Do NOT guess or assume anything the description does not say.
793
- 3. Select exactly ONE statement from the character count group (statements about how many characters there are).
794
- 4. Select ALL statements that apply from the body type and gender groups.
795
- 5. If the description does not mention gender at all, select the "gender cannot be determined" statement.
796
 
797
- Return JSON matching this exact format — nothing else:
798
  {{"selections": [{{"i": 1}}, {{"i": 5}}]}}
799
 
800
- where each "i" is a statement number from 1 to {N}.
801
-
802
  EXAMPLE:
803
- Description: "A muscular male wolf standing in a forest, giving a thumbs up"
804
- Statements: 1. No characters 2. Exactly one character 3. Exactly two 4. Exactly three 5. Four or more 6. Normal animal on all fours 7. Animal with human-like body 8. Human 9. Male shown 10. Female shown 11. Gender unknown 12. Intersex shown
805
- Correct answer: {{"selections": [{{"i": 2}}, {{"i": 7}}, {{"i": 9}}]}}
806
- Reasoning: One character (2), wolf standing upright with hands giving thumbs up = animal with human body (7), described as male (9)."""
807
 
808
  STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.
809
 
@@ -851,29 +989,39 @@ def _build_structural_response_format() -> Dict[str, Any]:
851
  }
852
 
853
 
 
 
 
 
 
 
 
 
 
 
 
854
  def llm_infer_structural_tags(
855
  query_text: str,
856
  log=None,
857
  *,
858
  temperature: float = 0.0,
859
- max_tokens: int = 256,
860
  retries: int = 2,
861
  ) -> List[str]:
862
- """Infer structural tags (solo/duo/male/female/anthro/biped/…) via LLM.
863
 
864
- Instead of retrieving these from a candidate list, we ask the LLM to agree
865
- with natural-language statements about the image. This handles tags that
866
- are almost never stated in captions but are visually/structurally obvious.
867
 
868
- Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "biped"]).
869
  """
870
  if log:
871
- log("Stage3s (structural): inferring structural tags via statement agreement")
872
 
873
- statements = _STRUCTURAL_STATEMENTS
874
- lines = [f"{j}. {stmt}" for j, (stmt, _tag) in enumerate(statements, 1)]
875
- statement_lines = "\n".join(lines)
876
- N = len(statements)
877
 
878
  response_format = _build_structural_response_format()
879
  llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
@@ -892,7 +1040,8 @@ def llm_infer_structural_tags(
892
  chain = prompt | llm | parser
893
 
894
  if log:
895
- log(f"Stage3s: model={model_name} statements={N}")
 
896
 
897
  for att in range(retries + 1):
898
  try:
@@ -907,12 +1056,12 @@ def llm_infer_structural_tags(
907
 
908
  sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
909
  chosen_tags: List[str] = []
910
- seen = set()
911
  for item in sels:
912
  idx = item.get("i") if isinstance(item, dict) else None
913
  if not isinstance(idx, int) or idx < 1 or idx > N:
914
  continue
915
- tag = statements[idx - 1][1]
916
  if tag not in seen:
917
  chosen_tags.append(tag)
918
  seen.add(tag)
 
12
  import os
13
  import re
14
  from dataclasses import dataclass
15
+ from pathlib import Path
16
  from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
17
 
18
  from langchain_openai import ChatOpenAI
 
764
 
765
 
766
  # ---------------------------------------------------------------------------
767
+ # Stage 3s: Structural tag inference (solo/duo/male/female/anthro/… )
768
  # ---------------------------------------------------------------------------
769
+ # Group-based approach: tags are organized into semantic groups loaded from
770
+ # tag_groups.json / tag_wiki_defs.json where possible, with curated fallback
771
+ # definitions for tags whose wiki entries are only thumbnail references.
772
+ #
773
+ # Each group specifies a constraint mode:
774
+ # "exclusive" = pick exactly one (e.g. character count)
775
+ # "multi" = pick all that apply (e.g. body type, gender)
776
+
777
+ import json as _json
778
+
779
+ @dataclass
780
+ class StructuralGroup:
781
+ """One category of structural tags to probe."""
782
+ name: str
783
+ constraint: str # "exclusive" or "multi"
784
+ tags: List[Tuple[str, str]] # (tag, definition) pairs
785
+
786
+ def _load_structural_groups() -> List[StructuralGroup]:
787
+ """Build structural groups from curated config + data files.
788
+
789
+ Uses tag_groups.json for membership and tag_wiki_defs.json for definitions
790
+ where text definitions exist; falls back to curated definitions otherwise.
791
+ """
792
+ data_dir = Path(__file__).resolve().parents[2] / "data"
793
+
794
+ # Load wiki definitions (may not exist yet)
795
+ wiki_defs: Dict[str, str] = {}
796
+ wiki_path = data_dir / "tag_wiki_defs.json"
797
+ if wiki_path.is_file():
798
+ with wiki_path.open("r", encoding="utf-8") as f:
799
+ wiki_defs = _json.load(f)
800
+
801
+ def _def(tag: str, fallback: str) -> str:
802
+ """Get wiki definition if it's real text, otherwise use fallback."""
803
+ d = wiki_defs.get(tag, "")
804
+ # Skip thumbnail-only definitions
805
+ if not d or d.startswith("thumb ") or len(d) < 15:
806
+ return fallback
807
+ return d[:200] # cap length for prompt
808
+
809
+ groups: List[StructuralGroup] = []
810
+
811
+ # ── Group A: Character Count (exclusive) ──
812
+ groups.append(StructuralGroup(
813
+ name="character_count",
814
+ constraint="exclusive",
815
+ tags=[
816
+ ("zero_pictured", _def("zero_pictured",
817
+ "No characters or living beings appear in the image")),
818
+ ("solo", _def("solo",
819
+ "Exactly one character appears in the image")),
820
+ ("duo", _def("duo",
821
+ "Exactly two characters appear in the image")),
822
+ ("trio", _def("trio",
823
+ "Exactly three characters appear in the image")),
824
+ ("group", _def("group",
825
+ "Four or more characters appear in the image")),
826
+ ],
827
+ ))
828
+
829
+ # ── Group B: Body Type (multi — per character) ──
830
+ # Key distinction the LLM must learn:
831
+ # anthro = ANIMAL with human body shape (upright, hands)
832
+ # humanoid = HUMAN or near-human (elf, dwarf) with NO animal features
833
+ # feral = normal animal shape, on all fours
834
+ groups.append(StructuralGroup(
835
+ name="body_type",
836
+ constraint="multi",
837
+ tags=[
838
+ ("anthro", _def("anthro",
839
+ "An animal character with a human-like body: walks upright on two legs, "
840
+ "has arms and hands. Examples: a wolf-person, a fox standing up. "
841
+ "Still has animal features like fur, tail, muzzle")),
842
+ ("feral", _def("feral",
843
+ "A regular animal in its natural body shape. Walks on all fours (or "
844
+ "flies/swims naturally). NOT standing upright, NOT humanized")),
845
+ ("humanoid", _def("humanoid",
846
+ "A human or human-like character with NO animal features. Includes "
847
+ "humans, elves, dwarves, and fantasy races that look human. "
848
+ "Does NOT include animal-people — those are anthro")),
849
+ ("taur", _def("taur",
850
+ "A centaur-like body: human or anthro upper body attached to a "
851
+ "four-legged animal lower body")),
852
+ ],
853
+ ))
854
+
855
+ # ── Group C: Gender (multi — per character) ──
856
+ groups.append(StructuralGroup(
857
+ name="gender",
858
+ constraint="multi",
859
+ tags=[
860
+ ("male", _def("male",
861
+ "A character described as male, a boy, or with he/him pronouns")),
862
+ ("female", _def("female",
863
+ "A character described as female, a girl, or with she/her pronouns")),
864
+ ("ambiguous_gender", _def("ambiguous_gender",
865
+ "A character whose gender is not stated or cannot be determined")),
866
+ ("intersex", _def("intersex",
867
+ "A character explicitly described as intersex or hermaphrodite")),
868
+ ],
869
+ ))
870
+
871
+ # ── Group D: Clothing State (multi) ──
872
+ groups.append(StructuralGroup(
873
+ name="clothing_state",
874
+ constraint="multi",
875
+ tags=[
876
+ ("clothed", _def("clothed",
877
+ "A character is wearing clothes on both upper and lower body")),
878
+ ("nude", _def("nude",
879
+ "A character is wearing no clothes at all")),
880
+ ("topless", _def("topless",
881
+ "A character's upper body is uncovered but lower body has clothing")),
882
+ ("bottomless", _def("bottomless",
883
+ "A character wears clothing on upper body but lower body is uncovered")),
884
+ ],
885
+ ))
886
+
887
+ # ── Group E: Common Visual Elements (multi) ──
888
+ groups.append(StructuralGroup(
889
+ name="visual_elements",
890
+ constraint="multi",
891
+ tags=[
892
+ ("looking_at_viewer", _def("looking_at_viewer",
893
+ "A character is looking directly at the camera or viewer")),
894
+ ("text", _def("text",
895
+ "The image contains visible writing, words, or lettering")),
896
+ ],
897
+ ))
898
 
899
+ return groups
900
+
901
+
902
+ def _build_structural_prompt(groups: List[StructuralGroup]) -> Tuple[str, List[Tuple[str, str]]]:
903
+ """Build numbered statement list from structural groups.
904
+
905
+ Returns (formatted_text, flat_list_of_(tag, definition)_pairs).
906
+ The flat list maps 1-based statement numbers to tags.
907
+ """
908
+ lines: List[str] = []
909
+ flat: List[Tuple[str, str]] = []
910
+ idx = 1
911
+
912
+ for g in groups:
913
+ constraint_label = "pick EXACTLY ONE" if g.constraint == "exclusive" else "pick ALL that apply"
914
+ group_header = f"--- {g.name.replace('_', ' ').upper()} ({constraint_label}) ---"
915
+ lines.append(group_header)
916
+ for tag, defn in g.tags:
917
+ lines.append(f"{idx}. {defn}")
918
+ flat.append((tag, defn))
919
+ idx += 1
920
+ lines.append("") # blank line between groups
921
+
922
+ return "\n".join(lines), flat
923
+
924
+
925
+ STRUCTURAL_SYSTEM_TEMPLATE = """You classify image descriptions by selecting true statements from a numbered list.
926
+
927
+ The statements are organized into GROUPS. Each group header tells you how many to pick:
928
+ - "pick EXACTLY ONE" = choose the single best match in that group
929
+ - "pick ALL that apply" = choose every statement that is true
930
 
931
  IMPORTANT RULES:
932
  1. ONLY select a statement if the description directly says it or makes it very obvious.
933
+ 2. Do NOT guess or assume things the description does not mention.
934
+ 3. For body type: "anthro" means an ANIMAL with a human-shaped body (walks upright, has hands, but still has fur/tail/muzzle). "humanoid" means HUMAN or human-like with NO animal features. A wolf standing on two legs = anthro, NOT humanoid.
935
+ 4. If the description never mentions gender, pick "gender cannot be determined".
936
+ 5. If clothing is not mentioned, do NOT pick any clothing statement.
937
 
938
+ Return JSON ONLY:
939
  {{"selections": [{{"i": 1}}, {{"i": 5}}]}}
940
 
 
 
941
  EXAMPLE:
942
+ Description: "A muscular male wolf standing in a forest, wearing jeans, giving a thumbs up"
943
+ Answer: {{"selections": [{{"i": 2}}, {{"i": 6}}, {{"i": 10}}, {{"i": 14}}]}}
944
+ Why: One character = solo (2). Wolf standing upright with hands = anthro (6), NOT humanoid because it is a wolf. Male (10). Wearing jeans = clothed (14)."""
 
945
 
946
  STRUCTURAL_USER_TEMPLATE = """Read this image description and select which statements are true.
947
 
 
989
  }
990
 
991
 
992
+ # Cache the loaded groups so we only read JSON files once per process.
993
+ _cached_structural_groups: Optional[List[StructuralGroup]] = None
994
+
995
+
996
+ def _get_structural_groups() -> List[StructuralGroup]:
997
+ global _cached_structural_groups
998
+ if _cached_structural_groups is None:
999
+ _cached_structural_groups = _load_structural_groups()
1000
+ return _cached_structural_groups
1001
+
1002
+
1003
  def llm_infer_structural_tags(
1004
  query_text: str,
1005
  log=None,
1006
  *,
1007
  temperature: float = 0.0,
1008
+ max_tokens: int = 512,
1009
  retries: int = 2,
1010
  ) -> List[str]:
1011
+ """Infer structural tags via LLM using group-based statement agreement.
1012
 
1013
+ Probes multiple semantic groups (character count, body type, gender,
1014
+ clothing state, visual elements) with definitions loaded from wiki data
1015
+ where available.
1016
 
1017
+ Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "clothed"]).
1018
  """
1019
  if log:
1020
+ log("Stage3s (structural): inferring structural tags via group-based statement agreement")
1021
 
1022
+ groups = _get_structural_groups()
1023
+ statement_lines, flat_tags = _build_structural_prompt(groups)
1024
+ N = len(flat_tags)
 
1025
 
1026
  response_format = _build_structural_response_format()
1027
  llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
 
1040
  chain = prompt | llm | parser
1041
 
1042
  if log:
1043
+ group_summary = ", ".join(f"{g.name}({len(g.tags)})" for g in groups)
1044
+ log(f"Stage3s: model={model_name} groups=[{group_summary}] total_statements={N}")
1045
 
1046
  for att in range(retries + 1):
1047
  try:
 
1056
 
1057
  sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
1058
  chosen_tags: List[str] = []
1059
+ seen: Set[str] = set()
1060
  for item in sels:
1061
  idx = item.get("i") if isinstance(item, dict) else None
1062
  if not isinstance(idx, int) or idx < 1 or idx > N:
1063
  continue
1064
+ tag = flat_tags[idx - 1][0]
1065
  if tag not in seen:
1066
  chosen_tags.append(tag)
1067
  seen.add(tag)
scripts/analyze_compact_eval.py CHANGED
@@ -53,7 +53,18 @@ _TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis
53
  _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
54
  _POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"})
55
  _COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)")
56
- _STRUCTURAL = frozenset({"solo","duo","trio","group","zero_pictured","anthro","feral","humanoid","biped","quadruped","male","female","ambiguous_gender","intersex"})
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def categorize(tag, tag_type):
59
  tid = tag_type.get(tag, -1)
 
53
  _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
54
  _POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"})
55
  _COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)")
56
+ _STRUCTURAL = frozenset({
57
+ # Character count
58
+ "solo","duo","trio","group","zero_pictured",
59
+ # Body type
60
+ "anthro","feral","humanoid","taur",
61
+ # Gender
62
+ "male","female","ambiguous_gender","intersex",
63
+ # Clothing state
64
+ "clothed","nude","topless","bottomless",
65
+ # Visual elements
66
+ "looking_at_viewer","text",
67
+ })
68
 
69
  def categorize(tag, tag_type):
70
  tid = tag_type.get(tag, -1)
scripts/extract_wiki_data.py CHANGED
@@ -24,16 +24,27 @@ def _extract_tag_links(body: str) -> List[str]:
24
  - * [[tagname|display]] — list items
25
  """
26
  tags = []
 
 
27
  # Anchor links: [[#tag_name|display_text]]
28
  for m in re.finditer(r'\[\[#([a-z0-9_]+)\|', body):
29
- tags.append(m.group(1))
 
 
30
  # If no anchor links found, try regular wiki links in list items
31
  if not tags:
32
  for m in re.finditer(r'\*\s*\[\[([a-z0-9_()]+?)(?:\||\]\])', body):
33
  tag = m.group(1)
34
- if not tag.startswith('tag_group:') and not tag.startswith('tag '):
35
  tags.append(tag)
36
- return tags
 
 
 
 
 
 
 
37
 
38
 
39
  def _first_sentence(body: str) -> str:
@@ -54,6 +65,15 @@ def _first_sentence(body: str) -> str:
54
  continue
55
  if len(line) < 10:
56
  continue
 
 
 
 
 
 
 
 
 
57
  # Truncate at first period if it's a real sentence
58
  period = line.find('. ')
59
  if period > 20:
 
24
  - * [[tagname|display]] — list items
25
  """
26
  tags = []
27
+ # Navigation/heading anchors to skip
28
+ _SKIP = {"top", "see_also", "related", "back", "contents", "toc"}
29
  # Anchor links: [[#tag_name|display_text]]
30
  for m in re.finditer(r'\[\[#([a-z0-9_]+)\|', body):
31
+ tag = m.group(1)
32
+ if tag not in _SKIP:
33
+ tags.append(tag)
34
  # If no anchor links found, try regular wiki links in list items
35
  if not tags:
36
  for m in re.finditer(r'\*\s*\[\[([a-z0-9_()]+?)(?:\||\]\])', body):
37
  tag = m.group(1)
38
+ if tag not in _SKIP and not tag.startswith('tag_group:') and not tag.startswith('tag '):
39
  tags.append(tag)
40
+ # Deduplicate while preserving order
41
+ seen = set()
42
+ deduped = []
43
+ for t in tags:
44
+ if t not in seen:
45
+ seen.add(t)
46
+ deduped.append(t)
47
+ return deduped
48
 
49
 
50
  def _first_sentence(body: str) -> str:
 
65
  continue
66
  if len(line) < 10:
67
  continue
68
+ # Skip lines that are just thumbnail references (e.g. "thumb #12345 thumb #67890")
69
+ if re.fullmatch(r'(thumb\s*#\d+\s*)+', line):
70
+ continue
71
+ # Skip lines that are mostly thumbnail references with little text
72
+ thumb_stripped = re.sub(r'thumb\s*#\d+', '', line).strip()
73
+ if len(thumb_stripped) < 10:
74
+ continue
75
+ # Use the thumb-stripped version for the definition
76
+ line = thumb_stripped
77
  # Truncate at first period if it's a real sentence
78
  period = line.find('. ')
79
  if period > 20: