llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged

GRPO-trained memory-update head for the LLM-as-RNN clinical pipeline (reproducing arXiv 2601.13352). LoRA adapter has been folded into the base model with peft.merge_and_unload(), so this is a stand-alone HF model โ€” load it like any other Llama-3.2-3B variant, no PEFT runtime needed.

Role: this is not a general-purpose chat model. It is fine-tuned to rewrite a natural-language "evolving summary" (the hidden state h_t in the LLM-as-RNN setup) after each patient visit. Use it as the memory_model in the 4-model LLM-as-RNN pipeline.


1. Run identity

field value
experiment phase1_v2judge_run02
source checkpoint global_step_340 (latest of run02)
launched 2026-05-13 04:50:26 UTC, host c301-001.ls6.tacc.utexas.edu (TACC Lonestar6)
git SHA at launch aefe978 (clean)
container verlai/verl-vllm017 (verl 0.7.1, torch 2.10, vllm 0.17)
GPU layout 2ร— A100 40GB actor + rollout (cuda:0,1, TP=2), 1ร— A100 40GB judge (cuda:2)
total optimizer steps 340 (โ‰ˆ 20 epochs over the train parquet)

2. Base model + LoRA

  • Base: meta-llama/Llama-3.2-3B-Instruct
  • LoRA: r=16, ฮฑ=32, dropout=0, bias=none
  • Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj (every linear inside attention + MLP โ€” full-coverage adapter)
  • Merged: yes. lora_alpha = 32 was preserved when folding the adapter back into the base. Pure HF weights in this repo; no lora_adapter/ subdir.

3. GRPO setup

GRPO knob value
LR 3e-6 (AdamW, no LR schedule)
KL ฮฒ 0.001, applied inside the actor loss (use_kl_loss=True, use_kl_in_reward=False)
KL estimator low_var_kl (k3)
entropy coef / grad clip 0 / 1.0
group size G 8 rollouts per prompt
train_batch / ppo_mini_batch / micro_batch_per_gpu 16 / 16 / 4 โ†’ 1 mini-batch per step
rollout temperature / top-p 0.8 / 0.95
max prompt / response length 3072 / 512 tokens
epochs / save_freq / test_freq 20 / 20 / 5

Algorithm: standard verl 0.7.1 GRPO โ€” group-relative advantage (no critic, the 8 candidates per prompt share a common baseline), PPO-clipped surrogate, KL anchor to the frozen base via model.disable_adapter().

Training was run with verl's stock trainer (verl.trainer.main_ppo, algorithm.adv_estimator=grpo); only the reward function is custom (see ยง4).


4. Reward โ€” rubric v2 (outcome-aware scalar judge)

The reward is a scalar judge call made on every rollout, returning Score: N with N โˆˆ {1..9}, mapped to reward N/9 โˆˆ (0, 1].

  • Judge model: OpenRubrics/RubricARM-8B-Judge (forced into scalar mode via a custom prompt template โ€” RubricARM is natively pairwise, but our prompt asks for a single 1-9 digit and the parser keys off the Score: N line)
  • Judge serving: vLLM on localhost:8001, T=0, max_tokens=1500, max_model_len=12000, gpu_memory_utilization=0.9 on cuda:2.
  • Source-of-truth YAML: training/configs/rubric_v2_rubricARM_scalar.yaml (SHA1 e2f89fecโ€ฆ) โ€” single file controls rubric content + judge model + serving config so the reward function and start_judge_server.sh cannot drift.

Why "outcome-aware": unlike a blind judge, this rubric shows the judge the next visit's prediction + ground truth + evaluation alongside the candidate memory. The judge then grades the memory by whether it actually moved the next-visit primary diagnosis ranking up or down โ€” not by abstract "summary quality". So reward correlates with downstream accuracy, not just stylistic polish.

Edge handling (training/verl_adapter/rubric_reward.py):

condition reward
policy output is not valid JSON 0.05 (gradient still flows; no NaN)
judge response has Score: N line N/9
judge response is short (<20 chars) with a digit last digit /9
judge response is long and has no Score: N line 0.5 (refuse to parse; avoids the rubric-v1 bug where the parser grabbed dimension numbers out of CoT text)

