""" Janus-Pro-7B Fine-Tuning for Thumbnail Generation Architecture: DeepSeek-LLM-7B + SigLIP (understanding) + VQ-16 (generation) Method: Full SFT following Janus-4o recipe (arxiv:2506.18095) Dataset: PosterCraft + ShareGPT-4o-Image + synthetic thumbnail prompts Supports all 3 input modes: 1. Text → Thumbnail (T2I) 2. Image → Thumbnail (I2T2I via captioning + generation) 3. Text + Image → Thumbnail (T&I2I) Based on Janus-4o paper hyperparameters: lr=5e-6, epochs=3, batch=128, full fine-tune """ import os import sys import json import math import random import logging import argparse from pathlib import Path from typing import Optional, List, Dict, Any, Tuple from dataclasses import dataclass import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, DistributedSampler from PIL import Image from tqdm import tqdm import trackio from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup logger = logging.getLogger(__name__) # ───────────────────────────────────────────────────────────────────────────── # JANUS IMPORTS — requires: pip install -e . from the Janus repo # ───────────────────────────────────────────────────────────────────────────── from janus.models import MultiModalityCausalLM, VLChatProcessor @dataclass class TrainingConfig: """Training configuration following Janus-4o recipe.""" model_path: str = "deepseek-ai/Janus-Pro-7B" # Data train_jsonl: str = "" image_dir: str = "" # Training hyperparameters (from Janus-4o paper §3.3) epochs: int = 3 batch_size: int = 2 # per-device (accumulate to effective 128) gradient_accumulation: int = 8 lr: float = 5e-6 weight_decay: float = 0.0 warmup_ratio: float = 0.03 max_grad_norm: float = 1.0 # CFG training prompt_mask_prob: float = 0.10 # 10% prompts masked for CFG input_image_mask_prob: float = 0.50 # 50% input VQ tokens masked # Model image_size: int = 384 patch_size: int = 16 image_token_num: int = 576 # 384/16 = 24, 24*24 = 576 vq_codebook_size: int = 16384 dtype: str = "bfloat16" # Output output_dir: str = "./results/janus_thumbnail" push_to_hub: bool = True hub_model_id: str = "asats/thumbnail-vlm-janus-pro" save_every: int = 500 log_every: int = 10 seed: int = 42 class ThumbnailJanusDataset(Dataset): """Dataset for Janus-Pro thumbnail fine-tuning. Each sample produces: - input_text: the prompt text - target_image: PIL Image (384x384) to be VQ-encoded - input_image: Optional PIL Image for T&I2I mode - mode: 't2i' or 'ti2i' """ def __init__(self, jsonl_path: str, image_dir: str, image_size: int = 384): self.image_dir = image_dir self.image_size = image_size self.entries = [] with open(jsonl_path, "r") as f: for line in f: line = line.strip() if line: self.entries.append(json.loads(line)) logger.info(f"Loaded {len(self.entries)} samples") def __len__(self): return len(self.entries) def _load_and_resize(self, filename: str) -> Optional[Image.Image]: path = os.path.join(self.image_dir, filename) if not os.path.exists(path): return None try: img = Image.open(path).convert("RGB") # Center crop to square, then resize to 384x384 w, h = img.size min_dim = min(w, h) left = (w - min_dim) // 2 top = (h - min_dim) // 2 img = img.crop((left, top, left + min_dim, top + min_dim)) img = img.resize((self.image_size, self.image_size), Image.LANCZOS) return img except Exception as e: logger.warning(f"Failed to load {path}: {e}") return None def __getitem__(self, idx): entry = self.entries[idx] instruction = entry["instruction"] output_image_name = entry["output_image"] input_image_names = entry.get("input_images", []) # Load target image target_image = self._load_and_resize(output_image_name) if target_image is None: return self.__getitem__(random.randint(0, len(self) - 1)) # Load input image if available input_image = None if input_image_names: input_image = self._load_and_resize(input_image_names[0]) mode = "ti2i" if input_image is not None else "t2i" return { "instruction": instruction, "target_image": target_image, "input_image": input_image, "mode": mode, } def image_to_tensor(img: Image.Image) -> torch.Tensor: """Convert PIL image to tensor normalized to [-1, 1].""" arr = np.array(img).astype(np.float32) / 255.0 arr = arr * 2.0 - 1.0 # [-1, 1] tensor = torch.from_numpy(arr).permute(2, 0, 1) # [3, H, W] return tensor def train_step_t2i( model: MultiModalityCausalLM, processor: VLChatProcessor, instruction: str, target_image: Image.Image, config: TrainingConfig, device: torch.device, ) -> torch.Tensor: """Forward pass for text-to-image thumbnail generation. 1. Encode target image to VQ tokens (target) 2. Build text input embeddings 3. Teacher-force: predict VQ tokens autoregressively 4. Loss = CE on image token predictions """ dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float32 # 1. Encode target image → VQ tokens target_tensor = image_to_tensor(target_image).unsqueeze(0).to(device, dtype=dtype) with torch.no_grad(): quant, emb_loss, info = model.gen_vision_model.encode(target_tensor) target_tokens = info[2].detach().reshape(1, -1) # [1, 576] # 2. Build conversation prompt # Apply CFG masking: 10% chance to mask the prompt if random.random() < config.prompt_mask_prob: prompt_text = "" else: prompt_text = instruction conversation = [ {"role": "<|User|>", "content": prompt_text}, {"role": "<|Assistant|>", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) prompt = sft_format + processor.image_start_tag # 3. Tokenize and get text embeddings input_ids = processor.tokenizer.encode(prompt) input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device) # [1, seq_len] text_embeds = model.language_model.get_input_embeddings()(input_ids) # [1, seq_len, 4096] # 4. Get image token embeddings (teacher forcing) img_embeds = model.prepare_gen_img_embeds(target_tokens.reshape(-1)) img_embeds = img_embeds.reshape(1, config.image_token_num, -1) # [1, 576, 4096] # 5. Concat: [text | img_tokens[:-1]] → predict img_tokens[1:] # Full input: text + first 575 image tokens → predict last 576 image tokens full_embeds = torch.cat([text_embeds, img_embeds[:, :-1, :]], dim=1) # [1, seq_len+575, 4096] # 6. Forward through LLM outputs = model.language_model.model(inputs_embeds=full_embeds) hidden = outputs.last_hidden_state # [1, seq_len+575, 4096] # 7. Extract logits for image token positions only text_len = text_embeds.shape[1] # The model should predict the first image token from the text, and subsequent ones from previous tokens # Positions text_len-1 through text_len+574 predict image tokens 0 through 575 image_hidden = hidden[:, text_len - 1:, :] # [1, 576, 4096] logits = model.gen_head(image_hidden) # [1, 576, 16384] # 8. Cross-entropy loss loss = F.cross_entropy( logits.reshape(-1, config.vq_codebook_size), target_tokens.reshape(-1), ) return loss def main(): # Parse config parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="deepseek-ai/Janus-Pro-7B") parser.add_argument("--train_jsonl", type=str, required=True) parser.add_argument("--image_dir", type=str, required=True) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--gradient_accumulation", type=int, default=8) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--output_dir", type=str, default="./results/janus_thumbnail") parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-janus-pro") parser.add_argument("--push_to_hub", action="store_true", default=True) parser.add_argument("--save_every", type=int, default=500) parser.add_argument("--log_every", type=int, default=10) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--local_rank", type=int, default=-1) args = parser.parse_args() config = TrainingConfig( model_path=args.model_path, train_jsonl=args.train_jsonl, image_dir=args.image_dir, epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation=args.gradient_accumulation, lr=args.lr, output_dir=args.output_dir, hub_model_id=args.hub_model_id, push_to_hub=args.push_to_hub, save_every=args.save_every, log_every=args.log_every, seed=args.seed, ) # Set seed random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) # Determine device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float32 # Initialize trackio trackio.init( project="thumbnail-vlm", name="janus-pro-finetune", ) # Load model logger.info(f"Loading Janus-Pro from {config.model_path}...") processor: VLChatProcessor = VLChatProcessor.from_pretrained(config.model_path) model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( config.model_path, trust_remote_code=True, torch_dtype=dtype, ) model = model.to(device) model.train() # Freeze the vision encoder (SigLIP) — only train LLM + gen_head + gen_aligner # This follows common practice for generation fine-tuning if hasattr(model, 'vision_model'): for param in model.vision_model.parameters(): param.requires_grad = False logger.info("Froze vision encoder (SigLIP)") # Freeze VQ tokenizer (gen_vision_model) if hasattr(model, 'gen_vision_model'): for param in model.gen_vision_model.parameters(): param.requires_grad = False logger.info("Froze VQ tokenizer") # Count trainable parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Total params: {total_params/1e6:.1f}M, Trainable: {trainable_params/1e6:.1f}M") # Dataset dataset = ThumbnailJanusDataset( jsonl_path=config.train_jsonl, image_dir=config.image_dir, image_size=config.image_size, ) dataloader = DataLoader( dataset, batch_size=config.batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True, ) # Optimizer (AdamW, matching Janus-4o) optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=config.lr, betas=(0.9, 0.95), weight_decay=config.weight_decay, ) # Scheduler num_steps = len(dataloader) * config.epochs // config.gradient_accumulation warmup_steps = int(num_steps * config.warmup_ratio) lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_steps, ) # Gradient scaler for mixed precision scaler = torch.amp.GradScaler('cuda', enabled=(config.dtype == "bfloat16")) os.makedirs(config.output_dir, exist_ok=True) logger.info("=" * 60) logger.info("Janus-Pro Thumbnail Fine-Tuning") logger.info(f" Model: {config.model_path}") logger.info(f" Dataset: {len(dataset)} samples") logger.info(f" Epochs: {config.epochs}") logger.info(f" Batch: {config.batch_size} × {config.gradient_accumulation} = {config.batch_size * config.gradient_accumulation}") logger.info(f" LR: {config.lr}") logger.info(f" Total steps: {num_steps}") logger.info(f" Trainable params: {trainable_params/1e6:.1f}M") logger.info("=" * 60) # Training loop global_step = 0 best_loss = float("inf") accumulation_loss = 0.0 for epoch in range(config.epochs): epoch_loss = 0.0 num_batches = 0 for step, batch in enumerate(dataloader): # Process each sample in the micro-batch micro_loss = torch.tensor(0.0, device=device) valid = 0 for i in range(len(batch["instruction"])): try: with torch.amp.autocast('cuda', dtype=dtype): loss = train_step_t2i( model=model, processor=processor, instruction=batch["instruction"][i], target_image=batch["target_image"][i], config=config, device=device, ) micro_loss += loss / config.gradient_accumulation valid += 1 except Exception as e: logger.warning(f"Step {step}, sample {i} error: {e}") continue if valid > 0: # Backward scaler.scale(micro_loss / valid * config.batch_size).backward() accumulation_loss += micro_loss.item() if (step + 1) % config.gradient_accumulation == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) scaler.step(optimizer) scaler.update() lr_scheduler.step() optimizer.zero_grad() global_step += 1 # Logging if global_step % config.log_every == 0: avg_loss = accumulation_loss / config.log_every current_lr = lr_scheduler.get_last_lr()[0] print(f"step={global_step}/{num_steps}, epoch={epoch+1}/{config.epochs}, " f"loss={avg_loss:.4f}, lr={current_lr:.2e}") trackio.log({ "train/loss": avg_loss, "train/lr": current_lr, "train/epoch": epoch + 1, "train/step": global_step, }) accumulation_loss = 0.0 # Save checkpoint if global_step % config.save_every == 0: ckpt_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") os.makedirs(ckpt_path, exist_ok=True) model.save_pretrained(ckpt_path) processor.save_pretrained(ckpt_path) logger.info(f"Saved checkpoint: {ckpt_path}") # End of epoch print(f"\n{'='*60}") print(f"Epoch {epoch+1}/{config.epochs} complete") print(f"{'='*60}\n") # Final save final_path = os.path.join(config.output_dir, "final") os.makedirs(final_path, exist_ok=True) model.save_pretrained(final_path) processor.save_pretrained(final_path) if config.push_to_hub: logger.info(f"Pushing to hub: {config.hub_model_id}") model.push_to_hub(config.hub_model_id, token=os.environ.get("HF_TOKEN")) processor.push_to_hub(config.hub_model_id, token=os.environ.get("HF_TOKEN")) print(f"\nModel pushed to: https://huggingface.co/{config.hub_model_id}") trackio.finish() logger.info("Training complete!") if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") main()