Claude commited on
Commit
a16e111
·
1 Parent(s): 4968635

Add structural tag inference (Stage 3s) and compact eval output

Browse files

Stage 3s: New LLM step that infers structural tags (solo/duo/male/female/
anthro/biped/feral/humanoid/quadruped/intersex/ambiguous_gender/zero_pictured/
trio/group) via natural-language statement agreement instead of retrieval.
These tags are almost never stated in captions but are structurally obvious.

Compact eval output: Eval pipeline now writes two files:
- Compact metrics JSONL (tracked in git, small) with counts + diff sets
- Full detail JSONL (gitignored, large) with complete tag lists for analysis

Also: .gitignore updated to exclude *_detail.jsonl files.

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +13 -2
  3. psq_rag/llm/select.py +175 -0
  4. scripts/eval_pipeline.py +102 -39
.gitignore CHANGED
@@ -10,3 +10,5 @@ tf_idf_files_420.joblib
10
  e621FastTextModel010Replacement_small.bin
11
  tfidf_hnsw_artists.bin
12
  tfidf_hnsw_tags.bin
 
 
 
10
  e621FastTextModel010Replacement_small.bin
11
  tfidf_hnsw_artists.bin
12
  tfidf_hnsw_tags.bin
13
+ # Full detail eval files (large) — only compact metrics tracked in git
14
+ *_detail.jsonl
app.py CHANGED
@@ -8,7 +8,7 @@ from typing import List
8
  from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
9
  from psq_rag.llm.rewrite import llm_rewrite_prompt
10
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
11
- from psq_rag.llm.select import llm_select_indices
12
  from psq_rag.retrieval.state import expand_tags_via_implications
13
 
14
 
@@ -224,7 +224,18 @@ def rag_pipeline_ui(user_prompt: str):
224
 
225
  selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
226
 
227
- log("Step 3b: Expand via tag implications")
 
 
 
 
 
 
 
 
 
 
 
228
  tag_set = set(selected_tags)
229
  expanded, implied_only = expand_tags_via_implications(tag_set)
230
  if implied_only:
 
8
  from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
9
  from psq_rag.llm.rewrite import llm_rewrite_prompt
10
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
11
+ from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags
12
  from psq_rag.retrieval.state import expand_tags_via_implications
13
 
14
 
 
224
 
225
  selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
226
 
227
+ log("Step 3b: Structural tag inference (solo/duo/gender/body plan)")
228
+ structural_tags = llm_infer_structural_tags(prompt_in, log=log)
229
+ if structural_tags:
230
+ # Add structural tags that aren't already selected
231
+ existing = {t for t in selected_tags}
232
+ new_structural = [t for t in structural_tags if t not in existing]
233
+ selected_tags.extend(new_structural)
234
+ log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
235
+ else:
236
+ log(" No structural tags inferred")
237
+
238
+ log("Step 3c: Expand via tag implications")
239
  tag_set = set(selected_tags)
240
  expanded, implied_only = expand_tags_via_implications(tag_set)
241
  if implied_only:
