"""Full test: text encoder (PyTorch) → DiT (MLX) → VAE (MLX) → image. Loads text encoder first, extracts embeddings, deletes it, then loads DiT+VAE for generation. Avoids 50GB simultaneous memory. """ import sys import time import numpy as np import torch import mlx.core as mx sys.path.insert(0, ".") PROMPT = "A vibrant 4-panel manga comic strip about a cat discovering a tiny dragon" OUT = "/Users/ritesh/Dev/model-training/nucleus-image/mlx/test-output" # Step 1: Extract text embeddings print("Loading text encoder (PyTorch)...") from transformers import AutoProcessor, AutoModel processor = AutoProcessor.from_pretrained("NucleusAI/Nucleus-Image", subfolder="processor", trust_remote_code=True) text_model = AutoModel.from_pretrained( "NucleusAI/Nucleus-Image", subfolder="text_encoder", dtype=torch.bfloat16, trust_remote_code=True ) text_model.eval() print("Encoding text...") # Format with system prompt (matching diffusers pipeline) SYSTEM_PROMPT = "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects." messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": [{"type": "text", "text": PROMPT}]}, ] formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[formatted], return_tensors="pt", padding=True) with torch.no_grad(): outputs = text_model( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), output_hidden_states=True, use_cache=False, ) # Use -8 (8th from last) matching diffusers default_return_index hidden = outputs.hidden_states[-8][0] # [T, 4096] print(f" Hidden states: {len(outputs.hidden_states)} layers, using [-8]") emb_np = hidden.cpu().float().numpy() print(f"Text embedding: {emb_np.shape}") # Free text encoder del text_model, processor, inputs, outputs torch.mps.empty_cache() if torch.backends.mps.is_available() else None import gc; gc.collect() print("Text encoder freed.") # Step 2: Generate with MLX print("\nLoading DiT + VAE (MLX, 4-bit)...") from nucleus_image.pipeline import NucleusImagePipeline pipe = NucleusImagePipeline.from_pretrained(quantize=4) text_emb = mx.array(emb_np) print(f"\nGenerating 512x512, 20 steps, CFG 4.0...") img = pipe.generate( text_embeddings=mx.expand_dims(text_emb, 0), height=512, width=512, num_inference_steps=20, guidance_scale=4.0, seed=42, ) img.save(f"{OUT}/full_test_512.png") print(f"Saved! {img.size}")