Audit trail: every reward call is appended to $SCRATCH/llm_as_rnn/rubric_audit/calls_<pid>.jsonl โ€” one file per worker PID, full rollout text + judge prompt + judge response + reward. Used to spot judge format-compliance bugs early.


5. Data

  • Source: data/splits/cleaned_df_train_100.json (100 MIMIC-IV patients, randomly sampled with seed 42 from the cleaned 6488-patient pool).
  • Dataset is the training split, not the test split. Generalization to the 6288-patient held-out test set is not measured here.
  • Cleaning: any patient with at least one visit that has an empty discharge_diagnosis was dropped before splitting.

From patients to training samples (Phase B)

training/verl_adapter/dump_trajectories_to_parquet.py runs the full 4-model LLM-as-RNN forward pass over each patient and emits one parquet row per (patient, visit_index โˆˆ [0, n-3]):

  • A patient needs โ‰ฅ 3 visits to produce any training row, because the rubric needs next_prediction + next_evaluation at visit i+1.
  • Each row = one memory-update training instance: given h_{t-1}, x_t, y_hat_t, e_t, rewrite the memory into h_t.

Parquet row schema (matches verl's GSM8K example):

{
  "prompt": [{"role": "user", "content": "<rendered prompt_update template>"}],
  "data_source": "llm_as_rnn_memory_update",          // reward router uses this
  "ability": "clinical_summary",
  "reward_model": {
    "style": "rubric",
    "ground_truth": "<JSON-serialized TrajectoryStep>" // contains x_{t+1}, y_hat_{t+1}, e_{t+1}
  },
  "extra_info": {"split", "patient_id", "visit_index", "visit_id"}
}

The reward_model.ground_truth field deserializes back into a full TrajectoryStep (see training/types.py) at reward time โ€” that's how the judge sees the next-visit outcome.


6. Workflow (end to end)

data/splits/cleaned_df_train_100.json
            โ”‚
[Phase B] dump_trajectories_to_parquet.py
            โ””โ”€โ”€ per (patient, visit_index โˆˆ [0..n-3]): emit one TrajectoryStep
            โ”‚
            โ–ผ
$SCRATCH/data/llm_as_rnn/train.parquet
            โ”‚
   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
   โ”‚                  โ”‚
[judge server]   [verl trainer]
RubricARM-8B     verl.trainer.main_ppo
on cuda:2        actor FSDP cuda:0,1 TP=2
   โ”‚                  โ”‚
   โ””โ”€ HTTP /v1/completions โ”€โ”˜
            โ”‚
            โ–ผ
GRPO inner loop per optimizer step:
  (a) sample 16 prompts from parquet
  (b) actor rollout (vLLM) โ†’ G=8 candidates per prompt
  (c) for each candidate:
        rubric_reward.compute_score()
          โ”œโ”€โ”€ parse evolving_summary from policy output
          โ”œโ”€โ”€ render rubric_v2 template
          โ”‚     with {memory_prompt, response, next_prediction,
          โ”‚           next_ground_truth, next_evaluation}
          โ””โ”€โ”€ HTTP call to judge โ†’ parse "Score: N" โ†’ N/9
  (d) GRPO group-relative advantage (G=8 share a baseline)
        loss = -E[clip(ฯ€/ฯ€_old) ยท adv] + ฮฒ ยท KL(ฯ€ โ€– ฯ€_ref)
  (e) PPO mini-batch (16/16, micro=4/GPU) โ†’ optimizer step
        โ†’ FSDP sync โ†’ vLLM rollout worker pulls latest LoRA
            โ”‚
            โ–ผ (every save_freq=20 steps)
checkpoints/llmrnn_grpo/phase1_v2judge_run02/global_step_*/
            โ”‚
            โ–ผ (this artifact: step 340)
[merge_and_upload.sh]
  (1) verl.model_merger merge   โ†’ gathers FSDP shards into a flat HF base dir
  (2) PEFT merge_and_unload()   โ†’ base + lora_adapter(ฮฑ=32) folded together
  (3) huggingface_hub upload    โ†’ THIS REPO

7. Reward distribution (sanity)

Early in training (first ~30 reward calls inspected from audit logs), rewards spanned {0.22, 0.33, 0.44, 0.55, 0.67, 0.78, 0.89, 1.0} โ€” i.e. the judge produced real variance across Score: 2 โ€ฆ 9. Reward did not collapse to a constant value, which was the failure mode we hit during the rubric-v1 โ†’ v2 migration.

A representative judge call (audited in $SCRATCH/llm_as_rnn/rubric_audit/calls_<pid>.jsonl) for patient 14573633, visit_index 0:

  • candidate h_t = JSON with three sections (clinical history / current focus / future considerations)
  • next visit ground truth = "post-bronchoscopy pneumothorax, hypertensive emergency"
  • next prediction (run with this candidate memory) primary = "Pulmonary Embolism" โ†’ wrong
  • judge reasoned that the memory preserved a useful anemia rule but added distracting cardiac focus that hurt the next primary โ†’ Score: 6 โ†’ reward 0.667.

8. Intended use

Drop this model into the LLM-as-RNN inference pipeline as the memory_model:

# In configs/experiment*.yaml under memory_model:
memory_model:
  backend: vllm
  model_name: jinrui123/llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged
  temperature: 0.0
  max_tokens: 1024

It expects the exact prompt_update template defined in configs/experiment1_seperateJudge.yaml (the one in the LLM-as-RNN repo). Using a different prompt structure (e.g. open-ended chat) is unsupported and likely produces malformed JSON output.


9. Limitations and caveats (read before using)

  • No held-out validation during training. val_parquet was not dumped before launch, so the launcher silently fell back to val = train and disabled validation. Reward curves are train-side only.
  • Small training set. Only 100 patients (cleaned_df_train_100). Behavior on the 6288-patient test split is untested.
  • Judge is coerced into scalar mode. RubricARM-8B-Judge is natively a pairwise model (Response A vs B). We force a scalar response via prompt engineering and a Score: N parser. Absolute reward values are therefore not directly comparable across rubric versions (v1 vs v2, etc.).
  • English MIMIC-IV discharge-summary text only. Research artifact, no clinical-deployment guarantees. Do not use for real patient care.
  • verl's model_merger merge does not actually merge LoRA. It pops the adapter into a side-by-side lora_adapter/ subdir and rewrites lora_alpha=0. We work around this with a real PEFT merge_and_unload() step using the original adapter (alpha=32) โ€” see training/verl_adapter/apptainer/merge_and_upload.sh. This artifact is the post-PEFT-merge version, so the LoRA is applied.
  • No safety alignment, no PHI scrubbing. MIMIC-IV is already de-identified upstream, but no additional safety filtering was applied during GRPO.

10. Reproducibility

Source code for re-deriving everything in this card:

component path (in jinrui/LLMasRNN at SHA aefe978)
launcher training/verl_adapter/apptainer/run_phase1.sh
rubric (judge prompt + model + serving) training/configs/rubric_v2_rubricARM_scalar.yaml
reward function training/verl_adapter/rubric_reward.py
parquet dumper training/verl_adapter/dump_trajectories_to_parquet.py
trajectory dataclass training/types.py
RLN forward pass models/rln_core.py
inference YAML (consumes this model) configs/experiment1_seperateJudge.yaml
merge + upload training/verl_adapter/apptainer/merge_and_upload.sh

Auto-generated companion files:

  • logs/MANIFEST_phase1_v2judge_run02.txt โ€” machine-readable manifest (git SHA, judge model, all GRPO knobs)
  • logs/TRAINING_REPORT_phase1_v2judge_run01.md โ€” human-facing report for the predecessor run; setup is identical to run02.
Downloads last month
-
Safetensors
Model size
3B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for jinrui123/llmrnn-grpo-phase1-v2judge-run02-global-step-340-merged

Finetuned
(1674)
this model