PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss commited on
Commit
77fad9d
·
verified ·
1 Parent(s): aaf8c59

Add qualitative failure example sampler

Browse files
Files changed (1) hide show
  1. scripts/sample_failure_examples.py +163 -0
scripts/sample_failure_examples.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Sample publication-friendly success/failure examples from evaluation predictions.
3
+
4
+ Reads raw predictions and normalized scored predictions from an eval directory, then writes:
5
+ - analysis/failure_examples.json
6
+ - analysis/failure_examples.md
7
+
8
+ Designed to support qualitative error analysis in a paper.
9
+ """
10
+ import argparse
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List
14
+
15
+ from tmf921_train.utils import parse_json, write_json
16
+
17
+ DEFAULT_LAYERS = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921", "camara", "intent_3gpp", "adversarial_ambiguous", "adversarial_out_of_scope"]
18
+
19
+
20
+ def load_rows(eval_dir: Path, split: str) -> List[Dict[str, Any]]:
21
+ pred_path = eval_dir / split / "predictions.json"
22
+ norm_path = eval_dir / split / "normalized_predictions_scored.json"
23
+ if not pred_path.exists():
24
+ return []
25
+ pred = json.loads(pred_path.read_text())
26
+ if norm_path.exists():
27
+ norm = {r.get("id"): r for r in json.loads(norm_path.read_text())}
28
+ out = []
29
+ for r in pred:
30
+ nr = norm.get(r.get("id"), {})
31
+ merged = dict(r)
32
+ for k, v in nr.items():
33
+ if k not in merged:
34
+ merged[k] = v
35
+ out.append(merged)
36
+ return out
37
+ return pred
38
+
39
+
40
+ def summarize_text(text: str, max_chars: int = 1800) -> str:
41
+ if text is None:
42
+ return ""
43
+ text = str(text).strip()
44
+ if len(text) <= max_chars:
45
+ return text
46
+ return text[:max_chars] + "\n...<truncated>..."
47
+
48
+
49
+ def infer_error_label(row: Dict[str, Any]) -> str:
50
+ if not row.get("parse_json", False) or not row.get("norm_parse_json", True):
51
+ return "invalid_or_unparseable_json"
52
+ layer = row.get("target_layer")
53
+ nf1 = row.get("norm_field_f1", row.get("field_f1", 0.0)) or 0.0
54
+ kf1 = row.get("norm_key_f1", 0.0) or 0.0
55
+ if kf1 > 0.95 and nf1 < 0.5:
56
+ return "correct_structure_wrong_values"
57
+ if kf1 < 0.8:
58
+ return "structural_mismatch_or_extra_missing_keys"
59
+ if layer == "o1_nrm":
60
+ return "o1_value_fidelity_error"
61
+ if layer == "a1_policy":
62
+ return "a1_policy_value_error"
63
+ if "lifecycle_report" in str(layer):
64
+ return "lifecycle_report_measurement_mismatch"
65
+ if "lifecycle_monitor" in str(layer):
66
+ return "lifecycle_monitor_measurement_mismatch"
67
+ return "value_level_mismatch"
68
+
69
+
70
+ def choose_examples(rows: List[Dict[str, Any]], layer: str, n_fail: int, n_success: int) -> Dict[str, List[Dict[str, Any]]]:
71
+ layer_rows = [r for r in rows if r.get("target_layer") == layer]
72
+ if not layer_rows:
73
+ return {"failures": [], "successes": []}
74
+ failures = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)))[:n_fail]
75
+ successes = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)), reverse=True)[:n_success]
76
+ return {"failures": failures, "successes": successes}
77
+
78
+
79
+ def compact_row(row: Dict[str, Any], split: str, kind: str) -> Dict[str, Any]:
80
+ pred_obj, _ = parse_json(row.get("prediction", ""))
81
+ gold_obj, _ = parse_json(row.get("gold", ""))
82
+ return {
83
+ "split": split,
84
+ "kind": kind,
85
+ "id": row.get("id"),
86
+ "target_layer": row.get("target_layer"),
87
+ "slice_type": row.get("slice_type"),
88
+ "lifecycle_operation": row.get("lifecycle_operation"),
89
+ "parse_json": row.get("parse_json"),
90
+ "exact_match": row.get("exact_match"),
91
+ "field_f1": row.get("field_f1"),
92
+ "norm_field_f1": row.get("norm_field_f1"),
93
+ "norm_key_f1": row.get("norm_key_f1"),
94
+ "error_label": infer_error_label(row) if kind == "failure" else "success_or_high_scoring_example",
95
+ "gold_json_keys": list(gold_obj.keys()) if isinstance(gold_obj, dict) else None,
96
+ "prediction_json_keys": list(pred_obj.keys()) if isinstance(pred_obj, dict) else None,
97
+ "gold": summarize_text(row.get("gold", "")),
98
+ "prediction": summarize_text(row.get("prediction", "")),
99
+ }
100
+
101
+
102
+ def main():
103
+ ap = argparse.ArgumentParser()
104
+ ap.add_argument("--eval_dir", required=True, help="Eval dir containing split/predictions.json and normalized_predictions_scored.json")
105
+ ap.add_argument("--output_dir", default="analysis")
106
+ ap.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
107
+ ap.add_argument("--layers", nargs="+", default=DEFAULT_LAYERS)
108
+ ap.add_argument("--failures_per_layer", type=int, default=3)
109
+ ap.add_argument("--successes_per_layer", type=int, default=1)
110
+ args = ap.parse_args()
111
+
112
+ eval_dir = Path(args.eval_dir)
113
+ out_dir = Path(args.output_dir)
114
+ out_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ examples: List[Dict[str, Any]] = []
117
+ for split in args.splits:
118
+ rows = load_rows(eval_dir, split)
119
+ for layer in args.layers:
120
+ picked = choose_examples(rows, layer, args.failures_per_layer, args.successes_per_layer)
121
+ for r in picked["failures"]:
122
+ examples.append(compact_row(r, split, "failure"))
123
+ for r in picked["successes"]:
124
+ examples.append(compact_row(r, split, "success"))
125
+
126
+ write_json(out_dir / "failure_examples.json", examples)
127
+
128
+ lines = []
129
+ A = lines.append
130
+ A("# Qualitative Success and Failure Examples")
131
+ A("")
132
+ A(f"Source eval dir: `{eval_dir}`")
133
+ A("")
134
+ A("These examples are sampled to support qualitative error analysis. Long JSON objects are truncated for readability; full examples are in `failure_examples.json`.")
135
+ A("")
136
+ for i, ex in enumerate(examples, start=1):
137
+ A(f"## Example {i}: {ex['kind']} — `{ex['target_layer']}` — `{ex['split']}`")
138
+ A("")
139
+ A(f"- id: `{ex['id']}`")
140
+ A(f"- slice type: `{ex.get('slice_type')}`")
141
+ A(f"- lifecycle: `{ex.get('lifecycle_operation')}`")
142
+ A(f"- error label: `{ex['error_label']}`")
143
+ A(f"- raw field F1: `{ex.get('field_f1')}`")
144
+ A(f"- normalized field F1: `{ex.get('norm_field_f1')}`")
145
+ A(f"- normalized key F1: `{ex.get('norm_key_f1')}`")
146
+ A("")
147
+ A("### Gold")
148
+ A("```json")
149
+ A(ex["gold"])
150
+ A("```")
151
+ A("")
152
+ A("### Prediction")
153
+ A("```json")
154
+ A(ex["prediction"])
155
+ A("```")
156
+ A("")
157
+ (out_dir / "failure_examples.md").write_text("\n".join(lines), encoding="utf-8")
158
+ print(out_dir / "failure_examples.md")
159
+ print(out_dir / "failure_examples.json")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()