psq_rag/llm/select.py CHANGED
@@ -760,3 +760,178 @@ def llm_select_indices(
760
  return out_idx, tag_why
761
 
762
  return out_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  return out_idx, tag_why
761
 
762
  return out_idx
763
+
764
+
765
+ # ---------------------------------------------------------------------------
766
+ # Stage 3s: Structural tag inference (solo/duo/male/female/anthro/biped …)
767
+ # ---------------------------------------------------------------------------
768
+
769
+ # Each statement maps to exactly one tag. The LLM picks statement numbers.
770
+ _STRUCTURAL_STATEMENTS: List[Tuple[str, str]] = [
771
+ # Character count
772
+ ("The image contains zero characters (no living beings depicted)", "zero_pictured"),
773
+ ("The image contains exactly one character", "solo"),
774
+ ("The image contains exactly two characters", "duo"),
775
+ ("The image contains exactly three characters", "trio"),
776
+ ("The image contains four or more characters", "group"),
777
+ # Body plan
778
+ ("A character is a regular (non-anthropomorphic) animal", "feral"),
779
+ ("A character is an anthropomorphic animal (walks upright, has human-like posture)", "anthro"),
780
+ ("A character is a human or human-like being", "humanoid"),
781
+ ("A character stands or walks on two legs", "biped"),
782
+ ("A character stands or walks on four legs", "quadruped"),
783
+ # Gender
784
+ ("The image contains a male character", "male"),
785
+ ("The image contains a female character", "female"),
786
+ ("A character's gender is ambiguous or unspecified", "ambiguous_gender"),
787
+ ("The image contains an intersex character", "intersex"),
788
+ ]
789
+
790
+ STRUCTURAL_SYSTEM_TEMPLATE = """You are given a description of an image and a numbered list of statements.
791
+
792
+ Select EVERY statement that is true about the described image.
793
+
794
+ Return JSON ONLY matching this schema:
795
+
796
+ {{
797
+ "selections": [
798
+ {{"i": <int>}},
799
+ ...
800
+ ]
801
+ }}
802
+
803
+ Rules:
804
+ - Choose ONLY from indices 1..{N}.
805
+ - A statement is true if the description clearly supports it OR it is very strongly implied.
806
+ - Select ALL true statements, not just one per category.
807
+ - For example, if the image has two anthropomorphic characters (one male, one female),
808
+ you would select: "exactly two characters", "anthropomorphic animal", "stands on two legs",
809
+ "male character", and "female character".
810
+ - When no characters are visible, select only the "zero characters" statement.
811
+ - Do NOT guess when the description provides no evidence.
812
+ """
813
+
814
+ STRUCTURAL_USER_TEMPLATE = """IMAGE DESCRIPTION:
815
+ {image_description}
816
+
817
+ STATEMENTS (select all that are true by index):
818
+ {statement_lines}
819
+ """
820
+
821
+
822
+ class StructuralSelectionItem(BaseModel):
823
+ i: int = Field(..., description="1-based index into the statement list.")
824
+
825
+
826
+ class StructuralSelectionResponse(BaseModel):
827
+ selections: List[StructuralSelectionItem] = Field(default_factory=list)
828
+
829
+
830
+ def _build_structural_response_format() -> Dict[str, Any]:
831
+ schema = {
832
+ "type": "object",
833
+ "properties": {
834
+ "selections": {
835
+ "type": "array",
836
+ "items": {
837
+ "type": "object",
838
+ "properties": {
839
+ "i": {"type": "integer"},
840
+ },
841
+ "required": ["i"],
842
+ "additionalProperties": False,
843
+ },
844
+ }
845
+ },
846
+ "required": ["selections"],
847
+ "additionalProperties": False,
848
+ }
849
+ return {
850
+ "type": "json_schema",
851
+ "json_schema": {
852
+ "name": "structural_selection",
853
+ "strict": True,
854
+ "schema": schema,
855
+ },
856
+ }
857
+
858
+
859
+ def llm_infer_structural_tags(
860
+ query_text: str,
861
+ log=None,
862
+ *,
863
+ temperature: float = 0.0,
864
+ max_tokens: int = 256,
865
+ retries: int = 2,
866
+ ) -> List[str]:
867
+ """Infer structural tags (solo/duo/male/female/anthro/biped/…) via LLM.
868
+
869
+ Instead of retrieving these from a candidate list, we ask the LLM to agree
870
+ with natural-language statements about the image. This handles tags that
871
+ are almost never stated in captions but are visually/structurally obvious.
872
+
873
+ Returns a list of e621 tag strings (e.g. ["solo", "anthro", "male", "biped"]).
874
+ """
875
+ if log:
876
+ log("Stage3s (structural): inferring structural tags via statement agreement")
877
+
878
+ statements = _STRUCTURAL_STATEMENTS
879
+ lines = [f"{j}. {stmt}" for j, (stmt, _tag) in enumerate(statements, 1)]
880
+ statement_lines = "\n".join(lines)
881
+ N = len(statements)
882
+
883
+ response_format = _build_structural_response_format()
884
+ llm = _get_llm(temperature=temperature, max_tokens=max_tokens,
885
+ response_format=response_format)
886
+ model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
887
+
888
+ parser = PydanticOutputParser(pydantic_object=StructuralSelectionResponse)
889
+
890
+ prompt = ChatPromptTemplate.from_messages(
891
+ [
892
+ ("system", STRUCTURAL_SYSTEM_TEMPLATE),
893
+ ("human", STRUCTURAL_USER_TEMPLATE),
894
+ ],
895
+ template_format="f-string",
896
+ )
897
+ chain = prompt | llm | parser
898
+
899
+ if log:
900
+ log(f"Stage3s: model={model_name} statements={N}")
901
+
902
+ for att in range(retries + 1):
903
+ try:
904
+ parsed = chain.invoke({
905
+ "N": N,
906
+ "image_description": query_text,
907
+ "statement_lines": statement_lines,
908
+ })
909
+
910
+ if isinstance(parsed, BaseModel):
911
+ parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
912
+
913
+ sels = parsed.get("selections", []) if isinstance(parsed, dict) else []
914
+ chosen_tags: List[str] = []
915
+ seen = set()
916
+ for item in sels:
917
+ idx = item.get("i") if isinstance(item, dict) else None
918
+ if not isinstance(idx, int) or idx < 1 or idx > N:
919
+ continue
920
+ tag = statements[idx - 1][1]
921
+ if tag not in seen:
922
+ chosen_tags.append(tag)
923
+ seen.add(tag)
924
+
925
+ if log:
926
+ tag_str = ", ".join(chosen_tags) if chosen_tags else "(none)"
927
+ log(f"Stage3s: attempt {att+1} selected {len(chosen_tags)} tags: {tag_str}")
928
+
929
+ return chosen_tags
930
+
931
+ except Exception as e:
932
+ if log:
933
+ log(f"Stage3s: attempt {att+1} error: {e}")
934
+
935
+ if log:
936
+ log(f"Stage3s: gave up after {retries+1} attempts")
937
+ return []
scripts/eval_pipeline.py CHANGED
@@ -151,6 +151,8 @@ class SampleResult:
151
  why_counts: Dict[str, int] = field(default_factory=dict)
152
  # Tag implications
153
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
 
 
154
  # Leaf-only metrics (strips implied ancestors from both sides)
155
  leaf_precision: float = 0.0
156
  leaf_recall: float = 0.0
@@ -161,6 +163,7 @@ class SampleResult:
161
  stage1_time: float = 0.0
162
  stage2_time: float = 0.0
163
  stage3_time: float = 0.0
 
164
  # Errors
165
  error: Optional[str] = None
166
 
@@ -196,11 +199,12 @@ def _process_one_sample(
196
  print_lock: threading.Lock,
197
  min_why: Optional[str] = None,
198
  expand_implications: bool = False,
 
199
  ) -> SampleResult:
200
  """Process a single eval sample through the full pipeline. Thread-safe."""
201
  from psq_rag.llm.rewrite import llm_rewrite_prompt
202
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
203
- from psq_rag.llm.select import llm_select_indices
204
  from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
205
 
206
  def log(msg: str) -> None:
@@ -288,6 +292,19 @@ def _process_one_sample(
288
  why_counts[w] = why_counts.get(w, 0) + 1
289
  result.why_counts = why_counts
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  # Tag implication expansion (post-Stage 3)
292
  if expand_implications and result.selected_tags:
293
  expanded, implied_only = expand_tags_via_implications(result.selected_tags)
@@ -355,11 +372,12 @@ def _process_one_sample(
355
  if gt_char:
356
  char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
357
  impl_info = f" (+{len(result.implied_tags)} implied)" if result.implied_tags else ""
 
358
  with print_lock:
359
  print(
360
  f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
361
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
362
- f"selected={len(result.selected_tags)}{impl_info}{char_info} "
363
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
364
  )
365
 
@@ -404,6 +422,7 @@ def run_eval(
404
  workers: int = 1,
405
  min_why: Optional[str] = "strong_implied",
406
  expand_implications: bool = False,
 
407
  ) -> List[SampleResult]:
408
 
409
  # Load eval samples — prefer expanded file, fall back to raw
@@ -469,6 +488,7 @@ def run_eval(
469
  skip_rewrite, allow_nsfw, mode, chunk_size,
470
  per_phrase_k, temperature, max_tokens, verbose,
471
  print_lock, min_why, expand_implications,
 
472
  )
473
  results.append(result)
474
  else:
@@ -485,6 +505,7 @@ def run_eval(
485
  skip_rewrite, allow_nsfw, mode, chunk_size,
486
  per_phrase_k, temperature, max_tokens, verbose,
487
  print_lock, min_why, expand_implications,
 
488
  ): i
489
  for i, sample in enumerate(samples)
490
  }
@@ -553,6 +574,7 @@ def print_summary(results: List[SampleResult]) -> None:
553
  avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
554
 
555
  avg_implied = sum(len(r.implied_tags) for r in valid) / n
 
556
 
557
  print()
558
  print("Stage 3 - Selection (ALL tags):")
@@ -562,6 +584,8 @@ def print_summary(results: List[SampleResult]) -> None:
562
  print(f" Avg selected tags: {avg_selected:.1f}")
563
  if avg_implied > 0:
564
  print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
 
 
565
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
566
 
567
  # Leaf-only metrics
@@ -673,11 +697,14 @@ def print_summary(results: List[SampleResult]) -> None:
673
 
674
  print()
675
  print("-" * 70)
 
676
  print("Timing (avg per sample):")
677
  print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
678
  print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
679
  print(f" Stage 3 (selection): {avg_t3:.2f}s")
680
- print(f" Total: {avg_t1 + avg_t2 + avg_t3:.2f}s")
 
 
681
  print()
682
 
683
  # Show worst and best F1 samples
@@ -739,6 +766,8 @@ def main(argv=None) -> int:
739
  help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
740
  ap.add_argument("--expand-implications", action="store_true", default=False,
741
  help="Expand selected tags via tag implication chains (e.g. fox→canine→canid→mammal)")
 
 
742
 
743
  args = ap.parse_args(list(argv) if argv is not None else None)
744
 
@@ -761,18 +790,24 @@ def main(argv=None) -> int:
761
  workers=args.workers,
762
  min_why=min_why_val,
763
  expand_implications=args.expand_implications,
 
764
  )
765
 
766
  print_summary(results)
767
 
768
- # Always save detailed results
 
 
 
 
 
 
 
769
  if args.output:
770
  out_path = Path(args.output)
771
  else:
772
- results_dir = _REPO_ROOT / "data" / "eval_results"
773
- results_dir.mkdir(parents=True, exist_ok=True)
774
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
775
- out_path = results_dir / f"eval_{args.caption_field}_n{args.n}_seed{args.seed}_{timestamp}.jsonl"
776
 
777
  out_path.parent.mkdir(parents=True, exist_ok=True)
778
 
@@ -793,10 +828,65 @@ def main(argv=None) -> int:
793
  "workers": args.workers,
794
  "min_why": args.min_why,
795
  "expand_implications": args.expand_implications,
 
796
  "n_errors": sum(1 for r in results if r.error),
797
  }
