"""
Data Preparation Script for VLM Thumbnail Generation Training
Converts PosterCraft/Poster100K and ShareGPT-4o-Image datasets into
OmniGen-compatible JSONL format for fine-tuning.
Output format (JSONL):
Text-to-Image: {"instruction": "...", "output_image": "path.jpg"}
Image+Text-to-Image: {"instruction": "...
<|image_1|> ...", "input_images": ["path.jpg"], "output_image": "out.jpg"}
"""
import os
import json
import io
import base64
import random
import hashlib
from pathlib import Path
from PIL import Image
from datasets import load_dataset
from tqdm import tqdm
OUTPUT_DIR = "/app/thumbnail_training_data"
IMAGE_DIR = os.path.join(OUTPUT_DIR, "images")
JSONL_PATH = os.path.join(OUTPUT_DIR, "train.jsonl")
os.makedirs(IMAGE_DIR, exist_ok=True)
# Thumbnail-specific prompt templates for T2I
T2I_TEMPLATES = [
"Generate a professional thumbnail image: {caption}",
"Create an eye-catching thumbnail with the following description: {caption}",
"Design a visually compelling thumbnail: {caption}",
"Generate a thumbnail image that captures attention: {caption}",
"Create a high-quality thumbnail: {caption}",
]
# Image+Text-to-Image templates (for image editing/conditioning tasks)
I2I_TEMPLATES = [
"Transform this image
<|image_1|> into a professional thumbnail. {instruction}",
"Based on this reference image
<|image_1|>, create a thumbnail. {instruction}",
"Use this image
<|image_1|> as inspiration to generate a thumbnail. {instruction}",
"Redesign this image
<|image_1|> as an engaging thumbnail. {instruction}",
]
def save_image_from_bytes(image_bytes, filename):
"""Save binary image data to file."""
filepath = os.path.join(IMAGE_DIR, filename)
if isinstance(image_bytes, bytes):
img = Image.open(io.BytesIO(image_bytes))
elif isinstance(image_bytes, str):
# base64 encoded
img_data = base64.b64decode(image_bytes)
img = Image.open(io.BytesIO(img_data))
elif isinstance(image_bytes, Image.Image):
img = image_bytes
else:
raise ValueError(f"Unknown image type: {type(image_bytes)}")
# Resize to max 1024 maintaining aspect ratio
max_size = 1024
w, h = img.size
if max(w, h) > max_size:
ratio = max_size / max(w, h)
img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
img = img.convert("RGB")
img.save(filepath, "JPEG", quality=95)
return filename
def process_poster100k(max_samples=10000):
"""Process PosterCraft/Poster100K → T2I thumbnail training data."""
print("=" * 60)
print("Processing PosterCraft/Poster100K...")
print("=" * 60)
entries = []
try:
ds = load_dataset("PosterCraft/Poster100K", split="train", streaming=True)
count = 0
for sample in tqdm(ds, desc="PosterCraft", total=max_samples):
if count >= max_samples:
break
caption = sample.get("caption", "")
if not caption or len(caption) < 20:
continue
image_data = sample.get("image")
if image_data is None:
continue
# Generate unique filename
fname = f"poster_{count:06d}.jpg"
try:
save_image_from_bytes(image_data, fname)
except Exception as e:
print(f" Skipping image {count}: {e}")
continue
# Create T2I entry
template = random.choice(T2I_TEMPLATES)
# Truncate very long captions
if len(caption) > 500:
caption = caption[:500] + "..."
instruction = template.format(caption=caption)
entry = {
"instruction": instruction,
"output_image": fname
}
entries.append(entry)
count += 1
if count % 1000 == 0:
print(f" Processed {count}/{max_samples} PosterCraft samples")
except Exception as e:
print(f"Error loading PosterCraft: {e}")
print(f" Total PosterCraft entries: {len(entries)}")
return entries
def process_sharegpt_t2i(max_samples=5000):
"""Process ShareGPT-4o-Image text-to-image config."""
print("=" * 60)
print("Processing ShareGPT-4o-Image (text-to-image)...")
print("=" * 60)
entries = []
try:
ds = load_dataset(
"FreedomIntelligence/ShareGPT-4o-Image",
"1_text_to_image",
split="train",
streaming=True
)
count = 0
for sample in tqdm(ds, desc="ShareGPT-T2I", total=max_samples):
if count >= max_samples:
break
prompt = sample.get("input_prompt", "")
if not prompt:
continue
# This dataset has image paths, not actual images in parquet
# We store the prompt with a thumbnail-generation framing
fname = f"sgpt_t2i_{count:06d}.jpg"
template = random.choice(T2I_TEMPLATES)
instruction = template.format(caption=prompt)
entry = {
"instruction": instruction,
"output_image": fname
}
entries.append(entry)
count += 1
except Exception as e:
print(f"Error loading ShareGPT T2I: {e}")
print(f" Total ShareGPT T2I entries: {len(entries)}")
return entries
def process_sharegpt_ti2i(max_samples=5000):
"""Process ShareGPT-4o-Image text+image-to-image config."""
print("=" * 60)
print("Processing ShareGPT-4o-Image (text+image-to-image)...")
print("=" * 60)
entries = []
try:
ds = load_dataset(
"FreedomIntelligence/ShareGPT-4o-Image",
"2_text_and_image_to_image",
split="train",
streaming=True
)
count = 0
for sample in tqdm(ds, desc="ShareGPT-TI2I", total=max_samples):
if count >= max_samples:
break
prompt = sample.get("input_prompt", "")
if not prompt:
continue
input_fname = f"sgpt_ti2i_input_{count:06d}.jpg"
output_fname = f"sgpt_ti2i_output_{count:06d}.jpg"
template = random.choice(I2I_TEMPLATES)
instruction = template.format(instruction=prompt)
entry = {
"instruction": instruction,
"input_images": [input_fname],
"output_image": output_fname
}
entries.append(entry)
count += 1
except Exception as e:
print(f"Error loading ShareGPT TI2I: {e}")
print(f" Total ShareGPT TI2I entries: {len(entries)}")
return entries
def create_synthetic_thumbnail_prompts(n=2000):
"""Create synthetic thumbnail generation prompts for diverse training."""
print("=" * 60)
print(f"Generating {n} synthetic thumbnail prompts...")
print("=" * 60)
categories = [
# YouTube-style thumbnails
("tech review", [
"A sleek tech review thumbnail showing {product} with dramatic lighting, bold text overlay saying '{title}', modern gradient background",
"Professional tech thumbnail: {product} product shot with comparison graphics, rating stars, and the text '{title}'",
]),
("cooking", [
"Appetizing cooking thumbnail: close-up of {dish} with steam rising, warm golden lighting, text overlay '{title}' in bold font",
"Food tutorial thumbnail: beautiful plated {dish}, overhead shot, rustic wooden background, text '{title}'",
]),
("gaming", [
"Epic gaming thumbnail: dramatic scene from {game} with character in action pose, glowing effects, bold text '{title}'",
"Gaming content thumbnail: split-screen reaction shot with {game} gameplay, neon accents, text '{title}'",
]),
("fitness", [
"Fitness motivation thumbnail: athletic figure doing {exercise}, dynamic lighting, energetic colors, text '{title}'",
"Health and fitness thumbnail: before/after transformation graphic, clean design, text '{title}'",
]),
("education", [
"Educational content thumbnail: clean whiteboard-style graphic explaining {topic}, colorful diagrams, text '{title}'",
"Learning video thumbnail: engaging infographic about {topic}, modern flat design, text '{title}'",
]),
("vlog", [
"Travel vlog thumbnail: stunning panoramic view of {place}, warm color grading, bold title '{title}'",
"Daily vlog thumbnail: candid lifestyle shot, bright and airy, playful text '{title}'",
]),
("music", [
"Music video thumbnail: artistic portrait with {style} aesthetic, moody lighting, song title '{title}'",
"Music content thumbnail: abstract sound wave visualization, vibrant colors, artist name and '{title}'",
]),
("business", [
"Business advice thumbnail: professional portrait with speech bubble, clean corporate design, text '{title}'",
"Entrepreneurship thumbnail: rising graph graphic, motivational pose, bold text '{title}'",
]),
]
products = ["iPhone 16", "MacBook Pro", "PS5", "Nintendo Switch", "Tesla Model S", "AirPods Pro"]
dishes = ["pasta carbonara", "sushi rolls", "chocolate cake", "grilled steak", "avocado toast"]
games = ["Zelda", "Elden Ring", "GTA VI", "Minecraft", "Fortnite", "Call of Duty"]
exercises = ["deadlifts", "yoga", "HIIT training", "pull-ups", "running"]
topics = ["quantum physics", "machine learning", "history", "economics", "psychology"]
places = ["Tokyo", "Paris", "Bali", "New York", "Iceland", "Santorini"]
styles = ["synthwave", "lo-fi", "rock concert", "jazz club", "EDM festival"]
titles = [
"You Won't Believe This!", "GAME CHANGER", "The Ultimate Guide",
"Top 10 Secrets", "I Tried This for 30 Days", "Watch Before You Buy",
"Is It Worth It?", "My Honest Review", "This Changed Everything",
"The Truth About...", "How I Made $10K", "Best of 2025"
]
fill_map = {
"product": products, "dish": dishes, "game": games,
"exercise": exercises, "topic": topics, "place": places,
"style": styles, "title": titles
}
entries = []
for i in range(n):
cat_name, templates = random.choice(categories)
template = random.choice(templates)
# Fill in placeholders
prompt = template
for key, values in fill_map.items():
placeholder = "{" + key + "}"
if placeholder in prompt:
prompt = prompt.replace(placeholder, random.choice(values))
fname = f"synth_{i:06d}.jpg"
instruction = f"Generate a professional YouTube thumbnail: {prompt}"
entry = {
"instruction": instruction,
"output_image": fname
}
entries.append(entry)
print(f" Total synthetic entries: {len(entries)}")
return entries
def main():
print("=" * 60)
print("VLM Thumbnail Training Data Preparation")
print("=" * 60)
all_entries = []
# 1. PosterCraft/Poster100K (primary visual data)
poster_entries = process_poster100k(max_samples=10000)
all_entries.extend(poster_entries)
# 2. Synthetic thumbnail prompts (domain-specific text)
synthetic_entries = create_synthetic_thumbnail_prompts(n=3000)
all_entries.extend(synthetic_entries)
# Shuffle
random.seed(42)
random.shuffle(all_entries)
# Write JSONL
print(f"\nWriting {len(all_entries)} entries to {JSONL_PATH}")
with open(JSONL_PATH, "w") as f:
for entry in all_entries:
f.write(json.dumps(entry) + "\n")
# Print statistics
t2i_count = sum(1 for e in all_entries if "input_images" not in e)
ti2i_count = sum(1 for e in all_entries if "input_images" in e)
print(f"\nDataset Statistics:")
print(f" Total samples: {len(all_entries)}")
print(f" Text-to-Image: {t2i_count}")
print(f" Text+Image-to-Image: {ti2i_count}")
print(f" Images saved to: {IMAGE_DIR}")
print(f" JSONL saved to: {JSONL_PATH}")
# Show sample entries
print(f"\nSample entries:")
for i, entry in enumerate(all_entries[:3]):
print(f" [{i}] {json.dumps(entry, indent=2)[:200]}...")
if __name__ == "__main__":
main()