File size: 7,060 Bytes
998779f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a69f12b
998779f
a69f12b
998779f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a69f12b
998779f
a69f12b
998779f
a69f12b
998779f
a69f12b
998779f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a69f12b
 
998779f
 
a69f12b
998779f
a69f12b
998779f
 
a69f12b
998779f
 
 
 
 
 
 
 
 
 
 
 
 
a69f12b
998779f
a69f12b
998779f
 
 
 
 
a69f12b
998779f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python3
"""

Diagnostic script to test structural inference for clothing state tags.



Tests with hand-crafted captions that explicitly mention clothing to identify

why the LLM is systematically failing to infer clothing state tags.

"""

import sys
import os
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from psq_rag.llm.select import llm_infer_structural_tags, _get_structural_groups, _build_structural_prompt


# Test cases with explicit clothing mentions
TEST_CASES = [
    {
        "name": "Explicit clothed - formal wear",
        "caption": "A male wolf wearing a black suit, white shirt, and red tie standing in an office.",
        "expected": ["solo", "anthro", "male", "clothed"],
    },
    {
        "name": "Explicit clothed - casual wear",
        "caption": "An anthropomorphic fox in blue jeans and a t-shirt walking down a street.",
        "expected": ["solo", "anthro", "clothed"],
    },
    {
        "name": "Explicit nude",
        "caption": "A naked female cat sitting on a beach, no clothing visible.",
        "expected": ["solo", "anthro", "female", "nude"],
    },
    {
        "name": "Explicit topless",
        "caption": "A shirtless male dragon wearing pants, showing his muscular chest.",
        "expected": ["solo", "anthro", "male", "topless"],
    },
    {
        "name": "Explicit bottomless",
        "caption": "A female rabbit wearing only a hoodie on her upper body, with her lower half uncovered.",
        "expected": ["solo", "anthro", "female", "bottomless"],
    },
    {
        "name": "Multiple characters with clothing",
        "caption": "Two male dogs wearing police uniforms standing side by side.",
        "expected": ["duo", "anthro", "male", "clothed"],
    },
    {
        "name": "Clothing mentioned in middle of description",
        "caption": "A muscular male wolf with red fur stands in a forest. He wears a black leather jacket and torn jeans. His eyes glow blue in the darkness.",
        "expected": ["solo", "anthro", "male", "clothed"],
    },
]


def print_structural_prompt():
    """Print the actual statements the LLM sees."""
    groups = _get_structural_groups()
    statement_lines, flat_tags = _build_structural_prompt(groups)

    print("=" * 80)
    print("STRUCTURAL INFERENCE STATEMENTS")
    print("=" * 80)
    print(statement_lines)
    print("\n" + "=" * 80)
    print("TAG MAPPING (1-based index)")
    print("=" * 80)
    for i, (tag, defn) in enumerate(flat_tags, 1):
        print(f"{i:2d}. {tag:20s} | {defn[:60]}...")
    print("=" * 80 + "\n")


def run_diagnostic():
    """Run diagnostic tests on structural clothing inference."""

    print_structural_prompt()

    print("\n" + "=" * 80)
    print("RUNNING DIAGNOSTIC TESTS")
    print("=" * 80 + "\n")

    results = []

    for i, test_case in enumerate(TEST_CASES, 1):
        name = test_case["name"]
        caption = test_case["caption"]
        expected = test_case["expected"]

        print(f"\n{'-' * 80}")
        print(f"TEST {i}/{len(TEST_CASES)}: {name}")
        print(f"{'-' * 80}")
        print(f"Caption: {caption}")
        print(f"Expected tags: {expected}")
        print(f"\nCalling LLM...", flush=True)

        # Call structural inference
        def log_fn(msg):
            print(f"  [LOG] {msg}")

        selected = llm_infer_structural_tags(
            caption,
            log=log_fn,
            temperature=0.0,
            max_tokens=512,
        )

        print(f"\nSelected tags: {selected}")

        # Analyze results
        expected_set = set(expected)
        selected_set = set(selected)

        clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
        expected_clothing = expected_set & clothing_tags
        selected_clothing = selected_set & clothing_tags

        missed = expected_set - selected_set
        extra = selected_set - expected_set
        correct = expected_set & selected_set

        clothing_correct = expected_clothing == selected_clothing

        print(f"\n[OK] Correct:  {sorted(correct)}")
        if missed:
            print(f"[X] Missed:   {sorted(missed)}")
        if extra:
            print(f"[!] Extra:    {sorted(extra)}")

        print(f"\nClothing state inference: {'[OK] PASS' if clothing_correct else '[X] FAIL'}")
        if expected_clothing:
            print(f"  Expected: {sorted(expected_clothing)}")
            print(f"  Selected: {sorted(selected_clothing) if selected_clothing else '(none)'}")

        results.append({
            "name": name,
            "caption": caption,
            "expected": expected,
            "selected": selected,
            "clothing_correct": clothing_correct,
            "missed": list(missed),
            "extra": list(extra),
        })

    # Summary
    print("\n\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)

    total_tests = len(results)
    clothing_pass = sum(1 for r in results if r["clothing_correct"])
    clothing_fail = total_tests - clothing_pass

    print(f"\nTotal tests: {total_tests}")
    print(f"Clothing state inference:")
    print(f"  [OK] Pass: {clothing_pass}/{total_tests} ({100*clothing_pass/total_tests:.0f}%)")
    print(f"  [X] Fail: {clothing_fail}/{total_tests} ({100*clothing_fail/total_tests:.0f}%)")

    if clothing_fail > 0:
        print(f"\n{'-' * 80}")
        print("FAILURES:")
        print(f"{'-' * 80}")
        for r in results:
            if not r["clothing_correct"]:
                print(f"\n* {r['name']}")
                print(f"  Caption: {r['caption'][:60]}...")
                clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
                exp_clothing = set(r['expected']) & clothing_tags
                sel_clothing = set(r['selected']) & clothing_tags
                print(f"  Expected: {sorted(exp_clothing)}")
                print(f"  Selected: {sorted(sel_clothing) if sel_clothing else '(none)'}")

    # Overall assessment
    print(f"\n{'=' * 80}")
    print("DIAGNOSIS")
    print(f"{'=' * 80}")

    if clothing_pass == total_tests:
        print("\n[OK] All tests passed! Clothing inference is working correctly.")
    elif clothing_pass == 0:
        print("\n[X] ALL tests failed! The LLM is completely ignoring the clothing state group.")
        print("\nPossible causes:")
        print("1. Prompt design issue - clothing group not salient enough")
        print("2. Model capability issue - Llama 3.1 8B cannot handle this task")
        print("3. Response parsing issue - LLM is selecting but parser is missing it")
    else:
        print(f"\n[!] Partial failure! {clothing_fail}/{total_tests} tests failed.")
        print("\nThe LLM is sometimes inferring clothing state but inconsistently.")

    return results


if __name__ == "__main__":
    results = run_diagnostic()