treadon commited on
Commit
f43db1e
·
verified ·
1 Parent(s): 6791033

Upload test_tiny.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_tiny.py +43 -0
test_tiny.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny test: 128x128, 4 steps, no CFG, with text embeddings."""
2
+ import sys, time, torch, numpy as np
3
+ import mlx.core as mx
4
+ sys.path.insert(0, ".")
5
+
6
+ # Extract text embeddings
7
+ from transformers import AutoProcessor, AutoModel
8
+ PROMPT = "A red apple on a white table"
9
+ SYSTEM = "You are an image generation assistant."
10
+
11
+ processor = AutoProcessor.from_pretrained("NucleusAI/Nucleus-Image", subfolder="processor", trust_remote_code=True)
12
+ text_model = AutoModel.from_pretrained("NucleusAI/Nucleus-Image", subfolder="text_encoder", dtype=torch.bfloat16, trust_remote_code=True)
13
+ text_model.eval()
14
+
15
+ messages = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": [{"type": "text", "text": PROMPT}]}]
16
+ formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
+ inputs = processor(text=[formatted], return_tensors="pt", padding=True)
18
+
19
+ with torch.no_grad():
20
+ outputs = text_model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), output_hidden_states=True, use_cache=False)
21
+ hidden = outputs.hidden_states[-8][0]
22
+ text_emb = mx.array(hidden.cpu().float().numpy())
23
+ print(f"Text: {text_emb.shape}")
24
+
25
+ del text_model, processor
26
+ import gc; gc.collect()
27
+
28
+ # Generate
29
+ from nucleus_image.pipeline import NucleusImagePipeline
30
+ pipe = NucleusImagePipeline.from_pretrained(quantize=4)
31
+
32
+ print("Generating 128x128, 4 steps, no CFG...")
33
+ t0 = time.time()
34
+ img = pipe.generate(
35
+ text_embeddings=mx.expand_dims(text_emb, 0),
36
+ height=128, width=128,
37
+ num_inference_steps=4,
38
+ guidance_scale=1.0, # no CFG for speed
39
+ seed=42,
40
+ )
41
+ print(f"Done in {time.time()-t0:.1f}s")
42
+ img.save("/Users/ritesh/Dev/model-training/nucleus-image/mlx/test-output/tiny_test.png")
43
+ print(f"Saved! {img.size}")