# app.py # --- IMPORTS: spaces en tout premier --- import spaces # obligatoire pour ZeroGPU avant torch import os os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import gc import copy import random import tempfile import warnings import uuid import time import numpy as np import torch from PIL import Image from tqdm import tqdm import gradio as gr from diffusers import ( FlowMatchEulerDiscreteScheduler, SASolverScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, ) from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.utils.export_utils import export_to_video from torchao.quantization import ( quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig, ) import aoti warnings.filterwarnings("ignore") # --- CONSTANTS --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_DIM = 768 MIN_DIM = 480 SQUARE_DIM = 640 MULTIPLE_OF = 16 FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 96 MAX_SEED = np.iinfo(np.int32).max DEFAULT_PROMPT = "make this image come alive, cinematic motion, smooth animation" # --- VRAM CLEAR FUNCTION --- def clear_vram(): gc.collect() torch.cuda.empty_cache() # --- LOAD PIPELINE --- print("Loading Wan 2.2 I2V pipeline...") pipe = WanImageToVideoPipeline.from_pretrained( "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING", torch_dtype=torch.bfloat16, ) original_scheduler = copy.deepcopy(pipe.scheduler) # --- MEMORY OPTIMIZATIONS --- pipe.enable_model_cpu_offload() pipe.enable_attention_slicing("max") pipe.vae.enable_slicing() pipe.vae.enable_tiling() try: pipe.enable_xformers_memory_efficient_attention() print("xFormers enabled") except Exception: print("xFormers not installed, using default attention") # --- QUANTIZATION --- quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) # --- AOTI BLOCKS LOAD --- aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da") aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da") # --- IMAGE RESIZE --- def resize_image(image: Image.Image): width, height = image.size if width == height: return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) aspect_ratio = width / height if width > height: target_w = MAX_DIM target_h = int(target_w / aspect_ratio) else: target_h = MAX_DIM target_w = int(target_h * aspect_ratio) final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF final_w = max(MIN_DIM, min(MAX_DIM, final_w)) final_h = max(MIN_DIM, min(MAX_DIM, final_h)) return image.resize((final_w, final_h), Image.LANCZOS) # --- NUMBER OF FRAMES --- def get_num_frames(duration_seconds): return 1 + int(np.clip( int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL, )) # --- RUN INFERENCE (ZeroGPU ready) --- @spaces.GPU(duration=lambda *args, **kwargs: 10) def run_inference(resized_image, prompt, steps, num_frames, guidance_scale, guidance_scale_2, seed): clear_vram() generator = torch.Generator(device="cuda").manual_seed(seed) result = pipe( image=resized_image, prompt=prompt, height=resized_image.height, width=resized_image.width, num_frames=num_frames, guidance_scale=guidance_scale, guidance_scale_2=guidance_scale_2, num_inference_steps=steps, generator=generator, output_type="np" ) frames = result.frames[0] del result clear_vram() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: export_to_video(frames, tmp.name, fps=FIXED_FPS) return tmp.name # --- GENERATE VIDEO --- def generate_video(input_image, prompt=DEFAULT_PROMPT, duration_seconds=3.0, steps=6, guidance_scale=1.0, guidance_scale_2=1.0, seed=42): if input_image is None: raise gr.Error("Please upload an input image") resized = resize_image(input_image) num_frames = get_num_frames(duration_seconds) video_path = run_inference(resized, prompt, steps, num_frames, guidance_scale, guidance_scale_2, seed) return video_path # --- GRADIO UI --- with gr.Blocks() as demo: gr.Markdown("# Wan 2.2 I2V - ZeroGPU 16GB VRAM Ready") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt") duration_slider = gr.Slider(1,6,value=3,label="Duration (s)") steps_slider = gr.Slider(1,12,value=6,label="Steps") g1_slider = gr.Slider(0,5,value=1,label="Guidance 1") g2_slider = gr.Slider(0,5,value=1,label="Guidance 2") seed_slider = gr.Slider(0,MAX_SEED,value=42,label="Seed") generate_btn = gr.Button("Generate Video") with gr.Column(): video_output = gr.Video(label="Generated Video") generate_btn.click( generate_video, inputs=[image_input, prompt_input, duration_slider, steps_slider, g1_slider, g2_slider, seed_slider], outputs=video_output ) # --- LAUNCH --- if __name__ == "__main__": demo.queue().launch(show_error=True)