798
 
799
  with out_path.open("w", encoding="utf-8") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
  f.write(json.dumps(meta, ensure_ascii=False) + "\n")
801
  for r in results:
802
  row = {
@@ -806,44 +896,17 @@ def main(argv=None) -> int:
806
  "rewrite_phrases": r.rewrite_phrases,
807
  "retrieved_tags": sorted(r.retrieved_tags),
808
  "selected_tags": sorted(r.selected_tags),
809
- "retrieval_recall": round(r.retrieval_recall, 4),
810
- "selection_precision": round(r.selection_precision, 4),
811
- "selection_recall": round(r.selection_recall, 4),
812
- "selection_f1": round(r.selection_f1, 4),
813
- # Character tag breakdown
814
  "gt_character_tags": sorted(r.gt_character_tags),
815
  "selected_character_tags": sorted(r.selected_character_tags),
816
- "retrieved_character_tags": sorted(r.retrieved_character_tags),
817
- "char_retrieval_recall": round(r.char_retrieval_recall, 4),
818
- "char_precision": round(r.char_precision, 4),
819
- "char_recall": round(r.char_recall, 4),
820
- "char_f1": round(r.char_f1, 4),
821
- # General tag breakdown
822
  "gt_general_tags": sorted(r.gt_general_tags),
823
  "selected_general_tags": sorted(r.selected_general_tags),
824
- "general_precision": round(r.general_precision, 4),
825
- "general_recall": round(r.general_recall, 4),
826
- "general_f1": round(r.general_f1, 4),
827
- # Diagnostic metrics
828
- "retrieval_precision": round(r.retrieval_precision, 4),
829
- "selection_given_retrieval": round(r.selection_given_retrieval, 4),
830
- "over_selection_ratio": round(r.over_selection_ratio, 2),
831
- "why_counts": r.why_counts,
832
- "implied_tags": sorted(r.implied_tags),
833
- # Leaf metrics
834
- "leaf_precision": round(r.leaf_precision, 4),
835
- "leaf_recall": round(r.leaf_recall, 4),
836
- "leaf_f1": round(r.leaf_f1, 4),
837
- "leaf_selected_count": r.leaf_selected_count,
838
- "leaf_gt_count": r.leaf_gt_count,
839
- # Timing
840
- "stage1_time": round(r.stage1_time, 3),
841
- "stage2_time": round(r.stage2_time, 3),
842
- "stage3_time": round(r.stage3_time, 3),
843
  "error": r.error,
844
  }
