keysun89 commited on
Commit
9e554f6
·
verified ·
1 Parent(s): e7dee27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -105
app.py CHANGED
@@ -1,126 +1,425 @@
1
- import os
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
- import gradio as gr
5
  from PIL import Image
6
- from collections import OrderedDict
7
- from torchvision import transforms
8
- from transformers import CanineModel, CanineTokenizer
9
- from diffusers import AutoencoderKL, DDPMScheduler
10
-
11
- # Import your custom architectures
12
  from unet import UNetModel
13
- from feature_extractor import Mixed_Encoder
 
14
 
15
- # ==========================================
16
- # 1. SETUP
17
- # ==========================================
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # ==========================================
21
- # 2. MODEL LOADING
22
- # ==========================================
23
- print(f"🚀 Initializing on {DEVICE}...")
24
-
25
- # Tokenizer and VAE
26
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
27
- tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
28
- text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
29
-
30
- # Style Encoder
31
- style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
32
- s_w = torch.load("mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)
33
- style_encoder.load_state_dict(OrderedDict([(k.replace("module.", ""), v) for k, v in s_w.items()]))
34
- style_encoder.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # UNet (1 ResBlock, 320 Context)
37
- unet = UNetModel(
38
- image_size=(64, 256), in_channels=4, model_channels=320, out_channels=4,
39
- num_res_blocks=1, attention_resolutions=[4, 2, 1], channel_mult=[1, 1, 1, 1],
40
- context_dim=320, text_encoder=text_encoder
41
- ).to(DEVICE)
42
 
43
- # Weight Loader for Super-Checkpoint
44
- ckpt = torch.load("ema_ckpt.pt", map_location=DEVICE)
45
- u_dict, t_dict = OrderedDict(), OrderedDict()
46
- for k, v in ckpt.items():
47
- clean_k = k.replace("module.", "")
48
- if "text_encoder." in clean_k: t_dict[clean_k.split("text_encoder.")[-1]] = v
49
- else: u_dict[clean_k] = v
50
 
51
- unet.load_state_dict(u_dict, strict=False)
52
- try: text_encoder.load_state_dict(t_dict, strict=False)
53
- except: pass
54
 
55
- unet.eval()
56
- text_encoder.eval()
57
- scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
 
 
 
58
 
59
- # ==========================================
60
- # 3. PREDICT FUNCTION
61
- # ==========================================
62
- st_trans = transforms.Compose([
63
- transforms.Resize((224, 224)),
64
- transforms.ToTensor(),
65
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
66
- ])
67
 
68
- def predict(hindi_text, s1, s2):
69
- if not hindi_text or s1 is None:
70
- return Image.new('RGB', (256, 64), color='white')
71
-
72
- try:
73
- with torch.no_grad():
74
- # A. Extract Style
75
- imgs = [i for i in [s1, s2] if i is not None]
76
- feats = [style_encoder(st_trans(i).unsqueeze(0).to(DEVICE))[1] for i in imgs]
77
- # Fixed variable name to match UNet call
78
- final_style_vec = torch.mean(torch.stack(feats), dim=0)
79
 
80
- # B. Process Text
81
- t_in = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt")
82
- t_in = {k: v.to(DEVICE) for k, v in t_in.items()}
83
 
84
- # C. Diffusion Loop (8 steps for CPU safety)
85
- latents = torch.randn((1, 4, 8, 32)).to(DEVICE)
86
- scheduler.set_timesteps(8)
87
 
88
- for t in scheduler.timesteps:
89
- # Pass the correct final_style_vec
90
- noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=t_in, style_extractor=final_style_vec)
91
- latents = scheduler.step(noise_pred, t, latents).prev_sample
92
 
93
- # D. Final Decode
94
- latents = 1 / 0.18215 * latents
95
- img = vae.decode(latents).sample
96
- img = (img / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0]
 
 
 
 
 
 
 
 
 
 
97
 
98
- return Image.fromarray((img * 255).astype(np.uint8))
 
 
 
 
 
99
 
