keysun89 commited on
Commit
6621447
·
verified ·
1 Parent(s): bf59780

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -73
app.py CHANGED
@@ -7,7 +7,7 @@ from torchvision import transforms
7
  from transformers import CanineModel, CanineTokenizer
8
  from diffusers import AutoencoderKL, DDPMScheduler
9
 
10
- # Import your custom architectures
11
  from unet import UNetModel
12
  from feature_extractor import Mixed_Encoder
13
 
@@ -15,29 +15,37 @@ from feature_extractor import Mixed_Encoder
15
  # 1. SETUP & CONFIGURATION
16
  # ==========================================
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
- # RE-RE-CRITICAL: Fill this list in the EXACT order of your training folders!
19
- HINDI_VOCAB = ["क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ"] # ... add all others
20
 
21
- # = :=========================================
22
- # 2. MODEL LOADING (Inference Optimized)
23
- # ==========================================
24
- print(f"🚀 Booting DiffusionPen on {DEVICE}...")
 
 
 
 
 
 
25
 
26
- # Load VAE (Directly via app.py as requested)
27
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
 
 
28
 
29
- # Load Style Encoder (Mixed_Encoder)
30
  style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
31
  style_encoder.load_state_dict(torch.load("weights/mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE))
32
  style_encoder.eval()
33
 
34
- # Load Text Encoder (Canine)
35
  tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
36
  text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
37
  text_encoder.eval()
38
 
39
- # Load UNet (Custom)
40
- # These parameters must match your training config
 
 
41
  unet = UNetModel(
42
  image_size=(64, 256),
43
  in_channels=4,
@@ -46,15 +54,16 @@ unet = UNetModel(
46
  num_res_blocks=2,
47
  attention_resolutions=[4, 2, 1],
48
  channel_mult=[1, 2, 4, 4],
49
- context_dim=768
50
  ).to(DEVICE)
51
  unet.load_state_dict(torch.load("weights/ema_ckpt.pt", map_location=DEVICE))
52
  unet.eval()
53
 
 
54
  scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
55
 
56
  # ==========================================
57
- # 3. PREPROCESSING UTILS
58
  # ==========================================
59
  style_transform = transforms.Compose([
60
  transforms.Resize((224, 224)),
@@ -62,62 +71,7 @@ style_transform = transforms.Compose([
62
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63
  ])
64
 
65
- def predict(hindi_text, style_image):
66
  with torch.no_grad():
67
- # A. Process Style
68
- if style_image is not None:
69
- style_t = style_transform(style_image).unsqueeze(0).to(DEVICE)
70
- _, style_features = style_encoder(style_t)
71
- else:
72
- style_features = torch.zeros((1, 1280)).to(DEVICE)
73
-
74
- # B. Process Text
75
- inputs = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE)
76
-
77
- # C. Diffusion Loop (Simplified DDPM)
78
- latents = torch.randn((1, 4, 8, 32)).to(DEVICE) # Latent size for 64x256
79
- scheduler.set_timesteps(50) # 50 steps for speed in demo
80
-
81
- for t in scheduler.timesteps:
82
- # Predict noise
83
- noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=inputs, style_extractor=style_features)
84
- # Step scheduler
85
- latents = scheduler.step(noise_pred, t, latents).prev_sample
86
-
87
- # D. Decode with VAE
88
- latents = 1 / 0.18215 * latents
89
- image = vae.decode(latents).sample
90
- image = (image / 2 + 0.5).clamp(0, 1)
91
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
92
- image = (image * 255).astype(np.uint8)
93
-
94
- return Image.fromarray(image)
95
-
96
- # ==========================================
97
- # 4. GRADIO INTERFACE (Resume Ready)
98
- # ==========================================
99
- description = """
100
- ### 🖋️ DiffusionPen: Hindi Handwriting Synthesis
101
- **Developed by Kishan Madlani | NIT Surat**
102
- This model uses a Latent Diffusion architecture to generate Hindi text in specific handwriting styles.
103
- It was trained on a custom dataset of 300+ writers using Triplet Loss and Cross-Attention.
104
- """
105
-
106
- demo = gr.Interface(
107
- fn=predict,
108
- inputs=[
109
- gr.Textbox(label="Input Hindi Text", placeholder="नमस्ते..."),
110
- gr.Image(label="Style Reference Image", type="pil")
111
- ],
112
- outputs=gr.Image(label="Generated Handwriting"),
113
- title="DiffusionPen - Hindi Style Transfer",
114
- description=description,
115
- theme="soft",
116
- examples=[
117
- ["भारत", None],
118
- ["शिक्षा", None]
119
- ]
120
- )
121
-
122
- if __name__ == "__main__":
123
- demo.launch()
 
7
  from transformers import CanineModel, CanineTokenizer
8
  from diffusers import AutoencoderKL, DDPMScheduler
9
 
10
+ # Import your custom architectures from your local files
11
  from unet import UNetModel
12
  from feature_extractor import Mixed_Encoder
13
 
 
15
  # 1. SETUP & CONFIGURATION
16
  # ==========================================
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
18
 
19
+ # ⚠️ CRITICAL: Fill this list in the EXACT alphabetical/folder order of your training data.
20
+ # This ensures "Ka" maps to the "Ka" vector, not "Kha".
21
+ HINDI_VOCAB = [
22
+ "अ", "आ", "इ", "ई", "उ", "ऊ", "ऋ", "ए", "ऐ", "ओ", "औ",
23
+ "क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ",
24
+ "ट", "ठ", "ड", "ढ", "ण", "त", "थ", "द", "ध", "न",
25
+ "प", "फ", "ब", "भ", "म", "य", "र", "ल", "व", "श",
26
+ "ष", "स", "ह"
27
+ # ... Add any conjuncts or matras you trained on
28
+ ]
29
 
30
+ # ==========================================
31
+ # 2. MODEL INITIALIZATION
32
+ # ==========================================
33
+ print(f"📦 Loading models on {DEVICE}...")
34
 
35
+ # A. Style Encoder (Mixed_Encoder from your feature_extractor.py)
36
  style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
37
  style_encoder.load_state_dict(torch.load("weights/mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE))
38
  style_encoder.eval()
39
 
40
+ # B. Text Encoder (Canine)
41
  tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
42
  text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
43
  text_encoder.eval()
44
 
45
+ # C. VAE (MSE-tuned for sharp handwriting)
46
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
47
+
48
+ # D. UNet (Your custom unet.py)
49
  unet = UNetModel(
50
  image_size=(64, 256),
51
  in_channels=4,
 
54
  num_res_blocks=2,
55
  attention_resolutions=[4, 2, 1],
56
  channel_mult=[1, 2, 4, 4],
57
+ context_dim=768 # Canine hidden size
58
  ).to(DEVICE)
59
  unet.load_state_dict(torch.load("weights/ema_ckpt.pt", map_location=DEVICE))
60
  unet.eval()
61
 
62
+ # E. Scheduler
63
  scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
64
 
65
  # ==========================================
66
+ # 3. INFERENCE LOGIC
67
  # ==========================================
68
  style_transform = transforms.Compose([
69
  transforms.Resize((224, 224)),
 
71
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
72
  ])
73
 
74
+ def generate_handwriting(hindi_text, s1, s2, s3, s4, s5):
75
  with torch.no_grad():
76
+ # 1. Few-Shot Style Extraction
77
+ style_images = [img for img in [s1, s2, s3