845
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
846
- print(f"\nDetailed results saved to: {out_path}")
847
 
848
  return 0
849
 
 
151
  why_counts: Dict[str, int] = field(default_factory=dict)
152
  # Tag implications
153
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
154
+ # Structural inference tags (solo/duo/male/female/anthro/biped etc.)
155
+ structural_tags: List[str] = field(default_factory=list)
156
  # Leaf-only metrics (strips implied ancestors from both sides)
157
  leaf_precision: float = 0.0
158
  leaf_recall: float = 0.0
 
163
  stage1_time: float = 0.0
164
  stage2_time: float = 0.0
165
  stage3_time: float = 0.0
166
+ stage3s_time: float = 0.0
167
  # Errors
168
  error: Optional[str] = None
169
 
 
199
  print_lock: threading.Lock,
200
  min_why: Optional[str] = None,
201
  expand_implications: bool = False,
202
+ infer_structural: bool = False,
203
  ) -> SampleResult:
204
  """Process a single eval sample through the full pipeline. Thread-safe."""
205
  from psq_rag.llm.rewrite import llm_rewrite_prompt
206
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
207
+ from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags
208
  from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
209
 
210
  def log(msg: str) -> None:
 
292
  why_counts[w] = why_counts.get(w, 0) + 1
293
  result.why_counts = why_counts
