import os import yaml import random def load_pid_config(): project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) pid_path = os.path.join(project_root, 'yaml', 'pid.yaml') with open(pid_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} def load_model_config(): project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) model_list_path = os.path.join(project_root, 'yaml', 'model_list.yaml') with open(model_list_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} def inject(assembler, chain_definition, chain_items): if not chain_items: return pid_config = {} try: pid_config = load_pid_config() or {} except Exception as e: print(f"Error loading PiD config: {e}") pid_items = pid_config.get("PiD", []) architectures_settings = {} default_settings = {"unet_name": "pid_flux1_1024_to_4096_4step_mxfp8.safetensors", "latent_format": "flux"} for item in pid_items: unet_name = item.get("filepath") latent_format = item.get("latent_format") archs = item.get("architectures", []) for arch in archs: architectures_settings[arch] = { "unet_name": unet_name, "latent_format": latent_format } if arch == "flux1": default_settings = { "unet_name": unet_name, "latent_format": latent_format } ksampler_name = chain_definition.get('ksampler_node', 'ksampler') if ksampler_name not in assembler.node_map: print(f"Warning: [PiD Injector] KSampler node '{ksampler_name}' not found. Skipping.") return original_ksampler_id = assembler.node_map[ksampler_name] original_vae_loader_id = assembler.node_map.get('vae_loader') original_vae_decode_id = assembler.node_map.get('vae_decode') original_pos_prompt_id = assembler.node_map.get('pos_prompt') original_neg_prompt_id = assembler.node_map.get('neg_prompt') if not original_vae_loader_id: for node_id, node_data in assembler.workflow.items(): if node_data.get('class_type') == 'VAELoader': original_vae_loader_id = node_id break if not original_vae_decode_id: for node_id, node_data in assembler.workflow.items(): if node_data.get('class_type') == 'VAEDecode': original_vae_decode_id = node_id break if not original_pos_prompt_id or not original_neg_prompt_id: for node_id, node_data in assembler.workflow.items(): if node_data.get('class_type') == 'CLIPTextEncode': title = node_data.get('_meta', {}).get('title', '') if 'Positive' in title: if not original_pos_prompt_id: original_pos_prompt_id = node_id elif 'Negative' in title: if not original_neg_prompt_id: original_neg_prompt_id = node_id pos_text = "" if original_pos_prompt_id and original_pos_prompt_id in assembler.workflow: pos_text = assembler.workflow[original_pos_prompt_id]['inputs'].get('text', '') neg_text = "" if original_neg_prompt_id and original_neg_prompt_id in assembler.workflow: neg_text = assembler.workflow[original_neg_prompt_id]['inputs'].get('text', '') clip_loader_id = assembler._get_unique_id() clip_loader_node = assembler._get_node_template("CLIPLoader") clip_loader_node['inputs']['clip_name'] = "gemma_2_2b_it_elm_fp8_scaled.safetensors" clip_loader_node['inputs']['type'] = "pixeldit" clip_loader_node['inputs']['device'] = "default" assembler.workflow[clip_loader_id] = clip_loader_node pos_text_encode_id = assembler._get_unique_id() pos_text_encode_node = assembler._get_node_template("CLIPTextEncode") pos_text_encode_node['inputs']['text'] = pos_text pos_text_encode_node['inputs']['clip'] = [clip_loader_id, 0] assembler.workflow[pos_text_encode_id] = pos_text_encode_node neg_text_encode_id = assembler._get_unique_id() neg_text_encode_node = assembler._get_node_template("CLIPTextEncode") neg_text_encode_node['inputs']['text'] = neg_text neg_text_encode_node['inputs']['clip'] = [clip_loader_id, 0] assembler.workflow[neg_text_encode_id] = neg_text_encode_node active_model_file = None for node_id, node_data in assembler.workflow.items(): class_type = node_data.get('class_type') if class_type == 'UNETLoader': active_model_file = node_data.get('inputs', {}).get('unet_name') if active_model_file: break elif class_type == 'CheckpointLoaderSimple': active_model_file = node_data.get('inputs', {}).get('ckpt_name') if active_model_file: break architecture = None if active_model_file: try: model_config = load_model_config() checkpoints = model_config.get("Checkpoints", {}) for arch_name, arch_data in checkpoints.items(): models_list = arch_data.get("models", []) for model_entry in models_list: if model_entry.get('path') == active_model_file: architecture = arch_name break components_dict = model_entry.get('components', {}) if active_model_file in components_dict.values(): architecture = arch_name break if architecture: break except Exception as e: print(f"Error looking up model architecture in PiD injector: {e}") if architecture: architecture = architecture.lower().replace(" ", "-").replace(".", "") else: file_lower = active_model_file.lower().replace("-", "").replace("_", "").replace(".", "") for arch in sorted(architectures_settings.keys(), key=len, reverse=True): candidates = [arch] if "-image" in arch: candidates.append(arch.replace("-image", "")) if "-i1" in arch: candidates.append(arch.replace("-i1", "")) if "-kv" in arch: candidates.append(arch.replace("-kv", "")) matched = False for cand in candidates: if cand.replace("-", "").replace(".", "") in file_lower: architecture = arch matched = True break if matched: break unet_name = default_settings.get("unet_name") latent_format = default_settings.get("latent_format") if architecture in architectures_settings: arch_config = architectures_settings[architecture] unet_name = arch_config.get("unet_name", unet_name) latent_format = arch_config.get("latent_format", latent_format) else: print(f"[PiD Injector] Warning: Model architecture '{architecture}' (file: '{active_model_file}') not explicitly mapped. Using default settings.") pid_pos_id = assembler._get_unique_id() pid_pos_node = assembler._get_node_template("PiDConditioning") pid_pos_node['inputs']['latent_format'] = latent_format pid_pos_node['inputs']['degrade_sigma'] = 0 pid_pos_node['inputs']['positive'] = [pos_text_encode_id, 0] pid_pos_node['inputs']['latent'] = [original_ksampler_id, 0] assembler.workflow[pid_pos_id] = pid_pos_node pid_neg_id = assembler._get_unique_id() pid_neg_node = assembler._get_node_template("PiDConditioning") pid_neg_node['inputs']['latent_format'] = latent_format pid_neg_node['inputs']['degrade_sigma'] = 0 pid_neg_node['inputs']['positive'] = [neg_text_encode_id, 0] pid_neg_node['inputs']['latent'] = [original_ksampler_id, 0] assembler.workflow[pid_neg_id] = pid_neg_node pid_unet_loader_id = assembler._get_unique_id() pid_unet_loader_node = assembler._get_node_template("UNETLoader") pid_unet_loader_node['inputs']['unet_name'] = unet_name pid_unet_loader_node['inputs']['weight_dtype'] = "default" assembler.workflow[pid_unet_loader_id] = pid_unet_loader_node orig_width = 1024 orig_height = 1024 original_latent_source_id = assembler.node_map.get('latent_source') if original_latent_source_id in assembler.workflow: node_inputs = assembler.workflow[original_latent_source_id].get('inputs', {}) if 'width' in node_inputs and 'height' in node_inputs: orig_width = node_inputs['width'] orig_height = node_inputs['height'] else: for node_data in assembler.workflow.values(): inputs = node_data.get('inputs', {}) if 'width' in inputs and 'height' in inputs and isinstance(inputs['width'], (int, float)) and isinstance(inputs['height'], (int, float)): if 256 <= inputs['width'] <= 4096 and 256 <= inputs['height'] <= 4096: orig_width = inputs['width'] orig_height = inputs['height'] break else: for node_data in assembler.workflow.values(): inputs = node_data.get('inputs', {}) if 'width' in inputs and 'height' in inputs and isinstance(inputs['width'], (int, float)) and isinstance(inputs['height'], (int, float)): if 256 <= inputs['width'] <= 4096 and 256 <= inputs['height'] <= 4096: orig_width = inputs['width'] orig_height = inputs['height'] break empty_latent_id = assembler._get_unique_id() empty_latent_node = assembler._get_node_template("EmptyChromaRadianceLatentImage") empty_latent_node['inputs']['width'] = int(orig_width) * 4 empty_latent_node['inputs']['height'] = int(orig_height) * 4 empty_latent_node['inputs']['batch_size'] = 1 if original_latent_source_id in assembler.workflow: orig_batch_size = assembler.workflow[original_latent_source_id]['inputs'].get('batch_size') or assembler.workflow[original_latent_source_id]['inputs'].get('amount') if orig_batch_size: empty_latent_node['inputs']['batch_size'] = orig_batch_size assembler.workflow[empty_latent_id] = empty_latent_node orig_seed = 0 if original_ksampler_id in assembler.workflow: orig_seed = assembler.workflow[original_ksampler_id]['inputs'].get('seed', 0) if orig_seed == -1: orig_seed = random.randint(0, 2**32 - 1) else: orig_seed = (orig_seed + 1) % (2**32) new_ksampler_id = assembler._get_unique_id() new_ksampler_node = assembler._get_node_template("KSampler") new_ksampler_node['inputs']['seed'] = orig_seed new_ksampler_node['inputs']['steps'] = 4 new_ksampler_node['inputs']['cfg'] = 1 new_ksampler_node['inputs']['sampler_name'] = "lcm" new_ksampler_node['inputs']['scheduler'] = "simple" new_ksampler_node['inputs']['denoise'] = 1.0 new_ksampler_node['inputs']['model'] = [pid_unet_loader_id, 0] new_ksampler_node['inputs']['positive'] = [pid_pos_id, 0] new_ksampler_node['inputs']['negative'] = [pid_neg_id, 0] new_ksampler_node['inputs']['latent_image'] = [empty_latent_id, 0] assembler.workflow[new_ksampler_id] = new_ksampler_node pid_vae_loader_id = assembler._get_unique_id() pid_vae_loader_node = assembler._get_node_template("VAELoader") pid_vae_loader_node['inputs']['vae_name'] = "pixel_space" assembler.workflow[pid_vae_loader_id] = pid_vae_loader_node pid_vae_decode_id = assembler._get_unique_id() pid_vae_decode_node = assembler._get_node_template("VAEDecode") pid_vae_decode_node['inputs']['samples'] = [new_ksampler_id, 0] pid_vae_decode_node['inputs']['vae'] = [pid_vae_loader_id, 0] assembler.workflow[pid_vae_decode_id] = pid_vae_decode_node if original_vae_decode_id: for node_id, node_data in assembler.workflow.items(): if 'inputs' in node_data: for input_name, input_val in list(node_data['inputs'].items()): if isinstance(input_val, list) and len(input_val) == 2: if input_val[0] == original_vae_decode_id: node_data['inputs'][input_name] = [pid_vae_decode_id, 0] if original_vae_loader_id in assembler.workflow: del assembler.workflow[original_vae_loader_id] if original_vae_decode_id in assembler.workflow: del assembler.workflow[original_vae_decode_id] print("[PiD Injector] Successfully injected PiD pipeline and replaced VAE decode/loader.")