import os from torch import nn from huggingface_hub import hf_hub_download from spaces.zero.torch.aoti import ( ZeroGPUCompiledModel, ZeroGPUWeights, ) from torch._functorch._aot_autograd.subclass_parametrization import ( unwrap_tensor_subclass_parameters, ) HF_TOKEN = os.getenv("HF_TOKEN") class AOTIWrappedBlock(nn.Module): def __init__(self, block, compiled_model): super().__init__() self.compiled = compiled_model def forward( self, hidden_states, encoder_hidden_states, temb, rotary_emb, ): rotary_emb_0, rotary_emb_1 = rotary_emb return self.compiled( hidden_states, encoder_hidden_states, temb, rotary_emb_0, rotary_emb_1, ) def aoti_blocks_load( transformer_module, repo_id, num_blocks, subfolder="transformer", ): for i in range(num_blocks): block = transformer_module.blocks[i] try: print(f"🔄 Loading AOTI block {i}...") aoti_file = hf_hub_download( repo_id=repo_id, filename=f"{subfolder}_block_{i}.pt2", token=HF_TOKEN, ) unwrap_tensor_subclass_parameters(block) # ------------------------------------------------- # Clean PEFT / LoRA wrapped parameter names # ------------------------------------------------- base_state = {} for k, v in block.state_dict().items(): # Skip actual LoRA tensors if ".lora_" in k: continue # Restore original export-time names k = k.replace(".base_layer", "") k = k.replace(".original_module", "") base_state[k] = v weights = ZeroGPUWeights(base_state) compiled = ZeroGPUCompiledModel( aoti_file, weights, ) transformer_module.blocks[i] = AOTIWrappedBlock( block, compiled, ) print(f"✅ AOTI enabled for block {i}") except Exception as e: print( f"⚠️ Failed to load AOTI for block {i}. " f"Using original block. Error: {e}" ) transformer_module.blocks[i] = block