import os import gc import torch import gradio as gr from PIL import Image from pathlib import Path import shutil import json from huggingface_hub import hf_hub_download, HfApi import tempfile # Training imports from diffusers import AutoencoderKL from transformers import AutoTokenizer, AutoModel from peft import LoraConfig, get_peft_model from accelerate import Accelerator from tqdm.auto import tqdm import numpy as np # Global state DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 def check_gpu(): """Check GPU availability and memory""" if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 return f"GPU: {gpu_name} ({gpu_mem:.1f}GB)" return "No GPU detected - please upgrade to L4 GPU for training" def prepare_dataset(images, trigger_word, output_dir): """Prepare training dataset from uploaded images""" dataset_dir = Path(output_dir) / "dataset" dataset_dir.mkdir(parents=True, exist_ok=True) image_paths = [] for i, img in enumerate(images): if img is None: continue # Handle different input types if isinstance(img, str): img_pil = Image.open(img) elif isinstance(img, np.ndarray): img_pil = Image.fromarray(img) else: img_pil = img # Convert to RGB if necessary if img_pil.mode != "RGB": img_pil = img_pil.convert("RGB") # Save image img_path = dataset_dir / f"image_{i:04d}.jpg" img_pil.save(img_path, quality=95) # Create caption file caption_path = dataset_dir / f"image_{i:04d}.txt" caption_path.write_text(f"{trigger_word}") image_paths.append(str(img_path)) return image_paths, str(dataset_dir) def train_lora( images, trigger_word, output_name, num_steps, learning_rate, lora_rank, resolution, batch_size, progress=gr.Progress() ): """Train a LoRA for Z-Image Turbo""" if not torch.cuda.is_available(): return None, "Error: No GPU available. Please upgrade to L4 GPU ($0.80/hr) in Space settings." if not images or len(images) < 3: return None, "Error: Please upload at least 3 training images." if not trigger_word: return None, "Error: Please specify a trigger word." if not output_name: output_name = "z_image_lora" # Clean output name output_name = output_name.replace(" ", "_").lower() progress(0, desc="Initializing...") # Create temp directory with tempfile.TemporaryDirectory() as tmpdir: try: # Prepare dataset progress(0.05, desc="Preparing dataset...") image_paths, dataset_dir = prepare_dataset(images, trigger_word, tmpdir) if len(image_paths) < 3: return None, "Error: Not enough valid images. Please upload at least 3 images." progress(0.1, desc=f"Dataset prepared: {len(image_paths)} images") # Download training adapter from ostris progress(0.15, desc="Downloading Z-Image training adapter...") adapter_path = hf_hub_download( repo_id="ostris/zimage_turbo_training_adapter", filename="training_adapter_v1.safetensors", local_dir=tmpdir ) # Load model components progress(0.2, desc="Loading Z-Image Turbo model...") from diffusers import ZImagePipeline pipe = ZImagePipeline.from_pretrained( "Tongyi-MAI/Z-Image-Turbo", torch_dtype=DTYPE, ) # Configure LoRA progress(0.3, desc="Configuring LoRA...") lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], lora_dropout=0.0, ) # Apply LoRA to transformer pipe.transformer = get_peft_model(pipe.transformer, lora_config) pipe.transformer.print_trainable_parameters() # Move to GPU pipe.to(DEVICE) # Enable gradient checkpointing for memory efficiency pipe.transformer.enable_gradient_checkpointing() # Setup optimizer optimizer = torch.optim.AdamW( pipe.transformer.parameters(), lr=learning_rate, weight_decay=0.01 ) # Load training adapter progress(0.35, desc="Loading training adapter...") from safetensors.torch import load_file adapter_weights = load_file(adapter_path) # Training loop progress(0.4, desc="Starting training...") pipe.transformer.train() for step in range(num_steps): # Select random image from dataset img_path = np.random.choice(image_paths) img = Image.open(img_path).convert("RGB") # Resize to training resolution img = img.resize((resolution, resolution), Image.LANCZOS) # Convert to tensor img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0 img_tensor = img_tensor.unsqueeze(0).to(DEVICE, dtype=DTYPE) img_tensor = 2.0 * img_tensor - 1.0 # Normalize to [-1, 1] # Get latents with torch.no_grad(): latents = pipe.vae.encode(img_tensor).latent_dist.sample() latents = latents * pipe.vae.config.scaling_factor # Sample random timestep timesteps = torch.randint(0, 1000, (1,), device=DEVICE).long() # Add noise noise = torch.randn_like(latents) noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps) # Get text embeddings prompt = trigger_word text_inputs = pipe.tokenizer( prompt, padding="max_length", max_length=512, truncation=True, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): text_embeddings = pipe.text_encoder(**text_inputs)[0] # Predict noise (simplified training step) with torch.cuda.amp.autocast(dtype=DTYPE): noise_pred = pipe.transformer( noisy_latents, encoder_hidden_states=text_embeddings, timestep=timesteps, ).sample # Calculate loss loss = torch.nn.functional.mse_loss(noise_pred, noise) # Backward pass loss.backward() optimizer.step() optimizer.zero_grad() # Update progress if step % 50 == 0: progress(0.4 + 0.5 * (step / num_steps), desc=f"Training step {step}/{num_steps} - Loss: {loss.item():.4f}") # Clear cache periodically if step % 100 == 0: torch.cuda.empty_cache() # Save LoRA progress(0.95, desc="Saving LoRA weights...") output_path = Path(tmpdir) / f"{output_name}.safetensors" pipe.transformer.save_pretrained(str(output_path.parent / "lora_weights")) # Copy to accessible location final_output = f"/tmp/{output_name}.safetensors" # Save using safetensors from safetensors.torch import save_file lora_state_dict = {} for name, param in pipe.transformer.named_parameters(): if "lora" in name.lower(): lora_state_dict[name] = param.cpu() save_file(lora_state_dict, final_output) # Cleanup del pipe gc.collect() torch.cuda.empty_cache() progress(1.0, desc="Training complete!") return final_output, f"Training complete! LoRA saved as {output_name}.safetensors\n\nUse trigger word: {trigger_word}" except Exception as e: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return None, f"Error during training: {str(e)}" # Gradio UI with gr.Blocks(title="Z-Image LoRA Trainer") as demo: gr.Markdown(""" # Z-Image LoRA Trainer Train custom LoRA models for Z-Image Turbo. Upload your training images and configure the settings below. **Important**: This Space requires an L4 GPU ($0.80/hr) for training. Upgrade in Settings before starting. """) with gr.Row(): gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False) refresh_btn = gr.Button("Refresh GPU Status") refresh_btn.click(fn=check_gpu, outputs=gpu_status) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Training Images") images = gr.Gallery( label="Upload 6-20 training images", columns=4, height=300, type="filepath", interactive=True ) gr.Markdown(""" **Tips for best results:** - Use 6-20 high quality images - Consistent lighting and style - Square or similar aspect ratios - Clear subjects, not too busy """) with gr.Column(scale=1): gr.Markdown("### Training Settings") trigger_word = gr.Textbox( label="Trigger Word", placeholder="e.g., or sks_person", info="Use a unique word that doesn't exist in normal vocabulary" ) output_name = gr.Textbox( label="Output Name", placeholder="my_lora", info="Name for your LoRA file" ) with gr.Row(): num_steps = gr.Slider( minimum=500, maximum=5000, value=2000, step=100, label="Training Steps", info="More steps = better quality but longer training" ) learning_rate = gr.Slider( minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5, label="Learning Rate" ) with gr.Row(): lora_rank = gr.Slider( minimum=4, maximum=32, value=8, step=4, label="LoRA Rank", info="Higher = more capacity but more VRAM" ) resolution = gr.Slider( minimum=256, maximum=1024, value=512, step=128, label="Training Resolution", info="Lower = faster and less VRAM" ) batch_size = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Batch Size", info="Keep at 1 for L4 GPU" ) with gr.Row(): train_btn = gr.Button("Start Training", variant="primary", size="lg") with gr.Row(): output_file = gr.File(label="Download LoRA") output_log = gr.Textbox(label="Training Log", lines=5) train_btn.click( fn=train_lora, inputs=[images, trigger_word, output_name, num_steps, learning_rate, lora_rank, resolution, batch_size], outputs=[output_file, output_log] ) gr.Markdown(""" --- ### Usage Instructions 1. **Upload Images**: Drag and drop 6-20 training images 2. **Set Trigger Word**: Choose a unique trigger word (e.g., ``) 3. **Configure Settings**: Adjust training parameters (defaults work well for most cases) 4. **Start Training**: Click the button and wait (1-2 hours for 2000 steps) 5. **Download**: Get your LoRA file when training completes ### Estimated Costs (L4 GPU at $0.80/hr) - 1000 steps: ~$0.40-0.60 - 2000 steps: ~$0.80-1.20 - 3000 steps: ~$1.20-1.80 ### After Training Use your LoRA with ComfyUI or diffusers: ```python pipe.load_lora_weights("your_lora.safetensors") image = pipe(", your prompt here").images[0] ``` """) if __name__ == "__main__": demo.launch()