""" Inference script for Janus-Pro Thumbnail Generation Supports all 3 input modes: 1. Text only → Thumbnail 2. Image only → Thumbnail (auto-caption then generate) 3. Text + Image → Thumbnail Usage: python inference_janus.py --mode text --prompt "Gaming video about Minecraft" python inference_janus.py --mode image --input_image photo.jpg python inference_janus.py --mode both --prompt "Make a cooking thumbnail" --input_image food.jpg """ import os import sys import argparse import numpy as np import torch import PIL.Image from transformers import AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor def load_model(model_path: str, device: str = "cuda"): """Load Janus model and processor.""" print(f"Loading model from {model_path}...") processor = VLChatProcessor.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 ) model = model.to(device).eval() return model, processor def text_to_thumbnail( model: MultiModalityCausalLM, processor: VLChatProcessor, prompt: str, cfg_weight: float = 5.0, temperature: float = 1.0, parallel_size: int = 2, device: str = "cuda", ) -> list: """Generate thumbnail from text prompt.""" conversation = [ {"role": "<|User|>", "content": f"Generate a professional YouTube thumbnail: {prompt}"}, {"role": "<|Assistant|>", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) prompt_text = sft_format + processor.image_start_tag image_token_num = 576 img_size = 384 patch_size = 16 with torch.inference_mode(): input_ids = processor.tokenizer.encode(prompt_text) input_ids = torch.LongTensor(input_ids) # CFG: interleave conditional + unconditional tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = processor.pad_id inputs_embeds = model.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).to(device) past_key_values = None for i in range(image_token_num): outputs = model.language_model.model( inputs_embeds=inputs_embeds, use_cache=True, past_key_values=past_key_values, ) past_key_values = outputs.past_key_values hidden = outputs.last_hidden_state logits = model.gen_head(hidden[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits_guided = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits_guided / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(-1) next_token_expanded = torch.cat( [next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1 ).view(-1) img_embeds = model.prepare_gen_img_embeds(next_token_expanded) inputs_embeds = img_embeds.unsqueeze(1) # Decode VQ tokens → pixels dec = model.gen_vision_model.decode_code( generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size], ) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8) images = [PIL.Image.fromarray(dec[i]) for i in range(parallel_size)] return images def image_to_thumbnail( model: MultiModalityCausalLM, processor: VLChatProcessor, input_image_path: str, device: str = "cuda", ) -> list: """Generate thumbnail from input image. Step 1: Use model's understanding capability to caption the image Step 2: Use the caption to generate a thumbnail """ # Step 1: Caption the image input_img = PIL.Image.open(input_image_path).convert("RGB") conversation = [ { "role": "<|User|>", "content": "\nDescribe this image in detail for creating an engaging YouTube thumbnail. Focus on the main subject, mood, colors, and any text that should appear.", }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [input_img] prepare_inputs = processor( conversations=conversation, images=pil_images, force_batchify=True, ).to(device) with torch.inference_mode(): inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs) outputs = model.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=processor.tokenizer.eos_token_id, bos_token_id=processor.tokenizer.bos_token_id, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=256, do_sample=False, ) caption = processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) print(f"Generated caption: {caption[:200]}...") # Step 2: Generate thumbnail from caption thumbnail_prompt = f"Create a professional YouTube thumbnail based on: {caption}" images = text_to_thumbnail(model, processor, thumbnail_prompt, device=device) return images def text_and_image_to_thumbnail( model: MultiModalityCausalLM, processor: VLChatProcessor, prompt: str, input_image_path: str, cfg_weight: float = 5.0, cfg_weight2: float = 5.0, temperature: float = 1.0, parallel_size: int = 2, device: str = "cuda", ) -> list: """Generate thumbnail from both text and image input. Uses the T&I2I pathway from Janus-4o: - Input image → SigLIP embedding + VQ tokens - Text prompt → tokenized - CFG with 3-way: full conditional, partial, unconditional """ input_img = PIL.Image.open(input_image_path).convert("RGB") # Build the prompt with image tokens input_img_tokens = ( processor.image_start_tag + processor.image_tag * processor.num_image_tokens + processor.image_end_tag + processor.image_start_tag + processor.pad_tag * processor.num_image_tokens + processor.image_end_tag ) output_img_tokens = processor.image_start_tag prompts = input_img_tokens + f"Generate a professional thumbnail: {prompt}" conversation = [ {"role": "<|User|>", "content": prompts}, {"role": "<|Assistant|>", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) sft_format = sft_format + output_img_tokens image_token_num = 576 img_size = 384 patch_size = 16 with torch.inference_mode(): # Process input image from dataclasses import dataclass @dataclass class VLChatProcessorOutput: sft_format: str input_ids: torch.Tensor pixel_values: torch.Tensor num_image_tokens: torch.IntTensor def __len__(self): return len(self.input_ids) def process_image(image_paths, proc): images = [PIL.Image.open(p).convert("RGB") if isinstance(p, str) else p for p in image_paths] outputs = proc.image_processor(images, return_tensors="pt") return outputs['pixel_values'] input_image_pixel_values = process_image([input_img], processor).to(torch.bfloat16).to(device) # VQ encode input image quant, emb_loss, info = model.gen_vision_model.encode(input_image_pixel_values) image_tokens_input = info[2].detach().reshape(1, -1) image_embeds_input = model.prepare_gen_img_embeds(image_tokens_input) input_ids = torch.LongTensor(processor.tokenizer.encode(sft_format)) encoder_pixel_values = process_image([input_img], processor).to(device) # 3-way CFG: full conditional, partial (no VQ tokens), unconditional pre_data = [] tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long) for i in range(parallel_size * 3): tokens[i, :] = input_ids if i % 3 == 2: tokens[i, 1:-1] = processor.pad_id pre_data.append(VLChatProcessorOutput( sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-2], num_image_tokens=[processor.num_image_tokens] )) pre_data.append(VLChatProcessorOutput( sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-1], num_image_tokens=[processor.num_image_tokens] )) pre_data.append(VLChatProcessorOutput( sft_format=sft_format, pixel_values=None, input_ids=tokens[i], num_image_tokens=[] )) prepare_inputs = processor.batchify(pre_data) inputs_embeds = model.prepare_inputs_embeds( input_ids=tokens.to(device), pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).to(device), images_emb_mask=prepare_inputs['images_emb_mask'].to(device), images_seq_mask=prepare_inputs['images_seq_mask'].to(device), ) # Insert VQ embeddings at correct positions image_gen_indices = (tokens == processor.image_end_id).nonzero() for ii, ind in enumerate(image_gen_indices): if ii % 4 == 0: offset = ind[1] + 2 inputs_embeds[ind[0], offset:offset + image_embeds_input.shape[1], :] = \ image_embeds_input[(ii // 2) % 1] generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).to(device) past_key_values = None for i in range(image_token_num): outputs = model.language_model.model( inputs_embeds=inputs_embeds, use_cache=True, past_key_values=past_key_values, ) past_key_values = outputs.past_key_values hidden = outputs.last_hidden_state logits = model.gen_head(hidden[:, -1, :]) logit_cond_full = logits[0::3, :] logit_cond_part = logits[1::3, :] logit_uncond = logits[2::3, :] logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2) logits_guided = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits_guided / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(-1) next_token_expanded = torch.cat( [next_token.unsqueeze(1)] * 3, dim=1 ).view(-1) img_embeds = model.prepare_gen_img_embeds(next_token_expanded) inputs_embeds = img_embeds.unsqueeze(1) # Decode dec = model.gen_vision_model.decode_code( generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size], ) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8) images = [PIL.Image.fromarray(dec[i]) for i in range(parallel_size)] return images def main(): parser = argparse.ArgumentParser(description="Thumbnail Generation Inference") parser.add_argument("--model_path", type=str, default="FreedomIntelligence/Janus-4o-7B", help="Model path (local or HF hub)") parser.add_argument("--mode", type=str, choices=["text", "image", "both"], required=True) parser.add_argument("--prompt", type=str, default="") parser.add_argument("--input_image", type=str, default="") parser.add_argument("--output_dir", type=str, default="./generated_thumbnails") parser.add_argument("--num_images", type=int, default=2) parser.add_argument("--cfg_weight", type=float, default=5.0) parser.add_argument("--temperature", type=float, default=1.0) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" model, processor = load_model(args.model_path, device) if args.mode == "text": assert args.prompt, "Text mode requires --prompt" print(f"Generating thumbnail from text: '{args.prompt}'") images = text_to_thumbnail( model, processor, args.prompt, cfg_weight=args.cfg_weight, temperature=args.temperature, parallel_size=args.num_images, device=device, ) elif args.mode == "image": assert args.input_image, "Image mode requires --input_image" print(f"Generating thumbnail from image: {args.input_image}") images = image_to_thumbnail(model, processor, args.input_image, device=device) elif args.mode == "both": assert args.prompt and args.input_image, "Both mode requires --prompt and --input_image" print(f"Generating thumbnail from text+image: '{args.prompt}' + {args.input_image}") images = text_and_image_to_thumbnail( model, processor, args.prompt, args.input_image, cfg_weight=args.cfg_weight, temperature=args.temperature, parallel_size=args.num_images, device=device, ) # Save results for i, img in enumerate(images): path = os.path.join(args.output_dir, f"thumbnail_{args.mode}_{i}.png") img.save(path) print(f"Saved: {path}") print(f"\nGenerated {len(images)} thumbnails in {args.output_dir}/") if __name__ == "__main__": main()