Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Diagnostic script to test structural inference for clothing state tags. | |
| Tests with hand-crafted captions that explicitly mention clothing to identify | |
| why the LLM is systematically failing to infer clothing state tags. | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from psq_rag.llm.select import llm_infer_structural_tags, _get_structural_groups, _build_structural_prompt | |
| # Test cases with explicit clothing mentions | |
| TEST_CASES = [ | |
| { | |
| "name": "Explicit clothed - formal wear", | |
| "caption": "A male wolf wearing a black suit, white shirt, and red tie standing in an office.", | |
| "expected": ["solo", "anthro", "male", "clothed"], | |
| }, | |
| { | |
| "name": "Explicit clothed - casual wear", | |
| "caption": "An anthropomorphic fox in blue jeans and a t-shirt walking down a street.", | |
| "expected": ["solo", "anthro", "clothed"], | |
| }, | |
| { | |
| "name": "Explicit nude", | |
| "caption": "A naked female cat sitting on a beach, no clothing visible.", | |
| "expected": ["solo", "anthro", "female", "nude"], | |
| }, | |
| { | |
| "name": "Explicit topless", | |
| "caption": "A shirtless male dragon wearing pants, showing his muscular chest.", | |
| "expected": ["solo", "anthro", "male", "topless"], | |
| }, | |
| { | |
| "name": "Explicit bottomless", | |
| "caption": "A female rabbit wearing only a hoodie on her upper body, with her lower half uncovered.", | |
| "expected": ["solo", "anthro", "female", "bottomless"], | |
| }, | |
| { | |
| "name": "Multiple characters with clothing", | |
| "caption": "Two male dogs wearing police uniforms standing side by side.", | |
| "expected": ["duo", "anthro", "male", "clothed"], | |
| }, | |
| { | |
| "name": "Clothing mentioned in middle of description", | |
| "caption": "A muscular male wolf with red fur stands in a forest. He wears a black leather jacket and torn jeans. His eyes glow blue in the darkness.", | |
| "expected": ["solo", "anthro", "male", "clothed"], | |
| }, | |
| ] | |
| def print_structural_prompt(): | |
| """Print the actual statements the LLM sees.""" | |
| groups = _get_structural_groups() | |
| statement_lines, flat_tags = _build_structural_prompt(groups) | |
| print("=" * 80) | |
| print("STRUCTURAL INFERENCE STATEMENTS") | |
| print("=" * 80) | |
| print(statement_lines) | |
| print("\n" + "=" * 80) | |
| print("TAG MAPPING (1-based index)") | |
| print("=" * 80) | |
| for i, (tag, defn) in enumerate(flat_tags, 1): | |
| print(f"{i:2d}. {tag:20s} | {defn[:60]}...") | |
| print("=" * 80 + "\n") | |
| def run_diagnostic(): | |
| """Run diagnostic tests on structural clothing inference.""" | |
| print_structural_prompt() | |
| print("\n" + "=" * 80) | |
| print("RUNNING DIAGNOSTIC TESTS") | |
| print("=" * 80 + "\n") | |
| results = [] | |
| for i, test_case in enumerate(TEST_CASES, 1): | |
| name = test_case["name"] | |
| caption = test_case["caption"] | |
| expected = test_case["expected"] | |
| print(f"\n{'-' * 80}") | |
| print(f"TEST {i}/{len(TEST_CASES)}: {name}") | |
| print(f"{'-' * 80}") | |
| print(f"Caption: {caption}") | |
| print(f"Expected tags: {expected}") | |
| print(f"\nCalling LLM...", flush=True) | |
| # Call structural inference | |
| def log_fn(msg): | |
| print(f" [LOG] {msg}") | |
| selected = llm_infer_structural_tags( | |
| caption, | |
| log=log_fn, | |
| temperature=0.0, | |
| max_tokens=512, | |
| ) | |
| print(f"\nSelected tags: {selected}") | |
| # Analyze results | |
| expected_set = set(expected) | |
| selected_set = set(selected) | |
| clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'} | |
| expected_clothing = expected_set & clothing_tags | |
| selected_clothing = selected_set & clothing_tags | |
| missed = expected_set - selected_set | |
| extra = selected_set - expected_set | |
| correct = expected_set & selected_set | |
| clothing_correct = expected_clothing == selected_clothing | |
| print(f"\n[OK] Correct: {sorted(correct)}") | |
| if missed: | |
| print(f"[X] Missed: {sorted(missed)}") | |
| if extra: | |
| print(f"[!] Extra: {sorted(extra)}") | |
| print(f"\nClothing state inference: {'[OK] PASS' if clothing_correct else '[X] FAIL'}") | |
| if expected_clothing: | |
| print(f" Expected: {sorted(expected_clothing)}") | |
| print(f" Selected: {sorted(selected_clothing) if selected_clothing else '(none)'}") | |
| results.append({ | |
| "name": name, | |
| "caption": caption, | |
| "expected": expected, | |
| "selected": selected, | |
| "clothing_correct": clothing_correct, | |
| "missed": list(missed), | |
| "extra": list(extra), | |
| }) | |
| # Summary | |
| print("\n\n" + "=" * 80) | |
| print("SUMMARY") | |
| print("=" * 80) | |
| total_tests = len(results) | |
| clothing_pass = sum(1 for r in results if r["clothing_correct"]) | |
| clothing_fail = total_tests - clothing_pass | |
| print(f"\nTotal tests: {total_tests}") | |
| print(f"Clothing state inference:") | |
| print(f" [OK] Pass: {clothing_pass}/{total_tests} ({100*clothing_pass/total_tests:.0f}%)") | |
| print(f" [X] Fail: {clothing_fail}/{total_tests} ({100*clothing_fail/total_tests:.0f}%)") | |
| if clothing_fail > 0: | |
| print(f"\n{'-' * 80}") | |
| print("FAILURES:") | |
| print(f"{'-' * 80}") | |
| for r in results: | |
| if not r["clothing_correct"]: | |
| print(f"\n* {r['name']}") | |
| print(f" Caption: {r['caption'][:60]}...") | |
| clothing_tags = {'clothed', 'nude', 'topless', 'bottomless'} | |
| exp_clothing = set(r['expected']) & clothing_tags | |
| sel_clothing = set(r['selected']) & clothing_tags | |
| print(f" Expected: {sorted(exp_clothing)}") | |
| print(f" Selected: {sorted(sel_clothing) if sel_clothing else '(none)'}") | |
| # Overall assessment | |
| print(f"\n{'=' * 80}") | |
| print("DIAGNOSIS") | |
| print(f"{'=' * 80}") | |
| if clothing_pass == total_tests: | |
| print("\n[OK] All tests passed! Clothing inference is working correctly.") | |
| elif clothing_pass == 0: | |
| print("\n[X] ALL tests failed! The LLM is completely ignoring the clothing state group.") | |
| print("\nPossible causes:") | |
| print("1. Prompt design issue - clothing group not salient enough") | |
| print("2. Model capability issue - Llama 3.1 8B cannot handle this task") | |
| print("3. Response parsing issue - LLM is selecting but parser is missing it") | |
| else: | |
| print(f"\n[!] Partial failure! {clothing_fail}/{total_tests} tests failed.") | |
| print("\nThe LLM is sometimes inferring clothing state but inconsistently.") | |
| return results | |
| if __name__ == "__main__": | |
| results = run_diagnostic() | |