""" OmniGen-v1 LoRA Fine-Tuning for Thumbnail Generation Model: Shitao/OmniGen-v1 (3.8B, Phi-3 based) Method: LoRA (rank=8) fine-tuning via accelerate Dataset: PosterCraft/Poster100K + synthetic thumbnail prompts Output: Image generation model for thumbnails Input modes supported: - Text only → Thumbnail image - Image only → Thumbnail image - Text + Image → Thumbnail image Based on OmniGen official fine-tuning recipe: https://github.com/VectorSpaceLab/OmniGen """ import os import sys import json import math import random import logging import argparse from pathlib import Path from typing import Optional, List, Dict, Any import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from PIL import Image from tqdm import tqdm # OmniGen imports from OmniGen import OmniGenPipeline from OmniGen.model import OmniGen from OmniGen.processor import OmniGenProcessor from OmniGen.scheduler import OmniGenScheduler from diffusers import AutoencoderKL from transformers import get_cosine_schedule_with_warmup from peft import LoraConfig, get_peft_model from accelerate import Accelerator from accelerate.utils import set_seed import trackio logger = logging.getLogger(__name__) class ThumbnailDataset(Dataset): """Dataset for thumbnail generation training.""" def __init__( self, jsonl_path: str, image_dir: str, processor: OmniGenProcessor, max_image_size: int = 1024, max_input_length_limit: int = 18000, keep_raw_resolution: bool = True, condition_dropout_prob: float = 0.01, ): self.image_dir = image_dir self.processor = processor self.max_image_size = max_image_size self.max_input_length_limit = max_input_length_limit self.keep_raw_resolution = keep_raw_resolution self.condition_dropout_prob = condition_dropout_prob # Load JSONL entries self.entries = [] with open(jsonl_path, "r") as f: for line in f: line = line.strip() if line: self.entries.append(json.loads(line)) logger.info(f"Loaded {len(self.entries)} training samples from {jsonl_path}") def __len__(self): return len(self.entries) def _load_image(self, filename: str) -> Optional[Image.Image]: """Load an image from the image directory.""" path = os.path.join(self.image_dir, filename) if not os.path.exists(path): return None try: img = Image.open(path).convert("RGB") return img except Exception as e: logger.warning(f"Failed to load image {path}: {e}") return None def __getitem__(self, idx: int) -> Dict[str, Any]: entry = self.entries[idx] instruction = entry["instruction"] output_image_name = entry["output_image"] input_image_names = entry.get("input_images", []) # Apply condition dropout for CFG training if random.random() < self.condition_dropout_prob: instruction = "" # Load output (target) image output_image = self._load_image(output_image_name) if output_image is None: # Return a random other sample if image missing return self.__getitem__(random.randint(0, len(self) - 1)) # Load input images if any input_images = [] for img_name in input_image_names: img = self._load_image(img_name) if img is not None: input_images.append(img) return { "instruction": instruction, "output_image": output_image, "input_images": input_images if input_images else None, } def collate_fn(batch): """Custom collate that keeps PIL images.""" instructions = [item["instruction"] for item in batch] output_images = [item["output_image"] for item in batch] input_images = [item["input_images"] for item in batch] return { "instructions": instructions, "output_images": output_images, "input_images": input_images, } def parse_args(): parser = argparse.ArgumentParser(description="OmniGen LoRA Fine-Tuning for Thumbnails") # Model parser.add_argument("--model_name_or_path", type=str, default="Shitao/OmniGen-v1") # Data parser.add_argument("--json_file", type=str, required=True, help="Path to JSONL training data") parser.add_argument("--image_path", type=str, required=True, help="Root dir for images") parser.add_argument("--max_image_size", type=int, default=1024) parser.add_argument("--max_input_length_limit", type=int, default=18000) parser.add_argument("--keep_raw_resolution", action="store_true") # Training parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch_size_per_device", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--condition_dropout_prob", type=float, default=0.01) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--bf16", action="store_true", default=True) # LoRA parser.add_argument("--use_lora", action="store_true", default=True) parser.add_argument("--lora_rank", type=int, default=8) parser.add_argument("--lora_alpha", type=int, default=16) parser.add_argument("--lora_dropout", type=float, default=0.05) # Output parser.add_argument("--results_dir", type=str, default="./results/thumbnail_lora") parser.add_argument("--ckpt_every", type=int, default=500) parser.add_argument("--log_every", type=int, default=10) parser.add_argument("--push_to_hub", action="store_true", default=True) parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-omnigen-lora") return parser.parse_args() def main(): args = parse_args() # Initialize accelerator accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16" if args.bf16 else "no", log_with="all", ) # Initialize trackio monitoring if accelerator.is_main_process: trackio.init( project="thumbnail-vlm", name="omnigen-lora-finetune", ) set_seed(args.seed) logger.info(f"Loading OmniGen model from {args.model_name_or_path}...") # Load the OmniGen pipeline components pipe = OmniGenPipeline.from_pretrained(args.model_name_or_path) model = pipe.model processor = pipe.processor vae = pipe.vae # Freeze VAE vae.requires_grad_(False) vae.eval() if args.use_lora: logger.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...") # Apply LoRA to the transformer backbone lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=["qkv_proj", "o_proj", "gate_up_proj", "down_proj"], bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() else: model.train() # Dataset logger.info(f"Loading dataset from {args.json_file}...") dataset = ThumbnailDataset( jsonl_path=args.json_file, image_dir=args.image_path, processor=processor, max_image_size=args.max_image_size, max_input_length_limit=args.max_input_length_limit, keep_raw_resolution=args.keep_raw_resolution, condition_dropout_prob=args.condition_dropout_prob, ) dataloader = DataLoader( dataset, batch_size=args.batch_size_per_device, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True, ) # Optimizer optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999), ) # Scheduler num_training_steps = len(dataloader) * args.epochs // args.gradient_accumulation_steps lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps, ) # Prepare with accelerator model, optimizer, dataloader, lr_scheduler = accelerator.prepare( model, optimizer, dataloader, lr_scheduler ) # Move VAE to device vae = vae.to(accelerator.device, dtype=torch.bfloat16 if args.bf16 else torch.float32) os.makedirs(args.results_dir, exist_ok=True) logger.info("=" * 60) logger.info("Training Configuration:") logger.info(f" Model: {args.model_name_or_path}") logger.info(f" LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") logger.info(f" Dataset: {len(dataset)} samples") logger.info(f" Epochs: {args.epochs}") logger.info(f" Batch size: {args.batch_size_per_device}") logger.info(f" Grad accum: {args.gradient_accumulation_steps}") logger.info(f" Effective batch: {args.batch_size_per_device * args.gradient_accumulation_steps * accelerator.num_processes}") logger.info(f" LR: {args.lr}") logger.info(f" Total steps: {num_training_steps}") logger.info(f" Hub model: {args.hub_model_id}") logger.info("=" * 60) # Training loop global_step = 0 best_loss = float("inf") for epoch in range(args.epochs): model.train() epoch_loss = 0.0 num_batches = 0 for step, batch in enumerate(dataloader): with accelerator.accumulate(model): instructions = batch["instructions"] output_images = batch["output_images"] input_images_list = batch["input_images"] # Process each sample in the batch total_loss = torch.tensor(0.0, device=accelerator.device) valid_samples = 0 for i in range(len(instructions)): try: instruction = instructions[i] output_img = output_images[i] input_imgs = input_images_list[i] # Encode target image with VAE from torchvision import transforms transform = transforms.Compose([ transforms.Resize((args.max_image_size, args.max_image_size)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) target_tensor = transform(output_img).unsqueeze(0).to( accelerator.device, dtype=torch.bfloat16 if args.bf16 else torch.float32 ) # Get VAE latents with torch.no_grad(): latents = vae.encode(target_tensor).latent_dist.sample() latents = latents * vae.config.scaling_factor # Add noise (flow matching) noise = torch.randn_like(latents) timesteps = torch.rand(1, device=accelerator.device) noisy_latents = (1 - timesteps) * latents + timesteps * noise # Process input through model # OmniGen processes text+images together through the Phi-3 backbone input_data = processor( instruction, input_images=input_imgs, height=args.max_image_size, width=args.max_image_size, ) # Forward pass model_output = model( input_ids=input_data["input_ids"].to(accelerator.device), input_img_latents=input_data.get("input_img_latents"), input_image_sizes=input_data.get("input_image_sizes"), attention_mask=input_data["attention_mask"].to(accelerator.device), position_ids=input_data["position_ids"].to(accelerator.device), x=noisy_latents, t=timesteps, ) # Flow matching loss: MSE between predicted velocity and target target = noise - latents # velocity target for rectified flow loss = F.mse_loss(model_output, target) total_loss += loss valid_samples += 1 except Exception as e: logger.warning(f"Error processing sample {i}: {e}") continue if valid_samples > 0: avg_loss = total_loss / valid_samples accelerator.backward(avg_loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() epoch_loss += avg_loss.item() num_batches += 1 global_step += 1 # Logging if global_step % args.log_every == 0 and accelerator.is_main_process: avg_epoch_loss = epoch_loss / max(num_batches, 1) current_lr = lr_scheduler.get_last_lr()[0] print(f"step={global_step}, epoch={epoch+1}/{args.epochs}, " f"loss={avg_loss.item():.4f}, avg_loss={avg_epoch_loss:.4f}, " f"lr={current_lr:.2e}") trackio.log({ "train/loss": avg_loss.item(), "train/avg_loss": avg_epoch_loss, "train/lr": current_lr, "train/epoch": epoch + 1, "train/step": global_step, }) # Checkpoint if global_step % args.ckpt_every == 0 and accelerator.is_main_process: ckpt_dir = os.path.join(args.results_dir, f"checkpoint-{global_step}") os.makedirs(ckpt_dir, exist_ok=True) unwrapped_model = accelerator.unwrap_model(model) if args.use_lora: unwrapped_model.save_pretrained(ckpt_dir) else: torch.save(unwrapped_model.state_dict(), os.path.join(ckpt_dir, "model.pt")) logger.info(f"Saved checkpoint to {ckpt_dir}") # End of epoch logging if accelerator.is_main_process: avg_epoch_loss = epoch_loss / max(num_batches, 1) print(f"\n{'='*60}") print(f"Epoch {epoch+1}/{args.epochs} complete. Avg loss: {avg_epoch_loss:.4f}") print(f"{'='*60}\n") if avg_epoch_loss < best_loss: best_loss = avg_epoch_loss best_dir = os.path.join(args.results_dir, "best") os.makedirs(best_dir, exist_ok=True) unwrapped_model = accelerator.unwrap_model(model) if args.use_lora: unwrapped_model.save_pretrained(best_dir) else: torch.save(unwrapped_model.state_dict(), os.path.join(best_dir, "model.pt")) logger.info(f"New best model saved (loss={best_loss:.4f})") # Final save and push to hub if accelerator.is_main_process: final_dir = os.path.join(args.results_dir, "final") os.makedirs(final_dir, exist_ok=True) unwrapped_model = accelerator.unwrap_model(model) if args.use_lora: unwrapped_model.save_pretrained(final_dir) if args.push_to_hub: logger.info(f"Pushing LoRA adapters to {args.hub_model_id}...") unwrapped_model.push_to_hub(args.hub_model_id, token=os.environ.get("HF_TOKEN")) else: torch.save(unwrapped_model.state_dict(), os.path.join(final_dir, "model.pt")) logger.info(f"Training complete! Final model saved to {final_dir}") logger.info(f"Best loss: {best_loss:.4f}") if args.push_to_hub: print(f"\nModel pushed to: https://huggingface.co/{args.hub_model_id}") trackio.finish() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") main()