treadon commited on
Commit
641e0cd
·
verified ·
1 Parent(s): 90dc3da

Upload generate_samples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_samples.py +65 -0
generate_samples.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate sample images for the repo."""
3
+ import sys, torch, gc, mlx.core as mx
4
+ sys.path.insert(0, ".")
5
+
6
+ from transformers import AutoProcessor, AutoModel
7
+ from nucleus_image.pipeline import NucleusImagePipeline
8
+
9
+ SYSTEM = "You are an image generation assistant."
10
+ MODEL = "NucleusAI/Nucleus-Image"
11
+
12
+ prompts = [
13
+ ("apple", "A red apple on a white table"),
14
+ ("puppy", "A golden retriever puppy playing in autumn leaves"),
15
+ ("city", "A futuristic city skyline at sunset with flying cars"),
16
+ ("coffee", "A steaming cup of coffee on a rainy windowsill"),
17
+ ("astronaut", "An astronaut riding a horse on the moon, digital art"),
18
+ ]
19
+
20
+ # Encode all prompts + negative
21
+ print("Loading text encoder...")
22
+ processor = AutoProcessor.from_pretrained(MODEL, subfolder="processor", trust_remote_code=True)
23
+ text_model = AutoModel.from_pretrained(MODEL, subfolder="text_encoder", dtype=torch.bfloat16, trust_remote_code=True).eval()
24
+
25
+ def encode(prompt):
26
+ messages = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": [{"type": "text", "text": prompt}]}]
27
+ formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
28
+ inputs = processor(text=[formatted], return_tensors="pt", padding=True)
29
+ with torch.no_grad():
30
+ out = text_model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"),
31
+ output_hidden_states=True, use_cache=False)
32
+ return mx.array(out.hidden_states[-8][0].cpu().float().numpy())
33
+
34
+ embeddings = {}
35
+ for name, prompt in prompts:
36
+ embeddings[name] = encode(prompt)
37
+ print(f" {name}: {embeddings[name].shape}")
38
+ neg_emb = encode("")
39
+ print(f" negative: {neg_emb.shape}")
40
+
41
+ del text_model, processor; gc.collect()
42
+
43
+ # Generate
44
+ pipe = NucleusImagePipeline.from_pretrained(quantize=4)
45
+ import os; os.makedirs("samples", exist_ok=True)
46
+
47
+ for i, (name, prompt) in enumerate(prompts):
48
+ print(f"\n[{i+1}/{len(prompts)}] {prompt}")
49
+ emb = embeddings[name]
50
+ # Pad neg to match
51
+ n = neg_emb
52
+ if n.shape[0] < emb.shape[0]:
53
+ n = mx.concatenate([n, mx.zeros((emb.shape[0] - n.shape[0], 4096))], axis=0)
54
+ elif n.shape[0] > emb.shape[0]:
55
+ n = n[:emb.shape[0]]
56
+
57
+ img = pipe.generate(
58
+ text_embeddings=mx.expand_dims(emb, 0),
59
+ neg_text_embeddings=mx.expand_dims(n, 0),
60
+ height=512, width=512, num_inference_steps=30, guidance_scale=4.0, seed=42 + i,
61
+ )
62
+ img.save(f"samples/{name}.png")
63
+ print(f" Saved samples/{name}.png")
64
+
65
+ print("\nDone!")