Claude commited on
Commit
0ed7e94
·
1 Parent(s): 7261188

Add ranking metrics infrastructure to eval pipeline

Browse files

Implements Recall@K, Precision@K, and MRR for categorized suggestions.

Changes to eval_pipeline.py:
- Add categorized_suggestions field to SampleResult
- Generate categorized suggestions after LLM selection
- Store top-20 suggestions per category in detail results

Changes to eval_categorized.py:
- Compute ranking metrics from categorized_suggestions
- Show Recall@K, Precision@K, MRR alongside P/R/F1
- Track coverage (how many GT tags appear in suggestions)
- Gracefully handle old results without categorized_suggestions

Ranking metrics:
- Recall@K: Fraction of ground truth tags found in top-K suggestions
- Precision@K: Fraction of top-K suggestions that are correct
- MRR: Mean reciprocal rank of GT tags in suggestion list
- Coverage: Total GT tags that appear anywhere in suggestions

To generate results with ranking metrics:
python scripts/eval_pipeline.py --n 50 --expand-implications

To evaluate with ranking metrics:
python scripts/eval_categorized.py --results <results.jsonl> --k 5

https://claude.ai/code/session_015ZwE7a5E6YVTrMpuB2pXX7

scripts/eval_categorized.py CHANGED
@@ -109,9 +109,11 @@ class CategoryMetrics:
109
 
110
  # Ranking metrics (for suggestions)
111
  total_gt_tags: int = 0 # Total ground truth tags across all samples
112
- recall_at_k: float = 0.0 # How many GT tags found in top-K
113
- precision_at_k: float = 0.0 # How many top-K are correct
 
114
  mrr: float = 0.0 # Mean reciprocal rank
 
115
 
116
  @property
117
  def precision(self) -> float:
@@ -217,13 +219,42 @@ def compute_category_metrics(
217
 
218
  cat_metric.total_gt_tags += len(gt_cat_tags)
219
 
220
- # TODO: Ranking metrics (Recall@K, Precision@K, MRR)
221
- # These require ranked suggestions per category in the eval results.
222
- # Current eval_pipeline.py only outputs binary predictions (selected_tags).
223
- # To add ranking metrics, we need to:
224
- # 1. Modify eval_pipeline to generate categorized suggestions
225
- # 2. Store top-K suggestions per category in results
226
- # 3. Compute rank of each GT tag in the suggestion list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  return metrics
229
 
@@ -231,9 +262,17 @@ def compute_category_metrics(
231
  def print_category_metrics(
232
  metrics: Dict[str, CategoryMetrics],
233
  categories: Dict[str, TagCategory],
 
 
234
  ):
235
  """
236
  Print metrics organized by importance.
 
 
 
 
 
 
237
  """
238
  # Group by importance level
239
  by_importance = defaultdict(list)
@@ -257,15 +296,24 @@ def print_category_metrics(
257
  print(f" Constraint: {category.constraint.value}")
258
  print(f" Ground truth tags: {cat_metric.total_gt_tags}")
259
 
260
- # For EXACTLY_ONE, show accuracy
261
  if category.constraint.value == "exactly_one":
262
  print(f" Accuracy: {cat_metric.accuracy:.3f}")
263
-
264
- # For others, show P/R/F1
265
  print(f" Precision: {cat_metric.precision:.3f}")
266
  print(f" Recall: {cat_metric.recall:.3f}")
267
  print(f" F1: {cat_metric.f1:.3f}")
268
 
 
 
 
 
 
 
 
 
 
 
 
269
  # Show raw counts for debugging
270
  print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})")
271
 
@@ -374,7 +422,7 @@ def main():
374
  )
375
 
376
  # Print results
377
- print_category_metrics(metrics, categories)
378
 
379
 
380
  if __name__ == "__main__":
 
109
 
110
  # Ranking metrics (for suggestions)
111
  total_gt_tags: int = 0 # Total ground truth tags across all samples
