treadon commited on
Commit
1a73726
·
verified ·
1 Parent(s): aa8297b

Upload generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate.py +108 -0
generate.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate images with Nucleus-Image on Apple Silicon (MLX).
3
+
4
+ Usage:
5
+ python generate.py --prompt "A red apple on a white table"
6
+ python generate.py --prompt "A futuristic city at sunset" --steps 30 --seed 42 --output city.png
7
+ """
8
+
9
+ import argparse
10
+ import gc
11
+ import time
12
+
13
+ import mlx.core as mx
14
+ import torch
15
+ from transformers import AutoModel, AutoProcessor
16
+
17
+ from nucleus_image.pipeline import NucleusImagePipeline
18
+
19
+ SYSTEM_PROMPT = "You are an image generation assistant."
20
+ TEXT_MODEL_ID = "NucleusAI/Nucleus-Image"
21
+ HIDDEN_LAYER_INDEX = -8 # 8th from last hidden state
22
+
23
+
24
+ def encode_text(prompt: str, processor, text_model) -> mx.array:
25
+ """Encode a text prompt into embeddings using the chat template format."""
26
+ messages = [
27
+ {"role": "system", "content": SYSTEM_PROMPT},
28
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
29
+ ]
30
+ formatted = processor.apply_chat_template(
31
+ messages, tokenize=False, add_generation_prompt=True
32
+ )
33
+ inputs = processor(text=[formatted], return_tensors="pt", padding=True)
34
+ with torch.no_grad():
35
+ outputs = text_model(
36
+ input_ids=inputs["input_ids"],
37
+ attention_mask=inputs.get("attention_mask"),
38
+ output_hidden_states=True,
39
+ use_cache=False,
40
+ )
41
+ hidden = outputs.hidden_states[HIDDEN_LAYER_INDEX][0]
42
+ return mx.array(hidden.cpu().float().numpy())
43
+
44
+
45
+ def main():
46
+ parser = argparse.ArgumentParser(description="Generate images with MLX Nucleus-Image")
47
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for image generation")
48
+ parser.add_argument("--height", type=int, default=512, help="Image height (default: 512)")
49
+ parser.add_argument("--width", type=int, default=512, help="Image width (default: 512)")
50
+ parser.add_argument("--steps", type=int, default=50, help="Number of inference steps (default: 50)")
51
+ parser.add_argument("--cfg", type=float, default=4.0, help="Classifier-free guidance scale (default: 4.0)")
52
+ parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
53
+ parser.add_argument("--output", type=str, default="output.png", help="Output file path (default: output.png)")
54
+ parser.add_argument("--quantize", type=int, default=4, choices=[4, 8, None], help="DiT quantization bits (default: 4)")
55
+ args = parser.parse_args()
56
+
57
+ t_total = time.time()
58
+
59
+ # Step 1: Load text encoder and encode prompt + negative (empty string)
60
+ print("Loading text encoder...")
61
+ t0 = time.time()
62
+ processor = AutoProcessor.from_pretrained(
63
+ TEXT_MODEL_ID, subfolder="processor", trust_remote_code=True
64
+ )
65
+ text_model = AutoModel.from_pretrained(
66
+ TEXT_MODEL_ID, subfolder="text_encoder",
67
+ dtype=torch.bfloat16, trust_remote_code=True,
68
+ )
69
+ text_model.eval()
70
+ print(f" Text encoder loaded in {time.time() - t0:.1f}s")
71
+
72
+ print("Encoding prompt...")
73
+ t0 = time.time()
74
+ text_emb = encode_text(args.prompt, processor, text_model)
75
+
76
+ print("Encoding negative embeddings...")
77
+ neg_emb = encode_text("", processor, text_model)
78
+ print(f" Text encoding done in {time.time() - t0:.1f}s")
79
+
80
+ # Free text encoder memory (~16GB)
81
+ del text_model, processor
82
+ gc.collect()
83
+
84
+ # Step 2: Load MLX pipeline (DiT + VAE)
85
+ print(f"Loading MLX pipeline (quantize={args.quantize})...")
86
+ t0 = time.time()
87
+ pipe = NucleusImagePipeline.from_pretrained(quantize=args.quantize)
88
+ print(f" Pipeline loaded in {time.time() - t0:.1f}s")
89
+
90
+ # Step 3: Generate image
91
+ print(f"Generating {args.height}x{args.width}, {args.steps} steps, CFG {args.cfg}...")
92
+ img = pipe.generate(
93
+ text_embeddings=mx.expand_dims(text_emb, 0),
94
+ neg_text_embeddings=mx.expand_dims(neg_emb, 0),
95
+ height=args.height,
96
+ width=args.width,
97
+ num_inference_steps=args.steps,
98
+ guidance_scale=args.cfg,
99
+ seed=args.seed,
100
+ )
101
+
102
+ img.save(args.output)
103
+ print(f"Saved to {args.output}")
104
+ print(f"Total time: {time.time() - t_total:.1f}s")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()