Spaces:
Running
Running
Claude commited on
Commit ·
a16e111
1
Parent(s): 4968635
Add structural tag inference (Stage 3s) and compact eval output
Browse filesStage 3s: New LLM step that infers structural tags (solo/duo/male/female/
anthro/biped/feral/humanoid/quadruped/intersex/ambiguous_gender/zero_pictured/
trio/group) via natural-language statement agreement instead of retrieval.
These tags are almost never stated in captions but are structurally obvious.
Compact eval output: Eval pipeline now writes two files:
- Compact metrics JSONL (tracked in git, small) with counts + diff sets
- Full detail JSONL (gitignored, large) with complete tag lists for analysis
Also: .gitignore updated to exclude *_detail.jsonl files.
https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG
- .gitignore +2 -0
- app.py +13 -2
- psq_rag/llm/select.py +175 -0
- scripts/eval_pipeline.py +102 -39
.gitignore
CHANGED
|
@@ -10,3 +10,5 @@ tf_idf_files_420.joblib
|
|
| 10 |
e621FastTextModel010Replacement_small.bin
|
| 11 |
tfidf_hnsw_artists.bin
|
| 12 |
tfidf_hnsw_tags.bin
|
|
|
|
|
|
|
|
|
| 10 |
e621FastTextModel010Replacement_small.bin
|
| 11 |
tfidf_hnsw_artists.bin
|
| 12 |
tfidf_hnsw_tags.bin
|
| 13 |
+
# Full detail eval files (large) — only compact metrics tracked in git
|
| 14 |
+
*_detail.jsonl
|
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import List
|
|
| 8 |
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
|
| 9 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 10 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
|
| 11 |
-
from psq_rag.llm.select import llm_select_indices
|
| 12 |
from psq_rag.retrieval.state import expand_tags_via_implications
|
| 13 |
|
| 14 |
|
|
@@ -224,7 +224,18 @@ def rag_pipeline_ui(user_prompt: str):
|
|
| 224 |
|
| 225 |
selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
|
| 226 |
|
| 227 |
-
log("Step 3b:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
tag_set = set(selected_tags)
|
| 229 |
expanded, implied_only = expand_tags_via_implications(tag_set)
|
| 230 |
if implied_only:
|
|
|
|
| 8 |
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
|
| 9 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 10 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
|
| 11 |
+
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags
|
| 12 |
from psq_rag.retrieval.state import expand_tags_via_implications
|
| 13 |
|
| 14 |
|
|
|
|
| 224 |
|
| 225 |
selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
|
| 226 |
|
| 227 |
+
log("Step 3b: Structural tag inference (solo/duo/gender/body plan)")
|
| 228 |
+
structural_tags = llm_infer_structural_tags(prompt_in, log=log)
|
| 229 |
+
if structural_tags:
|
| 230 |
+
# Add structural tags that aren't already selected
|
| 231 |
+
existing = {t for t in selected_tags}
|
| 232 |
+
new_structural = [t for t in structural_tags if t not in existing]
|
| 233 |
+
selected_tags.extend(new_structural)
|
| 234 |
+
log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
|
| 235 |
+
else:
|
| 236 |
+
log(" No structural tags inferred")
|
| 237 |
+
|
| 238 |
+
log("Step 3c: Expand via tag implications")
|
| 239 |
tag_set = set(selected_tags)
|
| 240 |
expanded, implied_only = expand_tags_via_implications(tag_set)
|
| 241 |
if implied_only:
|
psq_rag/llm/select.py
CHANGED
|
@@ -760,3 +760,178 @@ def llm_select_indices(
|
|
| 760 |
return out_idx, tag_why
|
| 761 |
|
| 762 |
return out_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
return out_idx, tag_why
|
| 761 |
|
| 762 |
return out_idx
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
# ---------------------------------------------------------------------------
|
| 766 |
+
# Stage 3s: Structural tag inference (solo/duo/male/female/anthro/biped …)
|
| 767 |
+
# ---------------------------------------------------------------------------
|
| 768 |
+
|
| 769 |
+
# Each statement maps to exactly one tag. The LLM picks statement numbers.
|
| 770 |
+
_STRUCTURAL_STATEMENTS: List[Tuple[str, str]] = [
|
| 771 |
+
# Character count
|
| 772 |
+
("The image contains zero characters (no living beings depicted)", "zero_pictured"),
|
| 773 |
+
("The image contains exactly one character", "solo"),
|
| 774 |
+
("The image contains exactly two characters", "duo"),
|
| 775 |
+
("The image contains exactly three characters", "trio"),
|
| 776 |
+
("The image contains four or more characters", "group"),
|
| 777 |
+
# Body plan
|
| 778 |
+
("A character is a regular (non-anthropomorphic) animal", "feral"),
|
| 779 |
+
("A character is an anthropomorphic animal (walks upright, has human-like posture)", "anthro"),
|
| 780 |
+
("A character is a human or human-like being", "humanoid"),
|
| 781 |
+
("A character stands or walks on two legs", "biped"),
|
| 782 |
+
("A character stands or walks on four legs", "quadruped"),
|
| 783 |
+
# Gender
|
| 784 |
+
("The image contains a male character", "male"),
|
| 785 |
+
("The image contains a female character", "female"),
|
| 786 |
+
("A character's gender is ambiguous or unspecified", "ambiguous_gender"),
|
| 787 |
+
("The image contains an intersex character", "intersex"),
|
| 788 |
+
]
|
| 789 |
+
|
| 790 |
+
STRUCTURAL_SYSTEM_TEMPLATE = """You are given a description of an image and a numbered list of statements.
|
| 791 |
+
|
| 792 |
+
Select EVERY statement that is true about the described image.
|
| 793 |
+
|
| 794 |
+
Return JSON ONLY matching this schema:
|
| 795 |
+
|
| 796 |
+
{{
|
| 797 |
+
"selections": [
|
| 798 |
+
{{"i": <int>}},
|
| 799 |
+
...
|
| 800 |
+
]
|
| 801 |
+
}}
|
| 802 |
+
|
| 803 |
+
Rules:
|
| 804 |
+
- Choose ONLY from indices 1..{N}.
|
| 805 |
+
- A statement is true if the description clearly supports it OR it is very strongly implied.
|
| 806 |
+
- Select ALL true statements, not just one per category.
|
| 807 |
+
- For example, if the image has two anthropomorphic characters (one male, one female),
|
| 808 |
+
you would select: "exactly two characters", "anthropomorphic animal", "stands on two legs",
|
| 809 |
+
"male character", and "female character".
|
| 810 |
+
- When no characters are visible, select only the "zero characters" statement.
|
| 811 |
+
- Do NOT guess when the description provides no evidence.
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
STRUCTURAL_USER_TEMPLATE = """IMAGE DESCRIPTION:
|
| 815 |
+
{image_description}
|
| 816 |
+
|
| 817 |
+
STATEMENTS (select all that are true by index):
|
| 818 |
+
{statement_lines}
|
| 819 |
+
"""
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
class StructuralSelectionItem(BaseModel):
|
| 823 |
+
i: int = Field(..., description="1-based index into the statement list.")
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
class StructuralSelectionResponse(BaseModel):
|
| 827 |
+
selections: List[StructuralSelectionItem] = Field(default_factory=list)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def _build_structural_response_format() -> Dict[str, Any]:
|
| 831 |
+
schema = {
|
| 832 |
+
"type": "object",
|
| 833 |
+
"properties": {
|
| 834 |
+
"selections": {
|
| 835 |
+
"type": "array",
|
| 836 |
+
"items": {
|
| 837 |
+
"type": "object",
|
| 838 |
+
"properties": {
|
| 839 |
+
"i": {"type": "integer"},
|
| 840 |
+
},
|
| 841 |
+
"required": ["i"],
|
| 842 |
+
"additionalProperties": False,
|
| 843 |
+
},
|
| 844 |
+
}
|
| 845 |
+
},
|
| 846 |
+
"required": ["selections"],
|
| 847 |
+
"additionalProperties": False,
|
| 848 |
+
}
|
| 849 |
+
return {
|
| 850 |
+
"type": "json_schema",
|
| 851 |
+
"json_schema": {
|
| 852 |
+
"name": "structural_selection",
|
| 853 |
+
"strict": True,
|
| 854 |
+
"schema": schema,
|
| 855 |
+
},
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def llm_infer_structural_tags(
|
| 860 |
+
query_text: str,
|
| 861 |
+
log=None,
|
| 862 |
+
*,
|
| 863 |
+
temperature: float = 0.0,
|
| 864 |
+
max_tokens: int = 256,
|
| 865 |
+
retries: int = 2,
|
| 866 |
+
) -> List[str]:
|
| 867 |
+
"""Infer structural tags (solo/duo/male/female/anthro/biped/…) via LLM.
|
| 868 |
+
|
| 869 |
+
Instead of retrieving these from a candidate list, we ask the LLM to agree
|
| 870 |
+
with natural-language statements about the image. This handles tags that
|
| 871 |
+
are almost never stated in captions but are visually/structurally obvious.
|
| 872 |
+
|
| 873 |
+
Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "biped"]).
|
| 874 |
+
"""
|
| 875 |
+
if log:
|
| 876 |
+
log("Stage3s (structural): inferring structural tags via statement agreement")
|
| 877 |
+
|
| 878 |
+
statements = _STRUCTURAL_STATEMENTS
|
| 879 |
+
lines = [f"{j}. {stmt}" for j, (stmt, _tag) in enumerate(statements, 1)]
|
| 880 |
+
statement_lines = "\n".join(lines)
|
| 881 |
+
N = len(statements)
|
| 882 |
+
|
| 883 |
+
response_format = _build_structural_response_format()
|
| 884 |
+
llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
|
| 885 |
+
response_format=response_format)
|
| 886 |
+
model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
|
| 887 |
+
|
| 888 |
+
parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)
|
| 889 |
+
|
| 890 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 891 |
+
[
|
| 892 |
+
("system", STRUCTURAL_SYSTEM_TEMPLATE),
|
| 893 |
+
("human", STRUCTURAL_USER_TEMPLATE),
|
| 894 |
+
],
|
| 895 |
+
template_format="f-string",
|
| 896 |
+
)
|
| 897 |
+
chain = prompt | llm | parser
|
| 898 |
+
|
| 899 |
+
if log:
|
| 900 |
+
log(f"Stage3s: model={model_name} statements={N}")
|
| 901 |
+
|
| 902 |
+
for att in range(retries + 1):
|
| 903 |
+
try:
|
| 904 |
+
parsed = chain.invoke({
|
| 905 |
+
"N": N,
|
| 906 |
+
"image_description": query_text,
|
| 907 |
+
"statement_lines": statement_lines,
|
| 908 |
+
})
|
| 909 |
+
|
| 910 |
+
if isinstance(parsed, BaseModel):
|
| 911 |
+
parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
|
| 912 |
+
|
| 913 |
+
sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
|
| 914 |
+
chosen_tags: List[str] = []
|
| 915 |
+
seen = set()
|
| 916 |
+
for item in sels:
|
| 917 |
+
idx = item.get("i") if isinstance(item, dict) else None
|
| 918 |
+
if not isinstance(idx, int) or idx < 1 or idx > N:
|
| 919 |
+
continue
|
| 920 |
+
tag = statements[idx - 1][1]
|
| 921 |
+
if tag not in seen:
|
| 922 |
+
chosen_tags.append(tag)
|
| 923 |
+
seen.add(tag)
|
| 924 |
+
|
| 925 |
+
if log:
|
| 926 |
+
tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
|
| 927 |
+
log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")
|
| 928 |
+
|
| 929 |
+
return chosen_tags
|
| 930 |
+
|
| 931 |
+
except Exception as e:
|
| 932 |
+
if log:
|
| 933 |
+
log(f"Stage3s: attempt {att+1} error: {e}")
|
| 934 |
+
|
| 935 |
+
if log:
|
| 936 |
+
log(f"Stage3s: gave up after {retries+1} attempts")
|
| 937 |
+
return []
|
scripts/eval_pipeline.py
CHANGED
|
@@ -151,6 +151,8 @@ class SampleResult:
|
|
| 151 |
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 152 |
# Tag implications
|
| 153 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
|
|
|
|
|
|
| 154 |
# Leaf-only metrics (strips implied ancestors from both sides)
|
| 155 |
leaf_precision: float = 0.0
|
| 156 |
leaf_recall: float = 0.0
|
|
@@ -161,6 +163,7 @@ class SampleResult:
|
|
| 161 |
stage1_time: float = 0.0
|
| 162 |
stage2_time: float = 0.0
|
| 163 |
stage3_time: float = 0.0
|
|
|
|
| 164 |
# Errors
|
| 165 |
error: Optional[str] = None
|
| 166 |
|
|
@@ -196,11 +199,12 @@ def _process_one_sample(
|
|
| 196 |
print_lock: threading.Lock,
|
| 197 |
min_why: Optional[str] = None,
|
| 198 |
expand_implications: bool = False,
|
|
|
|
| 199 |
) -> SampleResult:
|
| 200 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 201 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 202 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 203 |
-
from psq_rag.llm.select import llm_select_indices
|
| 204 |
from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
|
| 205 |
|
| 206 |
def log(msg: str) -> None:
|
|
@@ -288,6 +292,19 @@ def _process_one_sample(
|
|
| 288 |
why_counts[w] = why_counts.get(w, 0) + 1
|
| 289 |
result.why_counts = why_counts
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
# Tag implication expansion (post-Stage 3)
|
| 292 |
if expand_implications and result.selected_tags:
|
| 293 |
expanded, implied_only = expand_tags_via_implications(result.selected_tags)
|
|
@@ -355,11 +372,12 @@ def _process_one_sample(
|
|
| 355 |
if gt_char:
|
| 356 |
char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
|
| 357 |
impl_info = f" (+{len(result.implied_tags)} implied)" if result.implied_tags else ""
|
|
|
|
| 358 |
with print_lock:
|
| 359 |
print(
|
| 360 |
f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
|
| 361 |
f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
|
| 362 |
-
f"selected={len(result.selected_tags)}{impl_info}{char_info} "
|
| 363 |
f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
|
| 364 |
)
|
| 365 |
|
|
@@ -404,6 +422,7 @@ def run_eval(
|
|
| 404 |
workers: int = 1,
|
| 405 |
min_why: Optional[str] = "strong_implied",
|
| 406 |
expand_implications: bool = False,
|
|
|
|
| 407 |
) -> List[SampleResult]:
|
| 408 |
|
| 409 |
# Load eval samples — prefer expanded file, fall back to raw
|
|
@@ -469,6 +488,7 @@ def run_eval(
|
|
| 469 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 470 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 471 |
print_lock, min_why, expand_implications,
|
|
|
|
| 472 |
)
|
| 473 |
results.append(result)
|
| 474 |
else:
|
|
@@ -485,6 +505,7 @@ def run_eval(
|
|
| 485 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 486 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 487 |
print_lock, min_why, expand_implications,
|
|
|
|
| 488 |
): i
|
| 489 |
for i, sample in enumerate(samples)
|
| 490 |
}
|
|
@@ -553,6 +574,7 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 553 |
avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
|
| 554 |
|
| 555 |
avg_implied = sum(len(r.implied_tags) for r in valid) / n
|
|
|
|
| 556 |
|
| 557 |
print()
|
| 558 |
print("Stage 3 - Selection (ALL tags):")
|
|
@@ -562,6 +584,8 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 562 |
print(f" Avg selected tags: {avg_selected:.1f}")
|
| 563 |
if avg_implied > 0:
|
| 564 |
print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
|
|
|
|
|
|
|
| 565 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
| 566 |
|
| 567 |
# Leaf-only metrics
|
|
@@ -673,11 +697,14 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 673 |
|
| 674 |
print()
|
| 675 |
print("-" * 70)
|
|
|
|
| 676 |
print("Timing (avg per sample):")
|
| 677 |
print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
|
| 678 |
print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
|
| 679 |
print(f" Stage 3 (selection): {avg_t3:.2f}s")
|
| 680 |
-
|
|
|
|
|
|
|
| 681 |
print()
|
| 682 |
|
| 683 |
# Show worst and best F1 samples
|
|
@@ -739,6 +766,8 @@ def main(argv=None) -> int:
|
|
| 739 |
help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
|
| 740 |
ap.add_argument("--expand-implications", action="store_true", default=False,
|
| 741 |
help="Expand selected tags via tag implication chains (e.g. fox→canine→canid→mammal)")
|
|
|
|
|
|
|
| 742 |
|
| 743 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 744 |
|
|
@@ -761,18 +790,24 @@ def main(argv=None) -> int:
|
|
| 761 |
workers=args.workers,
|
| 762 |
min_why=min_why_val,
|
| 763 |
expand_implications=args.expand_implications,
|
|
|
|
| 764 |
)
|
| 765 |
|
| 766 |
print_summary(results)
|
| 767 |
|
| 768 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
if args.output:
|
| 770 |
out_path = Path(args.output)
|
| 771 |
else:
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 775 |
-
out_path = results_dir / f"eval_{args.caption_field}_n{args.n}_seed{args.seed}_{timestamp}.jsonl"
|
| 776 |
|
| 777 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 778 |
|
|
@@ -793,10 +828,65 @@ def main(argv=None) -> int:
|
|
| 793 |
"workers": args.workers,
|
| 794 |
"min_why": args.min_why,
|
| 795 |
"expand_implications": args.expand_implications,
|
|
|
|
| 796 |
"n_errors": sum(1 for r in results if r.error),
|
| 797 |
}
|
| 798 |
|
| 799 |
with out_path.open("w", encoding="utf-8") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
f.write(json.dumps(meta, ensure_ascii=False) + "\n")
|
| 801 |
for r in results:
|
| 802 |
row = {
|
|
@@ -806,44 +896,17 @@ def main(argv=None) -> int:
|
|
| 806 |
"rewrite_phrases": r.rewrite_phrases,
|
| 807 |
"retrieved_tags": sorted(r.retrieved_tags),
|
| 808 |
"selected_tags": sorted(r.selected_tags),
|
| 809 |
-
"
|
| 810 |
-
"
|
| 811 |
-
"
|
| 812 |
-
"selection_f1": round(r.selection_f1, 4),
|
| 813 |
-
# Character tag breakdown
|
| 814 |
"gt_character_tags": sorted(r.gt_character_tags),
|
| 815 |
"selected_character_tags": sorted(r.selected_character_tags),
|
| 816 |
-
"retrieved_character_tags": sorted(r.retrieved_character_tags),
|
| 817 |
-
"char_retrieval_recall": round(r.char_retrieval_recall, 4),
|
| 818 |
-
"char_precision": round(r.char_precision, 4),
|
| 819 |
-
"char_recall": round(r.char_recall, 4),
|
| 820 |
-
"char_f1": round(r.char_f1, 4),
|
| 821 |
-
# General tag breakdown
|
| 822 |
"gt_general_tags": sorted(r.gt_general_tags),
|
| 823 |
"selected_general_tags": sorted(r.selected_general_tags),
|
| 824 |
-
"general_precision": round(r.general_precision, 4),
|
| 825 |
-
"general_recall": round(r.general_recall, 4),
|
| 826 |
-
"general_f1": round(r.general_f1, 4),
|
| 827 |
-
# Diagnostic metrics
|
| 828 |
-
"retrieval_precision": round(r.retrieval_precision, 4),
|
| 829 |
-
"selection_given_retrieval": round(r.selection_given_retrieval, 4),
|
| 830 |
-
"over_selection_ratio": round(r.over_selection_ratio, 2),
|
| 831 |
-
"why_counts": r.why_counts,
|
| 832 |
-
"implied_tags": sorted(r.implied_tags),
|
| 833 |
-
# Leaf metrics
|
| 834 |
-
"leaf_precision": round(r.leaf_precision, 4),
|
| 835 |
-
"leaf_recall": round(r.leaf_recall, 4),
|
| 836 |
-
"leaf_f1": round(r.leaf_f1, 4),
|
| 837 |
-
"leaf_selected_count": r.leaf_selected_count,
|
| 838 |
-
"leaf_gt_count": r.leaf_gt_count,
|
| 839 |
-
# Timing
|
| 840 |
-
"stage1_time": round(r.stage1_time, 3),
|
| 841 |
-
"stage2_time": round(r.stage2_time, 3),
|
| 842 |
-
"stage3_time": round(r.stage3_time, 3),
|
| 843 |
"error": r.error,
|
| 844 |
}
|
| 845 |
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 846 |
-
print(f"
|
| 847 |
|
| 848 |
return 0
|
| 849 |
|
|
|
|
| 151 |
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 152 |
# Tag implications
|
| 153 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
| 154 |
+
# Structural inference tags (solo/duo/male/female/anthro/biped etc.)
|
| 155 |
+
structural_tags: List[str] = field(default_factory=list)
|
| 156 |
# Leaf-only metrics (strips implied ancestors from both sides)
|
| 157 |
leaf_precision: float = 0.0
|
| 158 |
leaf_recall: float = 0.0
|
|
|
|
| 163 |
stage1_time: float = 0.0
|
| 164 |
stage2_time: float = 0.0
|
| 165 |
stage3_time: float = 0.0
|
| 166 |
+
stage3s_time: float = 0.0
|
| 167 |
# Errors
|
| 168 |
error: Optional[str] = None
|
| 169 |
|
|
|
|
| 199 |
print_lock: threading.Lock,
|
| 200 |
min_why: Optional[str] = None,
|
| 201 |
expand_implications: bool = False,
|
| 202 |
+
infer_structural: bool = False,
|
| 203 |
) -> SampleResult:
|
| 204 |
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 205 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 206 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 207 |
+
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags
|
| 208 |
from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
|
| 209 |
|
| 210 |
def log(msg: str) -> None:
|
|
|
|
| 292 |
why_counts[w] = why_counts.get(w, 0) + 1
|
| 293 |
result.why_counts = why_counts
|
| 294 |
|
| 295 |
+
# Structural tag inference (solo/duo/male/female/anthro/biped etc.)
|
| 296 |
+
if infer_structural:
|
| 297 |
+
t0s = time.time()
|
| 298 |
+
structural = llm_infer_structural_tags(
|
| 299 |
+
caption, log=log, temperature=temperature,
|
| 300 |
+
)
|
| 301 |
+
result.stage3s_time = time.time() - t0s
|
| 302 |
+
result.structural_tags = structural
|
| 303 |
+
# Add structural tags not already selected
|
| 304 |
+
for st in structural:
|
| 305 |
+
result.selected_tags.add(st)
|
| 306 |
+
log(f"Structural: {structural}")
|
| 307 |
+
|
| 308 |
# Tag implication expansion (post-Stage 3)
|
| 309 |
if expand_implications and result.selected_tags:
|
| 310 |
expanded, implied_only = expand_tags_via_implications(result.selected_tags)
|
|
|
|
| 372 |
if gt_char:
|
| 373 |
char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
|
| 374 |
impl_info = f" (+{len(result.implied_tags)} implied)" if result.implied_tags else ""
|
| 375 |
+
struct_info = f" (+{len(result.structural_tags)} structural)" if result.structural_tags else ""
|
| 376 |
with print_lock:
|
| 377 |
print(
|
| 378 |
f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
|
| 379 |
f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
|
| 380 |
+
f"selected={len(result.selected_tags)}{impl_info}{struct_info}{char_info} "
|
| 381 |
f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
|
| 382 |
)
|
| 383 |
|
|
|
|
| 422 |
workers: int = 1,
|
| 423 |
min_why: Optional[str] = "strong_implied",
|
| 424 |
expand_implications: bool = False,
|
| 425 |
+
infer_structural: bool = False,
|
| 426 |
) -> List[SampleResult]:
|
| 427 |
|
| 428 |
# Load eval samples — prefer expanded file, fall back to raw
|
|
|
|
| 488 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 489 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 490 |
print_lock, min_why, expand_implications,
|
| 491 |
+
infer_structural,
|
| 492 |
)
|
| 493 |
results.append(result)
|
| 494 |
else:
|
|
|
|
| 505 |
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 506 |
per_phrase_k, temperature, max_tokens, verbose,
|
| 507 |
print_lock, min_why, expand_implications,
|
| 508 |
+
infer_structural,
|
| 509 |
): i
|
| 510 |
for i, sample in enumerate(samples)
|
| 511 |
}
|
|
|
|
| 574 |
avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
|
| 575 |
|
| 576 |
avg_implied = sum(len(r.implied_tags) for r in valid) / n
|
| 577 |
+
avg_structural = sum(len(r.structural_tags) for r in valid) / n
|
| 578 |
|
| 579 |
print()
|
| 580 |
print("Stage 3 - Selection (ALL tags):")
|
|
|
|
| 584 |
print(f" Avg selected tags: {avg_selected:.1f}")
|
| 585 |
if avg_implied > 0:
|
| 586 |
print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
|
| 587 |
+
if avg_structural > 0:
|
| 588 |
+
print(f" Avg structural tags: {avg_structural:.1f} (inferred via statement agreement)")
|
| 589 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
| 590 |
|
| 591 |
# Leaf-only metrics
|
|
|
|
| 697 |
|
| 698 |
print()
|
| 699 |
print("-" * 70)
|
| 700 |
+
avg_t3s = sum(r.stage3s_time for r in valid) / n
|
| 701 |
print("Timing (avg per sample):")
|
| 702 |
print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
|
| 703 |
print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
|
| 704 |
print(f" Stage 3 (selection): {avg_t3:.2f}s")
|
| 705 |
+
if avg_t3s > 0:
|
| 706 |
+
print(f" Stage 3s (structural):{avg_t3s:.2f}s")
|
| 707 |
+
print(f" Total: {avg_t1 + avg_t2 + avg_t3 + avg_t3s:.2f}s")
|
| 708 |
print()
|
| 709 |
|
| 710 |
# Show worst and best F1 samples
|
|
|
|
| 766 |
help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
|
| 767 |
ap.add_argument("--expand-implications", action="store_true", default=False,
|
| 768 |
help="Expand selected tags via tag implication chains (e.g. fox→canine→canid→mammal)")
|
| 769 |
+
ap.add_argument("--infer-structural", action="store_true", default=False,
|
| 770 |
+
help="Infer structural tags (solo/duo/male/female/anthro/biped) via LLM statement agreement")
|
| 771 |
|
| 772 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 773 |
|
|
|
|
| 790 |
workers=args.workers,
|
| 791 |
min_why=min_why_val,
|
| 792 |
expand_implications=args.expand_implications,
|
| 793 |
+
infer_structural=args.infer_structural,
|
| 794 |
)
|
| 795 |
|
| 796 |
print_summary(results)
|
| 797 |
|
| 798 |
+
# Save results in two formats:
|
| 799 |
+
# 1. Compact metrics JSONL (small, for git / LLM reading)
|
| 800 |
+
# 2. Full detail JSONL (large, for analysis scripts, gitignored)
|
| 801 |
+
results_dir = _REPO_ROOT / "data" / "eval_results"
|
| 802 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 803 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 804 |
+
base_name = f"eval_{args.caption_field}_n{args.n}_seed{args.seed}_{timestamp}"
|
| 805 |
+
|
| 806 |
if args.output:
|
| 807 |
out_path = Path(args.output)
|
| 808 |
else:
|
| 809 |
+
out_path = results_dir / f"{base_name}.jsonl"
|
| 810 |
+
detail_path = results_dir / f"{base_name}_detail.jsonl"
|
|
|
|
|
|
|
| 811 |
|
| 812 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 813 |
|
|
|
|
| 828 |
"workers": args.workers,
|
| 829 |
"min_why": args.min_why,
|
| 830 |
"expand_implications": args.expand_implications,
|
| 831 |
+
"infer_structural": args.infer_structural,
|
| 832 |
"n_errors": sum(1 for r in results if r.error),
|
| 833 |
}
|
| 834 |
|
| 835 |
with out_path.open("w", encoding="utf-8") as f:
|
| 836 |
+
f.write(json.dumps(meta, ensure_ascii=False) + "\n")
|
| 837 |
+
for r in results:
|
| 838 |
+
# Compact format: metrics + counts + small diff sets (not full tag lists)
|
| 839 |
+
missed_tags = sorted(r.ground_truth_tags - r.selected_tags)
|
| 840 |
+
extra_tags = sorted(r.selected_tags - r.ground_truth_tags)
|
| 841 |
+
row = {
|
| 842 |
+
"id": r.sample_id,
|
| 843 |
+
# Counts (not full lists)
|
| 844 |
+
"n_gt": len(r.ground_truth_tags),
|
| 845 |
+
"n_retrieved": len(r.retrieved_tags),
|
| 846 |
+
"n_selected": len(r.selected_tags),
|
| 847 |
+
"n_implied": len(r.implied_tags),
|
| 848 |
+
"n_structural": len(r.structural_tags),
|
| 849 |
+
# Overall metrics
|
| 850 |
+
"ret_R": round(r.retrieval_recall, 4),
|
| 851 |
+
"P": round(r.selection_precision, 4),
|
| 852 |
+
"R": round(r.selection_recall, 4),
|
| 853 |
+
"F1": round(r.selection_f1, 4),
|
| 854 |
+
# Leaf metrics
|
| 855 |
+
"leaf_P": round(r.leaf_precision, 4),
|
| 856 |
+
"leaf_R": round(r.leaf_recall, 4),
|
| 857 |
+
"leaf_F1": round(r.leaf_f1, 4),
|
| 858 |
+
"n_leaf_sel": r.leaf_selected_count,
|
| 859 |
+
"n_leaf_gt": r.leaf_gt_count,
|
| 860 |
+
# Diagnostic
|
| 861 |
+
"ret_P": round(r.retrieval_precision, 4),
|
| 862 |
+
"sel_given_ret": round(r.selection_given_retrieval, 4),
|
| 863 |
+
"over_sel": round(r.over_selection_ratio, 2),
|
| 864 |
+
"why": r.why_counts,
|
| 865 |
+
# Character metrics (compact)
|
| 866 |
+
"n_gt_char": len(r.gt_character_tags),
|
| 867 |
+
"n_sel_char": len(r.selected_character_tags),
|
| 868 |
+
"char_F1": round(r.char_f1, 4),
|
| 869 |
+
# General metrics (compact)
|
| 870 |
+
"gen_P": round(r.general_precision, 4),
|
| 871 |
+
"gen_R": round(r.general_recall, 4),
|
| 872 |
+
"gen_F1": round(r.general_f1, 4),
|
| 873 |
+
# Diff sets (small — only the errors, not the full lists)
|
| 874 |
+
"missed": missed_tags,
|
| 875 |
+
"extra": extra_tags,
|
| 876 |
+
# Structural tags inferred
|
| 877 |
+
"structural": r.structural_tags,
|
| 878 |
+
# Timing
|
| 879 |
+
"t1": round(r.stage1_time, 2),
|
| 880 |
+
"t2": round(r.stage2_time, 2),
|
| 881 |
+
"t3": round(r.stage3_time, 2),
|
| 882 |
+
"t3s": round(r.stage3s_time, 2),
|
| 883 |
+
"err": r.error,
|
| 884 |
+
}
|
| 885 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 886 |
+
print(f"\nCompact results saved to: {out_path}")
|
| 887 |
+
|
| 888 |
+
# Write full detail file (for analysis scripts)
|
| 889 |
+
with detail_path.open("w", encoding="utf-8") as f:
|
| 890 |
f.write(json.dumps(meta, ensure_ascii=False) + "\n")
|
| 891 |
for r in results:
|
| 892 |
row = {
|
|
|
|
| 896 |
"rewrite_phrases": r.rewrite_phrases,
|
| 897 |
"retrieved_tags": sorted(r.retrieved_tags),
|
| 898 |
"selected_tags": sorted(r.selected_tags),
|
| 899 |
+
"implied_tags": sorted(r.implied_tags),
|
| 900 |
+
"structural_tags": r.structural_tags,
|
| 901 |
+
"why_counts": r.why_counts,
|
|
|
|
|
|
|
| 902 |
"gt_character_tags": sorted(r.gt_character_tags),
|
| 903 |
"selected_character_tags": sorted(r.selected_character_tags),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
"gt_general_tags": sorted(r.gt_general_tags),
|
| 905 |
"selected_general_tags": sorted(r.selected_general_tags),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
"error": r.error,
|
| 907 |
}
|
| 908 |
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 909 |
+
print(f"Detail results saved to: {detail_path}")
|
| 910 |
|
| 911 |
return 0
|
| 912 |
|