Claude commited on
Commit
16c5aa4
·
1 Parent(s): 3bb67c1

Add compact eval analysis script for new output format

Browse files

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (1) hide show
  1. scripts/analyze_compact_eval.py +262 -0
scripts/analyze_compact_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analyze compact eval results (n=50) for patterns in missed and extra tags.
2
+
3
+ Works with the new compact JSONL format (missed/extra diff sets, not full tag lists).
4
+ """
5
+ from __future__ import annotations
6
+ import csv, json, re, sys
7
+ from collections import Counter, defaultdict
8
+ from pathlib import Path
9
+ from typing import Dict, List, Set, Tuple
10
+
11
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"}
13
+
14
+ def load_tag_db():
15
+ tag_type, tag_count = {}, {}
16
+ with (_REPO_ROOT / "fluffyrock_3m.csv").open("r", encoding="utf-8") as f:
17
+ for row in csv.reader(f):
18
+ if len(row) < 3: continue
19
+ tag = row[0].strip()
20
+ try: tid = int(row[1]) if row[1].strip() else -1
21
+ except ValueError: tid = -1
22
+ try: cnt = int(row[2]) if row[2].strip() else 0
23
+ except ValueError: cnt = 0
24
+ tag_type[tag] = tid
25
+ tag_count[tag] = cnt
26
+ return tag_type, tag_count
27
+
28
+ def load_implications():
29
+ impl = defaultdict(list)
30
+ p = _REPO_ROOT / "tag_implications-2023-07-20.csv"
31
+ if not p.is_file(): return impl
32
+ with p.open("r", encoding="utf-8") as f:
33
+ for row in csv.DictReader(f):
34
+ if row.get("status") == "active":
35
+ impl[row["antecedent_name"].strip()].append(row["consequent_name"].strip())
36
+ return dict(impl)
37
+
38
+ def get_leaf_tags(tags, impl):
39
+ non_leaves = set()
40
+ for tag in tags:
41
+ q = [tag]; vis = set()
42
+ while q:
43
+ t = q.pop()
44
+ for p in impl.get(t, []):
45
+ if p not in vis:
46
+ vis.add(p)
47
+ if p in tags: non_leaves.add(p)
48
+ q.append(p)
49
+ return tags - non_leaves
50
+
51
+ # ── Categorization ──
52
+ _TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine","mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid","amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine","oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid","suid","suina"})
53
+ _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
54
+ _POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"})
55
+ _COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)")
56
+ _STRUCTURAL = frozenset({"solo","duo","trio","group","zero_pictured","anthro","feral","humanoid","biped","quadruped","male","female","ambiguous_gender","intersex"})
57
+
58
+ def categorize(tag, tag_type):
59
+ tid = tag_type.get(tag, -1)
60
+ tn = TYPE_ID_NAMES.get(tid, "unknown")
61
+ if tn == "species": return "species"
62
+ if tn in ("artist","copyright","character","meta"): return tn
63
+ if tag in _TAXONOMY: return "taxonomy"
64
+ if tag in _BODY_PLAN: return "body_plan"
65
+ if tag in _POSE: return "pose/composition"
66
+ if _COUNT_RE.match(tag): return "count/anatomy"
67
+ if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender"
68
+ if any(k in tag for k in ("clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants","shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt","collar","scarf","cape","armor","suit","uniform","costume","outfit")): return "clothing"
69
+ if any(tag.startswith(c+"_") for c in ("red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream","gold","silver","teal","cyan","magenta")): return "color/marking"
70
+ if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking"
71
+ if "hair" in tag: return "hair"
72
+ if any(k in tag for k in ("muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw","paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane")): return "body/anatomy"
73
+ if any(k in tag for k in ("smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth","closed_eyes","wink")): return "expression"
74
+ return "other_general"
75
+
76
+ def main():
77
+ path = Path(sys.argv[1]) if len(sys.argv) > 1 else sorted((_REPO_ROOT/"data"/"eval_results").glob("eval_*.jsonl"))[-1]
78
+ tag_type, tag_count = load_tag_db()
79
+ impl = load_implications()
80
+
81
+ samples = []
82
+ with path.open("r", encoding="utf-8") as f:
83
+ for line in f:
84
+ row = json.loads(line)
85
+ if row.get("_meta"):
86
+ print(f"Config: min_why={row.get('min_why')}, expand_impl={row.get('expand_implications')}, "
87
+ f"structural={row.get('infer_structural')}, n={row.get('n_samples')}")
88
+ continue
89
+ if row.get("err"): continue
90
+ samples.append(row)
91
+
92
+ N = len(samples)
93
+ print(f"Analyzing {N} samples from {path.name}\n")
94
+
95
+ # ── 1. Missed tags (GT tags not in selected) ──
96
+ missed_counter = Counter()
97
+ extra_counter = Counter()
98
+ structural_results = []
99
+
100
+ for s in samples:
101
+ for t in s.get("missed", []): missed_counter[t] += 1
102
+ for t in s.get("extra", []): extra_counter[t] += 1
103
+ structural_results.append(s.get("structural", []))
104
+
105
+ # ── REPORT 1: Missed by category ──
106
+ print("=" * 70)
107
+ print(f"MISSED TAGS — GT tags not selected ({sum(missed_counter.values())} total misses, {len(missed_counter)} unique)")
108
+ print("=" * 70)
109
+
110
+ cat_missed = defaultdict(Counter)
111
+ for tag, cnt in missed_counter.items():
112
+ cat_missed[categorize(tag, tag_type)][tag] = cnt
113
+ cat_totals = {c: sum(v.values()) for c, v in cat_missed.items()}
114
+
115
+ for cat in sorted(cat_totals, key=cat_totals.get, reverse=True):
116
+ tags = cat_missed[cat]
117
+ total = cat_totals[cat]
118
+ # Is this category covered by structural inference?
119
+ struct_covered = sum(1 for t in tags if t in _STRUCTURAL)
120
+ struct_note = f" ({struct_covered} structural-coverable)" if struct_covered else ""
121
+ print(f"\n [{cat}] — {total} misses across {len(tags)} unique tags{struct_note}")
122
+ for tag, cnt in tags.most_common(10):
123
+ freq = tag_count.get(tag, 0)
124
+ struct_mark = " *STRUCTURAL*" if tag in _STRUCTURAL else ""
125
+ print(f" {tag:40s} missed {cnt:>2}/{N}{struct_mark} freq={freq:>9,}")
126
+
127
+ # ── REPORT 2: Missed tags that structural should catch ──
128
+ print("\n" + "=" * 70)
129
+ print("STRUCTURAL TAG ACCURACY")
130
+ print("=" * 70)
131
+
132
+ # Which structural tags are still being missed?
133
+ structural_missed = {t: c for t, c in missed_counter.items() if t in _STRUCTURAL}
134
+ if structural_missed:
135
+ print("\n Structural tags STILL missed (Stage 3s should catch these):")
136
+ for t, c in sorted(structural_missed.items(), key=lambda x: -x[1]):
137
+ print(f" {t:30s} missed {c}/{N}")
138
+ else:
139
+ print("\n All structural tags covered!")
140
+
141
+ # What structural tags are over-applied (false positives)?
142
+ structural_extra = {t: c for t, c in extra_counter.items() if t in _STRUCTURAL}
143
+ if structural_extra:
144
+ print(f"\n Structural tags wrongly added (false positives):")
145
+ for t, c in sorted(structural_extra.items(), key=lambda x: -x[1]):
146
+ print(f" {t:30s} extra {c}/{N}")
147
+
148
+ # Per-structural-tag stats from the structural field
149
+ struct_tag_counts = Counter()
150
+ for sl in structural_results:
151
+ for t in sl: struct_tag_counts[t] += 1
152
+ print(f"\n Structural tag selection frequency (how often Stage 3s picks each):")
153
+ for t, c in struct_tag_counts.most_common():
154
+ missed_c = structural_missed.get(t, 0)
155
+ extra_c = structural_extra.get(t, 0)
156
+ print(f" {t:30s} picked {c:>2}/{N} missed_in_GT={missed_c} false_pos={extra_c}")
157
+
158
+ # ── REPORT 3: Extra tags (false positives) by category ──
159
+ print("\n" + "=" * 70)
160
+ print(f"EXTRA TAGS — Selected but not in GT ({sum(extra_counter.values())} total, {len(extra_counter)} unique)")
161
+ print("=" * 70)
162
+
163
+ cat_extra = defaultdict(Counter)
164
+ for tag, cnt in extra_counter.items():
165
+ cat_extra[categorize(tag, tag_type)][tag] = cnt
166
+ cat_extra_totals = {c: sum(v.values()) for c, v in cat_extra.items()}
167
+
168
+ for cat in sorted(cat_extra_totals, key=cat_extra_totals.get, reverse=True):
169
+ tags = cat_extra[cat]
170
+ total = cat_extra_totals[cat]
171
+ print(f"\n [{cat}] — {total} false positives across {len(tags)} unique tags")
172
+ for tag, cnt in tags.most_common(8):
173
+ freq = tag_count.get(tag, 0)
174
+ print(f" {tag:40s} extra {cnt:>2}/{N} freq={freq:>9,}")
175
+
176
+ # ── REPORT 4: Leaf vs non-leaf in missed ──
177
+ print("\n" + "=" * 70)
178
+ print("MISSED: LEAF vs IMPLIED ANCESTORS")
179
+ print("=" * 70)
180
+ all_missed = set(missed_counter.keys())
181
+ leaf_missed = get_leaf_tags(all_missed, impl)
182
+ anc_missed = all_missed - leaf_missed
183
+ leaf_vol = sum(missed_counter[t] for t in leaf_missed)
184
+ anc_vol = sum(missed_counter[t] for t in anc_missed)
185
+ total_vol = leaf_vol + anc_vol
186
+ print(f"\n Unique missed: {len(all_missed)} tags")
187
+ print(f" Leaf: {len(leaf_missed)} ({len(leaf_missed)/max(1,len(all_missed))*100:.0f}%)")
188
+ print(f" Ancestor: {len(anc_missed)} ({len(anc_missed)/max(1,len(all_missed))*100:.0f}%)")
189
+ print(f" Miss volume: {total_vol}")
190
+ print(f" From leaf: {leaf_vol} ({leaf_vol/max(1,total_vol)*100:.0f}%)")
191
+ print(f" From ancestor: {anc_vol} ({anc_vol/max(1,total_vol)*100:.0f}%) — recoverable via implications")
192
+
193
+ # ── REPORT 5: Frequency distribution ──
194
+ print("\n" + "=" * 70)
195
+ print("FREQUENCY DISTRIBUTION OF MISSED TAGS")
196
+ print("=" * 70)
197
+ buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0,
198
+ "common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0}
199
+ for tag in missed_counter:
200
+ freq = tag_count.get(tag, -1)
201
+ if freq < 0: buckets["not_in_db"] += 1
202
+ elif freq < 100: buckets["very_rare (<100)"] += 1
203
+ elif freq < 1000: buckets["rare (100-1k)"] += 1
204
+ elif freq < 10000: buckets["medium (1k-10k)"] += 1
205
+ elif freq < 100000: buckets["common (10k-100k)"] += 1
206
+ else: buckets["very_common (100k+)"] += 1
207
+ for b, c in buckets.items():
208
+ print(f" {b:25s} {c:4d} unique tags ({c/max(1,len(missed_counter))*100:.0f}%)")
209
+
210
+ # ── REPORT 6: Over-selection analysis ──
211
+ print("\n" + "=" * 70)
212
+ print("OVER-SELECTION ANALYSIS")
213
+ print("=" * 70)
214
+ over_sels = [s["over_sel"] for s in samples]
215
+ over_sels.sort()
216
+ print(f"\n Avg over-selection ratio: {sum(over_sels)/N:.2f}x")
217
+ print(f" Median: {over_sels[N//2]:.2f}x")
218
+ print(f" Min: {over_sels[0]:.2f}x")
219
+ print(f" Max: {over_sels[-1]:.2f}x")
220
+ tight = sum(1 for x in over_sels if 0.8 <= x <= 1.5)
221
+ over = sum(1 for x in over_sels if x > 2.0)
222
+ under = sum(1 for x in over_sels if x < 0.5)
223
+ print(f" Tight (0.8-1.5x): {tight}/{N}")
224
+ print(f" Over (>2.0x): {over}/{N}")
225
+ print(f" Under (<0.5x): {under}/{N}")
226
+
227
+ # Worst over-selectors
228
+ worst = sorted(samples, key=lambda s: -s["over_sel"])[:5]
229
+ print(f"\n Worst over-selectors:")
230
+ for s in worst:
231
+ print(f" id={s['id']:>8} over_sel={s['over_sel']:.2f}x selected={s['n_selected']} gt={s['n_gt']} "
232
+ f"F1={s['F1']:.3f} n_extra={len(s.get('extra',[]))}")
233
+
234
+ # ── REPORT 7: Aggregate metrics ──
235
+ print("\n" + "=" * 70)
236
+ print("AGGREGATE METRICS")
237
+ print("=" * 70)
238
+ for metric, key in [("F1", "F1"), ("Precision", "P"), ("Recall", "R"),
239
+ ("Leaf F1", "leaf_F1"), ("Leaf P", "leaf_P"), ("Leaf R", "leaf_R"),
240
+ ("Retrieval Recall", "ret_R")]:
241
+ vals = [s[key] for s in samples]
242
+ avg = sum(vals)/N
243
+ vals.sort()
244
+ med = vals[N//2]
245
+ print(f" {metric:20s} avg={avg:.4f} median={med:.4f} min={vals[0]:.4f} max={vals[-1]:.4f}")
246
+
247
+ # ── REPORT 8: Samples sorted by F1 ──
248
+ print("\n" + "=" * 70)
249
+ print("WORST 10 SAMPLES BY F1")
250
+ print("=" * 70)
251
+ by_f1 = sorted(samples, key=lambda s: s["F1"])
252
+ for s in by_f1[:10]:
253
+ n_missed = len(s.get("missed", []))
254
+ n_extra = len(s.get("extra", []))
255
+ print(f" id={s['id']:>8} F1={s['F1']:.3f} P={s['P']:.3f} R={s['R']:.3f} "
256
+ f"gt={s['n_gt']} sel={s['n_selected']} missed={n_missed} extra={n_extra} "
257
+ f"structural={s.get('structural',[])} over_sel={s['over_sel']:.2f}x")
258
+
259
+ print()
260
+
261
+ if __name__ == "__main__":
262
+ main()