# -------------------------- # Wan-AI/Wan2.2-I2V-A14B-Diffusers # -------------------------- import spaces import os import gc import shutil from PIL import Image import torch print(f"PyTorch: {torch.__version__}") print(f"CUDA: {torch.cuda.is_available()}") from huggingface_hub import HfApi, login, snapshot_download from diffusers import ( WanTransformer3DModel, WanImageToVideoPipeline, DPMSolverMultistepScheduler, AutoencoderKLWan, ) from transformers import UMT5EncoderModel from torch._inductor import aoti_compile_and_package from torch._inductor import config from torch.export import Dim from torchao.quantization import ( quantize_, Int8WeightOnlyConfig, ) # --------------------------------------------------- # Config # --------------------------------------------------- MODEL_ID = "ibyteohdear/Wan2.2-I2V-14B-Lightning" SAFE_REPO = "ibyteohdear/Wan2.2-I2V-14B-Lightning-Safe" AOTI_REPO = "ibyteohdear/Wan2.2-AOTI-Weights-zero-lora-v1" device = "cuda" dtype = torch.bfloat16 num_blocks = 8 subfolder = "transformer" config.max_autotune_gemm = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True width = 784 height = 480 # --------------------------------------------------- # Helpers # --------------------------------------------------- def safe_rmtree(path): try: shutil.rmtree(path, ignore_errors=True) except Exception as e: print(f"āŒ Failed to delete {path}: {e}") if not os.path.exists(path): print(f"āœ… Successfully deleted {path}") else: print(f"āš ļø Still exists: {path}") def aligned_num_frames(duration, fps=16): n = int(duration * fps) return ((n - 1) // 4) * 4 + 1 def aspect_correct_pad(image, target_width=width, target_height=height): scale = min(target_width / image.width, target_height / image.height) new_w = int(image.width * scale) new_h = int(image.height * scale) resized_img = image.resize((new_w, new_h), Image.LANCZOS) canvas = Image.new( "RGB", (target_width, target_height), (0, 0, 0), ) paste_x = (target_width - new_w) // 2 paste_y = (target_height - new_h) // 2 canvas.paste(resized_img, (paste_x, paste_y)) return canvas # --------------------------------------------------- # Cleanup # --------------------------------------------------- safe_rmtree("/tmp/hf") # --------------------------------------------------- # Login # --------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) # --------------------------------------------------- # HF API # --------------------------------------------------- api = HfApi() api.create_repo( repo_id=AOTI_REPO, private=False, exist_ok=True, ) # --------------------------------------------------- # Download transformer # --------------------------------------------------- print("šŸ“„ Downloading transformer...") CORE_PATH = snapshot_download( repo_id=SAFE_REPO, token=HF_TOKEN, allow_patterns=["transformer/*"], ) transformer = WanTransformer3DModel.from_pretrained( CORE_PATH, subfolder="transformer", torch_dtype=dtype, ) gc.collect() # --------------------------------------------------- # Download transformer_2 # --------------------------------------------------- print("šŸ“„ Downloading transformer_2...") CORE_PATH_2 = snapshot_download( repo_id=SAFE_REPO, token=HF_TOKEN, allow_patterns=["transformer_2/*"], ) transformer_2 = WanTransformer3DModel.from_pretrained( CORE_PATH_2, subfolder="transformer_2", torch_dtype=dtype, ) gc.collect() # --------------------------------------------------- # Download text encoder # --------------------------------------------------- print("šŸ“„ Downloading text encoder...") CORE_PATH_3 = snapshot_download( repo_id=MODEL_ID, token=HF_TOKEN, allow_patterns=["text_encoder/*"], ) text_encoder = UMT5EncoderModel.from_pretrained( CORE_PATH_3, subfolder="text_encoder", torch_dtype=dtype, ) gc.collect() # --------------------------------------------------- # Download VAE # --------------------------------------------------- print("šŸ“„ Downloading VAE...") CORE_PATH_4 = snapshot_download( repo_id=MODEL_ID, token=HF_TOKEN, allow_patterns=["vae/*"], ) vae = AutoencoderKLWan.from_pretrained( CORE_PATH_4, subfolder="vae", torch_dtype=dtype, ) gc.collect() safe_rmtree("/tmp/hf") # --------------------------------------------------- # Build pipeline # --------------------------------------------------- pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, transformer=transformer, transformer_2=transformer_2, text_encoder=text_encoder, vae=vae, torch_dtype=dtype, ) pipe.scheduler = DPMSolverMultistepScheduler.from_config( pipe.scheduler.config ) pipe.to(device) torch.cuda.empty_cache() gc.collect() # --------------------------------------------------- # Quantize ONLY non-exported blocks # --------------------------------------------------- print("šŸ’¾ Quantizing remaining blocks...") quantize_( pipe.text_encoder, Int8WeightOnlyConfig(version=2), ) quantize_( pipe.transformer.blocks[num_blocks:], Int8WeightOnlyConfig(version=2), ) quantize_( pipe.transformer_2.blocks[num_blocks:], Int8WeightOnlyConfig(version=2), ) # --------------------------------------------------- # Export helper # --------------------------------------------------- def export_block(block, block_idx, saved_inputs): print(f"\nšŸš€ Exporting block {block_idx}") s_video = Dim.AUTO s_text = Dim.AUTO dynamic_shapes = { "hidden_states": {1: s_video}, "encoder_hidden_states": {1: s_text}, "temb": None, "rotary_emb": ( {1: s_video}, {1: s_video}, ), } print("šŸ“¦ Exporting graph...") with torch.no_grad(): exported = torch.export.export( block, args=saved_inputs, dynamic_shapes=dynamic_shapes, ) print("āš™ļø Compiling AOTI...") pt2_path = f"{subfolder}_block_{block_idx}.pt2" aoti_compile_and_package( exported, package_path=pt2_path, ) print("ā˜ļø Uploading...") api.upload_file( path_or_fileobj=pt2_path, path_in_repo=pt2_path, repo_id=AOTI_REPO, ) os.remove(pt2_path) print(f"āœ… Finished block {block_idx}") # --------------------------------------------------- # Single-pass probe # --------------------------------------------------- @spaces.GPU(duration=1000) def probe(): if subfolder == "transformer": active_transformer = pipe.transformer elif subfolder == "transformer_2": active_transformer = pipe.transformer_2 else: raise ValueError(f"Unknown subfolder: {subfolder}") print("\nšŸ” Capturing block inputs...") all_saved_inputs = {} handles = [] for block_idx in range(num_blocks): block = active_transformer.blocks[block_idx] def make_hook(idx): def hook_fn(module, inputs, output): safe_inputs = [] for x in inputs: if torch.is_tensor(x): safe_inputs.append( x.detach().clone() ) else: safe_inputs.append(x) all_saved_inputs[idx] = tuple(safe_inputs) print(f"šŸ“Œ Captured block {idx}") return hook_fn handle = block.register_forward_hook( make_hook(block_idx) ) handles.append(handle) dummy_image = Image.new( "RGB", (width, height), "white", ) image = aspect_correct_pad( dummy_image, width, height, ) FPS = 16 duration = 5 num_frames = aligned_num_frames( duration, FPS, ) print("\nšŸš€ Running single probe inference...") with torch.no_grad(): pipe( image=image, width=width, height=height, prompt="probe", num_frames=num_frames, num_inference_steps=16, guidance_scale=3, output_type="latent", ) for h in handles: h.remove() print("\nāœ… Finished capture") for block_idx in range(num_blocks): export_block( active_transformer.blocks[block_idx], block_idx, all_saved_inputs[block_idx], ) if __name__ == "__main__": probe()