294
 
295
+ # Structural tag inference (solo/duo/male/female/anthro/biped etc.)
296
+ if infer_structural:
297
+ t0s = time.time()
298
+ structural = llm_infer_structural_tags(
299
+ caption, log=log, temperature=temperature,
300
+ )
301
+ result.stage3s_time = time.time() - t0s
302
+ result.structural_tags = structural
303
+ # Add structural tags not already selected
304
+ for st in structural:
305
+ result.selected_tags.add(st)
306
+ log(f"Structural: {structural}")
307
+
308
  # Tag implication expansion (post-Stage 3)
309
  if expand_implications and result.selected_tags:
310
  expanded, implied_only = expand_tags_via_implications(result.selected_tags)
 
372
  if gt_char:
373
  char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
374
  impl_info = f" (+{len(result.implied_tags)} implied)" if result.implied_tags else ""
375
+ struct_info = f" (+{len(result.structural_tags)} structural)" if result.structural_tags else ""
376
  with print_lock:
377
  print(
378
  f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
379
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
380
+ f"selected={len(result.selected_tags)}{impl_info}{struct_info}{char_info} "
381
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
382
  )
383
 
 
422
  workers: int = 1,
423
  min_why: Optional[str] = "strong_implied",
424
  expand_implications: bool = False,