112
+ found_in_suggestions: int = 0 # GT tags that appear anywhere in suggestions
113
+ recall_at_k: float = 0.0 # Fraction of GT tags found in top-K
114
+ precision_at_k: float = 0.0 # Fraction of top-K that are correct
115
  mrr: float = 0.0 # Mean reciprocal rank
116
+ mrr_count: int = 0 # Number of GT tags used for MRR calculation
117
 
118
  @property
119
  def precision(self) -> float:
 
219
 
220
  cat_metric.total_gt_tags += len(gt_cat_tags)
221
 
222
+ # Ranking metrics (if categorized_suggestions are available)
223
+ categorized_suggestions = result.get('categorized_suggestions', {})
224
+ cat_suggestions = categorized_suggestions.get(cat_name, [])
225
+
226
+ if cat_suggestions and gt_cat_tags:
227
+ # Convert to dict for easier lookup: {tag: rank}
228
+ # Suggestions are already sorted by score, so index = rank (0-indexed)
229
+ suggestion_ranks = {tag: rank for rank, (tag, score) in enumerate(cat_suggestions)}
230
+
231
+ # Count how many GT tags appear in suggestions (at any rank)
232
+ found_count = sum(1 for gt_tag in gt_cat_tags if gt_tag in suggestion_ranks)
233
+ cat_metric.found_in_suggestions += found_count
234
+
235
+ # Recall@K: fraction of GT tags in top-K
236
+ top_k_tags = {tag for tag, score in cat_suggestions[:k]}
237
+ recall_at_k_count = len(gt_cat_tags & top_k_tags)
238
+
239
+ # Precision@K: fraction of top-K that are in GT
240
+ if len(top_k_tags) > 0:
241
+ precision_at_k_count = len(top_k_tags & gt_cat_tags)
242
+ else:
243
+ precision_at_k_count = 0
244
+
245
+ # MRR: mean of 1/rank for each GT tag found in suggestions
246
+ reciprocal_ranks = []
247
+ for gt_tag in gt_cat_tags:
248
+ if gt_tag in suggestion_ranks:
249
+ rank = suggestion_ranks[gt_tag]
250
+ reciprocal_ranks.append(1.0 / (rank + 1)) # +1 because rank is 0-indexed
251
+
252
+ # Accumulate for averaging later
253
+ cat_metric.recall_at_k += recall_at_k_count / len(gt_cat_tags) if gt_cat_tags else 0
254
+ cat_metric.precision_at_k += precision_at_k_count / min(k, len(cat_suggestions)) if cat_suggestions else 0
255
+ if reciprocal_ranks:
256
+ cat_metric.mrr += sum(reciprocal_ranks) / len(reciprocal_ranks)
257
+ cat_metric.mrr_count += 1
258
 
259
  return metrics
260
 
 
262
  def print_category_metrics(
263
  metrics: Dict[str, CategoryMetrics],
264
  categories: Dict[str, TagCategory],
265
+ n_samples: int,
266
+ k: int,
267
  ):
268
  """
269
  Print metrics organized by importance.
270
+
271
+ Args:
272
+ metrics: Category metrics
273
+ categories: Category definitions
274
+ n_samples: Number of samples evaluated
275
+ k: Top-K for ranking metrics
276
  """
277
  # Group by importance level
278
  by_importance = defaultdict(list)
 
296
  print(f" Constraint: {category.constraint.value}")
297
  print(f" Ground truth tags: {cat_metric.total_gt_tags}")
298
 
299
+ # Binary prediction metrics
300
  if category.constraint.value == "exactly_one":
301
  print(f" Accuracy: {cat_metric.accuracy:.3f}")
 
 
302
  print(f" Precision: {cat_metric.precision:.3f}")
303
  print(f" Recall: {cat_metric.recall:.3f}")
304
  print(f" F1: {cat_metric.f1:.3f}")
305
 
