Spaces:
Running
Running
File size: 7,060 Bytes
998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f a69f12b 998779f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""
Diagnostic script to test structural inference for clothing state tags.
Tests with hand-crafted captions that explicitly mention clothing to identify
why the LLM is systematically failing to infer clothing state tags.
"""
import sys
import os
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from psq_rag.llm.select import llm_infer_structural_tags, _get_structural_groups, _build_structural_prompt
# Test cases with explicit clothing mentions
TEST_CASES = [
{
"name": "Explicit clothed - formal wear",
"caption": "A male wolf wearing a black suit, white shirt, and red tie standing in an office.",
"expected": ["solo", "anthro", "male", "clothed"],
},
{
"name": "Explicit clothed - casual wear",
"caption": "An anthropomorphic fox in blue jeans and a t-shirt walking down a street.",
"expected": ["solo", "anthro", "clothed"],
},
{
"name": "Explicit nude",
"caption": "A naked female cat sitting on a beach, no clothing visible.",
"expected": ["solo", "anthro", "female", "nude"],
},
{
"name": "Explicit topless",
"caption": "A shirtless male dragon wearing pants, showing his muscular chest.",
"expected": ["solo", "anthro", "male", "topless"],
},
{
"name": "Explicit bottomless",
"caption": "A female rabbit wearing only a hoodie on her upper body, with her lower half uncovered.",
"expected": ["solo", "anthro", "female", "bottomless"],
},
{
"name": "Multiple characters with clothing",
"caption": "Two male dogs wearing police uniforms standing side by side.",
"expected": ["duo", "anthro", "male", "clothed"],
},
{
"name": "Clothing mentioned in middle of description",
"caption": "A muscular male wolf with red fur stands in a forest. He wears a black leather jacket and torn jeans. His eyes glow blue in the darkness.",
"expected": ["solo", "anthro", "male", "clothed"],
},
]
def print_structural_prompt():
"""Print the actual statements the LLM sees."""
groups = _get_structural_groups()
statement_lines, flat_tags = _build_structural_prompt(groups)
print("=" * 80)
print("STRUCTURAL INFERENCE STATEMENTS")
print("=" * 80)
print(statement_lines)
print("\n" + "=" * 80)
print("TAG MAPPING (1-based index)")
print("=" * 80)
for i, (tag, defn) in enumerate(flat_tags, 1):
print(f"{i:2d}. {tag:20s} | {defn[:60]}...")
print("=" * 80 + "\n")
def run_diagnostic():
"""Run diagnostic tests on structural clothing inference."""
print_structural_prompt()
print("\n" + "=" * 80)
print("RUNNING DIAGNOSTIC TESTS")
print("=" * 80 + "\n")
results = []
for i, test_case in enumerate(TEST_CASES, 1):
name = test_case["name"]
caption = test_case["caption"]
expected = test_case["expected"]
print(f"\n{'-' * 80}")
print(f"TEST {i}/{len(TEST_CASES)}: {name}")
print(f"{'-' * 80}")
print(f"Caption: {caption}")
print(f"Expected tags: {expected}")
print(f"\nCalling LLM...", flush=True)
# Call structural inference
def log_fn(msg):
print(f" [LOG] {msg}")
selected = llm_infer_structural_tags(
caption,
log=log_fn,
temperature=0.0,
max_tokens=512,
)
print(f"\nSelected tags: {selected}")
# Analyze results
expected_set = set(expected)
selected_set = set(selected)
clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
expected_clothing = expected_set & clothing_tags
selected_clothing = selected_set & clothing_tags
missed = expected_set - selected_set
extra = selected_set - expected_set
correct = expected_set & selected_set
clothing_correct = expected_clothing == selected_clothing
print(f"\n[OK] Correct: {sorted(correct)}")
if missed:
print(f"[X] Missed: {sorted(missed)}")
if extra:
print(f"[!] Extra: {sorted(extra)}")
print(f"\nClothing state inference: {'[OK] PASS' if clothing_correct else '[X] FAIL'}")
if expected_clothing:
print(f" Expected: {sorted(expected_clothing)}")
print(f" Selected: {sorted(selected_clothing) if selected_clothing else '(none)'}")
results.append({
"name": name,
"caption": caption,
"expected": expected,
"selected": selected,
"clothing_correct": clothing_correct,
"missed": list(missed),
"extra": list(extra),
})
# Summary
print("\n\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
total_tests = len(results)
clothing_pass = sum(1 for r in results if r["clothing_correct"])
clothing_fail = total_tests - clothing_pass
print(f"\nTotal tests: {total_tests}")
print(f"Clothing state inference:")
print(f" [OK] Pass: {clothing_pass}/{total_tests} ({100*clothing_pass/total_tests:.0f}%)")
print(f" [X] Fail: {clothing_fail}/{total_tests} ({100*clothing_fail/total_tests:.0f}%)")
if clothing_fail > 0:
print(f"\n{'-' * 80}")
print("FAILURES:")
print(f"{'-' * 80}")
for r in results:
if not r["clothing_correct"]:
print(f"\n* {r['name']}")
print(f" Caption: {r['caption'][:60]}...")
clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
exp_clothing = set(r['expected']) & clothing_tags
sel_clothing = set(r['selected']) & clothing_tags
print(f" Expected: {sorted(exp_clothing)}")
print(f" Selected: {sorted(sel_clothing) if sel_clothing else '(none)'}")
# Overall assessment
print(f"\n{'=' * 80}")
print("DIAGNOSIS")
print(f"{'=' * 80}")
if clothing_pass == total_tests:
print("\n[OK] All tests passed! Clothing inference is working correctly.")
elif clothing_pass == 0:
print("\n[X] ALL tests failed! The LLM is completely ignoring the clothing state group.")
print("\nPossible causes:")
print("1. Prompt design issue - clothing group not salient enough")
print("2. Model capability issue - Llama 3.1 8B cannot handle this task")
print("3. Response parsing issue - LLM is selecting but parser is missing it")
else:
print(f"\n[!] Partial failure! {clothing_fail}/{total_tests} tests failed.")
print("\nThe LLM is sometimes inferring clothing state but inconsistently.")
return results
if __name__ == "__main__":
results = run_diagnostic()
|