PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss commited on
Commit
f198a32
Β·
verified Β·
1 Parent(s): 0e7b293

Upload scripts/run_all_baselines.sh

Browse files
Files changed (1) hide show
  1. scripts/run_all_baselines.sh +160 -0
scripts/run_all_baselines.sh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # Run all baseline evaluations for publication comparison.
5
+ # Run this on your RTX 6000 Ada server.
6
+ #
7
+ # Prerequisites:
8
+ # - HF_TOKEN set
9
+ # - OPENAI_API_KEY set (for GPT-4o-mini)
10
+ # - .venv activated with dependencies from requirements.txt
11
+ #
12
+ # Usage:
13
+ # bash scripts/run_all_baselines.sh
14
+
15
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
16
+ PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
17
+ cd "$PROJECT_ROOT"
18
+
19
+ source .venv/bin/activate
20
+ export PYTHONPATH="$PROJECT_ROOT/src:${PYTHONPATH:-}"
21
+ export TOKENIZERS_PARALLELISM=false
22
+
23
+ # ─── Configuration ──────────────────────────────────────────────────────────
24
+ MAX_SAMPLES=200 # Set to null (empty) for full evaluation
25
+ BATCH_SIZE_LOCAL=4
26
+ BATCH_SIZE_API=1
27
+ SAVE_EVERY=25
28
+
29
+ # ─── 1. Llama-3.1-8B-Instruct (local, zero-shot) ────────────────────────────
30
+ echo "========================================"
31
+ echo "1. Llama-3.1-8B-Instruct zero-shot"
32
+ echo "========================================"
33
+
34
+ python scripts/baseline_eval.py \
35
+ --model meta-llama/Llama-3.1-8B-Instruct \
36
+ --output_dir outputs/baselines/llama-3.1-8b-instruct \
37
+ --batch_size "$BATCH_SIZE_LOCAL" \
38
+ --max_samples_per_split ${MAX_SAMPLES:-} \
39
+ --save_every "$SAVE_EVERY"
40
+
41
+ python scripts/normalize_eval_metrics.py \
42
+ --eval_dir outputs/baselines/llama-3.1-8b-instruct
43
+
44
+ # ─── 2. Qwen2.5-7B-Instruct (local, zero-shot) ──────────────────────────────
45
+ echo "========================================"
46
+ echo "2. Qwen2.5-7B-Instruct zero-shot"
47
+ echo "========================================"
48
+
49
+ python scripts/baseline_eval.py \
50
+ --model Qwen/Qwen2.5-7B-Instruct \
51
+ --output_dir outputs/baselines/qwen2.5-7b-instruct \
52
+ --batch_size "$BATCH_SIZE_LOCAL" \
53
+ --max_samples_per_split ${MAX_SAMPLES:-} \
54
+ --save_every "$SAVE_EVERY"
55
+
56
+ python scripts/normalize_eval_metrics.py \
57
+ --eval_dir outputs/baselines/qwen2.5-7b-instruct
58
+
59
+ # ─── 3. GPT-4o-mini (API, zero-shot) ────────────────────────────────────────
60
+ if [ -n "${OPENAI_API_KEY:-}" ]; then
61
+ echo "========================================"
62
+ echo "3. GPT-4o-mini zero-shot (API)"
63
+ echo "========================================"
64
+
65
+ python scripts/baseline_eval.py \
66
+ --model gpt-4o-mini \
67
+ --api_provider openai \
68
+ --output_dir outputs/baselines/gpt-4o-mini \
69
+ --batch_size "$BATCH_SIZE_API" \
70
+ --max_samples_per_split ${MAX_SAMPLES:-} \
71
+ --save_every "$SAVE_EVERY"
72
+
73
+ python scripts/normalize_eval_metrics.py \
74
+ --eval_dir outputs/baselines/gpt-4o-mini
75
+ else
76
+ echo "Skipping GPT-4o-mini (OPENAI_API_KEY not set)"
77
+ fi
78
+
79
+ # ─── 4. Package comparison results ──────────────────────────────────────────
80
+ echo "========================================"
81
+ echo "4. Packaging comparison results"
82
+ echo "========================================"
83
+
84
+ python - <<'PYEOF'
85
+ import json
86
+ from pathlib import Path
87
+
88
+ results = {}
89
+
90
+ for name, eval_dir in [
91
+ ("llama-3.1-8b-instruct", "outputs/baselines/llama-3.1-8b-instruct"),
92
+ ("qwen2.5-7b-instruct", "outputs/baselines/qwen2.5-7b-instruct"),
93
+ ("gpt-4o-mini", "outputs/baselines/gpt-4o-mini"),
94
+ ]:
95
+ path = Path(eval_dir) / "all_normalized_metrics.json"
96
+ if path.exists():
97
+ results[name] = json.loads(path.read_text())
98
+
99
+ # Stage 1 results for comparison
100
+ stage1 = {
101
+ "test_in_distribution": {"parse_json": 1.0000, "norm_field_f1": 0.7956, "norm_key_f1": 0.9811},
102
+ "test_template_ood": {"parse_json": 1.0000, "norm_field_f1": 0.7865, "norm_key_f1": 0.9801},
103
+ "test_use_case_ood": {"parse_json": 0.9998, "norm_field_f1": 0.7907, "norm_key_f1": 0.9805},
104
+ "test_sector_ood": {"parse_json": 1.0000, "norm_field_f1": 0.7697, "norm_key_f1": 0.9818},
105
+ "test_adversarial": {"parse_json": 1.0000, "norm_field_f1": 0.9697, "norm_key_f1": 1.0000},
106
+ }
107
+
108
+ # Zero-shot Qwen3-8B from journal
109
+ qwen3_zero = {
110
+ "test_in_distribution": {"parse_json": 0.335, "norm_field_f1": 0.0009, "norm_key_f1": 0.0169},
111
+ "test_template_ood": {"parse_json": 0.340, "norm_field_f1": 0.0014, "norm_key_f1": 0.0172},
112
+ "test_use_case_ood": {"parse_json": 0.325, "norm_field_f1": 0.0012, "norm_key_f1": 0.0198},
113
+ "test_sector_ood": {"parse_json": 0.345, "norm_field_f1": 0.0008, "norm_key_f1": 0.0171},
114
+ "test_adversarial": {"parse_json": 0.000, "norm_field_f1": 0.0000, "norm_key_f1": 0.0000},
115
+ }
116
+
117
+ # Print comparison table
118
+ print("\n" + "=" * 100)
119
+ print("BASELINE COMPARISON: All Models vs Qwen3-8B QLoRA Stage 1")
120
+ print("=" * 100)
121
+
122
+ splits = ["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"]
123
+
124
+ for split in splits:
125
+ print(f"\n--- {split} ---")
126
+ print(f"{'Model':<30s} {'Parse':>8s} {'Norm Field F1':>14s} {'Norm Key F1':>12s} {'vs Stage1 Ξ”':>12s}")
127
+ print("-" * 80)
128
+
129
+ for model_name, model_results in results.items():
130
+ metrics = model_results.get(split, {})
131
+ parse_val = metrics.get("parse_json", 0)
132
+ field_val = metrics.get("norm_field_f1", 0)
133
+ key_val = metrics.get("norm_key_f1", 0)
134
+ delta = field_val - stage1[split]["norm_field_f1"]
135
+ print(f"{model_name:<30s} {parse_val:8.4f} {field_val:14.4f} {key_val:12.4f} {delta:+12.4f}")
136
+
137
+ # Zero-shot Qwen3-8B
138
+ z = qwen3_zero[split]
139
+ delta_z = z["norm_field_f1"] - stage1[split]["norm_field_f1"]
140
+ print(f"{'Qwen3-8B zero-shot':<30s} {z['parse_json']:8.4f} {z['norm_field_f1']:14.4f} {z['norm_key_f1']:12.4f} {delta_z:+12.4f}")
141
+
142
+ # Stage 1
143
+ s = stage1[split]
144
+ print(f"{'Qwen3-8B-QLoRA (stage1)':<30s} {s['parse_json']:8.4f} {s['norm_field_f1']:14.4f} {s['norm_key_f1']:12.4f} {'(baseline)':>12s}")
145
+
146
+ # Save combined results
147
+ out_path = Path("outputs/baselines/comparison_results.json")
148
+ out_path.parent.mkdir(parents=True, exist_ok=True)
149
+ out_path.write_text(json.dumps({
150
+ "baselines": results,
151
+ "stage1": stage1,
152
+ "qwen3_zero_shot": qwen3_zero,
153
+ }, indent=2))
154
+
155
+ print(f"\nCombined results saved to {out_path}")
156
+ PYEOF
157
+
158
+ echo "========================================"
159
+ echo "All baseline evaluations complete!"
160
+ echo "========================================"