306
+ # Ranking metrics (averaged across samples)
307
+ if cat_metric.mrr_count > 0:
308
+ avg_recall_at_k = cat_metric.recall_at_k / n_samples if n_samples > 0 else 0
309
+ avg_precision_at_k = cat_metric.precision_at_k / n_samples if n_samples > 0 else 0
310
+ avg_mrr = cat_metric.mrr / cat_metric.mrr_count
311
+
312
+ print(f" Recall@{k}: {avg_recall_at_k:.3f} (GT tags found in top-{k})")
313
+ print(f" Precision@{k}: {avg_precision_at_k:.3f} (top-{k} that are correct)")
314
+ print(f" MRR: {avg_mrr:.3f} (mean reciprocal rank)")
315
+ print(f" Coverage: {cat_metric.found_in_suggestions}/{cat_metric.total_gt_tags} (GT tags in suggestions)")
316
+
317
  # Show raw counts for debugging
318
  print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})")
319
 
 
422
  )
423
 
424
  # Print results
425
+ print_category_metrics(metrics, categories, len(eval_results), args.k)
426
 
427
 
428
  if __name__ == "__main__":
scripts/eval_pipeline.py CHANGED
@@ -166,6 +166,8 @@ class SampleResult:
166
  stage2_time: float = 0.0
167
  stage3_time: float = 0.0
168
  stage3s_time: float = 0.0
 
 
169
  # Errors
170
  error: Optional[str] = None
171
 
@@ -327,6 +329,33 @@ def _process_one_sample(
327
  result.selected_tags = expanded
328
  log(f"Implications: +{len(implied_only)} tags")
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  # Remove eval-excluded tags from predictions before scoring
331
  result.selected_tags -= _EVAL_EXCLUDED_TAGS
332
  result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
@@ -915,6 +944,7 @@ def main(argv=None) -> int:
915
  "selected_tags": sorted(r.selected_tags),
916
  "implied_tags": sorted(r.implied_tags),
917
  "structural_tags": r.structural_tags,
 
918
  "why_counts": r.why_counts,
919
  "tag_evidence": r.tag_evidence,
920
  "gt_character_tags": sorted(r.gt_character_tags),
 
166
  stage2_time: float = 0.0
167
  stage3_time: float = 0.0
168
  stage3s_time: float = 0.0
169
+ # Categorized suggestions (for ranking metrics)
170
+ categorized_suggestions: Dict[str, List[Tuple[str, float]]] = field(default_factory=dict)
171
  # Errors
172
  error: Optional[str] = None
173
 
 
329
  result.selected_tags = expanded
330
  log(f"Implications: +{len(implied_only)} tags")
331
 
332
+ # Generate categorized suggestions (for ranking metrics)
333
+ try:
334
+ from psq_rag.tagging.categorized_suggestions import (
335
+ generate_categorized_suggestions,
336
+ get_category_suggestions_dict,
337
+ )
338
+
339
+ # Use selected tags to generate category-wise ranked suggestions
340
+ categorized = generate_categorized_suggestions(
341
+ selected_tags=list(result.selected_tags),
342
+ allow_nsfw_tags=allow_nsfw,
343
+ top_n_per_category=20, # Get top 20 per category for eval
344
+ top_n_other=50,
345
+ )
346
+
347
+ # Convert to simple dict format: category -> [(tag, score), ...]
348
+ result.categorized_suggestions = {}
349
+ for cat_name, cat_sugg in categorized.by_category.items():
350
+ result.categorized_suggestions[cat_name] = cat_sugg.suggestions
351
+
352
+ # Also store "other" suggestions
353
+ result.categorized_suggestions['other'] = categorized.other_suggestions
354
+
355
+ log(f"Categorized: {len(result.categorized_suggestions)} categories")
356
+ except Exception as e:
357
+ log(f"Warning: Failed to generate categorized suggestions: {e}")
358
+
359
  # Remove eval-excluded tags from predictions before scoring
360
  result.selected_tags -= _EVAL_EXCLUDED_TAGS
361
  result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
 
944
  "selected_tags": sorted(r.selected_tags),
945
  "implied_tags": sorted(r.implied_tags),
946
  "structural_tags": r.structural_tags,
947
+ "categorized_suggestions": r.categorized_suggestions,
948
  "why_counts": r.why_counts,
949
  "tag_evidence": r.tag_evidence,
950
  "gt_character_tags": sorted(r.gt_character_tags),