llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged
GRPO-trained memory-update head for the LLM-as-RNN clinical pipeline
(reproducing arXiv 2601.13352). LoRA adapter has been folded into the base
model with peft.merge_and_unload(), so this is a stand-alone HF model โ
load it like any other Llama-3.2-3B variant, no PEFT runtime needed.
Role: this is not a general-purpose chat model. It is fine-tuned to rewrite a natural-language "evolving summary" (the hidden state
h_tin the LLM-as-RNN setup) after each patient visit. Use it as thememory_modelin the 4-model LLM-as-RNN pipeline.
1. Run identity
| field | value |
|---|---|
| experiment | phase1_v2judge_run02 |
| source checkpoint | global_step_340 (latest of run02) |
| launched | 2026-05-13 04:50:26 UTC, host c301-001.ls6.tacc.utexas.edu (TACC Lonestar6) |
| git SHA at launch | aefe978 (clean) |
| container | verlai/verl-vllm017 (verl 0.7.1, torch 2.10, vllm 0.17) |
| GPU layout | 2ร A100 40GB actor + rollout (cuda:0,1, TP=2), 1ร A100 40GB judge (cuda:2) |
| total optimizer steps | 340 (โ 20 epochs over the train parquet) |
2. Base model + LoRA
- Base:
meta-llama/Llama-3.2-3B-Instruct - LoRA: r=16, ฮฑ=32, dropout=0, bias=none
- Target modules:
q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj(every linear inside attention + MLP โ full-coverage adapter) - Merged: yes.
lora_alpha = 32was preserved when folding the adapter back into the base. Pure HF weights in this repo; nolora_adapter/subdir.
3. GRPO setup
| GRPO knob | value |
|---|---|
| LR | 3e-6 (AdamW, no LR schedule) |
| KL ฮฒ | 0.001, applied inside the actor loss (use_kl_loss=True, use_kl_in_reward=False) |
| KL estimator | low_var_kl (k3) |
| entropy coef / grad clip | 0 / 1.0 |
| group size G | 8 rollouts per prompt |
| train_batch / ppo_mini_batch / micro_batch_per_gpu | 16 / 16 / 4 โ 1 mini-batch per step |
| rollout temperature / top-p | 0.8 / 0.95 |
| max prompt / response length | 3072 / 512 tokens |
| epochs / save_freq / test_freq | 20 / 20 / 5 |
Algorithm: standard verl 0.7.1 GRPO โ group-relative advantage (no critic,
the 8 candidates per prompt share a common baseline), PPO-clipped surrogate,
KL anchor to the frozen base via model.disable_adapter().
Training was run with verl's stock trainer
(verl.trainer.main_ppo, algorithm.adv_estimator=grpo); only the reward
function is custom (see ยง4).
4. Reward โ rubric v2 (outcome-aware scalar judge)
The reward is a scalar judge call made on every rollout, returning
Score: N with N โ {1..9}, mapped to reward N/9 โ (0, 1].
- Judge model:
OpenRubrics/RubricARM-8B-Judge(forced into scalar mode via a custom prompt template โ RubricARM is natively pairwise, but our prompt asks for a single 1-9 digit and the parser keys off theScore: Nline) - Judge serving: vLLM on
localhost:8001, T=0,max_tokens=1500,max_model_len=12000,gpu_memory_utilization=0.9on cuda:2. - Source-of-truth YAML:
training/configs/rubric_v2_rubricARM_scalar.yaml(SHA1e2f89fecโฆ) โ single file controls rubric content + judge model + serving config so the reward function andstart_judge_server.shcannot drift.
Why "outcome-aware": unlike a blind judge, this rubric shows the judge the next visit's prediction + ground truth + evaluation alongside the candidate memory. The judge then grades the memory by whether it actually moved the next-visit primary diagnosis ranking up or down โ not by abstract "summary quality". So reward correlates with downstream accuracy, not just stylistic polish.
Edge handling (training/verl_adapter/rubric_reward.py):
| condition | reward |
|---|---|
| policy output is not valid JSON | 0.05 (gradient still flows; no NaN) |
judge response has Score: N line |
N/9 |
| judge response is short (<20 chars) with a digit | last digit /9 |
judge response is long and has no Score: N line |
0.5 (refuse to parse; avoids the rubric-v1 bug where the parser grabbed dimension numbers out of CoT text) |
Audit trail: every reward call is appended to
$SCRATCH/llm_as_rnn/rubric_audit/calls_<pid>.jsonl โ one file per worker
PID, full rollout text + judge prompt + judge response + reward. Used to
spot judge format-compliance bugs early.
5. Data
- Source:
data/splits/cleaned_df_train_100.json(100 MIMIC-IV patients, randomly sampled with seed 42 from the cleaned 6488-patient pool). - Dataset is the training split, not the test split. Generalization to the 6288-patient held-out test set is not measured here.
- Cleaning: any patient with at least one visit that has an empty
discharge_diagnosiswas dropped before splitting.
From patients to training samples (Phase B)
training/verl_adapter/dump_trajectories_to_parquet.py runs the full
4-model LLM-as-RNN forward pass over each patient and emits one parquet
row per (patient, visit_index โ [0, n-3]):
- A patient needs โฅ 3 visits to produce any training row, because the
rubric needs
next_prediction+next_evaluationat visiti+1. - Each row = one memory-update training instance: given
h_{t-1},x_t,y_hat_t,e_t, rewrite the memory intoh_t.
Parquet row schema (matches verl's GSM8K example):
{
"prompt": [{"role": "user", "content": "<rendered prompt_update template>"}],
"data_source": "llm_as_rnn_memory_update", // reward router uses this
"ability": "clinical_summary",
"reward_model": {
"style": "rubric",
"ground_truth": "<JSON-serialized TrajectoryStep>" // contains x_{t+1}, y_hat_{t+1}, e_{t+1}
},
"extra_info": {"split", "patient_id", "visit_index", "visit_id"}
}
The reward_model.ground_truth field deserializes back into a full
TrajectoryStep (see training/types.py) at reward time โ that's how the
judge sees the next-visit outcome.
6. Workflow (end to end)
data/splits/cleaned_df_train_100.json
โ
[Phase B] dump_trajectories_to_parquet.py
โโโ per (patient, visit_index โ [0..n-3]): emit one TrajectoryStep
โ
โผ
$SCRATCH/data/llm_as_rnn/train.parquet
โ
โโโโโโโโโโดโโโโโโโโโโ
โ โ
[judge server] [verl trainer]
RubricARM-8B verl.trainer.main_ppo
on cuda:2 actor FSDP cuda:0,1 TP=2
โ โ
โโ HTTP /v1/completions โโ
โ
โผ
GRPO inner loop per optimizer step:
(a) sample 16 prompts from parquet
(b) actor rollout (vLLM) โ G=8 candidates per prompt
(c) for each candidate:
rubric_reward.compute_score()
โโโ parse evolving_summary from policy output
โโโ render rubric_v2 template
โ with {memory_prompt, response, next_prediction,
โ next_ground_truth, next_evaluation}
โโโ HTTP call to judge โ parse "Score: N" โ N/9
(d) GRPO group-relative advantage (G=8 share a baseline)
loss = -E[clip(ฯ/ฯ_old) ยท adv] + ฮฒ ยท KL(ฯ โ ฯ_ref)
(e) PPO mini-batch (16/16, micro=4/GPU) โ optimizer step
โ FSDP sync โ vLLM rollout worker pulls latest LoRA
โ
โผ (every save_freq=20 steps)
checkpoints/llmrnn_grpo/phase1_v2judge_run02/global_step_*/
โ
โผ (this artifact: step 340)
[merge_and_upload.sh]
(1) verl.model_merger merge โ gathers FSDP shards into a flat HF base dir
(2) PEFT merge_and_unload() โ base + lora_adapter(ฮฑ=32) folded together
(3) huggingface_hub upload โ THIS REPO
7. Reward distribution (sanity)
Early in training (first ~30 reward calls inspected from audit logs),
rewards spanned {0.22, 0.33, 0.44, 0.55, 0.67, 0.78, 0.89, 1.0} โ i.e.
the judge produced real variance across Score: 2 โฆ 9. Reward did not
collapse to a constant value, which was the failure mode we hit during the
rubric-v1 โ v2 migration.
A representative judge call (audited in
$SCRATCH/llm_as_rnn/rubric_audit/calls_<pid>.jsonl) for patient
14573633, visit_index 0:
- candidate
h_t= JSON with three sections (clinical history / current focus / future considerations) - next visit ground truth = "post-bronchoscopy pneumothorax, hypertensive emergency"
- next prediction (run with this candidate memory) primary = "Pulmonary Embolism" โ wrong
- judge reasoned that the memory preserved a useful anemia rule but added
distracting cardiac focus that hurt the next primary โ
Score: 6โ reward0.667.
8. Intended use
Drop this model into the LLM-as-RNN inference pipeline as the memory_model:
# In configs/experiment*.yaml under memory_model:
memory_model:
backend: vllm
model_name: jinrui123/llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged
temperature: 0.0
max_tokens: 1024
It expects the exact prompt_update template defined in
configs/experiment1_seperateJudge.yaml (the one in the LLM-as-RNN repo).
Using a different prompt structure (e.g. open-ended chat) is unsupported
and likely produces malformed JSON output.
9. Limitations and caveats (read before using)
- No held-out validation during training.
val_parquetwas not dumped before launch, so the launcher silently fell back toval = trainand disabled validation. Reward curves are train-side only. - Small training set. Only 100 patients (
cleaned_df_train_100). Behavior on the 6288-patient test split is untested. - Judge is coerced into scalar mode.
RubricARM-8B-Judgeis natively a pairwise model (Response A vs B). We force a scalar response via prompt engineering and aScore: Nparser. Absolute reward values are therefore not directly comparable across rubric versions (v1 vs v2, etc.). - English MIMIC-IV discharge-summary text only. Research artifact, no clinical-deployment guarantees. Do not use for real patient care.
- verl's
model_merger mergedoes not actually merge LoRA. It pops the adapter into a side-by-sidelora_adapter/subdir and rewriteslora_alpha=0. We work around this with a real PEFTmerge_and_unload()step using the original adapter (alpha=32) โ seetraining/verl_adapter/apptainer/merge_and_upload.sh. This artifact is the post-PEFT-merge version, so the LoRA is applied. - No safety alignment, no PHI scrubbing. MIMIC-IV is already de-identified upstream, but no additional safety filtering was applied during GRPO.
10. Reproducibility
Source code for re-deriving everything in this card:
| component | path (in jinrui/LLMasRNN at SHA aefe978) |
|---|---|
| launcher | training/verl_adapter/apptainer/run_phase1.sh |
| rubric (judge prompt + model + serving) | training/configs/rubric_v2_rubricARM_scalar.yaml |
| reward function | training/verl_adapter/rubric_reward.py |
| parquet dumper | training/verl_adapter/dump_trajectories_to_parquet.py |
| trajectory dataclass | training/types.py |
| RLN forward pass | models/rln_core.py |
| inference YAML (consumes this model) | configs/experiment1_seperateJudge.yaml |
| merge + upload | training/verl_adapter/apptainer/merge_and_upload.sh |
Auto-generated companion files:
logs/MANIFEST_phase1_v2judge_run02.txtโ machine-readable manifest (git SHA, judge model, all GRPO knobs)logs/TRAINING_REPORT_phase1_v2judge_run01.mdโ human-facing report for the predecessor run; setup is identical to run02.
- Downloads last month
- -
Model tree for jinrui123/llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged
Base model
meta-llama/Llama-3.2-3B-Instruct