import copy import json import math import weakref import os import re import sys from typing import List, Optional, Dict, Type, Union import torch from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel from transformers import CLIPTextModel from toolkit.models.lokr import LokrModule from .config_modules import NetworkConfig from .lorm import count_parameters from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin from toolkit.kohya_lora import LoRANetwork from toolkit.models.DoRA import DoRAModule from typing import TYPE_CHECKING if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") # diffusers specific stuff LINEAR_MODULES = [ 'Linear', 'LoRACompatibleLinear', 'QLinear', # 'GroupNorm', ] CONV_MODULES = [ 'Conv2d', 'LoRACompatibleConv', 'QConv2d', ] class IdentityModule(torch.nn.Module): def forward(self, x): return x class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None, rank_dropout=None, module_dropout=None, network: 'LoRASpecialNetwork' = None, use_bias: bool = False, is_ara: bool = False, **kwargs ): self.can_merge_in = True """if alpha == 0 or None, alpha is rank (no scaling).""" ToolkitModuleMixin.__init__(self, network=network) torch.nn.Module.__init__(self) self.lora_name = lora_name self.orig_module_ref = weakref.ref(org_module) self.scalar = torch.tensor(1.0, device=org_module.weight.device) # if is ara lora module, mark it on the layer so memory manager can handle it if is_ara: org_module.ara_lora_ref = weakref.ref(self) # check if parent has bias. if not force use_bias to False if org_module.bias is None: use_bias = False if org_module.__class__.__name__ in CONV_MODULES: in_dim = org_module.in_channels out_dim = org_module.out_channels else: in_dim = org_module.in_features out_dim = org_module.out_features # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim self.full_rank = network.network_type.lower() == "fullrank" if org_module.__class__.__name__ in CONV_MODULES: kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding if self.full_rank: self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False) self.lora_up = IdentityModule() else: self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) else: if self.full_rank: self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False) self.lora_up = IdentityModule() else: self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) if not self.full_rank: torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier # wrap the original module so it doesn't get weights updated self.org_module = [org_module] self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward # del self.org_module def _is_quantized_tensor(t) -> bool: # torchao stores quantized weights as tensor subclasses (e.g. AffineQuantizedTensor) under torchao.* # that are still nn.Parameter instances and expose .dequantize(). (quanto is intentionally not handled.) return 'torchao' in type(t).__module__ and hasattr(t, 'dequantize') def _dequantize_if_needed(t): return t.dequantize() if _is_quantized_tensor(t) else t class FullModule(ToolkitModuleMixin, torch.nn.Module): """ Full weight "lora" for layers that have no sensible low rank decomposition (norm layers, embeddings, stray biases, etc). It does not have an up/down projection. It holds a trainable delta that is added to the original weight (and bias) of the wrapped module. On save it emits `.diff` (and `.diff_b` for bias) which ComfyUI applies as `weight += strength * diff`, so it merges directly into the model weights without any extra adapter. If the wrapped module's weight is torchao-quantized, the delta is kept in full precision and the original weight is dequantized on the fly in the forward pass (the original quantized tensor is left untouched). """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, network: 'LoRASpecialNetwork' = None, **kwargs ): self.can_merge_in = True ToolkitModuleMixin.__init__(self, network=network) torch.nn.Module.__init__(self) self.lora_name = lora_name # keep the original module out of our state_dict (list hides it from nn.Module registration) self.org_module = [org_module] self.orig_module_ref = weakref.ref(org_module) self.multiplier: Union[float, List[float]] = multiplier # these are unused for full modules but the mixin/forward path expects them to exist self.dropout = None self.rank_dropout = None self.module_dropout = None self.is_checkpointing = False # trainable delta, zero initialized so an untrained layer is a no-op (zero diff) # dequantize first so the delta is full precision and shaped like the real (unpacked) weight self.weight_is_quantized = _is_quantized_tensor(org_module.weight) ref_weight = _dequantize_if_needed(org_module.weight) self.diff = torch.nn.Parameter(torch.zeros_like(ref_weight)) # some modules (e.g. Embedding) have no bias attribute at all org_bias = getattr(org_module, 'bias', None) if org_bias is not None: self.diff_b = torch.nn.Parameter(torch.zeros_like(org_bias)) else: self.diff_b = None def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward def forward(self, x, *args, **kwargs): network: 'LoRASpecialNetwork' = self.network_ref() skip = (not network.is_active) or network.is_merged_in or network._multiplier == 0 or network.is_lorm if skip: return self.org_forward(x, *args, **kwargs) om = self.org_module[0] multiplier = network.torch_multiplier # weight space application can't be done per sample, so use the mean (same as the DoRA path) mult = multiplier.mean() if isinstance(multiplier, torch.Tensor) else multiplier orig_weight = om._parameters['weight'] # dequantize quantized weights to full precision so the delta can be added (the original # quantized tensor is restored in the finally block below) base_weight = _dequantize_if_needed(orig_weight) eff_weight = base_weight + (self.diff.to(base_weight.device) * mult).to(base_weight.dtype) has_bias = self.diff_b is not None and om._parameters.get('bias', None) is not None if has_bias: orig_bias = om._parameters['bias'] eff_bias = orig_bias + (self.diff_b.to(orig_bias.device) * mult).to(orig_bias.dtype) # temporarily swap in the effective weights so the original forward (norm/linear/etc) uses them. # this keeps autograd flowing into our delta while supporting any layer type. om._parameters['weight'] = eff_weight if has_bias: om._parameters['bias'] = eff_bias try: out = self.org_forward(x, *args, **kwargs) finally: om._parameters['weight'] = orig_weight if has_bias: om._parameters['bias'] = orig_bias return out @torch.no_grad() def merge_in(self: 'FullModule', merge_weight=1.0): if not self.can_merge_in: return om = self.org_module[0] if 'weight._data' in om.state_dict(): # quanto quantized weight, can't merge return org_weight = om.weight orig_dtype = org_weight.dtype # dequantize torchao weights so we can fold the full precision delta in merged_weight = _dequantize_if_needed(org_weight).float() + merge_weight * self.diff.float().to(org_weight.device) if self.weight_is_quantized: # re-quantize so the model stays quantized across continuous merge/reset cycles from toolkit.util.quantize import get_torchao_config, requantize_module_weight requantize_module_weight(om, merged_weight, orig_dtype, get_torchao_config(self._get_base_qtype())) else: om.weight.data = merged_weight.to(org_weight.device, orig_dtype) # bias is never quantized if self.diff_b is not None and getattr(om, 'bias', None) is not None: om.bias.data = (om.bias.data.float() + merge_weight * self.diff_b.float().to(om.bias.device)).to(om.bias.dtype) def reset_weights(self: 'FullModule'): with torch.no_grad(): self.diff.zero_() if self.diff_b is not None: self.diff_b.zero_() class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"] UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"] # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" PEFT_PREFIX_UNET = "unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" def __init__( self, text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet, multiplier: float = 1.0, lora_dim: int = 4, alpha: float = 1, dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, block_dims: Optional[List[int]] = None, block_alphas: Optional[List[float]] = None, conv_block_dims: Optional[List[int]] = None, conv_block_alphas: Optional[List[float]] = None, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, train_text_encoder: Optional[bool] = True, use_text_encoder_1: bool = True, use_text_encoder_2: bool = True, train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, is_v3=False, is_pixart: bool = False, is_auraflow: bool = False, is_flux: bool = False, is_lumina2: bool = False, use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, only_if_contains = None, full_if_contains = None, parameter_threshold: float = 0.0, attn_only: bool = False, target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, network_type: str = "lora", full_train_in_out: bool = False, transformer_only: bool = False, peft_format: bool = False, is_assistant_adapter: bool = False, is_transformer: bool = False, base_model: 'StableDiffusion' = None, is_ara: bool = False, **kwargs ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り 1. lora_dimとalphaを指定 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する 5. modules_dimとmodules_alphaを指定 (推論用) """ # call the parent of the parent we are replacing (LoRANetwork) init torch.nn.Module.__init__(self) ToolkitNetworkMixin.__init__( self, train_text_encoder=train_text_encoder, train_unet=train_unet, is_sdxl=is_sdxl, is_v2=is_v2, is_lorm=is_lorm, **kwargs ) if ignore_if_contains is None: ignore_if_contains = [] self.ignore_if_contains = ignore_if_contains # full_if_contains: any layer (even linear/conv) whose name matches becomes a full weight # module instead of a normal lora module if full_if_contains is None: full_if_contains = [] elif isinstance(full_if_contains, str): full_if_contains = [full_if_contains] self.full_if_contains = full_if_contains self.transformer_only = transformer_only self.base_model_ref = None if base_model is not None: self.base_model_ref = weakref.ref(base_model) self.only_if_contains: Union[List, None] = only_if_contains self.lora_dim = lora_dim self.alpha = alpha self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False self._multiplier: float = 1.0 self.is_active: bool = False self.torch_multiplier = None # triggers the state updates self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 self.is_v3 = is_v3 self.is_pixart = is_pixart self.is_auraflow = is_auraflow self.is_flux = is_flux self.is_lumina2 = is_lumina2 self.network_type = network_type self.is_assistant_adapter = is_assistant_adapter self.full_rank = network_type.lower() == "fullrank" self.is_ara = is_ara if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule elif self.network_type.lower() == "lokr": self.module_class = LokrModule module_class = LokrModule self.network_config: NetworkConfig = kwargs.get("network_config", None) self.peft_format = peft_format self.is_transformer = is_transformer # use the old format for older models unless the user has specified otherwise self.use_old_lokr_format = False if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'): self.use_old_lokr_format = self.network_config.old_lokr_format # also allow a false from the model itself if base_model is not None and not base_model.use_old_lokr_format: self.use_old_lokr_format = False # always do peft for flux only for now if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: # don't do peft format for lokr if using old format if self.network_type.lower() != "lokr" or not self.use_old_lokr_format: self.peft_format = True if self.peft_format: # no alpha for peft self.alpha = self.lora_dim alpha = self.alpha self.conv_alpha = self.conv_lora_dim conv_alpha = self.conv_alpha self.full_train_in_out = full_train_in_out if modules_dim is not None: print(f"create LoRA network from weights") elif block_dims is not None: print(f"create LoRA network from block_dims") print( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") print(f"block_dims: {block_dims}") print(f"block_alphas: {block_alphas}") if conv_block_dims is not None: print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_alphas: {conv_block_alphas}") else: print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: print( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( is_unet: bool, text_encoder_idx: Optional[int], # None, 1, 2 root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" prefix = ( unet_prefix if is_unet else ( self.LORA_PREFIX_TEXT_ENCODER if text_encoder_idx is None else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) ) ) loras = [] skipped = [] attached_modules = [] lora_shape_dict = {} for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ in LINEAR_MODULES is_conv2d = child_module.__class__.__name__ in CONV_MODULES is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) lora_name = [prefix, name, child_name] # filter out blank lora_name = [x for x in lora_name if x and x != ""] lora_name = ".".join(lora_name) # if it doesnt have a name, it wil have two dots lora_name.replace("..", ".") clean_name = lora_name if self.peft_format: # we replace this on saving lora_name = lora_name.replace(".", "$$") else: lora_name = lora_name.replace(".", "_") # decide if this should be a full weight module instead of a normal lora. # - all_layers: every remaining weight bearing leaf that isn't linear/conv # (norm layers, embeddings, stray biases, etc) # - full_if_contains: any matching layer, INCLUDING linear/conv, overriding the # normal lora for it all_layers = self.network_config is not None and getattr(self.network_config, 'all_layers', False) is_leaf_with_weight = ( len(list(child_module.children())) == 0 and isinstance(getattr(child_module, 'weight', None), torch.nn.Parameter) ) matches_full_if_contains = len(self.full_if_contains) > 0 and ( any([word in clean_name for word in self.full_if_contains]) or any([word in lora_name for word in self.full_if_contains]) ) is_full_layer = is_leaf_with_weight and ( matches_full_if_contains or (all_layers and not is_linear and not is_conv2d) ) skip = False if any([word in clean_name for word in self.ignore_if_contains]): skip = True # see if it is over threshold if count_parameters(child_module) < parameter_threshold: skip = True if self.transformer_only and is_unet: transformer_block_names = None if base_model is not None: transformer_block_names = base_model.get_transformer_block_names() if transformer_block_names is not None: # match against clean_name (dotted) so block names can be # dotted paths (e.g. "model.language_model.layers"); lora_name # has dots replaced with "$$"/"_" and wouldn't match. if not any([block_name in clean_name for block_name in transformer_block_names]): skip = True else: if self.is_pixart: if "transformer_blocks" not in lora_name: skip = True if self.is_flux: if "transformer_blocks" not in lora_name: skip = True if self.is_lumina2: if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: skip = True if self.is_v3: if "transformer_blocks" not in lora_name: skip = True # handle custom models if hasattr(root_module, 'transformer_blocks'): if "transformer_blocks" not in lora_name: skip = True if hasattr(root_module, 'blocks'): if "blocks" not in lora_name: skip = True if hasattr(root_module, 'single_blocks'): if "single_blocks" not in lora_name and "double_blocks" not in lora_name: skip = True if (is_linear or is_conv2d) and not skip and not is_full_layer: if self.only_if_contains is not None: if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): continue dim = None alpha = None if modules_dim is not None: # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 if is_linear or is_conv2d_1x1 or ( self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) continue module_kwargs = {} if self.network_type.lower() == "lokr": module_kwargs["factor"] = self.network_config.lokr_factor if self.is_ara: module_kwargs["is_ara"] = True lora = module_class( lora_name, child_module, self.multiplier, dim, alpha, dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, network=self, parent=module, use_bias=use_bias, **module_kwargs ) loras.append(lora) if self.network_type.lower() == "lokr": try: lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] except: pass else: if self.full_rank: lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] else: lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] elif is_full_layer and not skip: if self.only_if_contains is not None: if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): continue lora = FullModule( lora_name, child_module, self.multiplier, network=self, parent=module, ) loras.append(lora) lora_shape_dict[lora_name] = [list(lora.diff.shape)] return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] # create LoRA for text encoder # 毎回すべてのモジュールを作るのは無駄なので要検討 self.text_encoder_loras = [] skipped_te = [] if train_text_encoder: for i, text_encoder in enumerate(text_encoders): if not use_text_encoder_1 and i == 0: continue if not use_text_encoder_2 and i == 1: continue if len(text_encoders) > 1: index = i + 1 print(f"create LoRA for Text Encoder {index}:") else: index = None print(f"create LoRA for Text Encoder:") replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE if self.is_pixart: replace_modules = ["T5EncoderModel"] text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = target_lin_modules if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += target_conv_modules if is_v3: target_modules = ["SD3Transformer2DModel"] if is_pixart: target_modules = ["PixArtTransformer2DModel"] if is_auraflow: target_modules = ["AuraFlowTransformer2DModel"] if is_flux: target_modules = ["FluxTransformer2DModel"] if is_lumina2: target_modules = ["Lumina2Transformer2DModel"] if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) else: self.unet_loras = [] skipped_un = [] print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: print( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: print(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None self.mid_lr_weight: float = None self.block_lr = False # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) if self.full_train_in_out: print("full train in out") # we are going to retrain the main in out layers for VAE change usually if self.is_pixart: transformer: PixArtTransformer2DModel = unet self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) self.transformer_proj_out = copy.deepcopy(transformer.proj_out) transformer.pos_embed = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out elif self.is_auraflow: transformer: AuraFlowTransformer2DModel = unet self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) self.transformer_proj_out = copy.deepcopy(transformer.proj_out) transformer.pos_embed = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out elif base_model is not None and base_model.arch == "wan21": transformer: WanTransformer3DModel = unet self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding) self.transformer_proj_out = copy.deepcopy(transformer.proj_out) transformer.patch_embedding = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out else: unet: UNet2DConditionModel = unet unet_conv_in: torch.nn.Conv2d = unet.conv_in unet_conv_out: torch.nn.Conv2d = unet.conv_out # clone these and replace their forwards with ours self.unet_conv_in = copy.deepcopy(unet_conv_in) self.unet_conv_out = copy.deepcopy(unet_conv_out) unet.conv_in = self.unet_conv_in unet.conv_out = self.unet_conv_out def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # call Lora prepare_optimizer_params all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) if self.full_train_in_out: base_model = self.base_model_ref() if self.base_model_ref is not None else None if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"): all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) else: all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())}) return all_params