100
- except Exception as e:
101
- print(f" Error: {e}")
102
- # Return a simple error image instead of a broken icon
103
- return Image.new('RGB', (256, 64), color='red')
104
-
105
- # ==========================================
106
- # 4. GRADIO INTERFACE
107
- # ==========================================
108
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("# 🖋️ DiffusionPen: Hindi Handwriting Synthesis")
110
- gr.Markdown("### Developed by Kishan Madlani")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- with gr.Row():
113
- with gr.Column():
114
- txt = gr.Textbox(label="Input Hindi Text", placeholder="नमस्ते...")
115
- gr.Markdown("#### 📷 Style Reference Samples")
116
- im1 = gr.Image(type="pil", label="Sample 1")
117
- im2 = gr.Image(type="pil", label="Sample 2")
118
- btn = gr.Button("Synthesize Handwriting", variant="primary")
119
-
120
- with gr.Column():
121
- out = gr.Image(label="Generated Output")
122
- gr.Markdown("**Note:** Use small, cropped snippets for style samples for best quality.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- btn.click(predict, inputs=[txt, im1, im2], outputs=out)
125
 
126
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiffusionPen: Hindi Handwriting Generation Demo
3
+ Inference-focused Gradio application with CANINE text encoding
4
+ """
5
+
6
+ import gradio as gr
7
  import torch
8
  import numpy as np
 
9
  from PIL import Image
 
 
 
 
 
 
10
  from unet import UNetModel
11
+ from transformers import CanineTokenizer, CanineModel
12
+ from pathlib import Path
13
 
 
 
 
 
14
 
15
+ class DiffusionPenDemo:
16
+ """
17
+ Hindi Handwriting Generation Demo using DiffusionPen UNet
18
+
19
+ Features:
20
+ - CANINE text encoder for character-level Hindi encoding
21
+ - 339 different writer styles
22
+ - Configurable diffusion steps and guidance
23
+ - GPU/CPU automatic detection
24
+ - Checkpoint loading support
25
+ """
26
+
27
+ def __init__(self, checkpoint_path=None, device=None):
28
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
29
+ self.checkpoint_path = checkpoint_path
30
+ self.model = None
31
+ self.text_encoder = None
32
+ self.tokenizer = None
33
+ self.checkpoint_loaded = False
34
+ self.load_models()
35
+
36
+ def load_models(self):
37
+ """Load UNet model and CANINE text encoder"""
38
+ try:
39
+ print(f"\n{'='*60}")
40
+ print(f"🔧 DiffusionPen Initialization")
41
+ print(f"{'='*60}")
42
+ print(f"📱 Device: {self.device.upper()}")
43
+
44
+ # Load CANINE text encoder
45
+ print("\n📝 Loading CANINE text encoder...")
46
+ self.tokenizer = CanineTokenizer.from_pretrained('google/canine-s')
47
+ self.text_encoder = CanineModel.from_pretrained('google/canine-s').to(self.device)
48
+ self.text_encoder.eval()
49
+ print(" ✓ CANINE loaded (768-dim embeddings)")
50
+
51
+ # Initialize UNet model
52
+ print("\n🧠 Initializing UNet model...")
53
+
54
+ class Args:
55
+ interpolation = False
56
+ mix_rate = 0.5
57
+
58
+ self.model = UNetModel(
59
+ image_size=64,
60
+ in_channels=1,
61
+ model_channels=128,
62
+ out_channels=1,
63
+ num_res_blocks=2,
64
+ attention_resolutions=[16, 8],
65
+ dropout=0.1,
66
+ channel_mult=(1, 2, 4),
67
+ dims=2,
68
+ num_classes=339, # Hindi writer styles
69
+ use_checkpoint=True,
70
+ num_heads=8,
71
+ num_head_channels=-1,
72
+ use_scale_shift_norm=True,
73
+ resblock_updown=False,
74
+ use_spatial_transformer=True,
75
+ transformer_depth=1,
76
+ context_dim=768,
77
+ text_encoder=self.text_encoder,
78
+ args=Args()
79
+ ).to(self.device)
80
+ self.model.eval()
81
+
82
+ # Count parameters
83
+ total_params = sum(p.numel() for p in self.model.parameters())
84
+ print(f" ✓ UNet initialized ({total_params/1e6:.1f}M parameters)")
85
+
86
+ # Load checkpoint if available
87
+ if self.checkpoint_path and Path(self.checkpoint_path).exists():
88
+ self._load_checkpoint()
89
+ else:
90
+ print(f"\n⚠️ No checkpoint found at: {self.checkpoint_path}")
91
+ print(" Using random initialization")
92
+
93
+ print(f"\n{'='*60}")
94
+ print(f"✅ Ready for inference!")
95
+ print(f"{'='*60}\n")
96
+
97
+ except Exception as e:
98
+ print(f"\n❌ Error during initialization: {str(e)}")
99
+ raise
100
+
101
+ def _load_checkpoint(self):
102
+ """Load model checkpoint"""
103
+ try:
104
+ print(f"\n📂 Loading checkpoint: {self.checkpoint_path}")
105
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
106
+
107
+ # Handle different checkpoint formats
108
+ if isinstance(checkpoint, dict):
109
+ if 'model_state_dict' in checkpoint:
110
+ state_dict = checkpoint['model_state_dict']
111
+ print(f" Format: Standard (model_state_dict)")
112
+ elif 'state_dict' in checkpoint:
113
+ state_dict = checkpoint['state_dict']
114
+ print(f" Format: Alternative (state_dict)")
115
+ else:
116
+ state_dict = checkpoint
117
+ print(f" Format: Raw state dict")
118
+ else:
119
+ state_dict = checkpoint
120
+ print(f" Format: Direct tensor state")
121
+
122
+ # Load state dict with strict=False to handle minor mismatches
123
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
124
+
125
+ if missing_keys:
126
+ print(f" ⚠️ Missing keys: {len(missing_keys)}")
127
+ if unexpected_keys:
128
+ print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
129
+
130
+ self.checkpoint_loaded = True
131
+ print(f" ✓ Checkpoint loaded successfully")
132
+
133
+ except Exception as e:
134
+ print(f" ❌ Failed to load checkpoint: {str(e)}")
135
+ self.checkpoint_loaded = False
136
+
137
+ def encode_text(self, text):
138
+ """Encode Hindi text using CANINE"""
139
+ try:
140
+ # CANINE handles character-level encoding natively
141
+ inputs = self.tokenizer(
142
+ text,
143
+ return_tensors='pt',
144
+ padding=True,
145
+ truncation=True,
146
+ max_length=512
147
+ )
148
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
149
+ return inputs
150
+ except Exception as e:
151
+ print(f"❌ Text encoding error: {e}")
152
+ return None
153
+
154
+ @torch.no_grad()
155
+ def generate(self, text, writer_id=0, num_steps=50, guidance_scale=7.5):
156
+ """
157
+ Generate Hindi handwriting from text
158
+
159
+ Args:
160
+ text: Hindi text in Devanagari script
161
+ writer_id: Writer style ID (0-338)
162
+ num_steps: Number of diffusion steps (10-100)
163
+ guidance_scale: Text guidance strength (1.0-15.0)
164
+
165
+ Returns:
166
+ Tuple[PIL.Image, str]: Generated image and status message
167
+ """
168
+ if self.model is None:
169
+ return None, "❌ Model not initialized"
170
+
171
+ try:
172
+ # Input validation
173
+ if not text.strip():
174
+ return None, "⚠️ Please enter Hindi text"
175
+
176
+ writer_id = max(0, min(int(writer_id), 338))
177
+ num_steps = max(10, min(int(num_steps), 100))
178
+ guidance_scale = max(1.0, min(float(guidance_scale), 15.0))
179
+
180
+ print(f"\n🎨 Generating handwriting...")
181
+ print(f" Text: '{text}'")
182
+ print(f" Writer: {writer_id}/338")
183
+ print(f" Steps: {num_steps}")
184
+ print(f" Guidance: {guidance_scale}")
185
+
186
+ # Encode text with CANINE
187
+ context = self.encode_text(text)
188
+ if context is None:
189
+ return None, "❌ Text encoding failed"
190
+
191
+ batch_size = 1
192
+
193
+ # Initialize from noise
194
+ x = torch.randn(batch_size, 1, 64, 64, device=self.device)
195
+
196
+ # Reverse diffusion process
197
+ for step in range(num_steps - 1, -1, -1):
198
+ # Prepare timestep and writer conditioning
199
+ t = torch.full((batch_size,), step, dtype=torch.long, device=self.device)
200
+ y = torch.tensor([writer_id], dtype=torch.long, device=self.device)
201
+
202
+ # Model prediction
203
+ with torch.no_grad():
204
+ noise_pred = self.model(
205
+ x,
206
+ timesteps=t,
207
+ context=context,
208
+ y=y
209
+ )
210
+
211
+ # Denoising step with adaptive scaling
212
+ alpha_t = 1.0 - (step / num_steps)
213
+ scale = guidance_scale * alpha_t
214
+ x = x - 0.01 * scale * noise_pred
215
+
216
+ # Progress indicator
217
+ if (num_steps - step) % max(1, num_steps // 5) == 0:
218
+ progress = ((num_steps - step) / num_steps) * 100
219
+ print(f" Progress: {progress:.0f}%")
220
+
221
+ # Post-processing
222
+ x = torch.clamp(x, -1, 1)
223
+ x = (x + 1) / 2 # Normalize to [0, 1]
224
+ x = x.squeeze(0).squeeze(0).cpu().numpy()
225
+
226
+ # Convert to PIL Image
227
+ img_array = (x * 255).astype(np.uint8)
228
+ img = Image.fromarray(img_array, mode='L')
229
+
230
+ status = f"✅ Generated with writer {writer_id}, {num_steps} steps"
231
+ print(f" {status}\n")
232
+ return img, status
233
+
234
+ except Exception as e:
235
+ error_msg = f"❌ Generation error: {str(e)}"
236
+ print(f" {error_msg}")
237
+ return None, error_msg
238
 
 
 
 
 
 
 
239
 
240
+ # ==============================================================================
241
+ # CONFIGURATION
242
+ # ==============================================================================
 
 
 
 
243
 
244
+ # Path to your trained checkpoint (edit this!)
245
+ CHECKPOINT_PATH = "./checkpoints/model.pt"
 
246
 
247
+ # Initialize demo
248
+ print("\n🚀 Initializing DiffusionPen...")
249
+ demo_instance = DiffusionPenDemo(
250
+ checkpoint_path=CHECKPOINT_PATH,
251
+ device=None # Auto-detect GPU/CPU
252
+ )
253
 
 
 
 
 
 
 
 
 
254
 
255
+ def gradio_generate(text, writer_id, num_steps, guidance_scale):
256
+ """Gradio callback for generation"""
257
+ img, message = demo_instance.generate(
258
+ text=text,
259
+ writer_id=writer_id,
260
+ num_steps=num_steps,
261
+ guidance_scale=guidance_scale
262
+ )
263
+ return img, message
 
 
264
 
 
 
 
265
 
266
+ # ==============================================================================
267
+ # GRADIO INTERFACE
268
+ # ==============================================================================
269
 
270
+ theme = gr.themes.Soft(
271
+ primary_hue="indigo",
272
+ secondary_hue="amber",
273
+ )
274
 
275
+ with gr.Blocks(title="DiffusionPen - Hindi Handwriting Generation", theme=theme) as demo:
276
+
277
+ # Header
278
+ gr.Markdown("""
279
+ # 🎨 DiffusionPen: Hindi Handwriting Generation
280
+
281
+ Generate authentic Hindi handwriting using diffusion models with CANINE text encoding.
282
+ """)
283
+
284
+ # Main content
285
+ with gr.Row():
286
+ # Input panel
287
+ with gr.Column(scale=1, min_width=300):
288
+ gr.Markdown("### ✏️ Input Settings")
289
 
290
+ text_input = gr.Textbox(
291
+ label="Hindi Text (Devanagari)",
292
+ placeholder="नमस्ते",
293
+ lines=2,
294
+ info="Enter text in Devanagari script"
295
+ )
296
 
297
+ writer_id = gr.Slider(
298
+ label="Writer ID",
299
+ minimum=0,
300
+ maximum=338,
301
+ value=0,
302
+ step=1,
303
+ info="0-338: Different writing styles"
304
+ )
305
+
306
+ num_steps = gr.Slider(
307
+ label="Diffusion Steps",
308
+ minimum=10,
309
+ maximum=100,
310
+ value=50,
311
+ step=10,
312
+ info="10=fast, 100=quality"
313
+ )
314
+
315
+ guidance_scale = gr.Slider(
316
+ label="Guidance Scale",
317
+ minimum=1.0,
318
+ maximum=15.0,
319
+ value=7.5,
320
+ step=0.5,
321
+ info="1=ignore text, 15=strict"
322
+ )
323
+
324
+ generate_btn = gr.Button(
325
+ "✨ Generate Handwriting",
326
+ variant="primary",
327
+ size="lg"
328
+ )
329
+
330
+ # Output panel
331
+ with gr.Column(scale=1, min_width=300):
332
+ gr.Markdown("### 📊 Output")
333
+
334
+ output_image = gr.Image(
335
+ label="Generated Handwriting",
336
+ type='pil',
337
+ interactive=False,
338
+ show_download_button=True
339
+ )
340
+
341
+ status_text = gr.Textbox(
342
+ label="Status",
343
+ interactive=False,
344
+ info="Generation progress and results"
345
+ )
346
 
347
+ # Examples
348
+ gr.Markdown("### 📚 Examples to Try")
349
+ gr.Examples(
350
+ examples=[
351
+ ["नमस्ते", 0, 50, 7.5],
352
+ ["हिंदी", 50, 50, 7.5],
353
+ ["आईआईआीटी", 100, 50, 7.5],
354
+ ["लिपि", 150, 50, 7.5],
355
+ ["भाषा", 200, 50, 7.5],
356
+ ["नई लिखावट", 250, 60, 7.5],
357
+ ],
358
+ inputs=[text_input, writer_id, num_steps, guidance_scale],
359
+ outputs=[output_image, status_text],
360
+ fn=gradio_generate,
361
+ cache_examples=False,
362
+ run_on_click=False
363
+ )
364
+
365
+ # Information
366
+ gr.Markdown("""
367
+ ---
368
+
369
+ ### 📖 About This Demo
370
+
371
+ **Model Architecture:**
372
+ - **Base**: UNet with 128 channels, 3 levels
373
+ - **Attention**: Spatial transformers at resolutions 16×8
374
+ - **Text Encoding**: CANINE (768-dim, character-level)
375
+ - **Writer Styles**: 339 different writing styles
376
+ - **Input/Output**: 64×64 grayscale images
377
+
378
+ **CANINE Text Encoder:**
379
+ - ✓ Character-level (no subword tokenization)
380
+ - ✓ Native Devanagari support
381
+ - ✓ Pre-trained on 104 languages
382
+ - ✓ 768-dimensional contextual embeddings
383
+
384
+ **Performance:**
385
+ - CPU: ~2 minutes per image
386
+ - GPU: ~20 seconds per image
387
+ - Memory: 6-8 GB
388
+
389
+ ### 💡 Tips
390
+ 1. Keep text short (5-10 characters) for faster generation
391
+ 2. Try different Writer IDs for style variation
392
+ 3. Increase steps from 50→100 for better quality
393
+ 4. Guidance scale 5-10 works best for most cases
394
+ 5. Use CPU to generate demos, GPU for production
395
+
396
+ ### 🔗 Resources
397
+ - [CANINE Paper](https://arxiv.org/abs/2103.06367)
398
+ - [Diffusion Models Course](https://huggingface.co/course)
399
+ - [UNet Architecture](https://en.wikipedia.org/wiki/U-Net)
400
+ """)
401
+
402
+ # Connect button
403
+ generate_btn.click(
404
+ fn=gradio_generate,
405
+ inputs=[text_input, writer_id, num_steps, guidance_scale],
406
+ outputs=[output_image, status_text],
407
+ api_name="generate"
408
+ )
409
 
 
410
 
411
+ if __name__ == "__main__":
412
+ print(f"\n{'='*60}")
413
+ print("🚀 Starting DiffusionPen Gradio Demo")
414
+ print(f"{'='*60}")
415
+ print(f"Device: {demo_instance.device}")
416
+ print(f"Checkpoint: {'✓ Loaded' if demo_instance.checkpoint_loaded else '✗ Not found'}")
417
+ print(f"Models: {'✓ Ready' if demo_instance.model is not None else '✗ Error'}")
418
+ print(f"{'='*60}\n")
419
+
420
+ demo.launch(
421
+ share=False,
422
+ server_name="0.0.0.0",
423
+ server_port=7860,
424
+ show_error=True
425
+ )