keysun89 commited on
Commit
bf59780
·
verified ·
1 Parent(s): bdc783f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import CanineModel, CanineTokenizer
8
+ from diffusers import AutoencoderKL, DDPMScheduler
9
+
10
+ # Import your custom architectures
11
+ from unet import UNetModel
12
+ from feature_extractor import Mixed_Encoder
13
+
14
+ # ==========================================
15
+ # 1. SETUP & CONFIGURATION
16
+ # ==========================================
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ # RE-RE-CRITICAL: Fill this list in the EXACT order of your training folders!
19
+ HINDI_VOCAB = ["क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ"] # ... add all others
20
+
21
+ # = :=========================================
22
+ # 2. MODEL LOADING (Inference Optimized)
23
+ # ==========================================
24
+ print(f"🚀 Booting DiffusionPen on {DEVICE}...")
25
+
26
+ # Load VAE (Directly via app.py as requested)
27
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
28
+
29
+ # Load Style Encoder (Mixed_Encoder)
30
+ style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
31
+ style_encoder.load_state_dict(torch.load("weights/mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE))
32
+ style_encoder.eval()
33
+
34
+ # Load Text Encoder (Canine)
35
+ tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
36
+ text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
37
+ text_encoder.eval()
38
+
39
+ # Load UNet (Custom)
40
+ # These parameters must match your training config
41
+ unet = UNetModel(
42
+ image_size=(64, 256),
43
+ in_channels=4,
44
+ model_channels=320,
45
+ out_channels=4,
46
+ num_res_blocks=2,
47
+ attention_resolutions=[4, 2, 1],
48
+ channel_mult=[1, 2, 4, 4],
49
+ context_dim=768
50
+ ).to(DEVICE)
51
+ unet.load_state_dict(torch.load("weights/ema_ckpt.pt", map_location=DEVICE))
52
+ unet.eval()
53
+
54
+ scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
55
+
56
+ # ==========================================
57
+ # 3. PREPROCESSING UTILS
58
+ # ==========================================
59
+ style_transform = transforms.Compose([
60
+ transforms.Resize((224, 224)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63
+ ])
64
+
65
+ def predict(hindi_text, style_image):
66
+ with torch.no_grad():
67
+ # A. Process Style
68
+ if style_image is not None:
69
+ style_t = style_transform(style_image).unsqueeze(0).to(DEVICE)
70
+ _, style_features = style_encoder(style_t)
71
+ else:
72
+ style_features = torch.zeros((1, 1280)).to(DEVICE)
73
+
74
+ # B. Process Text
75
+ inputs = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE)
76
+
77
+ # C. Diffusion Loop (Simplified DDPM)
78
+ latents = torch.randn((1, 4, 8, 32)).to(DEVICE) # Latent size for 64x256
79
+ scheduler.set_timesteps(50) # 50 steps for speed in demo
80
+
81
+ for t in scheduler.timesteps:
82
+ # Predict noise
83
+ noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=inputs, style_extractor=style_features)
84
+ # Step scheduler
85
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
86
+
87
+ # D. Decode with VAE
88
+ latents = 1 / 0.18215 * latents
89
+ image = vae.decode(latents).sample
90
+ image = (image / 2 + 0.5).clamp(0, 1)
91
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
92
+ image = (image * 255).astype(np.uint8)
93
+
94
+ return Image.fromarray(image)
95
+
96
+ # ==========================================
97
+ # 4. GRADIO INTERFACE (Resume Ready)
98
+ # ==========================================
99
+ description = """
100
+ ### 🖋️ DiffusionPen: Hindi Handwriting Synthesis
101
+ **Developed by Kishan Madlani | NIT Surat**
102
+ This model uses a Latent Diffusion architecture to generate Hindi text in specific handwriting styles.
103
+ It was trained on a custom dataset of 300+ writers using Triplet Loss and Cross-Attention.
104
+ """
105
+
106
+ demo = gr.Interface(
107
+ fn=predict,
108
+ inputs=[
109
+ gr.Textbox(label="Input Hindi Text", placeholder="नमस्ते..."),
110
+ gr.Image(label="Style Reference Image", type="pil")
111
+ ],
112
+ outputs=gr.Image(label="Generated Handwriting"),
113
+ title="DiffusionPen - Hindi Style Transfer",
114
+ description=description,
115
+ theme="soft",
116
+ examples=[
117
+ ["भारत", None],
118
+ ["शिक्षा", None]
119
+ ]
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()