425
+ infer_structural: bool = False,
426
  ) -> List[SampleResult]:
427
 
428
  # Load eval samples — prefer expanded file, fall back to raw
 
488
  skip_rewrite, allow_nsfw, mode, chunk_size,
489
  per_phrase_k, temperature, max_tokens, verbose,
490
  print_lock, min_why, expand_implications,
491
+ infer_structural,
492
  )
493
  results.append(result)
494
  else:
 
505
  skip_rewrite, allow_nsfw, mode, chunk_size,
506
  per_phrase_k, temperature, max_tokens, verbose,
507
  print_lock, min_why, expand_implications,
508
+ infer_structural,
509
  ): i
510
  for i, sample in enumerate(samples)
511
  }
 
574
  avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
575
 
576
  avg_implied = sum(len(r.implied_tags) for r in valid) / n
577
+ avg_structural = sum(len(r.structural_tags) for r in valid) / n
578
 
579
  print()
580
  print("Stage 3 - Selection (ALL tags):")
 
584
  print(f" Avg selected tags: {avg_selected:.1f}")
585
  if avg_implied > 0:
586
  print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
587
+ if avg_structural > 0:
588
+ print(f" Avg structural tags: {avg_structural:.1f} (inferred via statement agreement)")
589
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
590
 
591
  # Leaf-only metrics
 
697
 
698
  print()
699
  print("-" * 70)
700
+ avg_t3s = sum(r.stage3s_time for r in valid) / n
701
  print("Timing (avg per sample):")
702
  print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
703
  print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
704
  print(f" Stage 3 (selection): {avg_t3:.2f}s")
705
+ if avg_t3s > 0:
706
+ print(f" Stage 3s (structural):{avg_t3s:.2f}s")
707
+ print(f" Total: {avg_t1 + avg_t2 + avg_t3 + avg_t3s:.2f}s")
708
  print()
709
 
710
  # Show worst and best F1 samples
 
766
  help="Minimum 'why' confidence to keep (default: strong_implied). Use 'none' to disable filtering.")
767
  ap.add_argument("--expand-implications", action="store_true", default=False,
768
  help="Expand selected tags via tag implication chains (e.g. fox→canine→canid→mammal)")
769
+ ap.add_argument("--infer-structural", action="store_true", default=False,
770
+ help="Infer structural tags (solo/duo/male/female/anthro/biped) via LLM statement agreement")
771
 
772
  args = ap.parse_args(list(argv) if argv is not None else None)
773
 
 
790
  workers=args.workers,
791
  min_why=min_why_val,
792
  expand_implications=args.expand_implications,
793
+ infer_structural=args.infer_structural,
794
  )
795
 
796
  print_summary(results)
797
 
798
+ # Save results in two formats:
799
+ # 1. Compact metrics JSONL (small, for git / LLM reading)
800
+ # 2. Full detail JSONL (large, for analysis scripts, gitignored)
801
+ results_dir = _REPO_ROOT / "data" / "eval_results"
802
+ results_dir.mkdir(parents=True, exist_ok=True)
803
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
804
+ base_name = f"eval_{args.caption_field}_n{args.n}_seed{args.seed}_{timestamp}"
805
+
806
  if args.output:
807
  out_path = Path(args.output)
808
  else:
809
+ out_path = results_dir / f"{base_name}.jsonl"
810
+ detail_path = results_dir / f"{base_name}_detail.jsonl"
 
 
811
 
812
  out_path.parent.mkdir(parents=True, exist_ok=True)
813
 
 
828
  "workers": args.workers,
