# ============================================================================= # ENVIRONMENT SETUP # ============================================================================= import spaces import shutil import os os.environ["HF_HOME"] = "/tmp/hf" os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" def safe_rmtree(path): """Delete folder and confirm deletion.""" try: shutil.rmtree(path, ignore_errors=True) except Exception as e: print(f"❌ Failed to delete {path}: {e}") if not os.path.exists(path): print(f"✅ Successfully deleted {path}") else: print(f"⚠️ Still exists: {path}") safe_rmtree("/tmp/hf") # ============================================================================= # IMPORTS # ============================================================================= import gc import random import tempfile import torch print(f"PyTorch: {torch.__version__}") print(f"CUDA: {torch.cuda.is_available()}") import gradio as gr import aoti import sys from PIL import Image from torch.nn.attention import sdpa_kernel, SDPBackend from huggingface_hub import snapshot_download, hf_hub_download from diffusers import WanImageToVideoPipeline, WanTransformer3DModel from diffusers.utils import export_to_video from diffusers import DPMSolverMultistepScheduler, AutoencoderKLWan from transformers import UMT5EncoderModel from torchao.quantization import quantize_, Int8WeightOnlyConfig import logging from datetime import datetime from pathlib import Path current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "MMAudio")) from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # ============================================================================= # AUTH & CONSTANTS # ============================================================================= num_blocks = 8 HF_TOKEN = os.getenv("HF_TOKEN") MY_VAULT_REPO = "ibyteohdear/mmaudio-weights-vault" MODEL_ID = "ibyteohdear/Wan2.2-I2V-14B-Lightning" AOTI_REPO = "ibyteohdear/Wan2.2-AOTI-Weights-zero-lora-v1" SAFE_REPO = "ibyteohdear/Wan2.2-I2V-14B-Lightning-Safe" STEP_REPO = "ibyteohdear/wan2.2-i2v-lightx2v-260412" LORA_REPO = "ibyteohdear/Wan-2.2-LoRA" device = "cuda" dtype = torch.bfloat16 # ============================================================================= # PIPELINE # ============================================================================= print("📥 Downloading transformer...") CORE_PATH = snapshot_download( repo_id=SAFE_REPO, token=HF_TOKEN, allow_patterns=["transformer/*"], ) transformer = WanTransformer3DModel.from_pretrained( CORE_PATH, subfolder="transformer", torch_dtype=dtype, #device_map="cuda", ) gc.collect() print("📥 Downloading transformer 2...") CORE_PATH_2 = snapshot_download( repo_id=SAFE_REPO, token=HF_TOKEN, allow_patterns=["transformer_2/*"], ) transformer_2 = WanTransformer3DModel.from_pretrained( CORE_PATH_2, subfolder="transformer_2", torch_dtype=dtype, #device_map="cuda", ) gc.collect() print("📥 Downloading text encoder...") CORE_PATH_3 = snapshot_download( repo_id=MODEL_ID, token=HF_TOKEN, allow_patterns=["text_encoder/*"], ) text_encoder = UMT5EncoderModel.from_pretrained( CORE_PATH_3, subfolder="text_encoder", torch_dtype=dtype, #device_map="cuda", ) gc.collect() print("📥 Downloading vae...") CORE_PATH_4 = snapshot_download( repo_id=MODEL_ID, token=HF_TOKEN, allow_patterns=["vae/*"], ) vae = AutoencoderKLWan.from_pretrained( CORE_PATH_4, subfolder="vae", torch_dtype=dtype, #device_map="cpu", ) safe_rmtree("/tmp/hf") gc.collect() pipeline = WanImageToVideoPipeline.from_pretrained( MODEL_ID, transformer=transformer, transformer_2=transformer_2, text_encoder=text_encoder, vae=vae, torch_dtype=dtype, #device_map="balanced", ) torch.cuda.empty_cache() gc.collect() pipeline.to(device) # ============================================================================= # LORA # ============================================================================= print("📥 Downloading loras...") def _lora1(path_in_repo: str) -> str: return hf_hub_download( token=HF_TOKEN, repo_id=STEP_REPO, filename=path_in_repo ) def _lora2(path_in_repo: str) -> str: return hf_hub_download( token=HF_TOKEN, repo_id=LORA_REPO, filename=path_in_repo ) LightX2V_LORAS = { "LightX2V": { "high": ( _lora1("wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_720p_260412.safetensors"), "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_720p_260412.safetensors", ), "low": ( _lora1("wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_720p_260412.safetensors"), "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_720p_260412.safetensors", ), }, } AVAILABLE_LORAS = { "None": None, "TiltUpOverhead": { "high": ( _lora2("highnoiseCamera20tiltUp20overhead.mxnq.safetensors"), "highnoiseCamera20tiltUp20overhead.mxnq.safetensors", ), "low": ( _lora2("lownoiseCamera20tiltUp20overhead.H9gl.safetensors"), "lownoiseCamera20tiltUp20overhead.H9gl.safetensors", ), }, "TiltDownUndershot": { "high": ( _lora2("v2HighCamera20tiltDown20undershot.rVOE.safetensors"), "v2HighCamera20tiltDown20undershot.rVOE.safetensors", ), "low": ( _lora2("v2Camera20tiltDown20undershotE5.alDA.safetensors"), "v2Camera20tiltDown20undershotE5.alDA.safetensors", ), }, "AnimeStyle": { "high": ( _lora2("wan2.2_i2v_animestyle_v2_high.safetensors"), "wan2.2_i2v_animestyle_v2_high.safetensors", ), "low": ( _lora2("wan2.2_i2v_animestyle_v2_low.safetensors"), "wan2.2_i2v_animestyle_v2_low.safetensors", ), }, "Walking": { "high": ( _lora2("wlkng high 260512.safetensors"), "wlkng high 260512.safetensors", ), "low": ( _lora2("wlkng low 260512.safetensors"), "wlkng low 260512.safetensors", ), }, } print("💾 Preloading LoRAs...") ALL_LORAS = { **LightX2V_LORAS, **AVAILABLE_LORAS, } for name, pair in ALL_LORAS.items(): if name == "None": continue high_path, high_weight = pair["high"] pipeline.load_lora_weights( high_path, weight_name=high_weight, adapter_name=f"{name}_high", low_cpu_mem_usage=True, torch_dtype=dtype, ) low_path, low_weight = pair["low"] pipeline.load_lora_weights( low_path, weight_name=low_weight, adapter_name=f"{name}_low", low_cpu_mem_usage=True, torch_dtype=dtype, ) pipeline.disable_lora() # ============================================================================= # QUANTIZE # ============================================================================= print("💾 Quantize...") quantize_(pipeline.text_encoder, Int8WeightOnlyConfig(version=2)) quantize_( pipeline.transformer.blocks[num_blocks:], Int8WeightOnlyConfig(version=2) ) quantize_( pipeline.transformer_2.blocks[num_blocks:], Int8WeightOnlyConfig(version=2) ) # ============================================================================= # AOTI # ============================================================================= aoti.aoti_blocks_load( pipeline.transformer, AOTI_REPO, num_blocks=num_blocks, subfolder="transformer" ) aoti.aoti_blocks_load( pipeline.transformer_2, AOTI_REPO, num_blocks=num_blocks, subfolder="transformer_2" ) # ============================================================================= # SETTINGS # ============================================================================= #pipeline.enable_model_cpu_offload() #pipeline.vae.enable_tiling() #pipeline.vae.enable_slicing() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True #pipeline.enable_attention_slicing() pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) torch.set_float32_matmul_precision("high") # ============================================================================= # MMAUDIO # ============================================================================= log = logging.getLogger() device_mm = 'cuda' dtype_mm = torch.bfloat16 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() output_dir.mkdir(exist_ok=True, parents=True) def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device_mm, dtype_mm).eval() net.load_weights(torch.load(model.model_path, map_location=device_mm, weights_only=True)) log.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils( tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False ) feature_utils = feature_utils.to(device_mm, dtype_mm).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = get_model() @torch.inference_mode() def video_to_audio(video, prompt, negative_prompt, seed, num_steps, cfg_strength, duration): rng = torch.Generator(device=device_mm) rng.manual_seed(seed) if seed >= 0 else rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) video_info = load_video(video, duration) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) seq_cfg.duration = video_info.duration_sec net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audio = generate( clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength ).float().cpu()[0] output_dir.mkdir(exist_ok=True, parents=True) path = output_dir / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4" make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) gc.collect() return path # ============================================================================= # GENERATION FUNCTION # ============================================================================= def aligned_num_frames(duration, fps=16): n = int(duration * fps) return ((n - 1) // 4) * 4 + 1 def aspect_correct_pad(image, target_width=832, target_height=512): scale = min(target_width / image.width, target_height / image.height) new_w = int(image.width * scale) new_h = int(image.height * scale) resized_img = image.resize((new_w, new_h), Image.LANCZOS) canvas = Image.new("RGB", (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_w) // 2 paste_y = (target_height - new_h) // 2 canvas.paste(resized_img, (paste_x, paste_y)) return canvas def apply_lora( pipe, lora_name, high_strength, low_strength, ): high_strength = min(max(high_strength, 0.0), 1.0) low_strength = min(max(low_strength, 0.0), 1.0) lightning_high = max(0.2, 1.0 - high_strength) lightning_low = max(0.2, 1.0 - low_strength) pipe.enable_lora() high_adapters = ["LightX2V_high"] high_weights = [lightning_high] low_adapters = ["LightX2V_low"] low_weights = [lightning_low] if lora_name != "None": high_adapters.append(f"{lora_name}_high") high_weights.append(high_strength) low_adapters.append(f"{lora_name}_low") low_weights.append(low_strength) pipe.transformer.set_adapters( high_adapters, high_weights, ) pipe.transformer_2.set_adapters( low_adapters, low_weights, ) LORA_NAMES = list(AVAILABLE_LORAS.keys()) @spaces.GPU(duration=1000) def infer( image, prompt, negative_prompt, prompt_audio, negative_prompt_audio, duration, guidance, randomize, lora_name, high_strength, low_strength, progress=gr.Progress(track_tqdm=True) ): if image is None: raise gr.Error("Upload an image") progress(0, desc="Starting...") width, height = 832, 512 image = aspect_correct_pad(image, width, height) FPS = 16 num_frames = aligned_num_frames(duration, FPS) seed = random.randint(0, 1_000_000_000) if randomize else int(42) generator=torch.Generator(device=device).manual_seed(seed) apply_lora( pipeline, lora_name, high_strength, low_strength, ) progress(0.05, desc="Performing diffusion... this could take a awhile..") with torch.no_grad(): with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): frames = pipeline( image=image, width=width, height=height, prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, num_inference_steps=int(16), generator=generator, guidance_scale=float(guidance), ).frames[0] progress(0.50, desc="Encoding video...") out = tempfile.mktemp(suffix=".mp4") export_to_video(frames, out, fps=FPS) torch.cuda.empty_cache() gc.collect() path = video_to_audio(out, prompt_audio, negative_prompt_audio, 1, 25, 4.5, duration) progress(1.0, desc="Done!") return path # ============================================================================= # GRADIO UI # ============================================================================= with gr.Blocks() as demo: gr.Markdown("# 🎬 Wan 2.2 I2V 14B (AOTI Optimized Lightning MMaudio LoRAs)") gr.Markdown("Upload an image and enter a text prompt. Adjust settings below to control video generation.") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input Image") gr.Markdown("Upload the starting image for the video. This will be animated according to your prompt.") prompt_txt_video = gr.Textbox( lines=3, label="Prompt Video", ) gr.Markdown("Describe what you want in the video. Be as detailed as needed.") neg_prompt_txt_video = gr.Textbox(lines=3, label="Negative Prompt Video", value="low quality, deformed") gr.Markdown("Describe what you do not want in the video. Be as detailed as needed.") prompt_txt_audio = gr.Textbox( lines=3, label="Prompt Audio", ) gr.Markdown("Describe what you want in the audio. Be as detailed as needed.") neg_prompt_txt_audio = gr.Textbox(lines=3, label="Negative Prompt Audio", value="music") gr.Markdown("Describe what you do not want in the audio. Be as detailed as needed.") with gr.Accordion("🎨 Motion LoRA", open=False): lora_dropdown = gr.Dropdown( choices=LORA_NAMES, value="None", label="Motion Style", ) high_strength = gr.Slider( 0.0, 1.0, value=0.6, step=0.05, label="High Noise Strength", ) low_strength = gr.Slider( 0.0, 1.0, value=0.6, step=0.05, label="Low Noise Strength", ) with gr.Accordion("Settings", open=True): dur_slider = gr.Slider(1, 5, value=5, step=0.1, label="Duration (seconds)") gr.Markdown("Controls the length of the video. Longer durations generate more frames and require more compute.") guidance_slider = gr.Slider(1.0, 6.0, value=3, step=0.1, label="Guidance Strength") gr.Markdown("How strongly the model follows the prompt. Lower values are more natural and fluid, higher values are more literal and stylized.") rand_check = gr.Checkbox(value=True, label="Randomize Seed") gr.Markdown("Toggle this to get a fresh variation every run. Turn it off to reuse the same motion and style.") gen_btn = gr.Button("Generate Video", variant="primary") with gr.Column(): output_vid = gr.Video(label="Generated Video") gr.Examples( examples=[ "They all drink a beer, and cheer!", "Animate this cat playing", ], inputs=[prompt_txt_video], example_labels=[ "They all drink a beer, and cheer!", "Animate this cat playing", ], label="💡 Try these:" ) gen_btn.click( infer, inputs=[input_img, prompt_txt_video, neg_prompt_txt_video, prompt_txt_audio, neg_prompt_txt_audio, dur_slider, guidance_slider, rand_check, lora_dropdown, high_strength, low_strength], outputs=[output_vid], ) gr.Markdown(""" # 🎥 Wan 2.2 Motion LoRA Pack A collection of cinematic motion LoRAs for Wan 2.2 image-to-video workflows, featuring dramatic camera movements, anime-style enhancement, and natural walking animation. - **Camera Tilt-up Overhead** — Smooth crane-up camera motion transitioning into a cinematic top-down overhead shot. - **Camera Tilt-down Undershot** — Dramatic descending low-angle camera movement for powerful undershot perspectives. - **Anime Style WAN 2.2 I2V** — Enhances anime visuals with cleaner motion, improved consistency, and stylized animation flow. - **Walking Motion LoRA** — Adds natural walking movement and realistic body motion for character animation. Ideal for cinematic anime videos, dynamic camera shots, and advanced Wan 2.2 I2V workflows. """) if __name__ == "__main__": demo.queue().launch(debug=True)