Any-to-Any
Transformers
English
text-to-image
image-to-image
text-and-image-to-image
multimodal
unified-model
thumbnail-generation
vlm
Instructions to use asats/thumbnail-vlm-janus-pro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use asats/thumbnail-vlm-janus-pro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("asats/thumbnail-vlm-janus-pro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| OmniGen-v1 LoRA Fine-Tuning for Thumbnail Generation | |
| Model: Shitao/OmniGen-v1 (3.8B, Phi-3 based) | |
| Method: LoRA (rank=8) fine-tuning via accelerate | |
| Dataset: PosterCraft/Poster100K + synthetic thumbnail prompts | |
| Output: Image generation model for thumbnails | |
| Input modes supported: | |
| - Text only → Thumbnail image | |
| - Image only → Thumbnail image | |
| - Text + Image → Thumbnail image | |
| Based on OmniGen official fine-tuning recipe: | |
| https://github.com/VectorSpaceLab/OmniGen | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import math | |
| import random | |
| import logging | |
| import argparse | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| from tqdm import tqdm | |
| # OmniGen imports | |
| from OmniGen import OmniGenPipeline | |
| from OmniGen.model import OmniGen | |
| from OmniGen.processor import OmniGenProcessor | |
| from OmniGen.scheduler import OmniGenScheduler | |
| from diffusers import AutoencoderKL | |
| from transformers import get_cosine_schedule_with_warmup | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| import trackio | |
| logger = logging.getLogger(__name__) | |
| class ThumbnailDataset(Dataset): | |
| """Dataset for thumbnail generation training.""" | |
| def __init__( | |
| self, | |
| jsonl_path: str, | |
| image_dir: str, | |
| processor: OmniGenProcessor, | |
| max_image_size: int = 1024, | |
| max_input_length_limit: int = 18000, | |
| keep_raw_resolution: bool = True, | |
| condition_dropout_prob: float = 0.01, | |
| ): | |
| self.image_dir = image_dir | |
| self.processor = processor | |
| self.max_image_size = max_image_size | |
| self.max_input_length_limit = max_input_length_limit | |
| self.keep_raw_resolution = keep_raw_resolution | |
| self.condition_dropout_prob = condition_dropout_prob | |
| # Load JSONL entries | |
| self.entries = [] | |
| with open(jsonl_path, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| self.entries.append(json.loads(line)) | |
| logger.info(f"Loaded {len(self.entries)} training samples from {jsonl_path}") | |
| def __len__(self): | |
| return len(self.entries) | |
| def _load_image(self, filename: str) -> Optional[Image.Image]: | |
| """Load an image from the image directory.""" | |
| path = os.path.join(self.image_dir, filename) | |
| if not os.path.exists(path): | |
| return None | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| return img | |
| except Exception as e: | |
| logger.warning(f"Failed to load image {path}: {e}") | |
| return None | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| entry = self.entries[idx] | |
| instruction = entry["instruction"] | |
| output_image_name = entry["output_image"] | |
| input_image_names = entry.get("input_images", []) | |
| # Apply condition dropout for CFG training | |
| if random.random() < self.condition_dropout_prob: | |
| instruction = "" | |
| # Load output (target) image | |
| output_image = self._load_image(output_image_name) | |
| if output_image is None: | |
| # Return a random other sample if image missing | |
| return self.__getitem__(random.randint(0, len(self) - 1)) | |
| # Load input images if any | |
| input_images = [] | |
| for img_name in input_image_names: | |
| img = self._load_image(img_name) | |
| if img is not None: | |
| input_images.append(img) | |
| return { | |
| "instruction": instruction, | |
| "output_image": output_image, | |
| "input_images": input_images if input_images else None, | |
| } | |
| def collate_fn(batch): | |
| """Custom collate that keeps PIL images.""" | |
| instructions = [item["instruction"] for item in batch] | |
| output_images = [item["output_image"] for item in batch] | |
| input_images = [item["input_images"] for item in batch] | |
| return { | |
| "instructions": instructions, | |
| "output_images": output_images, | |
| "input_images": input_images, | |
| } | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="OmniGen LoRA Fine-Tuning for Thumbnails") | |
| # Model | |
| parser.add_argument("--model_name_or_path", type=str, default="Shitao/OmniGen-v1") | |
| # Data | |
| parser.add_argument("--json_file", type=str, required=True, help="Path to JSONL training data") | |
| parser.add_argument("--image_path", type=str, required=True, help="Root dir for images") | |
| parser.add_argument("--max_image_size", type=int, default=1024) | |
| parser.add_argument("--max_input_length_limit", type=int, default=18000) | |
| parser.add_argument("--keep_raw_resolution", action="store_true") | |
| # Training | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--batch_size_per_device", type=int, default=1) | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) | |
| parser.add_argument("--lr", type=float, default=1e-3) | |
| parser.add_argument("--weight_decay", type=float, default=0.01) | |
| parser.add_argument("--warmup_steps", type=int, default=100) | |
| parser.add_argument("--condition_dropout_prob", type=float, default=0.01) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--bf16", action="store_true", default=True) | |
| # LoRA | |
| parser.add_argument("--use_lora", action="store_true", default=True) | |
| parser.add_argument("--lora_rank", type=int, default=8) | |
| parser.add_argument("--lora_alpha", type=int, default=16) | |
| parser.add_argument("--lora_dropout", type=float, default=0.05) | |
| # Output | |
| parser.add_argument("--results_dir", type=str, default="./results/thumbnail_lora") | |
| parser.add_argument("--ckpt_every", type=int, default=500) | |
| parser.add_argument("--log_every", type=int, default=10) | |
| parser.add_argument("--push_to_hub", action="store_true", default=True) | |
| parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-omnigen-lora") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Initialize accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| mixed_precision="bf16" if args.bf16 else "no", | |
| log_with="all", | |
| ) | |
| # Initialize trackio monitoring | |
| if accelerator.is_main_process: | |
| trackio.init( | |
| project="thumbnail-vlm", | |
| name="omnigen-lora-finetune", | |
| ) | |
| set_seed(args.seed) | |
| logger.info(f"Loading OmniGen model from {args.model_name_or_path}...") | |
| # Load the OmniGen pipeline components | |
| pipe = OmniGenPipeline.from_pretrained(args.model_name_or_path) | |
| model = pipe.model | |
| processor = pipe.processor | |
| vae = pipe.vae | |
| # Freeze VAE | |
| vae.requires_grad_(False) | |
| vae.eval() | |
| if args.use_lora: | |
| logger.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...") | |
| # Apply LoRA to the transformer backbone | |
| lora_config = LoraConfig( | |
| r=args.lora_rank, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| target_modules=["qkv_proj", "o_proj", "gate_up_proj", "down_proj"], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| else: | |
| model.train() | |
| # Dataset | |
| logger.info(f"Loading dataset from {args.json_file}...") | |
| dataset = ThumbnailDataset( | |
| jsonl_path=args.json_file, | |
| image_dir=args.image_path, | |
| processor=processor, | |
| max_image_size=args.max_image_size, | |
| max_input_length_limit=args.max_input_length_limit, | |
| keep_raw_resolution=args.keep_raw_resolution, | |
| condition_dropout_prob=args.condition_dropout_prob, | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size_per_device, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| num_workers=2, | |
| pin_memory=True, | |
| ) | |
| # Optimizer | |
| optimizer = torch.optim.AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| betas=(0.9, 0.999), | |
| ) | |
| # Scheduler | |
| num_training_steps = len(dataloader) * args.epochs // args.gradient_accumulation_steps | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=num_training_steps, | |
| ) | |
| # Prepare with accelerator | |
| model, optimizer, dataloader, lr_scheduler = accelerator.prepare( | |
| model, optimizer, dataloader, lr_scheduler | |
| ) | |
| # Move VAE to device | |
| vae = vae.to(accelerator.device, dtype=torch.bfloat16 if args.bf16 else torch.float32) | |
| os.makedirs(args.results_dir, exist_ok=True) | |
| logger.info("=" * 60) | |
| logger.info("Training Configuration:") | |
| logger.info(f" Model: {args.model_name_or_path}") | |
| logger.info(f" LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") | |
| logger.info(f" Dataset: {len(dataset)} samples") | |
| logger.info(f" Epochs: {args.epochs}") | |
| logger.info(f" Batch size: {args.batch_size_per_device}") | |
| logger.info(f" Grad accum: {args.gradient_accumulation_steps}") | |
| logger.info(f" Effective batch: {args.batch_size_per_device * args.gradient_accumulation_steps * accelerator.num_processes}") | |
| logger.info(f" LR: {args.lr}") | |
| logger.info(f" Total steps: {num_training_steps}") | |
| logger.info(f" Hub model: {args.hub_model_id}") | |
| logger.info("=" * 60) | |
| # Training loop | |
| global_step = 0 | |
| best_loss = float("inf") | |
| for epoch in range(args.epochs): | |
| model.train() | |
| epoch_loss = 0.0 | |
| num_batches = 0 | |
| for step, batch in enumerate(dataloader): | |
| with accelerator.accumulate(model): | |
| instructions = batch["instructions"] | |
| output_images = batch["output_images"] | |
| input_images_list = batch["input_images"] | |
| # Process each sample in the batch | |
| total_loss = torch.tensor(0.0, device=accelerator.device) | |
| valid_samples = 0 | |
| for i in range(len(instructions)): | |
| try: | |
| instruction = instructions[i] | |
| output_img = output_images[i] | |
| input_imgs = input_images_list[i] | |
| # Encode target image with VAE | |
| from torchvision import transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize((args.max_image_size, args.max_image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| target_tensor = transform(output_img).unsqueeze(0).to( | |
| accelerator.device, | |
| dtype=torch.bfloat16 if args.bf16 else torch.float32 | |
| ) | |
| # Get VAE latents | |
| with torch.no_grad(): | |
| latents = vae.encode(target_tensor).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| # Add noise (flow matching) | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.rand(1, device=accelerator.device) | |
| noisy_latents = (1 - timesteps) * latents + timesteps * noise | |
| # Process input through model | |
| # OmniGen processes text+images together through the Phi-3 backbone | |
| input_data = processor( | |
| instruction, | |
| input_images=input_imgs, | |
| height=args.max_image_size, | |
| width=args.max_image_size, | |
| ) | |
| # Forward pass | |
| model_output = model( | |
| input_ids=input_data["input_ids"].to(accelerator.device), | |
| input_img_latents=input_data.get("input_img_latents"), | |
| input_image_sizes=input_data.get("input_image_sizes"), | |
| attention_mask=input_data["attention_mask"].to(accelerator.device), | |
| position_ids=input_data["position_ids"].to(accelerator.device), | |
| x=noisy_latents, | |
| t=timesteps, | |
| ) | |
| # Flow matching loss: MSE between predicted velocity and target | |
| target = noise - latents # velocity target for rectified flow | |
| loss = F.mse_loss(model_output, target) | |
| total_loss += loss | |
| valid_samples += 1 | |
| except Exception as e: | |
| logger.warning(f"Error processing sample {i}: {e}") | |
| continue | |
| if valid_samples > 0: | |
| avg_loss = total_loss / valid_samples | |
| accelerator.backward(avg_loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| epoch_loss += avg_loss.item() | |
| num_batches += 1 | |
| global_step += 1 | |
| # Logging | |
| if global_step % args.log_every == 0 and accelerator.is_main_process: | |
| avg_epoch_loss = epoch_loss / max(num_batches, 1) | |
| current_lr = lr_scheduler.get_last_lr()[0] | |
| print(f"step={global_step}, epoch={epoch+1}/{args.epochs}, " | |
| f"loss={avg_loss.item():.4f}, avg_loss={avg_epoch_loss:.4f}, " | |
| f"lr={current_lr:.2e}") | |
| trackio.log({ | |
| "train/loss": avg_loss.item(), | |
| "train/avg_loss": avg_epoch_loss, | |
| "train/lr": current_lr, | |
| "train/epoch": epoch + 1, | |
| "train/step": global_step, | |
| }) | |
| # Checkpoint | |
| if global_step % args.ckpt_every == 0 and accelerator.is_main_process: | |
| ckpt_dir = os.path.join(args.results_dir, f"checkpoint-{global_step}") | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| if args.use_lora: | |
| unwrapped_model.save_pretrained(ckpt_dir) | |
| else: | |
| torch.save(unwrapped_model.state_dict(), os.path.join(ckpt_dir, "model.pt")) | |
| logger.info(f"Saved checkpoint to {ckpt_dir}") | |
| # End of epoch logging | |
| if accelerator.is_main_process: | |
| avg_epoch_loss = epoch_loss / max(num_batches, 1) | |
| print(f"\n{'='*60}") | |
| print(f"Epoch {epoch+1}/{args.epochs} complete. Avg loss: {avg_epoch_loss:.4f}") | |
| print(f"{'='*60}\n") | |
| if avg_epoch_loss < best_loss: | |
| best_loss = avg_epoch_loss | |
| best_dir = os.path.join(args.results_dir, "best") | |
| os.makedirs(best_dir, exist_ok=True) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| if args.use_lora: | |
| unwrapped_model.save_pretrained(best_dir) | |
| else: | |
| torch.save(unwrapped_model.state_dict(), os.path.join(best_dir, "model.pt")) | |
| logger.info(f"New best model saved (loss={best_loss:.4f})") | |
| # Final save and push to hub | |
| if accelerator.is_main_process: | |
| final_dir = os.path.join(args.results_dir, "final") | |
| os.makedirs(final_dir, exist_ok=True) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| if args.use_lora: | |
| unwrapped_model.save_pretrained(final_dir) | |
| if args.push_to_hub: | |
| logger.info(f"Pushing LoRA adapters to {args.hub_model_id}...") | |
| unwrapped_model.push_to_hub(args.hub_model_id, token=os.environ.get("HF_TOKEN")) | |
| else: | |
| torch.save(unwrapped_model.state_dict(), os.path.join(final_dir, "model.pt")) | |
| logger.info(f"Training complete! Final model saved to {final_dir}") | |
| logger.info(f"Best loss: {best_loss:.4f}") | |
| if args.push_to_hub: | |
| print(f"\nModel pushed to: https://huggingface.co/{args.hub_model_id}") | |
| trackio.finish() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| main() | |