File size: 1,747 Bytes
f43db1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Tiny test: 128x128, 4 steps, no CFG, with text embeddings."""
import sys, time, torch, numpy as np
import mlx.core as mx
sys.path.insert(0, ".")

# Extract text embeddings
from transformers import AutoProcessor, AutoModel
PROMPT = "A red apple on a white table"
SYSTEM = "You are an image generation assistant."

processor = AutoProcessor.from_pretrained("NucleusAI/Nucleus-Image", subfolder="processor", trust_remote_code=True)
text_model = AutoModel.from_pretrained("NucleusAI/Nucleus-Image", subfolder="text_encoder", dtype=torch.bfloat16, trust_remote_code=True)
text_model.eval()

messages = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": [{"type": "text", "text": PROMPT}]}]
formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[formatted], return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = text_model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), output_hidden_states=True, use_cache=False)
    hidden = outputs.hidden_states[-8][0]
text_emb = mx.array(hidden.cpu().float().numpy())
print(f"Text: {text_emb.shape}")

del text_model, processor
import gc; gc.collect()

# Generate
from nucleus_image.pipeline import NucleusImagePipeline
pipe = NucleusImagePipeline.from_pretrained(quantize=4)

print("Generating 128x128, 4 steps, no CFG...")
t0 = time.time()
img = pipe.generate(
    text_embeddings=mx.expand_dims(text_emb, 0),
    height=128, width=128,
    num_inference_steps=4,
    guidance_scale=1.0,  # no CFG for speed
    seed=42,
)
print(f"Done in {time.time()-t0:.1f}s")
img.save("/Users/ritesh/Dev/model-training/nucleus-image/mlx/test-output/tiny_test.png")
print(f"Saved! {img.size}")