MSGEncrypted commited on
Commit
9341111
·
1 Parent(s): 8ccf67b
Files changed (1) hide show
  1. notebook/gemma-finetune.ipynb +98 -2
notebook/gemma-finetune.ipynb CHANGED
@@ -238,6 +238,103 @@
238
  "execution_count": null,
239
  "outputs": []
240
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  {
242
  "cell_type": "code",
243
  "metadata": {
@@ -364,8 +461,7 @@
364
  "# Option B — wrap manually, omit peft_config from SFTTrainer:\n",
365
  "# tuned_model = get_peft_model(tuned_model, lora_config)\n",
366
  "# trainer = SFTTrainer(model=tuned_model, ...) # no peft_config\n",
367
- "\n",
368
- ""
369
  ],
370
  "execution_count": null,
371
  "outputs": []
 
238
  "execution_count": null,
239
  "outputs": []
240
  },
241
+ {
242
+ "cell_type": "code",
243
+ "metadata": {},
244
+ "source": [
245
+ "\n",
246
+ "# Prepare model for k-bit training\n",
247
+ "tuned_model = prepare_model_for_kbit_training(tuned_model)\n",
248
+ "\n",
249
+ "# --- 2. Configure LoRA ---\n",
250
+ "lora_config = LoraConfig(\n",
251
+ " r=16, # LoRA attention dimension\n",
252
+ " lora_alpha=16, # Alpha parameter for LoRA scaling\n",
253
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], # Target all linear layers\n",
254
+ " lora_dropout=0.05, # Dropout probability for LoRA layers\n",
255
+ " bias=\"none\", # Only add bias to the LoRA layers\n",
256
+ " task_type=\"CAUSAL_LM\", # Task type for causal language modeling\n",
257
+ ")\n",
258
+ "\n",
259
+ "# Do NOT call get_peft_model() here — SFTTrainer wraps the model when peft_config is passed.\n",
260
+ "# tuned_model = get_peft_model(tuned_model, lora_config)\n",
261
+ "\n",
262
+ "# --- 3. Prepare a Sample Dataset ---\n",
263
+ "# For a real-world scenario, you would load your own dataset using `load_dataset`\n",
264
+ "# from the `datasets` library and format it appropriately.\n",
265
+ "# This is a simple dummy dataset for demonstration.\n",
266
+ "\n",
267
+ "# Example instruction tuning dataset format\n",
268
+ "data = {\n",
269
+ " \"text\": [\n",
270
+ " \"<start_of_turn>user\\nWhat is the capital of France?<end_of_turn>\\n<start_of_turn>model\\nParis is the capital of France.<end_of_turn>\",\n",
271
+ " \"<start_of_turn>user\\nSuggest a healthy snack.\\n<end_of_turn>\\n<start_of_turn>model\\nAlmonds or a piece of fruit like an apple are great healthy snack options.<end_of_turn>\",\n",
272
+ " \"<start_of_turn>user\\nExplain the concept of photosynthesis.\\n<end_of_turn>\\n<start_of_turn>model\\nPhotosynthesis is the process by which green plants and some other organisms convert light energy into chemical energy.<end_of_turn>\"\n",
273
+ " ]\n",
274
+ "}\n",
275
+ "\n",
276
+ "dataset = Dataset.from_dict(data)\n",
277
+ "\n",
278
+ "# --- 4. Define Training Arguments ---\n",
279
+ "from transformers import TrainingArguments\n",
280
+ "\n",
281
+ "training_args = TrainingArguments(\n",
282
+ " output_dir=\"./gemma_finetuned\", # Output directory for checkpoints and logs\n",
283
+ " num_train_epochs=1, # Number of training epochs\n",
284
+ " per_device_train_batch_size=2, # Batch size per GPU/CPU for training\n",
285
+ " gradient_accumulation_steps=2, # Number of updates steps to accumulate before performing a backward/update pass\n",
286
+ " optim=\"paged_adamw_8bit\", # Optimizer to use\n",
287
+ " save_steps=100, # Save checkpoint every X updates steps\n",
288
+ " logging_steps=10, # Log every X updates steps\n",
289
+ " learning_rate=2e-4, # Initial learning rate for AdamW optimizer\n",
290
+ " weight_decay=0.001, # Weight decay for AdamW\n",
291
+ " fp16=False, # Must match bnb_4bit_compute_dtype (bf16 below)\n",
292
+ " bf16=True, # Use bf16 when bnb_4bit_compute_dtype=torch.bfloat16\n",
293
+ " max_grad_norm=0.3, # Max gradient norm\n",
294
+ " max_steps=-1, # Don't limit training by steps, use epochs\n",
295
+ " warmup_ratio=0.03, # Ratio of total steps for a linear warmup from 0 to learning_rate\n",
296
+ " # group_by_length=True, # Group sequences of roughly the same length together to speed up training\n",
297
+ " lr_scheduler_type=\"constant\", # Learning rate scheduler type\n",
298
+ " report_to=\"none\" # Disable reporting to any tracking service\n",
299
+ ")\n",
300
+ "\n",
301
+ "# --- 5. Initialize and Run SFTTrainer ---\n",
302
+ "\n",
303
+ "trainer = SFTTrainer(\n",
304
+ " model=tuned_model, # plain (non-PEFT) base model\n",
305
+ " train_dataset=dataset,\n",
306
+ " peft_config=lora_config, # SFTTrainer applies LoRA internally\n",
307
+ " # dataset_text_field=\"text\", # Name of the column containing the text data\n",
308
+ " # tokenizer=tokenizer,\n",
309
+ " args=training_args,\n",
310
+ " # packing=False, # Whether to pack multiple short examples into one longer sequence to improve efficiency\n",
311
+ " # max_seq_length=512, # Max sequence length to use for training\n",
312
+ ")\n",
313
+ "\n",
314
+ "print(\"Starting finetuning...\")\n",
315
+ "trainer.train()\n",
316
+ "print(\"Finetuning complete!\")\n",
317
+ "\n",
318
+ "# --- 6. Save the LoRA adapter ---\n",
319
+ "trainer.model.save_pretrained(\"./gemma_finetuned_model\")\n",
320
+ "tokenizer.save_pretrained(\"./gemma_finetuned_model\")\n",
321
+ "\n",
322
+ "# --- 7. (Optional) Merge LoRA adapters for inference ---\n",
323
+ "# Merge in-memory from the trained model (avoids AutoPeft reload + torchao version issues).\n",
324
+ "merged_model = trainer.model.merge_and_unload()\n",
325
+ "merged_model.save_pretrained(\"gemma_merged_model\", safe_serialization=True)\n",
326
+ "tokenizer.save_pretrained(\"gemma_merged_model\")\n",
327
+ "\n",
328
+ "# If you need to reload the adapter from disk later instead, upgrade torchao first:\n",
329
+ "# !pip install -U \"torchao>=0.16.0\"\n",
330
+ "# from peft import PeftModel\n",
331
+ "# base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\")\n",
332
+ "# peft_model = PeftModel.from_pretrained(base_model, \"./gemma_finetuned_model\")\n",
333
+ ""
334
+ ],
335
+ "execution_count": null,
336
+ "outputs": []
337
+ },
338
  {
339
  "cell_type": "code",
340
  "metadata": {
 
461
  "# Option B — wrap manually, omit peft_config from SFTTrainer:\n",
462
  "# tuned_model = get_peft_model(tuned_model, lora_config)\n",
463
  "# trainer = SFTTrainer(model=tuned_model, ...) # no peft_config\n",
464
+ "\n"
 
465
  ],
466
  "execution_count": null,
467
  "outputs": []