Prompt_Squirrel_RAG / scripts /analyze_probe_informativeness.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
Raw
History Blame Contribute Delete
17.9 kB
"""Rank candidate probe tags by informativeness before any LLM queries.
This is an offline metric pass combining:
- entropy / information gain from sample co-occurrence,
- lift against active groups/categories,
- reduced TF-IDF semantic focus against group centroids.
Compact outputs (overwrite in place):
- data/analysis/probe_informativeness.csv
- data/analysis/probe_informativeness_summary.json
"""
from __future__ import annotations
import csv
import json
import math
from collections import Counter
from pathlib import Path
from typing import Dict, List, Set, Tuple
import numpy as np
from psq_rag.retrieval.state import get_tfidf_tag_vectors
REPO = Path(__file__).resolve().parents[1]
COUNTS_CSV = REPO / "fluffyrock_3m.csv"
SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json"
REGISTRY_CSV = REPO / "data" / "category_registry.csv"
CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv"
OUT_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv"
OUT_SUMMARY = REPO / "data" / "analysis" / "probe_informativeness_summary.json"
MIN_COUNT = 200
MIN_PROBE_IMAGES = 5
MIN_GROUP_IMAGES = 20
SOFTMAX_TAU = 0.15
MMR_LAMBDA = 0.35
MMR_TOP_POOL = 120
MMR_K = 15
DOMAIN_JARGON = {
"solo", "duo", "trio", "anthro", "feral", "gynomorph", "andromorph", "maleherm",
"topwear", "bottomwear", "legwear", "handwear", "headwear", "footwear",
"leporid", "canid", "canis", "felid", "felis", "equid", "haplorhine",
"zero_pictured", "male/female", "male/male", "female/female",
}
def load_counts(path: Path) -> Dict[str, int]:
out: Dict[str, int] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
try:
out[row[0]] = int(row[2]) if row[2] else 0
except ValueError:
out[row[0]] = 0
return out
def load_image_tags(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]:
rows: List[Set[str]] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
raw = obj.get("tags_ground_truth_categorized", "")
if not raw:
continue
try:
d = json.loads(raw)
except Exception:
continue
tags: Set[str] = set()
if isinstance(d, dict):
for vals in d.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str) and counts.get(t, 0) >= min_count:
tags.add(t)
if tags:
rows.append(tags)
return rows
def load_excluded_wiki_groups_from_policy(path: Path) -> Set[str]:
"""Read excluded wiki groups from the tag-group map file.
Convention:
- rows with enabled=1 and category_name starting with 'ignored_'
- tag_group column contains the wiki group name to exclude.
"""
excluded: Set[str] = set()
if not path.is_file():
return excluded
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("enabled") or "").strip() not in {"1", "true", "True"}:
continue
category = (row.get("category_name") or "").strip().lower()
group = (row.get("tag_group") or "").strip()
if category.startswith("ignored_") and group:
excluded.add(group)
return excluded
def load_groups() -> Tuple[Dict[str, Set[str]], Set[str]]:
groups: Dict[str, Set[str]] = {}
excluded_wiki_groups = load_excluded_wiki_groups_from_policy(CATEGORY_TAG_GROUP_MAP_CSV)
with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f:
wiki = json.load(f)
for g, tags in wiki.items():
if g in excluded_wiki_groups:
continue
if isinstance(tags, list):
groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t}
with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}:
continue
c = (row.get("category_name") or "").strip()
t = (row.get("tag") or "").strip()
if c and t:
groups.setdefault(f"cat:{c}", set()).add(t)
return groups, excluded_wiki_groups
def needs_glossary(tag: str) -> bool:
if tag in DOMAIN_JARGON:
return True
if "/" in tag or "(" in tag or ")" in tag:
return True
if any(ch.isdigit() for ch in tag):
return True
# Taxonomy-ish suffixes often need disambiguation in prompts.
if tag.endswith("id") or tag.endswith("ine"):
return True
return False
def infer_probe_bundle(tag: str, semantic_top_group: str, strongest_group: str) -> str:
t = tag
g = f"{semantic_top_group} {strongest_group}".lower()
if t in {"solo", "duo", "trio", "group", "zero_pictured"}:
return "count_cardinality"
if t in {"anthro", "feral", "humanoid", "biped", "quadruped"}:
return "body_type_presence"
if t in {"clothed", "clothing", "topless", "bottomless", "nude", "barefoot", "topwear", "bottomwear"}:
return "clothing_state"
if any(x in t for x in ["canid", "canis", "felid", "felis", "equid", "leporid", "species", "mammal", "bird", "bear", "unicorn", "reptile", "dragon"]):
return "species_taxonomy"
if any(x in t for x in ["breast", "thigh", "hips", "curvy", "muscular", "overweight", "chubby", "butt"]):
return "body_shape_breasts"
if any(x in t for x in ["look", "gaze", "eyes", "smile", "blush", "open_mouth", "eyes_closed"]):
return "gaze_expression"
if t in {"text", "dialogue", "<3"} or any(x in t for x in ["text", "dialogue", "logo", "symbol"]):
return "text_symbols"
if any(x in t for x in ["background", "outside", "inside", "indoors", "outdoors", "standing", "sitting"]):
return "scene_pose"
if "cat:clothing" in g or "wiki:clothes" in g:
return "clothing_state"
if "cat:count" in g:
return "count_cardinality"
return "other"
def entropy_binary(p: float) -> float:
p = min(max(p, 1e-12), 1 - 1e-12)
return -(p * math.log2(p) + (1 - p) * math.log2(1 - p))
def softmax(x: np.ndarray, tau: float) -> np.ndarray:
z = x / max(tau, 1e-6)
z = z - np.max(z)
e = np.exp(z)
return e / max(np.sum(e), 1e-12)
def binary_mi(a_idx: Set[int], b_idx: Set[int], n: int) -> float:
# MI for Bernoulli variables in bits.
n11 = len(a_idx & b_idx)
n10 = len(a_idx - b_idx)
n01 = len(b_idx - a_idx)
n00 = n - n11 - n10 - n01
probs = {
(1, 1): n11 / n,
(1, 0): n10 / n,
(0, 1): n01 / n,
(0, 0): n00 / n,
}
pa = (n11 + n10) / n
pb = (n11 + n01) / n
mi = 0.0
for (a, b), p in probs.items():
if p <= 0:
continue
qa = pa if a == 1 else (1 - pa)
qb = pb if b == 1 else (1 - pb)
mi += p * math.log2(p / max(qa * qb, 1e-12))
return max(mi, 0.0)
def main() -> None:
counts = load_counts(COUNTS_CSV)
image_tags = load_image_tags(SAMPLE_JSONL, counts, MIN_COUNT)
n_images = len(image_tags)
if n_images == 0:
raise RuntimeError("No image tags loaded.")
groups_all, excluded_wiki_groups = load_groups()
probe_to_images: Dict[str, Set[int]] = {}
tag_occ = Counter()
for i, tags in enumerate(image_tags):
for t in tags:
tag_occ[t] += 1
probe_to_images.setdefault(t, set()).add(i)
group_to_images: Dict[str, Set[int]] = {}
for g, members in groups_all.items():
idxs: Set[int] = set()
for i, tags in enumerate(image_tags):
if tags & members:
idxs.add(i)
if len(idxs) >= MIN_GROUP_IMAGES:
group_to_images[g] = idxs
active_groups = sorted(group_to_images.keys())
if not active_groups:
raise RuntimeError("No active groups after MIN_GROUP_IMAGES filter.")
# Semantic centroids for active groups.
vec = get_tfidf_tag_vectors()
mat = vec["reduced_matrix_norm"]
tag_to_row = vec["tag_to_row_index"]
group_centroids: Dict[str, np.ndarray] = {}
for g in active_groups:
rows = [tag_to_row[t] for t in groups_all[g] if t in tag_to_row]
if len(rows) < 2:
continue
c = mat[rows].mean(axis=0)
n = np.linalg.norm(c)
if n > 0:
group_centroids[g] = c / n
semantic_groups = sorted(group_centroids.keys())
C = np.stack([group_centroids[g] for g in semantic_groups], axis=0) if semantic_groups else None
baseline_group_probs = {g: len(group_to_images[g]) / n_images for g in active_groups}
baseline_top5_mass = sum(sorted(baseline_group_probs.values(), reverse=True)[:5])
rows_out: List[Dict[str, str]] = []
probe_scores: Dict[str, float] = {}
for p, p_idxs in probe_to_images.items():
if len(p_idxs) < MIN_PROBE_IMAGES:
continue
q = len(p_idxs) / n_images
if q <= 0.0 or q >= 1.0:
continue
ig_sum = 0.0
ig_vals = []
mean_abs_log_lift = 0.0
lifts: Dict[str, float] = {}
p1_group_probs: Dict[str, float] = {}
for g in active_groups:
g_idxs = group_to_images[g]
pg = len(g_idxs) / n_images
pg1 = len(p_idxs & g_idxs) / len(p_idxs)
p0 = n_images - len(p_idxs)
pg0 = len((set(range(n_images)) - p_idxs) & g_idxs) / p0 if p0 > 0 else pg
ig = entropy_binary(pg) - (q * entropy_binary(pg1) + (1 - q) * entropy_binary(pg0))
ig = max(ig, 0.0)
ig_vals.append(ig)
ig_sum += ig
lift = (pg1 + 1e-9) / (pg + 1e-9)
lifts[g] = lift
p1_group_probs[g] = pg1
mean_abs_log_lift += abs(math.log2(lift + 1e-12))
mean_abs_log_lift /= len(active_groups)
ig_mean = float(np.mean(ig_vals)) if ig_vals else 0.0
top5_mass_p1 = sum(sorted(p1_group_probs.values(), reverse=True)[:5])
delta_top5_mass = top5_mass_p1 - baseline_top5_mass
strongest_group = max(lifts.items(), key=lambda kv: abs(math.log2(kv[1] + 1e-12)))
strongest_group_name = strongest_group[0]
strongest_group_lift = strongest_group[1]
semantic_top_group = ""
semantic_margin = 0.0
semantic_entropy_norm = 1.0
if C is not None and p in tag_to_row:
sims = C @ mat[tag_to_row[p]]
order = np.argsort(sims)[::-1]
i1 = int(order[0])
i2 = int(order[1]) if len(order) > 1 else i1
semantic_top_group = semantic_groups[i1]
semantic_margin = float(sims[i1] - sims[i2])
probs = softmax(sims, SOFTMAX_TAU)
h = -float(np.sum(probs * np.log2(np.maximum(probs, 1e-12))))
semantic_entropy_norm = h / math.log2(len(probs)) if len(probs) > 1 else 0.0
prevalence_balance = math.sqrt(q * (1 - q))
focus = max(0.0, 1.0 - semantic_entropy_norm)
combined_score = ig_sum * prevalence_balance * (0.5 + 0.5 * focus)
probe_scores[p] = combined_score
rows_out.append(
{
"tag": p,
"sample_occurrences": str(len(p_idxs)),
"fluffyrock_count": str(counts.get(p, 0)),
"prevalence": f"{q:.6f}",
"ig_sum_bits": f"{ig_sum:.6f}",
"ig_mean_bits": f"{ig_mean:.6f}",
"delta_top5_mass": f"{delta_top5_mass:.6f}",
"mean_abs_log2_lift": f"{mean_abs_log_lift:.6f}",
"semantic_top_group": semantic_top_group,
"semantic_margin": f"{semantic_margin:.6f}",
"semantic_entropy_norm": f"{semantic_entropy_norm:.6f}",
"strongest_group_by_lift": strongest_group_name,
"strongest_group_lift": f"{strongest_group_lift:.6f}",
"suggested_probe_bundle": infer_probe_bundle(p, semantic_top_group, strongest_group_name),
"needs_glossary": "1" if needs_glossary(p) else "0",
"combined_score": f"{combined_score:.6f}",
}
)
# Add an actionability score that downweights very common probes and favors
# probes that noticeably reshape top-group mass.
for r in rows_out:
q = float(r["prevalence"])
ig = float(r["ig_sum_bits"])
delta_top5 = max(0.0, float(r["delta_top5_mass"]))
semantic_focus = max(0.0, 1.0 - float(r["semantic_entropy_norm"]))
prevalence_penalty = max(0.0, 1.0 - abs(2 * q - 1.0))
actionable_score = ig * prevalence_penalty * delta_top5 * (0.5 + 0.5 * semantic_focus)
r["actionable_score"] = f"{actionable_score:.6f}"
rows_out.sort(key=lambda r: float(r["combined_score"]), reverse=True)
# Diversified shortlist via MMR-like greedy on top pool.
top_pool = [r["tag"] for r in rows_out[:MMR_TOP_POOL]]
selected: List[str] = []
while len(selected) < MMR_K and top_pool:
best_tag = None
best_val = -1e9
for t in top_pool:
rel = probe_scores.get(t, 0.0)
if not selected:
val = rel
else:
red = float(np.mean([binary_mi(probe_to_images[t], probe_to_images[s], n_images) for s in selected]))
val = rel - MMR_LAMBDA * red
if val > best_val:
best_val = val
best_tag = t
if best_tag is None:
break
selected.append(best_tag)
top_pool.remove(best_tag)
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
with OUT_CSV.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"sample_occurrences",
"fluffyrock_count",
"prevalence",
"ig_sum_bits",
"ig_mean_bits",
"delta_top5_mass",
"mean_abs_log2_lift",
"semantic_top_group",
"semantic_margin",
"semantic_entropy_norm",
"strongest_group_by_lift",
"strongest_group_lift",
"suggested_probe_bundle",
"needs_glossary",
"combined_score",
"actionable_score",
],
)
writer.writeheader()
writer.writerows(rows_out)
# Aggregate bundle-level utility using top actionable tags per bundle.
by_bundle: Dict[str, List[Dict[str, str]]] = {}
for r in rows_out:
by_bundle.setdefault(r["suggested_probe_bundle"], []).append(r)
bundle_scores = []
for b, items in by_bundle.items():
items_sorted = sorted(items, key=lambda x: float(x["actionable_score"]), reverse=True)
top_items = items_sorted[:5]
score = sum(float(x["actionable_score"]) for x in top_items)
glossary_rate = sum(1 for x in top_items if x["needs_glossary"] == "1") / len(top_items) if top_items else 0.0
bundle_scores.append(
{
"bundle": b,
"bundle_score_top5_actionable": round(score, 6),
"top_tags": [x["tag"] for x in top_items],
"glossary_rate_top5": round(glossary_rate, 3),
}
)
bundle_scores.sort(key=lambda x: x["bundle_score_top5_actionable"], reverse=True)
top_actionable = sorted(rows_out, key=lambda r: float(r["actionable_score"]), reverse=True)
top_mid_prevalence = [
r for r in top_actionable if 0.03 <= float(r["prevalence"]) <= 0.35
][:40]
summary = {
"config": {
"min_count": MIN_COUNT,
"min_probe_images": MIN_PROBE_IMAGES,
"min_group_images": MIN_GROUP_IMAGES,
"softmax_tau": SOFTMAX_TAU,
"mmr_lambda": MMR_LAMBDA,
"mmr_top_pool": MMR_TOP_POOL,
"mmr_k": MMR_K,
},
"n_images": n_images,
"n_candidate_probes": len(rows_out),
"n_active_groups": len(active_groups),
"excluded_wiki_groups": sorted(excluded_wiki_groups),
"top_probes_by_combined_score": rows_out[:25],
"top_probes_by_actionable_score": top_actionable[:25],
"top_actionable_mid_prevalence_for_manual_review": top_mid_prevalence,
"bundle_scores": bundle_scores[:20],
"diversified_probe_shortlist": selected,
"outputs": {
"csv": str(OUT_CSV),
"summary_json": str(OUT_SUMMARY),
},
}
with OUT_SUMMARY.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Images: {n_images}")
print(f"Active groups: {len(active_groups)}")
print(f"Candidate probes: {len(rows_out)}")
print(f"Top probes: {[r['tag'] for r in rows_out[:10]]}")
print(f"Diversified shortlist: {selected}")
print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}")
if __name__ == "__main__":
main()