Any-to-Any
Transformers
English
text-to-image
image-to-image
text-and-image-to-image
multimodal
unified-model
thumbnail-generation
vlm
Instructions to use asats/thumbnail-vlm-janus-pro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use asats/thumbnail-vlm-janus-pro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("asats/thumbnail-vlm-janus-pro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Janus-Pro-7B Fine-Tuning for Thumbnail Generation | |
| Architecture: DeepSeek-LLM-7B + SigLIP (understanding) + VQ-16 (generation) | |
| Method: Full SFT following Janus-4o recipe (arxiv:2506.18095) | |
| Dataset: PosterCraft + ShareGPT-4o-Image + synthetic thumbnail prompts | |
| Supports all 3 input modes: | |
| 1. Text → Thumbnail (T2I) | |
| 2. Image → Thumbnail (I2T2I via captioning + generation) | |
| 3. Text + Image → Thumbnail (T&I2I) | |
| Based on Janus-4o paper hyperparameters: | |
| lr=5e-6, epochs=3, batch=128, full fine-tune | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import math | |
| import random | |
| import logging | |
| import argparse | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any, Tuple | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, DistributedSampler | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import trackio | |
| from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup | |
| logger = logging.getLogger(__name__) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # JANUS IMPORTS — requires: pip install -e . from the Janus repo | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| from janus.models import MultiModalityCausalLM, VLChatProcessor | |
| class TrainingConfig: | |
| """Training configuration following Janus-4o recipe.""" | |
| model_path: str = "deepseek-ai/Janus-Pro-7B" | |
| # Data | |
| train_jsonl: str = "" | |
| image_dir: str = "" | |
| # Training hyperparameters (from Janus-4o paper §3.3) | |
| epochs: int = 3 | |
| batch_size: int = 2 # per-device (accumulate to effective 128) | |
| gradient_accumulation: int = 8 | |
| lr: float = 5e-6 | |
| weight_decay: float = 0.0 | |
| warmup_ratio: float = 0.03 | |
| max_grad_norm: float = 1.0 | |
| # CFG training | |
| prompt_mask_prob: float = 0.10 # 10% prompts masked for CFG | |
| input_image_mask_prob: float = 0.50 # 50% input VQ tokens masked | |
| # Model | |
| image_size: int = 384 | |
| patch_size: int = 16 | |
| image_token_num: int = 576 # 384/16 = 24, 24*24 = 576 | |
| vq_codebook_size: int = 16384 | |
| dtype: str = "bfloat16" | |
| # Output | |
| output_dir: str = "./results/janus_thumbnail" | |
| push_to_hub: bool = True | |
| hub_model_id: str = "asats/thumbnail-vlm-janus-pro" | |
| save_every: int = 500 | |
| log_every: int = 10 | |
| seed: int = 42 | |
| class ThumbnailJanusDataset(Dataset): | |
| """Dataset for Janus-Pro thumbnail fine-tuning. | |
| Each sample produces: | |
| - input_text: the prompt text | |
| - target_image: PIL Image (384x384) to be VQ-encoded | |
| - input_image: Optional PIL Image for T&I2I mode | |
| - mode: 't2i' or 'ti2i' | |
| """ | |
| def __init__(self, jsonl_path: str, image_dir: str, image_size: int = 384): | |
| self.image_dir = image_dir | |
| self.image_size = image_size | |
| self.entries = [] | |
| with open(jsonl_path, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| self.entries.append(json.loads(line)) | |
| logger.info(f"Loaded {len(self.entries)} samples") | |
| def __len__(self): | |
| return len(self.entries) | |
| def _load_and_resize(self, filename: str) -> Optional[Image.Image]: | |
| path = os.path.join(self.image_dir, filename) | |
| if not os.path.exists(path): | |
| return None | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| # Center crop to square, then resize to 384x384 | |
| w, h = img.size | |
| min_dim = min(w, h) | |
| left = (w - min_dim) // 2 | |
| top = (h - min_dim) // 2 | |
| img = img.crop((left, top, left + min_dim, top + min_dim)) | |
| img = img.resize((self.image_size, self.image_size), Image.LANCZOS) | |
| return img | |
| except Exception as e: | |
| logger.warning(f"Failed to load {path}: {e}") | |
| return None | |
| def __getitem__(self, idx): | |
| entry = self.entries[idx] | |
| instruction = entry["instruction"] | |
| output_image_name = entry["output_image"] | |
| input_image_names = entry.get("input_images", []) | |
| # Load target image | |
| target_image = self._load_and_resize(output_image_name) | |
| if target_image is None: | |
| return self.__getitem__(random.randint(0, len(self) - 1)) | |
| # Load input image if available | |
| input_image = None | |
| if input_image_names: | |
| input_image = self._load_and_resize(input_image_names[0]) | |
| mode = "ti2i" if input_image is not None else "t2i" | |
| return { | |
| "instruction": instruction, | |
| "target_image": target_image, | |
| "input_image": input_image, | |
| "mode": mode, | |
| } | |
| def image_to_tensor(img: Image.Image) -> torch.Tensor: | |
| """Convert PIL image to tensor normalized to [-1, 1].""" | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| arr = arr * 2.0 - 1.0 # [-1, 1] | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1) # [3, H, W] | |
| return tensor | |
| def train_step_t2i( | |
| model: MultiModalityCausalLM, | |
| processor: VLChatProcessor, | |
| instruction: str, | |
| target_image: Image.Image, | |
| config: TrainingConfig, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """Forward pass for text-to-image thumbnail generation. | |
| 1. Encode target image to VQ tokens (target) | |
| 2. Build text input embeddings | |
| 3. Teacher-force: predict VQ tokens autoregressively | |
| 4. Loss = CE on image token predictions | |
| """ | |
| dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float32 | |
| # 1. Encode target image → VQ tokens | |
| target_tensor = image_to_tensor(target_image).unsqueeze(0).to(device, dtype=dtype) | |
| with torch.no_grad(): | |
| quant, emb_loss, info = model.gen_vision_model.encode(target_tensor) | |
| target_tokens = info[2].detach().reshape(1, -1) # [1, 576] | |
| # 2. Build conversation prompt | |
| # Apply CFG masking: 10% chance to mask the prompt | |
| if random.random() < config.prompt_mask_prob: | |
| prompt_text = "" | |
| else: | |
| prompt_text = instruction | |
| conversation = [ | |
| {"role": "<|User|>", "content": prompt_text}, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ] | |
| sft_format = processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=conversation, | |
| sft_format=processor.sft_format, | |
| system_prompt="", | |
| ) | |
| prompt = sft_format + processor.image_start_tag | |
| # 3. Tokenize and get text embeddings | |
| input_ids = processor.tokenizer.encode(prompt) | |
| input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device) # [1, seq_len] | |
| text_embeds = model.language_model.get_input_embeddings()(input_ids) # [1, seq_len, 4096] | |
| # 4. Get image token embeddings (teacher forcing) | |
| img_embeds = model.prepare_gen_img_embeds(target_tokens.reshape(-1)) | |
| img_embeds = img_embeds.reshape(1, config.image_token_num, -1) # [1, 576, 4096] | |
| # 5. Concat: [text | img_tokens[:-1]] → predict img_tokens[1:] | |
| # Full input: text + first 575 image tokens → predict last 576 image tokens | |
| full_embeds = torch.cat([text_embeds, img_embeds[:, :-1, :]], dim=1) # [1, seq_len+575, 4096] | |
| # 6. Forward through LLM | |
| outputs = model.language_model.model(inputs_embeds=full_embeds) | |
| hidden = outputs.last_hidden_state # [1, seq_len+575, 4096] | |
| # 7. Extract logits for image token positions only | |
| text_len = text_embeds.shape[1] | |
| # The model should predict the first image token from the text, and subsequent ones from previous tokens | |
| # Positions text_len-1 through text_len+574 predict image tokens 0 through 575 | |
| image_hidden = hidden[:, text_len - 1:, :] # [1, 576, 4096] | |
| logits = model.gen_head(image_hidden) # [1, 576, 16384] | |
| # 8. Cross-entropy loss | |
| loss = F.cross_entropy( | |
| logits.reshape(-1, config.vq_codebook_size), | |
| target_tokens.reshape(-1), | |
| ) | |
| return loss | |
| def main(): | |
| # Parse config | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_path", type=str, default="deepseek-ai/Janus-Pro-7B") | |
| parser.add_argument("--train_jsonl", type=str, required=True) | |
| parser.add_argument("--image_dir", type=str, required=True) | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--batch_size", type=int, default=2) | |
| parser.add_argument("--gradient_accumulation", type=int, default=8) | |
| parser.add_argument("--lr", type=float, default=5e-6) | |
| parser.add_argument("--output_dir", type=str, default="./results/janus_thumbnail") | |
| parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-janus-pro") | |
| parser.add_argument("--push_to_hub", action="store_true", default=True) | |
| parser.add_argument("--save_every", type=int, default=500) | |
| parser.add_argument("--log_every", type=int, default=10) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--local_rank", type=int, default=-1) | |
| args = parser.parse_args() | |
| config = TrainingConfig( | |
| model_path=args.model_path, | |
| train_jsonl=args.train_jsonl, | |
| image_dir=args.image_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| gradient_accumulation=args.gradient_accumulation, | |
| lr=args.lr, | |
| output_dir=args.output_dir, | |
| hub_model_id=args.hub_model_id, | |
| push_to_hub=args.push_to_hub, | |
| save_every=args.save_every, | |
| log_every=args.log_every, | |
| seed=args.seed, | |
| ) | |
| # Set seed | |
| random.seed(config.seed) | |
| np.random.seed(config.seed) | |
| torch.manual_seed(config.seed) | |
| # Determine device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float32 | |
| # Initialize trackio | |
| trackio.init( | |
| project="thumbnail-vlm", | |
| name="janus-pro-finetune", | |
| ) | |
| # Load model | |
| logger.info(f"Loading Janus-Pro from {config.model_path}...") | |
| processor: VLChatProcessor = VLChatProcessor.from_pretrained(config.model_path) | |
| model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| ) | |
| model = model.to(device) | |
| model.train() | |
| # Freeze the vision encoder (SigLIP) — only train LLM + gen_head + gen_aligner | |
| # This follows common practice for generation fine-tuning | |
| if hasattr(model, 'vision_model'): | |
| for param in model.vision_model.parameters(): | |
| param.requires_grad = False | |
| logger.info("Froze vision encoder (SigLIP)") | |
| # Freeze VQ tokenizer (gen_vision_model) | |
| if hasattr(model, 'gen_vision_model'): | |
| for param in model.gen_vision_model.parameters(): | |
| param.requires_grad = False | |
| logger.info("Froze VQ tokenizer") | |
| # Count trainable parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| logger.info(f"Total params: {total_params/1e6:.1f}M, Trainable: {trainable_params/1e6:.1f}M") | |
| # Dataset | |
| dataset = ThumbnailJanusDataset( | |
| jsonl_path=config.train_jsonl, | |
| image_dir=config.image_dir, | |
| image_size=config.image_size, | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| num_workers=2, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| # Optimizer (AdamW, matching Janus-4o) | |
| optimizer = torch.optim.AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=config.lr, | |
| betas=(0.9, 0.95), | |
| weight_decay=config.weight_decay, | |
| ) | |
| # Scheduler | |
| num_steps = len(dataloader) * config.epochs // config.gradient_accumulation | |
| warmup_steps = int(num_steps * config.warmup_ratio) | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=num_steps, | |
| ) | |
| # Gradient scaler for mixed precision | |
| scaler = torch.amp.GradScaler('cuda', enabled=(config.dtype == "bfloat16")) | |
| os.makedirs(config.output_dir, exist_ok=True) | |
| logger.info("=" * 60) | |
| logger.info("Janus-Pro Thumbnail Fine-Tuning") | |
| logger.info(f" Model: {config.model_path}") | |
| logger.info(f" Dataset: {len(dataset)} samples") | |
| logger.info(f" Epochs: {config.epochs}") | |
| logger.info(f" Batch: {config.batch_size} × {config.gradient_accumulation} = {config.batch_size * config.gradient_accumulation}") | |
| logger.info(f" LR: {config.lr}") | |
| logger.info(f" Total steps: {num_steps}") | |
| logger.info(f" Trainable params: {trainable_params/1e6:.1f}M") | |
| logger.info("=" * 60) | |
| # Training loop | |
| global_step = 0 | |
| best_loss = float("inf") | |
| accumulation_loss = 0.0 | |
| for epoch in range(config.epochs): | |
| epoch_loss = 0.0 | |
| num_batches = 0 | |
| for step, batch in enumerate(dataloader): | |
| # Process each sample in the micro-batch | |
| micro_loss = torch.tensor(0.0, device=device) | |
| valid = 0 | |
| for i in range(len(batch["instruction"])): | |
| try: | |
| with torch.amp.autocast('cuda', dtype=dtype): | |
| loss = train_step_t2i( | |
| model=model, | |
| processor=processor, | |
| instruction=batch["instruction"][i], | |
| target_image=batch["target_image"][i], | |
| config=config, | |
| device=device, | |
| ) | |
| micro_loss += loss / config.gradient_accumulation | |
| valid += 1 | |
| except Exception as e: | |
| logger.warning(f"Step {step}, sample {i} error: {e}") | |
| continue | |
| if valid > 0: | |
| # Backward | |
| scaler.scale(micro_loss / valid * config.batch_size).backward() | |
| accumulation_loss += micro_loss.item() | |
| if (step + 1) % config.gradient_accumulation == 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| # Logging | |
| if global_step % config.log_every == 0: | |
| avg_loss = accumulation_loss / config.log_every | |
| current_lr = lr_scheduler.get_last_lr()[0] | |
| print(f"step={global_step}/{num_steps}, epoch={epoch+1}/{config.epochs}, " | |
| f"loss={avg_loss:.4f}, lr={current_lr:.2e}") | |
| trackio.log({ | |
| "train/loss": avg_loss, | |
| "train/lr": current_lr, | |
| "train/epoch": epoch + 1, | |
| "train/step": global_step, | |
| }) | |
| accumulation_loss = 0.0 | |
| # Save checkpoint | |
| if global_step % config.save_every == 0: | |
| ckpt_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") | |
| os.makedirs(ckpt_path, exist_ok=True) | |
| model.save_pretrained(ckpt_path) | |
| processor.save_pretrained(ckpt_path) | |
| logger.info(f"Saved checkpoint: {ckpt_path}") | |
| # End of epoch | |
| print(f"\n{'='*60}") | |
| print(f"Epoch {epoch+1}/{config.epochs} complete") | |
| print(f"{'='*60}\n") | |
| # Final save | |
| final_path = os.path.join(config.output_dir, "final") | |
| os.makedirs(final_path, exist_ok=True) | |
| model.save_pretrained(final_path) | |
| processor.save_pretrained(final_path) | |
| if config.push_to_hub: | |
| logger.info(f"Pushing to hub: {config.hub_model_id}") | |
| model.push_to_hub(config.hub_model_id, token=os.environ.get("HF_TOKEN")) | |
| processor.push_to_hub(config.hub_model_id, token=os.environ.get("HF_TOKEN")) | |
| print(f"\nModel pushed to: https://huggingface.co/{config.hub_model_id}") | |
| trackio.finish() | |
| logger.info("Training complete!") | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| main() | |