File size: 5,634 Bytes
c175c01
b7378bf
 
 
 
 
 
 
 
 
 
 
 
c175c01
 
 
b7378bf
c175c01
b7378bf
 
c175c01
b7378bf
c175c01
 
 
b7378bf
 
 
 
 
 
c175c01
b7378bf
c175c01
b7378bf
c175c01
 
b7378bf
 
 
c175c01
 
b7378bf
 
 
c175c01
b7378bf
c175c01
 
 
b7378bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c175c01
b7378bf
c175c01
b7378bf
 
 
 
3d66ad7
b7378bf
35c2ef4
 
 
 
d5d15d7
 
 
b7378bf
c175c01
b7378bf
c175c01
b7378bf
 
1d2df31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c175c01
b7378bf
 
 
 
 
 
c175c01
b7378bf
 
 
 
c175c01
b7378bf
 
c175c01
b7378bf
 
 
 
 
 
 
 
c175c01
35c2ef4
 
 
 
b7378bf
 
c175c01
b7378bf
c175c01
b7378bf
 
 
c175c01
b7378bf
c175c01
b7378bf
c175c01
b7378bf
 
 
 
 
 
 
 
 
c175c01
b7378bf
c175c01
b7378bf
c175c01
b7378bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
---
language:
  - th
license: apache-2.0
library_name: transformers
tags:
  - llm
  - thai
  - mathematics
  - reasoning
  - lora
  - grpo
pipeline_tag: text-generation
base_model: google/gemma-3-4b-it
---

# Gemma-3-4B-IT GRPO Thai

This model is **Gemma-3-4B-IT** fine-tuned with **LoRA adapters** using **GRPO (Gradient Reward Policy Optimization)** on the **GSM8K-Thai** dataset.  
The model is trained to **solve math word problems in Thai** step-by-step, producing structured reasoning in `<think>…</think>` followed by the final answer in `<answer>…</answer>`.

---

## Model Details

- **Base model:** [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)  
- **Technique:** LoRA fine-tuning + GRPO reinforcement learning  
- **Languages:** Thai (primary)  
- **Task:** Math reasoning, step-by-step explanation, final numeric answer  
- **License:** Apache-2.0  
- **Author:** Thanayot (SuperAI Engineer SS5, KMUTT)  

---

## Intended Uses

### Direct Use
- Educational use: tutoring in math reasoning in Thai  
- Research on RLHF/GRPO methods for LLMs  
- Experimentation with structured reasoning outputs (`<think>…</think><answer>…</answer>`)

### Out-of-Scope Use
- High-stakes decision making (finance, medical, legal)  
- Problems requiring formal proofs or very advanced mathematics  
- Any malicious or harmful generation in Thai or other languages  

---

## Training Details

### Dataset
- **[VISAI-AI/gsm8k-thai](https://huggingface.co/datasets/VISAI-AI/gsm8k-thai)**  
  Thai translations of the GSM8K math word problems

### Procedure
- Reward shaping:  
  - **Format reward:** enforces `<think>…</think><answer>…</answer>`  
  - **Accuracy reward:** compares predicted numeric answer to ground truth via [`math_verify`](https://pypi.org/project/math-verify/)  

### Hyperparameters
- **LoRA rank:** 16  
- **LoRA alpha:** 32  
- **LoRA dropout:** 0.05  
- **Target modules:** q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj  
- **Learning rate:** 5e-5  
- **Batch size:** 1 (with gradient_accumulation_steps=8)  
- **Num generations per prompt:** 4  
- **Beta (KL penalty):** 0.01  
- **Precision:** bfloat16  
- **Max prompt length:** 256  
- **Max completion length:** 160  

---

## Evaluation Results

Below are the reward values observed during training:

| Step | Policy Loss (proxy from reward) |
|------|--------|
| 100   | 0.0030 |
| 200   | 0.0040 |
| 280   | 0.0042 |

- ค่า Reward มีแนวโน้มเพิ่มขึ้นอย่างต่อเนื่องในช่วงแรกของการเทรน (Step 100 → 200 → 280)
- ค่าที่ได้ (≈0.0030 → 0.0040 → 0.0042) แสดงถึงการปรับตัวของโมเดลให้สอดคล้องกับ reward function
- แนวโน้มบ่งชี้ว่าโมเดลกำลังเข้าใกล้ ภาวะเสถียร (convergence) แต่ยังไม่ถึง plateau; หากเทรนต่อไป คาดว่าค่า Reward จะคงที่ในระดับสูงขึ้น (≈0.0048–0.0050)
---

## How to Use

```python
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
from peft import PeftModel
model_id = "google/gemma-3-4b-it"

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token  # กัน edge case ตอน generate
tok.padding_side = "left"

base_model = Gemma3ForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,    # หรือ float16 ตาม GPU
)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "zoeythanayot/gemma3-it-grpo-thai")

# สร้าง prompt ตัวอย่าง
SYSTEM_PROMPT = (
    "คุณเป็นผู้ช่วยแก้ปัญหาคณิตศาสตร์เชิงเหตุผล ทีละขั้นเป็นภาษาไทย "
    "และใช้ <think></think><answer></answer> เพื่อบ่งบอกกระบวนการคิดและคำตอบสุดท้าย"
)
USER_PROMPT = "โจทย์: ถ้ามีลูกอม 15 เม็ด แบ่งให้เพื่อน 3 คนเท่า ๆ กัน แต่ละคนจะได้กี่เม็ด?"

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": USER_PROMPT},
]

# ใช้ chat template ของ tokenizer (ถ้ารองรับ)
inputs = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)

# generate คำตอบ
with torch.inference_mode():
    output_ids = model.generate(
        inputs,
        max_new_tokens=200,
        temperature=0.7,
        top_p=0.9
    )

input_length = inputs.shape[1]
new_tokens = output_ids[0, input_length:]
resp = tok.decode(new_tokens, skip_special_tokens=True)
print(resp.strip())
```
---

## Bias, Risks, and Limitations

- May produce plausible but incorrect answers
- Trained only on translated Thai data, so bias/errors from translation remain
- Limited to short reasoning problems (GSM8K style)

---

## Citation

```bibtex
@misc{thanayot2025gemmathai,
  title = {Gemma-3-4B-IT GRPO Thai: LoRA Fine-Tuned Math Reasoning Model},
  author = {Thanayot},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {Model on Hugging Face Hub},
}
```

---

## Contact

For questions or collaboration: **Thanayot @ KMUTT** (SuperAI Engineer SS5)