""" End-to-End Thumbnail VLM Training Pipeline This script: 1. Downloads and prepares training data (PosterCraft/Poster100K) 2. Fine-tunes Janus-Pro-7B for thumbnail generation 3. Evaluates the model with sample generations 4. Pushes to HuggingFace Hub Run with: python run_training.py """ import os import sys import json import random import logging import io import base64 from pathlib import Path import numpy as np import torch import torch.nn.functional as F from PIL import Image from tqdm import tqdm import trackio logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # ─── Configuration ────────────────────────────────────────────────────────── MODEL_PATH = "deepseek-ai/Janus-Pro-7B" HUB_MODEL_ID = "asats/thumbnail-vlm-janus-pro" OUTPUT_DIR = "/app/results/janus_thumbnail" DATA_DIR = "/app/thumbnail_data" IMAGE_DIR = os.path.join(DATA_DIR, "images") JSONL_PATH = os.path.join(DATA_DIR, "train.jsonl") # Training hyperparameters (Janus-4o recipe) EPOCHS = 3 BATCH_SIZE = 1 # per device (single GPU) GRADIENT_ACCUMULATION = 16 # effective batch = 16 LR = 5e-6 IMAGE_SIZE = 384 IMAGE_TOKEN_NUM = 576 VQ_CODEBOOK_SIZE = 16384 CFG_MASK_PROB = 0.10 SAVE_EVERY = 200 LOG_EVERY = 5 SEED = 42 os.makedirs(IMAGE_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # ─── Step 1: Prepare Training Data ────────────────────────────────────────── def prepare_training_data(max_poster_samples=8000, max_synthetic=2000): """Download PosterCraft images and create JSONL training file.""" logger.info("=" * 60) logger.info("Step 1: Preparing Training Data") logger.info("=" * 60) if os.path.exists(JSONL_PATH): with open(JSONL_PATH) as f: count = sum(1 for _ in f) logger.info(f"Training data already exists: {count} samples") return count from datasets import load_dataset entries = [] # 1a. Download PosterCraft/Poster100K images logger.info(f"Downloading PosterCraft/Poster100K (up to {max_poster_samples} samples)...") T2I_TEMPLATES = [ "Generate a professional thumbnail: {caption}", "Create an eye-catching visual thumbnail: {caption}", "Design a compelling thumbnail image: {caption}", "Generate a high-quality thumbnail: {caption}", "Create a visually striking thumbnail: {caption}", ] try: ds = load_dataset("PosterCraft/Poster100K", split="train", streaming=True) count = 0 for sample in tqdm(ds, desc="Downloading posters", total=max_poster_samples): if count >= max_poster_samples: break caption = sample.get("caption", "") if not caption or len(caption) < 20: continue image_data = sample.get("image") if image_data is None: continue fname = f"poster_{count:06d}.jpg" fpath = os.path.join(IMAGE_DIR, fname) try: if isinstance(image_data, bytes): img = Image.open(io.BytesIO(image_data)) elif isinstance(image_data, str): img_bytes = base64.b64decode(image_data) img = Image.open(io.BytesIO(img_bytes)) elif isinstance(image_data, Image.Image): img = image_data else: continue # Resize and save img = img.convert("RGB") w, h = img.size min_dim = min(w, h) left = (w - min_dim) // 2 top = (h - min_dim) // 2 img = img.crop((left, top, left + min_dim, top + min_dim)) img = img.resize((IMAGE_SIZE, IMAGE_SIZE), Image.LANCZOS) img.save(fpath, "JPEG", quality=95) # Truncate long captions if len(caption) > 500: caption = caption[:500] + "..." template = random.choice(T2I_TEMPLATES) instruction = template.format(caption=caption) entries.append({ "instruction": instruction, "output_image": fname, }) count += 1 if count % 500 == 0: logger.info(f" Downloaded {count}/{max_poster_samples} posters") except Exception as e: continue except Exception as e: logger.error(f"Error loading PosterCraft: {e}") logger.info(f" Downloaded {len(entries)} poster images") # 1b. Generate synthetic thumbnail prompts logger.info(f"Generating {max_synthetic} synthetic thumbnail prompts...") categories = { "tech": [ "Professional tech review thumbnail: sleek {product} product shot with dramatic studio lighting, bold comparison graphics, modern gradient background, text overlay '{title}'", "Technology unboxing thumbnail: hands opening {product} box, excited reaction, bright colorful background, text '{title}'", ], "cooking": [ "Appetizing cooking thumbnail: close-up of {dish} with steam rising, warm golden lighting, rustic kitchen background, text '{title}'", "Food recipe thumbnail: beautifully plated {dish}, overhead flat-lay shot, marble countertop, fresh ingredients scattered, text '{title}'", ], "gaming": [ "Epic gaming thumbnail: dramatic gameplay scene from {game}, character in action pose, particle effects, neon glow, bold text '{title}'", "Gaming reaction thumbnail: intense player reaction face, {game} gameplay in background, fire effects, text '{title}'", ], "fitness": [ "Fitness motivation thumbnail: athletic person doing {exercise}, dynamic studio lighting, energetic colors, bold text '{title}'", "Health transformation thumbnail: before/after split screen, clean modern design, motivational text '{title}'", ], "education": [ "Educational explainer thumbnail: clean whiteboard graphic about {topic}, colorful diagrams, professional font, text '{title}'", "Science video thumbnail: fascinating visualization of {topic}, bright curiosity-inducing colors, text '{title}'", ], "travel": [ "Travel vlog thumbnail: stunning panoramic view of {place}, golden hour lighting, wanderlust aesthetic, text '{title}'", "Adventure thumbnail: breathtaking landscape of {place}, person silhouette, dramatic sky, text '{title}'", ], } fill_values = { "product": ["iPhone 16 Pro", "MacBook Air M4", "PS5 Pro", "Tesla Model Y", "Quest 4", "Galaxy S25"], "dish": ["pasta carbonara", "ramen bowl", "sushi platter", "chocolate lava cake", "grilled salmon"], "game": ["Zelda Echoes", "GTA VI", "Minecraft", "Fortnite", "Elden Ring DLC", "Cyberpunk 2078"], "exercise": ["deadlifts", "yoga flow", "HIIT workout", "marathon training", "calisthenics"], "topic": ["quantum computing", "AI revolution", "black holes", "climate change", "human brain"], "place": ["Tokyo streets", "Santorini sunset", "Iceland northern lights", "Bali rice terraces", "Swiss Alps"], "title": [ "You NEED To See This!", "GAME CHANGER!", "The Ultimate Guide", "Top 10 SECRETS", "I Tried This For 30 Days", "Watch Before You Buy!", "Is It Worth It?", "HONEST Review", "This Changed EVERYTHING", "Don't Make This Mistake!", ], } for i in range(max_synthetic): cat_name = random.choice(list(categories.keys())) template = random.choice(categories[cat_name]) prompt = template for key, values in fill_values.items(): placeholder = "{" + key + "}" if placeholder in prompt: prompt = prompt.replace(placeholder, random.choice(values)) # For synthetic entries, we'll use a placeholder image # These will be used only for prompt diversity — we need poster images as targets # Pair with a random poster image if entries: random_poster = random.choice(entries) entries.append({ "instruction": f"Generate a YouTube thumbnail: {prompt}", "output_image": random_poster["output_image"], # reuse poster as target }) # Shuffle random.seed(SEED) random.shuffle(entries) # Write JSONL with open(JSONL_PATH, "w") as f: for entry in entries: f.write(json.dumps(entry) + "\n") logger.info(f"Total training samples: {len(entries)}") logger.info(f"JSONL saved to: {JSONL_PATH}") return len(entries) # ─── Step 2: Train Model ──────────────────────────────────────────────────── def train_model(): """Fine-tune Janus-Pro-7B for thumbnail generation.""" logger.info("=" * 60) logger.info("Step 2: Training Janus-Pro-7B") logger.info("=" * 60) from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup from janus.models import MultiModalityCausalLM, VLChatProcessor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 # Initialize tracking trackio.init(project="thumbnail-vlm", name="janus-pro-thumbnail") # Load model logger.info(f"Loading {MODEL_PATH}...") processor = VLChatProcessor.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, trust_remote_code=True, torch_dtype=dtype ) model = model.to(device) # Freeze encoders (SigLIP + VQ) — only train LLM backbone + gen_head + aligners for name, param in model.named_parameters(): if 'vision_model' in name or 'gen_vision_model' in name: param.requires_grad = False trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) logger.info(f"Parameters: {total/1e6:.1f}M total, {trainable/1e6:.1f}M trainable") model.train() # Load data entries = [] with open(JSONL_PATH) as f: for line in f: entries.append(json.loads(line.strip())) logger.info(f"Training on {len(entries)} samples") # Optimizer optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=LR, betas=(0.9, 0.95), weight_decay=0.0, ) num_steps = (len(entries) * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION) warmup_steps = int(num_steps * 0.03) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, num_steps) logger.info(f"Training config: epochs={EPOCHS}, batch={BATCH_SIZE}x{GRADIENT_ACCUMULATION}={BATCH_SIZE*GRADIENT_ACCUMULATION}") logger.info(f"Total steps: {num_steps}, warmup: {warmup_steps}") # Training loop global_step = 0 best_loss = float("inf") running_loss = 0.0 for epoch in range(EPOCHS): random.shuffle(entries) epoch_loss = 0.0 epoch_steps = 0 for batch_idx in range(0, len(entries), BATCH_SIZE): entry = entries[batch_idx] # Load target image img_path = os.path.join(IMAGE_DIR, entry["output_image"]) if not os.path.exists(img_path): continue try: target_img = Image.open(img_path).convert("RGB") # Ensure 384x384 target_img = target_img.resize((IMAGE_SIZE, IMAGE_SIZE), Image.LANCZOS) # Convert to tensor [-1, 1] arr = np.array(target_img).astype(np.float32) / 255.0 * 2.0 - 1.0 target_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype) # Encode target → VQ tokens with torch.no_grad(): quant, emb_loss, info = model.gen_vision_model.encode(target_tensor) target_tokens = info[2].detach().reshape(1, -1) # [1, 576] # Build prompt instruction = entry["instruction"] if random.random() < CFG_MASK_PROB: instruction = "" # CFG training: mask 10% of prompts conversation = [ {"role": "<|User|>", "content": instruction}, {"role": "<|Assistant|>", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) prompt = sft_format + processor.image_start_tag # Tokenize input_ids = processor.tokenizer.encode(prompt) input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device) # Get embeddings text_embeds = model.language_model.get_input_embeddings()(input_ids) img_embeds = model.prepare_gen_img_embeds(target_tokens.reshape(-1)) img_embeds = img_embeds.reshape(1, IMAGE_TOKEN_NUM, -1) # Teacher forcing: [text | img[:-1]] → predict img full_embeds = torch.cat([text_embeds, img_embeds[:, :-1, :]], dim=1) with torch.amp.autocast('cuda', dtype=dtype): outputs = model.language_model.model(inputs_embeds=full_embeds) hidden = outputs.last_hidden_state # Logits for image token positions text_len = text_embeds.shape[1] image_hidden = hidden[:, text_len - 1:, :] # [1, 576, 4096] logits = model.gen_head(image_hidden) # [1, 576, 16384] loss = F.cross_entropy( logits.reshape(-1, VQ_CODEBOOK_SIZE), target_tokens.reshape(-1), ) loss = loss / GRADIENT_ACCUMULATION loss.backward() running_loss += loss.item() * GRADIENT_ACCUMULATION epoch_loss += loss.item() * GRADIENT_ACCUMULATION epoch_steps += 1 # Optimizer step if (batch_idx // BATCH_SIZE + 1) % GRADIENT_ACCUMULATION == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # Logging if global_step % LOG_EVERY == 0: avg_loss = running_loss / LOG_EVERY current_lr = scheduler.get_last_lr()[0] print(f"step={global_step}/{num_steps}, epoch={epoch+1}/{EPOCHS}, " f"loss={avg_loss:.4f}, lr={current_lr:.2e}") trackio.log({ "train/loss": avg_loss, "train/lr": current_lr, "train/epoch": epoch + 1, "train/step": global_step, }) running_loss = 0.0 # Save checkpoint if global_step % SAVE_EVERY == 0: ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-{global_step}") os.makedirs(ckpt_dir, exist_ok=True) model.save_pretrained(ckpt_dir) processor.save_pretrained(ckpt_dir) logger.info(f"Checkpoint saved: {ckpt_dir}") if epoch_loss / max(epoch_steps, 1) < best_loss: best_loss = epoch_loss / epoch_steps best_dir = os.path.join(OUTPUT_DIR, "best") os.makedirs(best_dir, exist_ok=True) model.save_pretrained(best_dir) processor.save_pretrained(best_dir) logger.info(f"New best model! loss={best_loss:.4f}") except Exception as e: logger.warning(f"Error at batch {batch_idx}: {e}") optimizer.zero_grad() continue avg_epoch_loss = epoch_loss / max(epoch_steps, 1) logger.info(f"Epoch {epoch+1}/{EPOCHS} complete. Avg loss: {avg_epoch_loss:.4f}") # Final save final_dir = os.path.join(OUTPUT_DIR, "final") os.makedirs(final_dir, exist_ok=True) model.save_pretrained(final_dir) processor.save_pretrained(final_dir) # Push to Hub logger.info(f"Pushing model to {HUB_MODEL_ID}...") try: model.push_to_hub(HUB_MODEL_ID, token=os.environ.get("HF_TOKEN")) processor.push_to_hub(HUB_MODEL_ID, token=os.environ.get("HF_TOKEN")) logger.info(f"Model pushed to: https://huggingface.co/{HUB_MODEL_ID}") except Exception as e: logger.error(f"Failed to push to hub: {e}") trackio.finish() return model, processor # ─── Step 3: Evaluate ──────────────────────────────────────────────────────── def evaluate_model(model=None, processor=None): """Generate sample thumbnails to verify the model works.""" logger.info("=" * 60) logger.info("Step 3: Evaluating Model") logger.info("=" * 60) if model is None: from transformers import AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor model_path = os.path.join(OUTPUT_DIR, "final") if not os.path.exists(model_path): model_path = MODEL_PATH processor = VLChatProcessor.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 ).cuda().eval() else: model.eval() device = next(model.parameters()).device eval_dir = os.path.join(OUTPUT_DIR, "eval_samples") os.makedirs(eval_dir, exist_ok=True) test_prompts = [ "A professional tech review YouTube thumbnail showing a sleek smartphone with dramatic lighting and bold text 'BEST PHONE 2025'", "Cooking video thumbnail with a delicious pasta dish, steam rising, warm kitchen lighting, text 'EASY 15-MIN RECIPE'", "Gaming thumbnail with an epic battle scene, neon glow effects, excited gamer face, text 'INSANE GAMEPLAY'", "Travel vlog thumbnail showing a stunning sunset over Santorini, Greece, golden hour, text 'DREAM VACATION'", "Fitness transformation thumbnail with before/after split, motivational energy, text '30 DAY CHALLENGE'", ] for i, prompt in enumerate(test_prompts): logger.info(f"Generating thumbnail {i+1}/{len(test_prompts)}: {prompt[:50]}...") conversation = [ {"role": "<|User|>", "content": prompt}, {"role": "<|Assistant|>", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) prompt_text = sft_format + processor.image_start_tag with torch.inference_mode(): input_ids = processor.tokenizer.encode(prompt_text) input_ids = torch.LongTensor(input_ids) tokens = torch.zeros((2, len(input_ids)), dtype=torch.int).to(device) tokens[0, :] = input_ids # conditional tokens[1, :] = input_ids # unconditional tokens[1, 1:-1] = processor.pad_id inputs_embeds = model.language_model.get_input_embeddings()(tokens) generated = torch.zeros((1, IMAGE_TOKEN_NUM), dtype=torch.int).to(device) past_kv = None for t in range(IMAGE_TOKEN_NUM): outputs = model.language_model.model( inputs_embeds=inputs_embeds, use_cache=True, past_key_values=past_kv ) past_kv = outputs.past_key_values logits = model.gen_head(outputs.last_hidden_state[:, -1, :]) logit_cond = logits[0:1] logit_uncond = logits[1:2] logits_guided = logit_uncond + 5.0 * (logit_cond - logit_uncond) probs = torch.softmax(logits_guided, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) generated[:, t] = next_tok.squeeze(-1) next_tok_exp = torch.cat([next_tok, next_tok], dim=0) img_emb = model.prepare_gen_img_embeds(next_tok_exp.squeeze(-1)) inputs_embeds = img_emb.unsqueeze(1) dec = model.gen_vision_model.decode_code( generated, shape=[1, 8, IMAGE_SIZE // 16, IMAGE_SIZE // 16] ) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8) img = Image.fromarray(dec[0]) save_path = os.path.join(eval_dir, f"thumbnail_{i}.png") img.save(save_path) logger.info(f" Saved: {save_path}") logger.info(f"Evaluation complete! Samples saved to {eval_dir}/") # ─── Main ──────────────────────────────────────────────────────────────────── if __name__ == "__main__": random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) # Step 1: Prepare data num_samples = prepare_training_data(max_poster_samples=8000, max_synthetic=2000) # Step 2: Train model, processor = train_model() # Step 3: Evaluate evaluate_model(model, processor) logger.info("=" * 60) logger.info("PIPELINE COMPLETE!") logger.info(f" Model: https://huggingface.co/{HUB_MODEL_ID}") logger.info(f" Eval samples: {OUTPUT_DIR}/eval_samples/") logger.info("=" * 60)