Prompt_Squirrel_RAG / scripts /diagnose_structural_clothing.py
Claude
Fix Windows encoding issues in diagnostic script
a69f12b
Raw
History Blame Contribute Delete
7.06 kB
#!/usr/bin/env python3
"""
Diagnostic script to test structural inference for clothing state tags.
Tests with hand-crafted captions that explicitly mention clothing to identify
why the LLM is systematically failing to infer clothing state tags.
"""
import sys
import os
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from psq_rag.llm.select import llm_infer_structural_tags, _get_structural_groups, _build_structural_prompt
# Test cases with explicit clothing mentions
TEST_CASES = [
{
"name": "Explicit clothed - formal wear",
"caption": "A male wolf wearing a black suit, white shirt, and red tie standing in an office.",
"expected": ["solo", "anthro", "male", "clothed"],
},
{
"name": "Explicit clothed - casual wear",
"caption": "An anthropomorphic fox in blue jeans and a t-shirt walking down a street.",
"expected": ["solo", "anthro", "clothed"],
},
{
"name": "Explicit nude",
"caption": "A naked female cat sitting on a beach, no clothing visible.",
"expected": ["solo", "anthro", "female", "nude"],
},
{
"name": "Explicit topless",
"caption": "A shirtless male dragon wearing pants, showing his muscular chest.",
"expected": ["solo", "anthro", "male", "topless"],
},
{
"name": "Explicit bottomless",
"caption": "A female rabbit wearing only a hoodie on her upper body, with her lower half uncovered.",
"expected": ["solo", "anthro", "female", "bottomless"],
},
{
"name": "Multiple characters with clothing",
"caption": "Two male dogs wearing police uniforms standing side by side.",
"expected": ["duo", "anthro", "male", "clothed"],
},
{
"name": "Clothing mentioned in middle of description",
"caption": "A muscular male wolf with red fur stands in a forest. He wears a black leather jacket and torn jeans. His eyes glow blue in the darkness.",
"expected": ["solo", "anthro", "male", "clothed"],
},
]
def print_structural_prompt():
"""Print the actual statements the LLM sees."""
groups = _get_structural_groups()
statement_lines, flat_tags = _build_structural_prompt(groups)
print("=" * 80)
print("STRUCTURAL INFERENCE STATEMENTS")
print("=" * 80)
print(statement_lines)
print("\n" + "=" * 80)
print("TAG MAPPING (1-based index)")
print("=" * 80)
for i, (tag, defn) in enumerate(flat_tags, 1):
print(f"{i:2d}. {tag:20s} | {defn[:60]}...")
print("=" * 80 + "\n")
def run_diagnostic():
"""Run diagnostic tests on structural clothing inference."""
print_structural_prompt()
print("\n" + "=" * 80)
print("RUNNING DIAGNOSTIC TESTS")
print("=" * 80 + "\n")
results = []
for i, test_case in enumerate(TEST_CASES, 1):
name = test_case["name"]
caption = test_case["caption"]
expected = test_case["expected"]
print(f"\n{'-' * 80}")
print(f"TEST {i}/{len(TEST_CASES)}: {name}")
print(f"{'-' * 80}")
print(f"Caption: {caption}")
print(f"Expected tags: {expected}")
print(f"\nCalling LLM...", flush=True)
# Call structural inference
def log_fn(msg):
print(f" [LOG] {msg}")
selected = llm_infer_structural_tags(
caption,
log=log_fn,
temperature=0.0,
max_tokens=512,
)
print(f"\nSelected tags: {selected}")
# Analyze results
expected_set = set(expected)
selected_set = set(selected)
clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
expected_clothing = expected_set & clothing_tags
selected_clothing = selected_set & clothing_tags
missed = expected_set - selected_set
extra = selected_set - expected_set
correct = expected_set & selected_set
clothing_correct = expected_clothing == selected_clothing
print(f"\n[OK] Correct: {sorted(correct)}")
if missed:
print(f"[X] Missed: {sorted(missed)}")
if extra:
print(f"[!] Extra: {sorted(extra)}")
print(f"\nClothing state inference: {'[OK] PASS' if clothing_correct else '[X] FAIL'}")
if expected_clothing:
print(f" Expected: {sorted(expected_clothing)}")
print(f" Selected: {sorted(selected_clothing) if selected_clothing else '(none)'}")
results.append({
"name": name,
"caption": caption,
"expected": expected,
"selected": selected,
"clothing_correct": clothing_correct,
"missed": list(missed),
"extra": list(extra),
})
# Summary
print("\n\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
total_tests = len(results)
clothing_pass = sum(1 for r in results if r["clothing_correct"])
clothing_fail = total_tests - clothing_pass
print(f"\nTotal tests: {total_tests}")
print(f"Clothing state inference:")
print(f" [OK] Pass: {clothing_pass}/{total_tests} ({100*clothing_pass/total_tests:.0f}%)")
print(f" [X] Fail: {clothing_fail}/{total_tests} ({100*clothing_fail/total_tests:.0f}%)")
if clothing_fail > 0:
print(f"\n{'-' * 80}")
print("FAILURES:")
print(f"{'-' * 80}")
for r in results:
if not r["clothing_correct"]:
print(f"\n* {r['name']}")
print(f" Caption: {r['caption'][:60]}...")
clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'}
exp_clothing = set(r['expected']) & clothing_tags
sel_clothing = set(r['selected']) & clothing_tags
print(f" Expected: {sorted(exp_clothing)}")
print(f" Selected: {sorted(sel_clothing) if sel_clothing else '(none)'}")
# Overall assessment
print(f"\n{'=' * 80}")
print("DIAGNOSIS")
print(f"{'=' * 80}")
if clothing_pass == total_tests:
print("\n[OK] All tests passed! Clothing inference is working correctly.")
elif clothing_pass == 0:
print("\n[X] ALL tests failed! The LLM is completely ignoring the clothing state group.")
print("\nPossible causes:")
print("1. Prompt design issue - clothing group not salient enough")
print("2. Model capability issue - Llama 3.1 8B cannot handle this task")
print("3. Response parsing issue - LLM is selecting but parser is missing it")
else:
print(f"\n[!] Partial failure! {clothing_fail}/{total_tests} tests failed.")
print("\nThe LLM is sometimes inferring clothing state but inconsistently.")
return results
if __name__ == "__main__":
results = run_diagnostic()