thumbnail-vlm-janus-pro / scripts /train_omnigen.py
asats's picture
Upload complete train_omnigen.py
a56eb73 verified
Raw
History Blame Contribute Delete
17.7 kB
"""
OmniGen-v1 LoRA Fine-Tuning for Thumbnail Generation
Model: Shitao/OmniGen-v1 (3.8B, Phi-3 based)
Method: LoRA (rank=8) fine-tuning via accelerate
Dataset: PosterCraft/Poster100K + synthetic thumbnail prompts
Output: Image generation model for thumbnails
Input modes supported:
- Text only → Thumbnail image
- Image only → Thumbnail image
- Text + Image → Thumbnail image
Based on OmniGen official fine-tuning recipe:
https://github.com/VectorSpaceLab/OmniGen
"""
import os
import sys
import json
import math
import random
import logging
import argparse
from pathlib import Path
from typing import Optional, List, Dict, Any
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
# OmniGen imports
from OmniGen import OmniGenPipeline
from OmniGen.model import OmniGen
from OmniGen.processor import OmniGenProcessor
from OmniGen.scheduler import OmniGenScheduler
from diffusers import AutoencoderKL
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from accelerate.utils import set_seed
import trackio
logger = logging.getLogger(__name__)
class ThumbnailDataset(Dataset):
"""Dataset for thumbnail generation training."""
def __init__(
self,
jsonl_path: str,
image_dir: str,
processor: OmniGenProcessor,
max_image_size: int = 1024,
max_input_length_limit: int = 18000,
keep_raw_resolution: bool = True,
condition_dropout_prob: float = 0.01,
):
self.image_dir = image_dir
self.processor = processor
self.max_image_size = max_image_size
self.max_input_length_limit = max_input_length_limit
self.keep_raw_resolution = keep_raw_resolution
self.condition_dropout_prob = condition_dropout_prob
# Load JSONL entries
self.entries = []
with open(jsonl_path, "r") as f:
for line in f:
line = line.strip()
if line:
self.entries.append(json.loads(line))
logger.info(f"Loaded {len(self.entries)} training samples from {jsonl_path}")
def __len__(self):
return len(self.entries)
def _load_image(self, filename: str) -> Optional[Image.Image]:
"""Load an image from the image directory."""
path = os.path.join(self.image_dir, filename)
if not os.path.exists(path):
return None
try:
img = Image.open(path).convert("RGB")
return img
except Exception as e:
logger.warning(f"Failed to load image {path}: {e}")
return None
def __getitem__(self, idx: int) -> Dict[str, Any]:
entry = self.entries[idx]
instruction = entry["instruction"]
output_image_name = entry["output_image"]
input_image_names = entry.get("input_images", [])
# Apply condition dropout for CFG training
if random.random() < self.condition_dropout_prob:
instruction = ""
# Load output (target) image
output_image = self._load_image(output_image_name)
if output_image is None:
# Return a random other sample if image missing
return self.__getitem__(random.randint(0, len(self) - 1))
# Load input images if any
input_images = []
for img_name in input_image_names:
img = self._load_image(img_name)
if img is not None:
input_images.append(img)
return {
"instruction": instruction,
"output_image": output_image,
"input_images": input_images if input_images else None,
}
def collate_fn(batch):
"""Custom collate that keeps PIL images."""
instructions = [item["instruction"] for item in batch]
output_images = [item["output_image"] for item in batch]
input_images = [item["input_images"] for item in batch]
return {
"instructions": instructions,
"output_images": output_images,
"input_images": input_images,
}
def parse_args():
parser = argparse.ArgumentParser(description="OmniGen LoRA Fine-Tuning for Thumbnails")
# Model
parser.add_argument("--model_name_or_path", type=str, default="Shitao/OmniGen-v1")
# Data
parser.add_argument("--json_file", type=str, required=True, help="Path to JSONL training data")
parser.add_argument("--image_path", type=str, required=True, help="Root dir for images")
parser.add_argument("--max_image_size", type=int, default=1024)
parser.add_argument("--max_input_length_limit", type=int, default=18000)
parser.add_argument("--keep_raw_resolution", action="store_true")
# Training
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size_per_device", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--condition_dropout_prob", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--bf16", action="store_true", default=True)
# LoRA
parser.add_argument("--use_lora", action="store_true", default=True)
parser.add_argument("--lora_rank", type=int, default=8)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--lora_dropout", type=float, default=0.05)
# Output
parser.add_argument("--results_dir", type=str, default="./results/thumbnail_lora")
parser.add_argument("--ckpt_every", type=int, default=500)
parser.add_argument("--log_every", type=int, default=10)
parser.add_argument("--push_to_hub", action="store_true", default=True)
parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-omnigen-lora")
return parser.parse_args()
def main():
args = parse_args()
# Initialize accelerator
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision="bf16" if args.bf16 else "no",
log_with="all",
)
# Initialize trackio monitoring
if accelerator.is_main_process:
trackio.init(
project="thumbnail-vlm",
name="omnigen-lora-finetune",
)
set_seed(args.seed)
logger.info(f"Loading OmniGen model from {args.model_name_or_path}...")
# Load the OmniGen pipeline components
pipe = OmniGenPipeline.from_pretrained(args.model_name_or_path)
model = pipe.model
processor = pipe.processor
vae = pipe.vae
# Freeze VAE
vae.requires_grad_(False)
vae.eval()
if args.use_lora:
logger.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...")
# Apply LoRA to the transformer backbone
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=["qkv_proj", "o_proj", "gate_up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
else:
model.train()
# Dataset
logger.info(f"Loading dataset from {args.json_file}...")
dataset = ThumbnailDataset(
jsonl_path=args.json_file,
image_dir=args.image_path,
processor=processor,
max_image_size=args.max_image_size,
max_input_length_limit=args.max_input_length_limit,
keep_raw_resolution=args.keep_raw_resolution,
condition_dropout_prob=args.condition_dropout_prob,
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size_per_device,
shuffle=True,
collate_fn=collate_fn,
num_workers=2,
pin_memory=True,
)
# Optimizer
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
betas=(0.9, 0.999),
)
# Scheduler
num_training_steps = len(dataloader) * args.epochs // args.gradient_accumulation_steps
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=num_training_steps,
)
# Prepare with accelerator
model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, dataloader, lr_scheduler
)
# Move VAE to device
vae = vae.to(accelerator.device, dtype=torch.bfloat16 if args.bf16 else torch.float32)
os.makedirs(args.results_dir, exist_ok=True)
logger.info("=" * 60)
logger.info("Training Configuration:")
logger.info(f" Model: {args.model_name_or_path}")
logger.info(f" LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}")
logger.info(f" Dataset: {len(dataset)} samples")
logger.info(f" Epochs: {args.epochs}")
logger.info(f" Batch size: {args.batch_size_per_device}")
logger.info(f" Grad accum: {args.gradient_accumulation_steps}")
logger.info(f" Effective batch: {args.batch_size_per_device * args.gradient_accumulation_steps * accelerator.num_processes}")
logger.info(f" LR: {args.lr}")
logger.info(f" Total steps: {num_training_steps}")
logger.info(f" Hub model: {args.hub_model_id}")
logger.info("=" * 60)
# Training loop
global_step = 0
best_loss = float("inf")
for epoch in range(args.epochs):
model.train()
epoch_loss = 0.0
num_batches = 0
for step, batch in enumerate(dataloader):
with accelerator.accumulate(model):
instructions = batch["instructions"]
output_images = batch["output_images"]
input_images_list = batch["input_images"]
# Process each sample in the batch
total_loss = torch.tensor(0.0, device=accelerator.device)
valid_samples = 0
for i in range(len(instructions)):
try:
instruction = instructions[i]
output_img = output_images[i]
input_imgs = input_images_list[i]
# Encode target image with VAE
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((args.max_image_size, args.max_image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
target_tensor = transform(output_img).unsqueeze(0).to(
accelerator.device,
dtype=torch.bfloat16 if args.bf16 else torch.float32
)
# Get VAE latents
with torch.no_grad():
latents = vae.encode(target_tensor).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Add noise (flow matching)
noise = torch.randn_like(latents)
timesteps = torch.rand(1, device=accelerator.device)
noisy_latents = (1 - timesteps) * latents + timesteps * noise
# Process input through model
# OmniGen processes text+images together through the Phi-3 backbone
input_data = processor(
instruction,
input_images=input_imgs,
height=args.max_image_size,
width=args.max_image_size,
)
# Forward pass
model_output = model(
input_ids=input_data["input_ids"].to(accelerator.device),
input_img_latents=input_data.get("input_img_latents"),
input_image_sizes=input_data.get("input_image_sizes"),
attention_mask=input_data["attention_mask"].to(accelerator.device),
position_ids=input_data["position_ids"].to(accelerator.device),
x=noisy_latents,
t=timesteps,
)
# Flow matching loss: MSE between predicted velocity and target
target = noise - latents # velocity target for rectified flow
loss = F.mse_loss(model_output, target)
total_loss += loss
valid_samples += 1
except Exception as e:
logger.warning(f"Error processing sample {i}: {e}")
continue
if valid_samples > 0:
avg_loss = total_loss / valid_samples
accelerator.backward(avg_loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
epoch_loss += avg_loss.item()
num_batches += 1
global_step += 1
# Logging
if global_step % args.log_every == 0 and accelerator.is_main_process:
avg_epoch_loss = epoch_loss / max(num_batches, 1)
current_lr = lr_scheduler.get_last_lr()[0]
print(f"step={global_step}, epoch={epoch+1}/{args.epochs}, "
f"loss={avg_loss.item():.4f}, avg_loss={avg_epoch_loss:.4f}, "
f"lr={current_lr:.2e}")
trackio.log({
"train/loss": avg_loss.item(),
"train/avg_loss": avg_epoch_loss,
"train/lr": current_lr,
"train/epoch": epoch + 1,
"train/step": global_step,
})
# Checkpoint
if global_step % args.ckpt_every == 0 and accelerator.is_main_process:
ckpt_dir = os.path.join(args.results_dir, f"checkpoint-{global_step}")
os.makedirs(ckpt_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
if args.use_lora:
unwrapped_model.save_pretrained(ckpt_dir)
else:
torch.save(unwrapped_model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
logger.info(f"Saved checkpoint to {ckpt_dir}")
# End of epoch logging
if accelerator.is_main_process:
avg_epoch_loss = epoch_loss / max(num_batches, 1)
print(f"\n{'='*60}")
print(f"Epoch {epoch+1}/{args.epochs} complete. Avg loss: {avg_epoch_loss:.4f}")
print(f"{'='*60}\n")
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
best_dir = os.path.join(args.results_dir, "best")
os.makedirs(best_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
if args.use_lora:
unwrapped_model.save_pretrained(best_dir)
else:
torch.save(unwrapped_model.state_dict(), os.path.join(best_dir, "model.pt"))
logger.info(f"New best model saved (loss={best_loss:.4f})")
# Final save and push to hub
if accelerator.is_main_process:
final_dir = os.path.join(args.results_dir, "final")
os.makedirs(final_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
if args.use_lora:
unwrapped_model.save_pretrained(final_dir)
if args.push_to_hub:
logger.info(f"Pushing LoRA adapters to {args.hub_model_id}...")
unwrapped_model.push_to_hub(args.hub_model_id, token=os.environ.get("HF_TOKEN"))
else:
torch.save(unwrapped_model.state_dict(), os.path.join(final_dir, "model.pt"))
logger.info(f"Training complete! Final model saved to {final_dir}")
logger.info(f"Best loss: {best_loss:.4f}")
if args.push_to_hub:
print(f"\nModel pushed to: https://huggingface.co/{args.hub_model_id}")
trackio.finish()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
main()