829
  "min_why": args.min_why,
830
  "expand_implications": args.expand_implications,
831
+ "infer_structural": args.infer_structural,
832
  "n_errors": sum(1 for r in results if r.error),
833
  }
834
 
835
  with out_path.open("w", encoding="utf-8") as f:
836
+ f.write(json.dumps(meta, ensure_ascii=False) + "\n")
837
+ for r in results:
838
+ # Compact format: metrics + counts + small diff sets (not full tag lists)
839
+ missed_tags = sorted(r.ground_truth_tags - r.selected_tags)
840
+ extra_tags = sorted(r.selected_tags - r.ground_truth_tags)
841
+ row = {
842
+ "id": r.sample_id,
843
+ # Counts (not full lists)
844
+ "n_gt": len(r.ground_truth_tags),
845
+ "n_retrieved": len(r.retrieved_tags),
846
+ "n_selected": len(r.selected_tags),
847
+ "n_implied": len(r.implied_tags),
848
+ "n_structural": len(r.structural_tags),
849
+ # Overall metrics
850
+ "ret_R": round(r.retrieval_recall, 4),
851
+ "P": round(r.selection_precision, 4),
852
+ "R": round(r.selection_recall, 4),
853
+ "F1": round(r.selection_f1, 4),
854
+ # Leaf metrics
855
+ "leaf_P": round(r.leaf_precision, 4),
856
+ "leaf_R": round(r.leaf_recall, 4),
857
+ "leaf_F1": round(r.leaf_f1, 4),
858
+ "n_leaf_sel": r.leaf_selected_count,
859
+ "n_leaf_gt": r.leaf_gt_count,
860
+ # Diagnostic
861
+ "ret_P": round(r.retrieval_precision, 4),
862
+ "sel_given_ret": round(r.selection_given_retrieval, 4),
863
+ "over_sel": round(r.over_selection_ratio, 2),
864
+ "why": r.why_counts,
865
+ # Character metrics (compact)
866
+ "n_gt_char": len(r.gt_character_tags),
867
+ "n_sel_char": len(r.selected_character_tags),
868
+ "char_F1": round(r.char_f1, 4),
869
+ # General metrics (compact)
870
+ "gen_P": round(r.general_precision, 4),
871
+ "gen_R": round(r.general_recall, 4),
872
+ "gen_F1": round(r.general_f1, 4),
873
+ # Diff sets (small — only the errors, not the full lists)
874
+ "missed": missed_tags,
875
+ "extra": extra_tags,
876
+ # Structural tags inferred
877
+ "structural": r.structural_tags,
878
+ # Timing
879
+ "t1": round(r.stage1_time, 2),
880
+ "t2": round(r.stage2_time, 2),
881
+ "t3": round(r.stage3_time, 2),
882
+ "t3s": round(r.stage3s_time, 2),
883
+ "err": r.error,
884
+ }
885
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
886
+ print(f"\nCompact results saved to: {out_path}")
887
+
888
+ # Write full detail file (for analysis scripts)
889
+ with detail_path.open("w", encoding="utf-8") as f:
890
  f.write(json.dumps(meta, ensure_ascii=False) + "\n")
891
  for r in results:
892
  row = {
 
896
  "rewrite_phrases": r.rewrite_phrases,
897
  "retrieved_tags": sorted(r.retrieved_tags),
898
  "selected_tags": sorted(r.selected_tags),
899
+ "implied_tags": sorted(r.implied_tags),
900
+ "structural_tags": r.structural_tags,
901
+ "why_counts": r.why_counts,
 
 
902
  "gt_character_tags": sorted(r.gt_character_tags),
903
  "selected_character_tags": sorted(r.selected_character_tags),
 
 
 
 
 
 
904
  "gt_general_tags": sorted(r.gt_general_tags),
905
  "selected_general_tags": sorted(r.selected_general_tags),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
  "error": r.error,
907
  }
908
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
909
+ print(f"Detail results saved to: {detail_path}")
910
 
911
  return 0
912