Spaces:
Running
Add ranking metrics infrastructure to eval pipeline
Browse filesImplements Recall@K, Precision@K, and MRR for categorized suggestions.
Changes to eval_pipeline.py:
- Add categorized_suggestions field to SampleResult
- Generate categorized suggestions after LLM selection
- Store top-20 suggestions per category in detail results
Changes to eval_categorized.py:
- Compute ranking metrics from categorized_suggestions
- Show Recall@K, Precision@K, MRR alongside P/R/F1
- Track coverage (how many GT tags appear in suggestions)
- Gracefully handle old results without categorized_suggestions
Ranking metrics:
- Recall@K: Fraction of ground truth tags found in top-K suggestions
- Precision@K: Fraction of top-K suggestions that are correct
- MRR: Mean reciprocal rank of GT tags in suggestion list
- Coverage: Total GT tags that appear anywhere in suggestions
To generate results with ranking metrics:
python scripts/eval_pipeline.py --n 50 --expand-implications
To evaluate with ranking metrics:
python scripts/eval_categorized.py --results <results.jsonl> --k 5
https://claude.ai/code/session_015ZwE7a5E6YVTrMpuB2pXX7
- scripts/eval_categorized.py +61 -13
- scripts/eval_pipeline.py +30 -0
|
@@ -109,9 +109,11 @@ class CategoryMetrics:
|
|
| 109 |
|
| 110 |
# Ranking metrics (for suggestions)
|
| 111 |
total_gt_tags: int = 0 # Total ground truth tags across all samples
|
| 112 |
-
|
| 113 |
-
|
|
|
|
| 114 |
mrr: float = 0.0 # Mean reciprocal rank
|
|
|
|
| 115 |
|
| 116 |
@property
|
| 117 |
def precision(self) -> float:
|
|
@@ -217,13 +219,42 @@ def compute_category_metrics(
|
|
| 217 |
|
| 218 |
cat_metric.total_gt_tags += len(gt_cat_tags)
|
| 219 |
|
| 220 |
-
#
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
return metrics
|
| 229 |
|
|
@@ -231,9 +262,17 @@ def compute_category_metrics(
|
|
| 231 |
def print_category_metrics(
|
| 232 |
metrics: Dict[str, CategoryMetrics],
|
| 233 |
categories: Dict[str, TagCategory],
|
|
|
|
|
|
|
| 234 |
):
|
| 235 |
"""
|
| 236 |
Print metrics organized by importance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
"""
|
| 238 |
# Group by importance level
|
| 239 |
by_importance = defaultdict(list)
|
|
@@ -257,15 +296,24 @@ def print_category_metrics(
|
|
| 257 |
print(f" Constraint: {category.constraint.value}")
|
| 258 |
print(f" Ground truth tags: {cat_metric.total_gt_tags}")
|
| 259 |
|
| 260 |
-
#
|
| 261 |
if category.constraint.value == "exactly_one":
|
| 262 |
print(f" Accuracy: {cat_metric.accuracy:.3f}")
|
| 263 |
-
|
| 264 |
-
# For others, show P/R/F1
|
| 265 |
print(f" Precision: {cat_metric.precision:.3f}")
|
| 266 |
print(f" Recall: {cat_metric.recall:.3f}")
|
| 267 |
print(f" F1: {cat_metric.f1:.3f}")
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
# Show raw counts for debugging
|
| 270 |
print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})")
|
| 271 |
|
|
@@ -374,7 +422,7 @@ def main():
|
|
| 374 |
)
|
| 375 |
|
| 376 |
# Print results
|
| 377 |
-
print_category_metrics(metrics, categories)
|
| 378 |
|
| 379 |
|
| 380 |
if __name__ == "__main__":
|
|
|
|
| 109 |
|
| 110 |
# Ranking metrics (for suggestions)
|
| 111 |
total_gt_tags: int = 0 # Total ground truth tags across all samples
|
| 112 |
+
found_in_suggestions: int = 0 # GT tags that appear anywhere in suggestions
|
| 113 |
+
recall_at_k: float = 0.0 # Fraction of GT tags found in top-K
|
| 114 |
+
precision_at_k: float = 0.0 # Fraction of top-K that are correct
|
| 115 |
mrr: float = 0.0 # Mean reciprocal rank
|
| 116 |
+
mrr_count: int = 0 # Number of GT tags used for MRR calculation
|
| 117 |
|
| 118 |
@property
|
| 119 |
def precision(self) -> float:
|
|
|
|
| 219 |
|
| 220 |
cat_metric.total_gt_tags += len(gt_cat_tags)
|
| 221 |
|
| 222 |
+
# Ranking metrics (if categorized_suggestions are available)
|
| 223 |
+
categorized_suggestions = result.get('categorized_suggestions', {})
|
| 224 |
+
cat_suggestions = categorized_suggestions.get(cat_name, [])
|
| 225 |
+
|
| 226 |
+
if cat_suggestions and gt_cat_tags:
|
| 227 |
+
# Convert to dict for easier lookup: {tag: rank}
|
| 228 |
+
# Suggestions are already sorted by score, so index = rank (0-indexed)
|
| 229 |
+
suggestion_ranks = {tag: rank for rank, (tag, score) in enumerate(cat_suggestions)}
|
| 230 |
+
|
| 231 |
+
# Count how many GT tags appear in suggestions (at any rank)
|
| 232 |
+
found_count = sum(1 for gt_tag in gt_cat_tags if gt_tag in suggestion_ranks)
|
| 233 |
+
cat_metric.found_in_suggestions += found_count
|
| 234 |
+
|
| 235 |
+
# Recall@K: fraction of GT tags in top-K
|
| 236 |
+
top_k_tags = {tag for tag, score in cat_suggestions[:k]}
|
| 237 |
+
recall_at_k_count = len(gt_cat_tags & top_k_tags)
|
| 238 |
+
|
| 239 |
+
# Precision@K: fraction of top-K that are in GT
|
| 240 |
+
if len(top_k_tags) > 0:
|
| 241 |
+
precision_at_k_count = len(top_k_tags & gt_cat_tags)
|
| 242 |
+
else:
|
| 243 |
+
precision_at_k_count = 0
|
| 244 |
+
|
| 245 |
+
# MRR: mean of 1/rank for each GT tag found in suggestions
|
| 246 |
+
reciprocal_ranks = []
|
| 247 |
+
for gt_tag in gt_cat_tags:
|
| 248 |
+
if gt_tag in suggestion_ranks:
|
| 249 |
+
rank = suggestion_ranks[gt_tag]
|
| 250 |
+
reciprocal_ranks.append(1.0 / (rank + 1)) # +1 because rank is 0-indexed
|
| 251 |
+
|
| 252 |
+
# Accumulate for averaging later
|
| 253 |
+
cat_metric.recall_at_k += recall_at_k_count / len(gt_cat_tags) if gt_cat_tags else 0
|
| 254 |
+
cat_metric.precision_at_k += precision_at_k_count / min(k, len(cat_suggestions)) if cat_suggestions else 0
|
| 255 |
+
if reciprocal_ranks:
|
| 256 |
+
cat_metric.mrr += sum(reciprocal_ranks) / len(reciprocal_ranks)
|
| 257 |
+
cat_metric.mrr_count += 1
|
| 258 |
|
| 259 |
return metrics
|
| 260 |
|
|
|
|
| 262 |
def print_category_metrics(
|
| 263 |
metrics: Dict[str, CategoryMetrics],
|
| 264 |
categories: Dict[str, TagCategory],
|
| 265 |
+
n_samples: int,
|
| 266 |
+
k: int,
|
| 267 |
):
|
| 268 |
"""
|
| 269 |
Print metrics organized by importance.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
metrics: Category metrics
|
| 273 |
+
categories: Category definitions
|
| 274 |
+
n_samples: Number of samples evaluated
|
| 275 |
+
k: Top-K for ranking metrics
|
| 276 |
"""
|
| 277 |
# Group by importance level
|
| 278 |
by_importance = defaultdict(list)
|
|
|
|
| 296 |
print(f" Constraint: {category.constraint.value}")
|
| 297 |
print(f" Ground truth tags: {cat_metric.total_gt_tags}")
|
| 298 |
|
| 299 |
+
# Binary prediction metrics
|
| 300 |
if category.constraint.value == "exactly_one":
|
| 301 |
print(f" Accuracy: {cat_metric.accuracy:.3f}")
|
|
|
|
|
|
|
| 302 |
print(f" Precision: {cat_metric.precision:.3f}")
|
| 303 |
print(f" Recall: {cat_metric.recall:.3f}")
|
| 304 |
print(f" F1: {cat_metric.f1:.3f}")
|
| 305 |
|
| 306 |
+
# Ranking metrics (averaged across samples)
|
| 307 |
+
if cat_metric.mrr_count > 0:
|
| 308 |
+
avg_recall_at_k = cat_metric.recall_at_k / n_samples if n_samples > 0 else 0
|
| 309 |
+
avg_precision_at_k = cat_metric.precision_at_k / n_samples if n_samples > 0 else 0
|
| 310 |
+
avg_mrr = cat_metric.mrr / cat_metric.mrr_count
|
| 311 |
+
|
| 312 |
+
print(f" Recall@{k}: {avg_recall_at_k:.3f} (GT tags found in top-{k})")
|
| 313 |
+
print(f" Precision@{k}: {avg_precision_at_k:.3f} (top-{k} that are correct)")
|
| 314 |
+
print(f" MRR: {avg_mrr:.3f} (mean reciprocal rank)")
|
| 315 |
+
print(f" Coverage: {cat_metric.found_in_suggestions}/{cat_metric.total_gt_tags} (GT tags in suggestions)")
|
| 316 |
+
|
| 317 |
# Show raw counts for debugging
|
| 318 |
print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})")
|
| 319 |
|
|
|
|
| 422 |
)
|
| 423 |
|
| 424 |
# Print results
|
| 425 |
+
print_category_metrics(metrics, categories, len(eval_results), args.k)
|
| 426 |
|
| 427 |
|
| 428 |
if __name__ == "__main__":
|
|
@@ -166,6 +166,8 @@ class SampleResult:
|
|
| 166 |
stage2_time: float = 0.0
|
| 167 |
stage3_time: float = 0.0
|
| 168 |
stage3s_time: float = 0.0
|
|
|
|
|
|
|
| 169 |
# Errors
|
| 170 |
error: Optional[str] = None
|
| 171 |
|
|
@@ -327,6 +329,33 @@ def _process_one_sample(
|
|
| 327 |
result.selected_tags = expanded
|
| 328 |
log(f"Implications: +{len(implied_only)} tags")
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
# Remove eval-excluded tags from predictions before scoring
|
| 331 |
result.selected_tags -= _EVAL_EXCLUDED_TAGS
|
| 332 |
result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
|
|
@@ -915,6 +944,7 @@ def main(argv=None) -> int:
|
|
| 915 |
"selected_tags": sorted(r.selected_tags),
|
| 916 |
"implied_tags": sorted(r.implied_tags),
|
| 917 |
"structural_tags": r.structural_tags,
|
|
|
|
| 918 |
"why_counts": r.why_counts,
|
| 919 |
"tag_evidence": r.tag_evidence,
|
| 920 |
"gt_character_tags": sorted(r.gt_character_tags),
|
|
|
|
| 166 |
stage2_time: float = 0.0
|
| 167 |
stage3_time: float = 0.0
|
| 168 |
stage3s_time: float = 0.0
|
| 169 |
+
# Categorized suggestions (for ranking metrics)
|
| 170 |
+
categorized_suggestions: Dict[str, List[Tuple[str, float]]] = field(default_factory=dict)
|
| 171 |
# Errors
|
| 172 |
error: Optional[str] = None
|
| 173 |
|
|
|
|
| 329 |
result.selected_tags = expanded
|
| 330 |
log(f"Implications: +{len(implied_only)} tags")
|
| 331 |
|
| 332 |
+
# Generate categorized suggestions (for ranking metrics)
|
| 333 |
+
try:
|
| 334 |
+
from psq_rag.tagging.categorized_suggestions import (
|
| 335 |
+
generate_categorized_suggestions,
|
| 336 |
+
get_category_suggestions_dict,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Use selected tags to generate category-wise ranked suggestions
|
| 340 |
+
categorized = generate_categorized_suggestions(
|
| 341 |
+
selected_tags=list(result.selected_tags),
|
| 342 |
+
allow_nsfw_tags=allow_nsfw,
|
| 343 |
+
top_n_per_category=20, # Get top 20 per category for eval
|
| 344 |
+
top_n_other=50,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Convert to simple dict format: category -> [(tag, score), ...]
|
| 348 |
+
result.categorized_suggestions = {}
|
| 349 |
+
for cat_name, cat_sugg in categorized.by_category.items():
|
| 350 |
+
result.categorized_suggestions[cat_name] = cat_sugg.suggestions
|
| 351 |
+
|
| 352 |
+
# Also store "other" suggestions
|
| 353 |
+
result.categorized_suggestions['other'] = categorized.other_suggestions
|
| 354 |
+
|
| 355 |
+
log(f"Categorized: {len(result.categorized_suggestions)} categories")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
log(f"Warning: Failed to generate categorized suggestions: {e}")
|
| 358 |
+
|
| 359 |
# Remove eval-excluded tags from predictions before scoring
|
| 360 |
result.selected_tags -= _EVAL_EXCLUDED_TAGS
|
| 361 |
result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
|
|
|
|
| 944 |
"selected_tags": sorted(r.selected_tags),
|
| 945 |
"implied_tags": sorted(r.implied_tags),
|
| 946 |
"structural_tags": r.structural_tags,
|
| 947 |
+
"categorized_suggestions": r.categorized_suggestions,
|
| 948 |
"why_counts": r.why_counts,
|
| 949 |
"tag_evidence": r.tag_evidence,
|
| 950 |
"gt_character_tags": sorted(r.gt_character_tags),
|