PEFT
Safetensors
grpo
trl
lora
vision-language-model
topological-reasoning
curvebench
AmirMohseni commited on
Commit
6c8cbf7
·
verified ·
1 Parent(s): 8523ecf

Fix LoRA rank to r=4, remove VRAM note, clean up usage sections

Browse files
Files changed (1) hide show
  1. README.md +33 -15
README.md CHANGED
@@ -28,28 +28,22 @@ It corresponds to **model-c** in the [CurveBench paper](https://arxiv.org/abs/26
28
 
29
  ## Usage
30
 
31
- This is a LoRA adapter you need the base model alongside it. The recommended way is to serve with **vLLM**, which loads the base model once and applies the adapter on-the-fly.
32
 
33
- ### 1. Serve with vLLM
34
-
35
- ```bash
36
- pip install vllm
37
- ```
38
 
39
  ```bash
40
  vllm serve google/gemma-3-12b-it \
41
  --enable-lora \
42
  --lora-modules grpo-region-tree=AmirMohseni/curvebench-gemma-3-12b \
43
- --max-lora-rank 8 \
44
  --max-model-len 32768 \
45
  --gpu-memory-utilization 0.90 \
46
  --dtype bfloat16 \
47
  --trust-remote-code
48
  ```
49
 
50
- > Requires ~32 GB VRAM (single A100-80GB).
51
-
52
- ### 2. Query the server (Python)
53
 
54
  ```python
55
  from openai import OpenAI
@@ -85,21 +79,45 @@ response = client.chat.completions.create(
85
  print(response.choices[0].message.content)
86
  ```
87
 
88
- ### 3. Merge and load with 🤗 Transformers (offline)
89
 
90
- If you prefer to merge the adapter into the base weights without a server:
91
 
92
  ```python
93
  from peft import PeftModel
94
  from transformers import AutoModelForCausalLM, AutoProcessor
 
95
 
96
  base_id = "google/gemma-3-12b-it"
97
  adapter_id = "AmirMohseni/curvebench-gemma-3-12b"
98
 
99
  processor = AutoProcessor.from_pretrained(base_id, trust_remote_code=True)
100
- model = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype="auto", device_map="auto", trust_remote_code=True)
 
 
101
  model = PeftModel.from_pretrained(model, adapter_id)
102
- model = model.merge_and_unload() # optional: fuse weights for faster inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ```
104
 
105
  ---
@@ -114,7 +132,7 @@ Trained with GRPO using a fork of TRL with multimodal support: [AmirTuring/trl @
114
  - **Base model:** [google/gemma-3-12b-it](https://huggingface.co/google/gemma-3-12b-it)
115
  - **Training split:** `total_train` (210 images) from CurveBench-Easy
116
  - **Reward:** tree isomorphism (0.7) + node count (0.3)
117
- - **Adapter rank:** 8
118
 
119
  ### Framework versions
120
 
 
28
 
29
  ## Usage
30
 
31
+ ### Option 1vLLM (recommended for serving)
32
 
33
+ Start the server with the LoRA adapter loaded on top of the base model:
 
 
 
 
34
 
35
  ```bash
36
  vllm serve google/gemma-3-12b-it \
37
  --enable-lora \
38
  --lora-modules grpo-region-tree=AmirMohseni/curvebench-gemma-3-12b \
39
+ --max-lora-rank 4 \
40
  --max-model-len 32768 \
41
  --gpu-memory-utilization 0.90 \
42
  --dtype bfloat16 \
43
  --trust-remote-code
44
  ```
45
 
46
+ Then query it with the OpenAI-compatible API:
 
 
47
 
48
  ```python
49
  from openai import OpenAI
 
79
  print(response.choices[0].message.content)
80
  ```
81
 
82
+ ### Option 2 PEFT + Transformers (offline)
83
 
84
+ Load the base model and apply the LoRA adapter directly:
85
 
86
  ```python
87
  from peft import PeftModel
88
  from transformers import AutoModelForCausalLM, AutoProcessor
89
+ import torch
90
 
91
  base_id = "google/gemma-3-12b-it"
92
  adapter_id = "AmirMohseni/curvebench-gemma-3-12b"
93
 
94
  processor = AutoProcessor.from_pretrained(base_id, trust_remote_code=True)
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ base_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
97
+ )
98
  model = PeftModel.from_pretrained(model, adapter_id)
99
+
100
+ prompt = (
101
+ "The image shows a set of pairwise non-intersecting closed curves drawn on a plane. "
102
+ "Each curve creates a boundary between an interior region and its surroundings. "
103
+ "Output the containment tree of the regions as a list of edges in the format: "
104
+ "[(parent, child), ...] where 0 is the outermost (unbounded) region."
105
+ )
106
+
107
+ from PIL import Image
108
+ image = Image.open("curves.png")
109
+
110
+ inputs = processor(
111
+ text=processor.apply_chat_template(
112
+ [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}],
113
+ add_generation_prompt=True,
114
+ ),
115
+ images=[image],
116
+ return_tensors="pt",
117
+ ).to(model.device)
118
+
119
+ output = model.generate(**inputs, max_new_tokens=512)
120
+ print(processor.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
121
  ```
122
 
123
  ---
 
132
  - **Base model:** [google/gemma-3-12b-it](https://huggingface.co/google/gemma-3-12b-it)
133
  - **Training split:** `total_train` (210 images) from CurveBench-Easy
134
  - **Reward:** tree isomorphism (0.7) + node count (0.3)
135
+ - **LoRA rank (r):** 4 | **LoRA alpha:** 8
136
 
137
  ### Framework versions
138