import os import torch import numpy as np import gradio as gr from PIL import Image from torchvision import transforms from transformers import CanineModel, CanineTokenizer from diffusers import AutoencoderKL, DDPMScheduler # Import your custom architectures from unet import UNetModel from feature_extractor import Mixed_Encoder # ========================================== # 1. SETUP & CONFIGURATION # ========================================== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # RE-RE-CRITICAL: Fill this list in the EXACT order of your training folders! HINDI_VOCAB = ["क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ"] # ... add all others # = :========================================= # 2. MODEL LOADING (Inference Optimized) # ========================================== print(f"🚀 Booting DiffusionPen on {DEVICE}...") # Load VAE (Directly via app.py as requested) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE) # Load Style Encoder (Mixed_Encoder) style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE) style_encoder.load_state_dict(torch.load("weights/mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)) style_encoder.eval() # Load Text Encoder (Canine) tokenizer = CanineTokenizer.from_pretrained("google/canine-c") text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE) text_encoder.eval() # Load UNet (Custom) # These parameters must match your training config unet = UNetModel( image_size=(64, 256), in_channels=4, model_channels=320, out_channels=4, num_res_blocks=2, attention_resolutions=[4, 2, 1], channel_mult=[1, 2, 4, 4], context_dim=768 ).to(DEVICE) unet.load_state_dict(torch.load("weights/ema_ckpt.pt", map_location=DEVICE)) unet.eval() scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") # ========================================== # 3. PREPROCESSING UTILS # ========================================== style_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(hindi_text, style_image): with torch.no_grad(): # A. Process Style if style_image is not None: style_t = style_transform(style_image).unsqueeze(0).to(DEVICE) _, style_features = style_encoder(style_t) else: style_features = torch.zeros((1, 1280)).to(DEVICE) # B. Process Text inputs = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE) # C. Diffusion Loop (Simplified DDPM) latents = torch.randn((1, 4, 8, 32)).to(DEVICE) # Latent size for 64x256 scheduler.set_timesteps(50) # 50 steps for speed in demo for t in scheduler.timesteps: # Predict noise noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=inputs, style_extractor=style_features) # Step scheduler latents = scheduler.step(noise_pred, t, latents).prev_sample # D. Decode with VAE latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) return Image.fromarray(image) # ========================================== # 4. GRADIO INTERFACE (Resume Ready) # ========================================== description = """ ### 🖋️ DiffusionPen: Hindi Handwriting Synthesis **Developed by Kishan Madlani | NIT Surat** This model uses a Latent Diffusion architecture to generate Hindi text in specific handwriting styles. It was trained on a custom dataset of 300+ writers using Triplet Loss and Cross-Attention. """ demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Input Hindi Text", placeholder="नमस्ते..."), gr.Image(label="Style Reference Image", type="pil") ], outputs=gr.Image(label="Generated Handwriting"), title="DiffusionPen - Hindi Style Transfer", description=description, theme="soft", examples=[ ["भारत", None], ["शिक्षा", None] ] ) if __name__ == "__main__": demo.launch()