llmrnn-grpo-phase1-v2judge-step340-merged

GRPO-trained memory-update head for the LLM-as-RNN (RLN) system, applied to discharge-diagnosis prediction on MIMIC-IV. Base = meta-llama/Llama-3.2-3B-Instruct; LoRA r=16, α=32 trained with verl 0.7.1, then merged back into the base for direct vLLM/HF loading.

This is the memory_model slot in the four-model RLN architecture (prediction_model + memory_model + evaluation_model + final_judge_model). It rewrites the natural-language hidden state h_t from (h_{t-1}, current_visit, prediction, eval_feedback).

Run identity

field value
project / experiment llmrnn_grpo / phase1_v2judge_run01
checkpoint step global_step_340
repo git SHA at launch 3cca246 (clean)
launched 2026-05-09T16:16:07Z
host c304-001.ls6.tacc.utexas.edu (Lonestar6)
launcher training/verl_adapter/apptainer/run_phase1.sh

Training setup

Algorithm: GRPO (group-relative PPO, no critic) via verl release/v0.7.1, single-node 2×A100-40GB on TACC Lonestar6, Apptainer image verlai/verl-vllm017.

Policy backbone: meta-llama/Llama-3.2-3B-Instruct + LoRA, target modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj.

GRPO knob value
LR 3e-6
KL anchor coefficient (β) 0.001
KL estimator low_var_kl (k3)
KL location actor loss only (use_kl_loss=True, use_kl_in_reward=False)
Entropy coef 0
Grad clip 1.0
Group size G 8 rollouts/prompt
Train batch 16 prompts/step
PPO mini-batch 16 (= 1 mini-batch per step)
PPO micro-batch / GPU 4
Rollout temp / top-p 0.8 / 0.95
Rollout TP 2
Rollout GPU mem util 0.6
Max prompt / response len 3072 / 512 tokens
Total epochs 20 (save every 10, test every 5)

LoRA: r=16, α=32, dropout=0.0, bias=none. Merged back into base for this artifact.

Reward function

The reward is produced by training/verl_adapter/rubric_reward.py calling a vLLM-served scalar rubric judge over HTTP.

field value
rubric YAML training/configs/rubric_v2_rubricARM_scalar.yaml
rubric YAML SHA1 e2f89fecfbcce51b98fcaad84a4b83128cb5c64d
judge model OpenRubrics/RubricARM-8B-Judge
judge temperature 0.0 (deterministic)
judge max_tokens 1500 (CoT-style budget)
judge endpoint (during run) http://localhost:8001

v2 = outcome-aware scalar rubric: the judge reads the candidate evolving-summary together with the next visit's prediction + evaluation, so the reward measures whether the new memory actually helps downstream, not just whether it sounds good in isolation.

Caveat from earlier rubric-v1 work: the reward parser refuses to extract digits from long verbose output and prefers a strict Score: N line — see CLAUDE.md > Phase C smoke lesson. Same parser used here.

Data

split path notes
train $SCRATCH/data/llm_as_rnn/train.parquet dumped via training/verl_adapter/dump_trajectories_to_parquet.py from MIMIC-IV cleaned_df_train_100.json (100 patients, sequence of visits per patient → one parquet row per memory-update step)
val train.parquet (placeholder) val parquet was not dumped for this run; validation was disabled by setting test_freq=-1 upstream of the launcher

Each parquet row carries the chat-format prompt for the memory update plus the serialized TrajectoryStep as reward_model.ground_truth, which the rubric reward unpacks to render the judge prompt.

How to use

Plug into the RLN config (drop-in for memory_model)

# configs/experimentN_phase1v2judge.yaml
memory_model:
  backend: "vllm"
  model_name: "jinrui123/llmrnn-grpo-phase1-v2judge-step340-merged"
  temperature: 0.7
  top_p: 0.9
  max_tokens: 4096
  gpu_memory_utilization: 0.9
  tensor_parallel_size: 1
  max_model_len: 8192

Then:

python run_experiment.py --config configs/experimentN_phase1v2judge.yaml

The model is fully merged, so no PEFT/LoRA wiring is needed — VLLMInterface loads it the same way it loads the base Llama.

Direct inference (transformers)

from transformers import AutoModelForCausalLM, AutoTokenizer

mid = "jinrui123/llmrnn-grpo-phase1-v2judge-step340-merged"
tok = AutoTokenizer.from_pretrained(mid)
model = AutoModelForCausalLM.from_pretrained(mid, torch_dtype="auto", device_map="auto")

# Use the prompt_update template from configs/experiment1_seperateJudge.yaml.

Limitations / known issues

  • Trained on only 100 patients (cleaned_df_train_100.json). Sample diversity is small — expect overfit-to-summary-style risk on out-of-distribution patient timelines.
  • No held-out validation during training — val parquet was not produced before launch, so loss/reward curves come from train only. Treat reward trajectory as upper-bounded.
  • Judge dependency: reward signal comes from RubricARM-8B-Judge, which is itself a pairwise-trained model coerced into scalar mode via prompt template. Reward absolute values are not directly comparable across rubric versions.
  • English only, MIMIC-IV style discharge-summary text. No safety / clinical-deployment guarantees — research artifact.

Citation

Paper this reproduces:

LLM-as-RNN, arXiv:2601.13352.

Training stack:

verl: Volcano Engine Reinforcement Learning for LLMs, https://github.com/volcengine/verl

Reproducibility pointers

Inside the upstream repo (feat/grpo_verl branch):

  • Launcher (all GRPO knobs at top): training/verl_adapter/apptainer/run_phase1.sh
  • Reward function: training/verl_adapter/rubric_reward.py
  • Rubric config (single source of truth for judge + reward + serve): training/configs/rubric_v2_rubricARM_scalar.yaml
  • Trajectory → parquet dumper: training/verl_adapter/dump_trajectories_to_parquet.py
  • Run manifest (auto-generated by launcher): logs/MANIFEST_phase1_v2judge_run01.txt
Downloads last month
-
Safetensors
Model size
3B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for jinrui123/llmrnn-grpo-phase1-v2judge-step340-merged

Adapter
(769)
this model

Paper for jinrui123/llmrnn-grpo-phase1-v2judge-step340-merged