#!/usr/bin/env python3 """ Diagnostic script to test structural inference for clothing state tags. Tests with hand-crafted captions that explicitly mention clothing to identify why the LLM is systematically failing to infer clothing state tags. """ import sys import os from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from psq_rag.llm.select import llm_infer_structural_tags, _get_structural_groups, _build_structural_prompt # Test cases with explicit clothing mentions TEST_CASES = [ { "name": "Explicit clothed - formal wear", "caption": "A male wolf wearing a black suit, white shirt, and red tie standing in an office.", "expected": ["solo", "anthro", "male", "clothed"], }, { "name": "Explicit clothed - casual wear", "caption": "An anthropomorphic fox in blue jeans and a t-shirt walking down a street.", "expected": ["solo", "anthro", "clothed"], }, { "name": "Explicit nude", "caption": "A naked female cat sitting on a beach, no clothing visible.", "expected": ["solo", "anthro", "female", "nude"], }, { "name": "Explicit topless", "caption": "A shirtless male dragon wearing pants, showing his muscular chest.", "expected": ["solo", "anthro", "male", "topless"], }, { "name": "Explicit bottomless", "caption": "A female rabbit wearing only a hoodie on her upper body, with her lower half uncovered.", "expected": ["solo", "anthro", "female", "bottomless"], }, { "name": "Multiple characters with clothing", "caption": "Two male dogs wearing police uniforms standing side by side.", "expected": ["duo", "anthro", "male", "clothed"], }, { "name": "Clothing mentioned in middle of description", "caption": "A muscular male wolf with red fur stands in a forest. He wears a black leather jacket and torn jeans. His eyes glow blue in the darkness.", "expected": ["solo", "anthro", "male", "clothed"], }, ] def print_structural_prompt(): """Print the actual statements the LLM sees.""" groups = _get_structural_groups() statement_lines, flat_tags = _build_structural_prompt(groups) print("=" * 80) print("STRUCTURAL INFERENCE STATEMENTS") print("=" * 80) print(statement_lines) print("\n" + "=" * 80) print("TAG MAPPING (1-based index)") print("=" * 80) for i, (tag, defn) in enumerate(flat_tags, 1): print(f"{i:2d}. {tag:20s} | {defn[:60]}...") print("=" * 80 + "\n") def run_diagnostic(): """Run diagnostic tests on structural clothing inference.""" print_structural_prompt() print("\n" + "=" * 80) print("RUNNING DIAGNOSTIC TESTS") print("=" * 80 + "\n") results = [] for i, test_case in enumerate(TEST_CASES, 1): name = test_case["name"] caption = test_case["caption"] expected = test_case["expected"] print(f"\n{'-' * 80}") print(f"TEST {i}/{len(TEST_CASES)}: {name}") print(f"{'-' * 80}") print(f"Caption: {caption}") print(f"Expected tags: {expected}") print(f"\nCalling LLM...", flush=True) # Call structural inference def log_fn(msg): print(f" [LOG] {msg}") selected = llm_infer_structural_tags( caption, log=log_fn, temperature=0.0, max_tokens=512, ) print(f"\nSelected tags: {selected}") # Analyze results expected_set = set(expected) selected_set = set(selected) clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'} expected_clothing = expected_set & clothing_tags selected_clothing = selected_set & clothing_tags missed = expected_set - selected_set extra = selected_set - expected_set correct = expected_set & selected_set clothing_correct = expected_clothing == selected_clothing print(f"\n[OK] Correct: {sorted(correct)}") if missed: print(f"[X] Missed: {sorted(missed)}") if extra: print(f"[!] Extra: {sorted(extra)}") print(f"\nClothing state inference: {'[OK] PASS' if clothing_correct else '[X] FAIL'}") if expected_clothing: print(f" Expected: {sorted(expected_clothing)}") print(f" Selected: {sorted(selected_clothing) if selected_clothing else '(none)'}") results.append({ "name": name, "caption": caption, "expected": expected, "selected": selected, "clothing_correct": clothing_correct, "missed": list(missed), "extra": list(extra), }) # Summary print("\n\n" + "=" * 80) print("SUMMARY") print("=" * 80) total_tests = len(results) clothing_pass = sum(1 for r in results if r["clothing_correct"]) clothing_fail = total_tests - clothing_pass print(f"\nTotal tests: {total_tests}") print(f"Clothing state inference:") print(f" [OK] Pass: {clothing_pass}/{total_tests} ({100*clothing_pass/total_tests:.0f}%)") print(f" [X] Fail: {clothing_fail}/{total_tests} ({100*clothing_fail/total_tests:.0f}%)") if clothing_fail > 0: print(f"\n{'-' * 80}") print("FAILURES:") print(f"{'-' * 80}") for r in results: if not r["clothing_correct"]: print(f"\n* {r['name']}") print(f" Caption: {r['caption'][:60]}...") clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'} exp_clothing = set(r['expected']) & clothing_tags sel_clothing = set(r['selected']) & clothing_tags print(f" Expected: {sorted(exp_clothing)}") print(f" Selected: {sorted(sel_clothing) if sel_clothing else '(none)'}") # Overall assessment print(f"\n{'=' * 80}") print("DIAGNOSIS") print(f"{'=' * 80}") if clothing_pass == total_tests: print("\n[OK] All tests passed! Clothing inference is working correctly.") elif clothing_pass == 0: print("\n[X] ALL tests failed! The LLM is completely ignoring the clothing state group.") print("\nPossible causes:") print("1. Prompt design issue - clothing group not salient enough") print("2. Model capability issue - Llama 3.1 8B cannot handle this task") print("3. Response parsing issue - LLM is selecting but parser is missing it") else: print(f"\n[!] Partial failure! {clothing_fail}/{total_tests} tests failed.") print("\nThe LLM is sometimes inferring clothing state but inconsistently.") return results if __name__ == "__main__": results = run_diagnostic()