""" Data Preparation Script for VLM Thumbnail Generation Training Converts PosterCraft/Poster100K and ShareGPT-4o-Image datasets into OmniGen-compatible JSONL format for fine-tuning. Output format (JSONL): Text-to-Image: {"instruction": "...", "output_image": "path.jpg"} Image+Text-to-Image: {"instruction": "... <|image_1|> ...", "input_images": ["path.jpg"], "output_image": "out.jpg"} """ import os import json import io import base64 import random import hashlib from pathlib import Path from PIL import Image from datasets import load_dataset from tqdm import tqdm OUTPUT_DIR = "/app/thumbnail_training_data" IMAGE_DIR = os.path.join(OUTPUT_DIR, "images") JSONL_PATH = os.path.join(OUTPUT_DIR, "train.jsonl") os.makedirs(IMAGE_DIR, exist_ok=True) # Thumbnail-specific prompt templates for T2I T2I_TEMPLATES = [ "Generate a professional thumbnail image: {caption}", "Create an eye-catching thumbnail with the following description: {caption}", "Design a visually compelling thumbnail: {caption}", "Generate a thumbnail image that captures attention: {caption}", "Create a high-quality thumbnail: {caption}", ] # Image+Text-to-Image templates (for image editing/conditioning tasks) I2I_TEMPLATES = [ "Transform this image <|image_1|> into a professional thumbnail. {instruction}", "Based on this reference image <|image_1|>, create a thumbnail. {instruction}", "Use this image <|image_1|> as inspiration to generate a thumbnail. {instruction}", "Redesign this image <|image_1|> as an engaging thumbnail. {instruction}", ] def save_image_from_bytes(image_bytes, filename): """Save binary image data to file.""" filepath = os.path.join(IMAGE_DIR, filename) if isinstance(image_bytes, bytes): img = Image.open(io.BytesIO(image_bytes)) elif isinstance(image_bytes, str): # base64 encoded img_data = base64.b64decode(image_bytes) img = Image.open(io.BytesIO(img_data)) elif isinstance(image_bytes, Image.Image): img = image_bytes else: raise ValueError(f"Unknown image type: {type(image_bytes)}") # Resize to max 1024 maintaining aspect ratio max_size = 1024 w, h = img.size if max(w, h) > max_size: ratio = max_size / max(w, h) img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS) img = img.convert("RGB") img.save(filepath, "JPEG", quality=95) return filename def process_poster100k(max_samples=10000): """Process PosterCraft/Poster100K → T2I thumbnail training data.""" print("=" * 60) print("Processing PosterCraft/Poster100K...") print("=" * 60) entries = [] try: ds = load_dataset("PosterCraft/Poster100K", split="train", streaming=True) count = 0 for sample in tqdm(ds, desc="PosterCraft", total=max_samples): if count >= max_samples: break caption = sample.get("caption", "") if not caption or len(caption) < 20: continue image_data = sample.get("image") if image_data is None: continue # Generate unique filename fname = f"poster_{count:06d}.jpg" try: save_image_from_bytes(image_data, fname) except Exception as e: print(f" Skipping image {count}: {e}") continue # Create T2I entry template = random.choice(T2I_TEMPLATES) # Truncate very long captions if len(caption) > 500: caption = caption[:500] + "..." instruction = template.format(caption=caption) entry = { "instruction": instruction, "output_image": fname } entries.append(entry) count += 1 if count % 1000 == 0: print(f" Processed {count}/{max_samples} PosterCraft samples") except Exception as e: print(f"Error loading PosterCraft: {e}") print(f" Total PosterCraft entries: {len(entries)}") return entries def process_sharegpt_t2i(max_samples=5000): """Process ShareGPT-4o-Image text-to-image config.""" print("=" * 60) print("Processing ShareGPT-4o-Image (text-to-image)...") print("=" * 60) entries = [] try: ds = load_dataset( "FreedomIntelligence/ShareGPT-4o-Image", "1_text_to_image", split="train", streaming=True ) count = 0 for sample in tqdm(ds, desc="ShareGPT-T2I", total=max_samples): if count >= max_samples: break prompt = sample.get("input_prompt", "") if not prompt: continue # This dataset has image paths, not actual images in parquet # We store the prompt with a thumbnail-generation framing fname = f"sgpt_t2i_{count:06d}.jpg" template = random.choice(T2I_TEMPLATES) instruction = template.format(caption=prompt) entry = { "instruction": instruction, "output_image": fname } entries.append(entry) count += 1 except Exception as e: print(f"Error loading ShareGPT T2I: {e}") print(f" Total ShareGPT T2I entries: {len(entries)}") return entries def process_sharegpt_ti2i(max_samples=5000): """Process ShareGPT-4o-Image text+image-to-image config.""" print("=" * 60) print("Processing ShareGPT-4o-Image (text+image-to-image)...") print("=" * 60) entries = [] try: ds = load_dataset( "FreedomIntelligence/ShareGPT-4o-Image", "2_text_and_image_to_image", split="train", streaming=True ) count = 0 for sample in tqdm(ds, desc="ShareGPT-TI2I", total=max_samples): if count >= max_samples: break prompt = sample.get("input_prompt", "") if not prompt: continue input_fname = f"sgpt_ti2i_input_{count:06d}.jpg" output_fname = f"sgpt_ti2i_output_{count:06d}.jpg" template = random.choice(I2I_TEMPLATES) instruction = template.format(instruction=prompt) entry = { "instruction": instruction, "input_images": [input_fname], "output_image": output_fname } entries.append(entry) count += 1 except Exception as e: print(f"Error loading ShareGPT TI2I: {e}") print(f" Total ShareGPT TI2I entries: {len(entries)}") return entries def create_synthetic_thumbnail_prompts(n=2000): """Create synthetic thumbnail generation prompts for diverse training.""" print("=" * 60) print(f"Generating {n} synthetic thumbnail prompts...") print("=" * 60) categories = [ # YouTube-style thumbnails ("tech review", [ "A sleek tech review thumbnail showing {product} with dramatic lighting, bold text overlay saying '{title}', modern gradient background", "Professional tech thumbnail: {product} product shot with comparison graphics, rating stars, and the text '{title}'", ]), ("cooking", [ "Appetizing cooking thumbnail: close-up of {dish} with steam rising, warm golden lighting, text overlay '{title}' in bold font", "Food tutorial thumbnail: beautiful plated {dish}, overhead shot, rustic wooden background, text '{title}'", ]), ("gaming", [ "Epic gaming thumbnail: dramatic scene from {game} with character in action pose, glowing effects, bold text '{title}'", "Gaming content thumbnail: split-screen reaction shot with {game} gameplay, neon accents, text '{title}'", ]), ("fitness", [ "Fitness motivation thumbnail: athletic figure doing {exercise}, dynamic lighting, energetic colors, text '{title}'", "Health and fitness thumbnail: before/after transformation graphic, clean design, text '{title}'", ]), ("education", [ "Educational content thumbnail: clean whiteboard-style graphic explaining {topic}, colorful diagrams, text '{title}'", "Learning video thumbnail: engaging infographic about {topic}, modern flat design, text '{title}'", ]), ("vlog", [ "Travel vlog thumbnail: stunning panoramic view of {place}, warm color grading, bold title '{title}'", "Daily vlog thumbnail: candid lifestyle shot, bright and airy, playful text '{title}'", ]), ("music", [ "Music video thumbnail: artistic portrait with {style} aesthetic, moody lighting, song title '{title}'", "Music content thumbnail: abstract sound wave visualization, vibrant colors, artist name and '{title}'", ]), ("business", [ "Business advice thumbnail: professional portrait with speech bubble, clean corporate design, text '{title}'", "Entrepreneurship thumbnail: rising graph graphic, motivational pose, bold text '{title}'", ]), ] products = ["iPhone 16", "MacBook Pro", "PS5", "Nintendo Switch", "Tesla Model S", "AirPods Pro"] dishes = ["pasta carbonara", "sushi rolls", "chocolate cake", "grilled steak", "avocado toast"] games = ["Zelda", "Elden Ring", "GTA VI", "Minecraft", "Fortnite", "Call of Duty"] exercises = ["deadlifts", "yoga", "HIIT training", "pull-ups", "running"] topics = ["quantum physics", "machine learning", "history", "economics", "psychology"] places = ["Tokyo", "Paris", "Bali", "New York", "Iceland", "Santorini"] styles = ["synthwave", "lo-fi", "rock concert", "jazz club", "EDM festival"] titles = [ "You Won't Believe This!", "GAME CHANGER", "The Ultimate Guide", "Top 10 Secrets", "I Tried This for 30 Days", "Watch Before You Buy", "Is It Worth It?", "My Honest Review", "This Changed Everything", "The Truth About...", "How I Made $10K", "Best of 2025" ] fill_map = { "product": products, "dish": dishes, "game": games, "exercise": exercises, "topic": topics, "place": places, "style": styles, "title": titles } entries = [] for i in range(n): cat_name, templates = random.choice(categories) template = random.choice(templates) # Fill in placeholders prompt = template for key, values in fill_map.items(): placeholder = "{" + key + "}" if placeholder in prompt: prompt = prompt.replace(placeholder, random.choice(values)) fname = f"synth_{i:06d}.jpg" instruction = f"Generate a professional YouTube thumbnail: {prompt}" entry = { "instruction": instruction, "output_image": fname } entries.append(entry) print(f" Total synthetic entries: {len(entries)}") return entries def main(): print("=" * 60) print("VLM Thumbnail Training Data Preparation") print("=" * 60) all_entries = [] # 1. PosterCraft/Poster100K (primary visual data) poster_entries = process_poster100k(max_samples=10000) all_entries.extend(poster_entries) # 2. Synthetic thumbnail prompts (domain-specific text) synthetic_entries = create_synthetic_thumbnail_prompts(n=3000) all_entries.extend(synthetic_entries) # Shuffle random.seed(42) random.shuffle(all_entries) # Write JSONL print(f"\nWriting {len(all_entries)} entries to {JSONL_PATH}") with open(JSONL_PATH, "w") as f: for entry in all_entries: f.write(json.dumps(entry) + "\n") # Print statistics t2i_count = sum(1 for e in all_entries if "input_images" not in e) ti2i_count = sum(1 for e in all_entries if "input_images" in e) print(f"\nDataset Statistics:") print(f" Total samples: {len(all_entries)}") print(f" Text-to-Image: {t2i_count}") print(f" Text+Image-to-Image: {ti2i_count}") print(f" Images saved to: {IMAGE_DIR}") print(f" JSONL saved to: {JSONL_PATH}") # Show sample entries print(f"\nSample entries:") for i, entry in enumerate(all_entries[:3]): print(f" [{i}] {json.dumps(entry, indent=2)[:200]}...") if __name__ == "__main__": main()