Spaces:
Running on Zero
Running on Zero
Commit ·
9341111
1
Parent(s): 8ccf67b
fix
Browse files
notebook/gemma-finetune.ipynb
CHANGED
|
@@ -238,6 +238,103 @@
|
|
| 238 |
"execution_count": null,
|
| 239 |
"outputs": []
|
| 240 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
{
|
| 242 |
"cell_type": "code",
|
| 243 |
"metadata": {
|
|
@@ -364,8 +461,7 @@
|
|
| 364 |
"# Option B — wrap manually, omit peft_config from SFTTrainer:\n",
|
| 365 |
"# tuned_model = get_peft_model(tuned_model, lora_config)\n",
|
| 366 |
"# trainer = SFTTrainer(model=tuned_model, ...) # no peft_config\n",
|
| 367 |
-
"\n"
|
| 368 |
-
""
|
| 369 |
],
|
| 370 |
"execution_count": null,
|
| 371 |
"outputs": []
|
|
|
|
| 238 |
"execution_count": null,
|
| 239 |
"outputs": []
|
| 240 |
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"metadata": {},
|
| 244 |
+
"source": [
|
| 245 |
+
"\n",
|
| 246 |
+
"# Prepare model for k-bit training\n",
|
| 247 |
+
"tuned_model = prepare_model_for_kbit_training(tuned_model)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"# --- 2. Configure LoRA ---\n",
|
| 250 |
+
"lora_config = LoraConfig(\n",
|
| 251 |
+
" r=16, # LoRA attention dimension\n",
|
| 252 |
+
" lora_alpha=16, # Alpha parameter for LoRA scaling\n",
|
| 253 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], # Target all linear layers\n",
|
| 254 |
+
" lora_dropout=0.05, # Dropout probability for LoRA layers\n",
|
| 255 |
+
" bias=\"none\", # Only add bias to the LoRA layers\n",
|
| 256 |
+
" task_type=\"CAUSAL_LM\", # Task type for causal language modeling\n",
|
| 257 |
+
")\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"# Do NOT call get_peft_model() here — SFTTrainer wraps the model when peft_config is passed.\n",
|
| 260 |
+
"# tuned_model = get_peft_model(tuned_model, lora_config)\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"# --- 3. Prepare a Sample Dataset ---\n",
|
| 263 |
+
"# For a real-world scenario, you would load your own dataset using `load_dataset`\n",
|
| 264 |
+
"# from the `datasets` library and format it appropriately.\n",
|
| 265 |
+
"# This is a simple dummy dataset for demonstration.\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"# Example instruction tuning dataset format\n",
|
| 268 |
+
"data = {\n",
|
| 269 |
+
" \"text\": [\n",
|
| 270 |
+
" \"<start_of_turn>user\\nWhat is the capital of France?<end_of_turn>\\n<start_of_turn>model\\nParis is the capital of France.<end_of_turn>\",\n",
|
| 271 |
+
" \"<start_of_turn>user\\nSuggest a healthy snack.\\n<end_of_turn>\\n<start_of_turn>model\\nAlmonds or a piece of fruit like an apple are great healthy snack options.<end_of_turn>\",\n",
|
| 272 |
+
" \"<start_of_turn>user\\nExplain the concept of photosynthesis.\\n<end_of_turn>\\n<start_of_turn>model\\nPhotosynthesis is the process by which green plants and some other organisms convert light energy into chemical energy.<end_of_turn>\"\n",
|
| 273 |
+
" ]\n",
|
| 274 |
+
"}\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"dataset = Dataset.from_dict(data)\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"# --- 4. Define Training Arguments ---\n",
|
| 279 |
+
"from transformers import TrainingArguments\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"training_args = TrainingArguments(\n",
|
| 282 |
+
" output_dir=\"./gemma_finetuned\", # Output directory for checkpoints and logs\n",
|
| 283 |
+
" num_train_epochs=1, # Number of training epochs\n",
|
| 284 |
+
" per_device_train_batch_size=2, # Batch size per GPU/CPU for training\n",
|
| 285 |
+
" gradient_accumulation_steps=2, # Number of updates steps to accumulate before performing a backward/update pass\n",
|
| 286 |
+
" optim=\"paged_adamw_8bit\", # Optimizer to use\n",
|
| 287 |
+
" save_steps=100, # Save checkpoint every X updates steps\n",
|
| 288 |
+
" logging_steps=10, # Log every X updates steps\n",
|
| 289 |
+
" learning_rate=2e-4, # Initial learning rate for AdamW optimizer\n",
|
| 290 |
+
" weight_decay=0.001, # Weight decay for AdamW\n",
|
| 291 |
+
" fp16=False, # Must match bnb_4bit_compute_dtype (bf16 below)\n",
|
| 292 |
+
" bf16=True, # Use bf16 when bnb_4bit_compute_dtype=torch.bfloat16\n",
|
| 293 |
+
" max_grad_norm=0.3, # Max gradient norm\n",
|
| 294 |
+
" max_steps=-1, # Don't limit training by steps, use epochs\n",
|
| 295 |
+
" warmup_ratio=0.03, # Ratio of total steps for a linear warmup from 0 to learning_rate\n",
|
| 296 |
+
" # group_by_length=True, # Group sequences of roughly the same length together to speed up training\n",
|
| 297 |
+
" lr_scheduler_type=\"constant\", # Learning rate scheduler type\n",
|
| 298 |
+
" report_to=\"none\" # Disable reporting to any tracking service\n",
|
| 299 |
+
")\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"# --- 5. Initialize and Run SFTTrainer ---\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"trainer = SFTTrainer(\n",
|
| 304 |
+
" model=tuned_model, # plain (non-PEFT) base model\n",
|
| 305 |
+
" train_dataset=dataset,\n",
|
| 306 |
+
" peft_config=lora_config, # SFTTrainer applies LoRA internally\n",
|
| 307 |
+
" # dataset_text_field=\"text\", # Name of the column containing the text data\n",
|
| 308 |
+
" # tokenizer=tokenizer,\n",
|
| 309 |
+
" args=training_args,\n",
|
| 310 |
+
" # packing=False, # Whether to pack multiple short examples into one longer sequence to improve efficiency\n",
|
| 311 |
+
" # max_seq_length=512, # Max sequence length to use for training\n",
|
| 312 |
+
")\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"print(\"Starting finetuning...\")\n",
|
| 315 |
+
"trainer.train()\n",
|
| 316 |
+
"print(\"Finetuning complete!\")\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"# --- 6. Save the LoRA adapter ---\n",
|
| 319 |
+
"trainer.model.save_pretrained(\"./gemma_finetuned_model\")\n",
|
| 320 |
+
"tokenizer.save_pretrained(\"./gemma_finetuned_model\")\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"# --- 7. (Optional) Merge LoRA adapters for inference ---\n",
|
| 323 |
+
"# Merge in-memory from the trained model (avoids AutoPeft reload + torchao version issues).\n",
|
| 324 |
+
"merged_model = trainer.model.merge_and_unload()\n",
|
| 325 |
+
"merged_model.save_pretrained(\"gemma_merged_model\", safe_serialization=True)\n",
|
| 326 |
+
"tokenizer.save_pretrained(\"gemma_merged_model\")\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"# If you need to reload the adapter from disk later instead, upgrade torchao first:\n",
|
| 329 |
+
"# !pip install -U \"torchao>=0.16.0\"\n",
|
| 330 |
+
"# from peft import PeftModel\n",
|
| 331 |
+
"# base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\")\n",
|
| 332 |
+
"# peft_model = PeftModel.from_pretrained(base_model, \"./gemma_finetuned_model\")\n",
|
| 333 |
+
""
|
| 334 |
+
],
|
| 335 |
+
"execution_count": null,
|
| 336 |
+
"outputs": []
|
| 337 |
+
},
|
| 338 |
{
|
| 339 |
"cell_type": "code",
|
| 340 |
"metadata": {
|
|
|
|
| 461 |
"# Option B — wrap manually, omit peft_config from SFTTrainer:\n",
|
| 462 |
"# tuned_model = get_peft_model(tuned_model, lora_config)\n",
|
| 463 |
"# trainer = SFTTrainer(model=tuned_model, ...) # no peft_config\n",
|
| 464 |
+
"\n"
|
|
|
|
| 465 |
],
|
| 466 |
"execution_count": null,
|
| 467 |
"outputs": []
|