| import os | |
| import yaml | |
| import random | |
| def load_pid_config(): | |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| pid_path = os.path.join(project_root, 'yaml', 'pid.yaml') | |
| with open(pid_path, 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f) or {} | |
| def load_model_config(): | |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| model_list_path = os.path.join(project_root, 'yaml', 'model_list.yaml') | |
| with open(model_list_path, 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f) or {} | |
| def inject(assembler, chain_definition, chain_items): | |
| if not chain_items: | |
| return | |
| pid_config = {} | |
| try: | |
| pid_config = load_pid_config() or {} | |
| except Exception as e: | |
| print(f"Error loading PiD config: {e}") | |
| pid_items = pid_config.get("PiD", []) | |
| architectures_settings = {} | |
| default_settings = {"unet_name": "pid_flux1_1024_to_4096_4step_mxfp8.safetensors", "latent_format": "flux"} | |
| for item in pid_items: | |
| unet_name = item.get("filepath") | |
| latent_format = item.get("latent_format") | |
| archs = item.get("architectures", []) | |
| for arch in archs: | |
| architectures_settings[arch] = { | |
| "unet_name": unet_name, | |
| "latent_format": latent_format | |
| } | |
| if arch == "flux1": | |
| default_settings = { | |
| "unet_name": unet_name, | |
| "latent_format": latent_format | |
| } | |
| ksampler_name = chain_definition.get('ksampler_node', 'ksampler') | |
| if ksampler_name not in assembler.node_map: | |
| print(f"Warning: [PiD Injector] KSampler node '{ksampler_name}' not found. Skipping.") | |
| return | |
| original_ksampler_id = assembler.node_map[ksampler_name] | |
| original_vae_loader_id = assembler.node_map.get('vae_loader') | |
| original_vae_decode_id = assembler.node_map.get('vae_decode') | |
| original_pos_prompt_id = assembler.node_map.get('pos_prompt') | |
| original_neg_prompt_id = assembler.node_map.get('neg_prompt') | |
| if not original_vae_loader_id: | |
| for node_id, node_data in assembler.workflow.items(): | |
| if node_data.get('class_type') == 'VAELoader': | |
| original_vae_loader_id = node_id | |
| break | |
| if not original_vae_decode_id: | |
| for node_id, node_data in assembler.workflow.items(): | |
| if node_data.get('class_type') == 'VAEDecode': | |
| original_vae_decode_id = node_id | |
| break | |
| if not original_pos_prompt_id or not original_neg_prompt_id: | |
| for node_id, node_data in assembler.workflow.items(): | |
| if node_data.get('class_type') == 'CLIPTextEncode': | |
| title = node_data.get('_meta', {}).get('title', '') | |
| if 'Positive' in title: | |
| if not original_pos_prompt_id: | |
| original_pos_prompt_id = node_id | |
| elif 'Negative' in title: | |
| if not original_neg_prompt_id: | |
| original_neg_prompt_id = node_id | |
| pos_text = "" | |
| if original_pos_prompt_id and original_pos_prompt_id in assembler.workflow: | |
| pos_text = assembler.workflow[original_pos_prompt_id]['inputs'].get('text', '') | |
| neg_text = "" | |
| if original_neg_prompt_id and original_neg_prompt_id in assembler.workflow: | |
| neg_text = assembler.workflow[original_neg_prompt_id]['inputs'].get('text', '') | |
| clip_loader_id = assembler._get_unique_id() | |
| clip_loader_node = assembler._get_node_template("CLIPLoader") | |
| clip_loader_node['inputs']['clip_name'] = "gemma_2_2b_it_elm_fp8_scaled.safetensors" | |
| clip_loader_node['inputs']['type'] = "pixeldit" | |
| clip_loader_node['inputs']['device'] = "default" | |
| assembler.workflow[clip_loader_id] = clip_loader_node | |
| pos_text_encode_id = assembler._get_unique_id() | |
| pos_text_encode_node = assembler._get_node_template("CLIPTextEncode") | |
| pos_text_encode_node['inputs']['text'] = pos_text | |
| pos_text_encode_node['inputs']['clip'] = [clip_loader_id, 0] | |
| assembler.workflow[pos_text_encode_id] = pos_text_encode_node | |
| neg_text_encode_id = assembler._get_unique_id() | |
| neg_text_encode_node = assembler._get_node_template("CLIPTextEncode") | |
| neg_text_encode_node['inputs']['text'] = neg_text | |
| neg_text_encode_node['inputs']['clip'] = [clip_loader_id, 0] | |
| assembler.workflow[neg_text_encode_id] = neg_text_encode_node | |
| active_model_file = None | |
| for node_id, node_data in assembler.workflow.items(): | |
| class_type = node_data.get('class_type') | |
| if class_type == 'UNETLoader': | |
| active_model_file = node_data.get('inputs', {}).get('unet_name') | |
| if active_model_file: | |
| break | |
| elif class_type == 'CheckpointLoaderSimple': | |
| active_model_file = node_data.get('inputs', {}).get('ckpt_name') | |
| if active_model_file: | |
| break | |
| architecture = None | |
| if active_model_file: | |
| try: | |
| model_config = load_model_config() | |
| checkpoints = model_config.get("Checkpoints", {}) | |
| for arch_name, arch_data in checkpoints.items(): | |
| models_list = arch_data.get("models", []) | |
| for model_entry in models_list: | |
| if model_entry.get('path') == active_model_file: | |
| architecture = arch_name | |
| break | |
| components_dict = model_entry.get('components', {}) | |
| if active_model_file in components_dict.values(): | |
| architecture = arch_name | |
| break | |
| if architecture: | |
| break | |
| except Exception as e: | |
| print(f"Error looking up model architecture in PiD injector: {e}") | |
| if architecture: | |
| architecture = architecture.lower().replace(" ", "-").replace(".", "") | |
| else: | |
| file_lower = active_model_file.lower().replace("-", "").replace("_", "").replace(".", "") | |
| for arch in sorted(architectures_settings.keys(), key=len, reverse=True): | |
| candidates = [arch] | |
| if "-image" in arch: | |
| candidates.append(arch.replace("-image", "")) | |
| if "-i1" in arch: | |
| candidates.append(arch.replace("-i1", "")) | |
| if "-kv" in arch: | |
| candidates.append(arch.replace("-kv", "")) | |
| matched = False | |
| for cand in candidates: | |
| if cand.replace("-", "").replace(".", "") in file_lower: | |
| architecture = arch | |
| matched = True | |
| break | |
| if matched: | |
| break | |
| unet_name = default_settings.get("unet_name") | |
| latent_format = default_settings.get("latent_format") | |
| if architecture in architectures_settings: | |
| arch_config = architectures_settings[architecture] | |
| unet_name = arch_config.get("unet_name", unet_name) | |
| latent_format = arch_config.get("latent_format", latent_format) | |
| else: | |
| print(f"[PiD Injector] Warning: Model architecture '{architecture}' (file: '{active_model_file}') not explicitly mapped. Using default settings.") | |
| pid_pos_id = assembler._get_unique_id() | |
| pid_pos_node = assembler._get_node_template("PiDConditioning") | |
| pid_pos_node['inputs']['latent_format'] = latent_format | |
| pid_pos_node['inputs']['degrade_sigma'] = 0 | |
| pid_pos_node['inputs']['positive'] = [pos_text_encode_id, 0] | |
| pid_pos_node['inputs']['latent'] = [original_ksampler_id, 0] | |
| assembler.workflow[pid_pos_id] = pid_pos_node | |
| pid_neg_id = assembler._get_unique_id() | |
| pid_neg_node = assembler._get_node_template("PiDConditioning") | |
| pid_neg_node['inputs']['latent_format'] = latent_format | |
| pid_neg_node['inputs']['degrade_sigma'] = 0 | |
| pid_neg_node['inputs']['positive'] = [neg_text_encode_id, 0] | |
| pid_neg_node['inputs']['latent'] = [original_ksampler_id, 0] | |
| assembler.workflow[pid_neg_id] = pid_neg_node | |
| pid_unet_loader_id = assembler._get_unique_id() | |
| pid_unet_loader_node = assembler._get_node_template("UNETLoader") | |
| pid_unet_loader_node['inputs']['unet_name'] = unet_name | |
| pid_unet_loader_node['inputs']['weight_dtype'] = "default" | |
| assembler.workflow[pid_unet_loader_id] = pid_unet_loader_node | |
| orig_width = 1024 | |
| orig_height = 1024 | |
| original_latent_source_id = assembler.node_map.get('latent_source') | |
| if original_latent_source_id in assembler.workflow: | |
| node_inputs = assembler.workflow[original_latent_source_id].get('inputs', {}) | |
| if 'width' in node_inputs and 'height' in node_inputs: | |
| orig_width = node_inputs['width'] | |
| orig_height = node_inputs['height'] | |
| else: | |
| for node_data in assembler.workflow.values(): | |
| inputs = node_data.get('inputs', {}) | |
| if 'width' in inputs and 'height' in inputs and isinstance(inputs['width'], (int, float)) and isinstance(inputs['height'], (int, float)): | |
| if 256 <= inputs['width'] <= 4096 and 256 <= inputs['height'] <= 4096: | |
| orig_width = inputs['width'] | |
| orig_height = inputs['height'] | |
| break | |
| else: | |
| for node_data in assembler.workflow.values(): | |
| inputs = node_data.get('inputs', {}) | |
| if 'width' in inputs and 'height' in inputs and isinstance(inputs['width'], (int, float)) and isinstance(inputs['height'], (int, float)): | |
| if 256 <= inputs['width'] <= 4096 and 256 <= inputs['height'] <= 4096: | |
| orig_width = inputs['width'] | |
| orig_height = inputs['height'] | |
| break | |
| empty_latent_id = assembler._get_unique_id() | |
| empty_latent_node = assembler._get_node_template("EmptyChromaRadianceLatentImage") | |
| empty_latent_node['inputs']['width'] = int(orig_width) * 4 | |
| empty_latent_node['inputs']['height'] = int(orig_height) * 4 | |
| empty_latent_node['inputs']['batch_size'] = 1 | |
| if original_latent_source_id in assembler.workflow: | |
| orig_batch_size = assembler.workflow[original_latent_source_id]['inputs'].get('batch_size') or assembler.workflow[original_latent_source_id]['inputs'].get('amount') | |
| if orig_batch_size: | |
| empty_latent_node['inputs']['batch_size'] = orig_batch_size | |
| assembler.workflow[empty_latent_id] = empty_latent_node | |
| orig_seed = 0 | |
| if original_ksampler_id in assembler.workflow: | |
| orig_seed = assembler.workflow[original_ksampler_id]['inputs'].get('seed', 0) | |
| if orig_seed == -1: | |
| orig_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| orig_seed = (orig_seed + 1) % (2**32) | |
| new_ksampler_id = assembler._get_unique_id() | |
| new_ksampler_node = assembler._get_node_template("KSampler") | |
| new_ksampler_node['inputs']['seed'] = orig_seed | |
| new_ksampler_node['inputs']['steps'] = 4 | |
| new_ksampler_node['inputs']['cfg'] = 1 | |
| new_ksampler_node['inputs']['sampler_name'] = "lcm" | |
| new_ksampler_node['inputs']['scheduler'] = "simple" | |
| new_ksampler_node['inputs']['denoise'] = 1.0 | |
| new_ksampler_node['inputs']['model'] = [pid_unet_loader_id, 0] | |
| new_ksampler_node['inputs']['positive'] = [pid_pos_id, 0] | |
| new_ksampler_node['inputs']['negative'] = [pid_neg_id, 0] | |
| new_ksampler_node['inputs']['latent_image'] = [empty_latent_id, 0] | |
| assembler.workflow[new_ksampler_id] = new_ksampler_node | |
| pid_vae_loader_id = assembler._get_unique_id() | |
| pid_vae_loader_node = assembler._get_node_template("VAELoader") | |
| pid_vae_loader_node['inputs']['vae_name'] = "pixel_space" | |
| assembler.workflow[pid_vae_loader_id] = pid_vae_loader_node | |
| pid_vae_decode_id = assembler._get_unique_id() | |
| pid_vae_decode_node = assembler._get_node_template("VAEDecode") | |
| pid_vae_decode_node['inputs']['samples'] = [new_ksampler_id, 0] | |
| pid_vae_decode_node['inputs']['vae'] = [pid_vae_loader_id, 0] | |
| assembler.workflow[pid_vae_decode_id] = pid_vae_decode_node | |
| if original_vae_decode_id: | |
| for node_id, node_data in assembler.workflow.items(): | |
| if 'inputs' in node_data: | |
| for input_name, input_val in list(node_data['inputs'].items()): | |
| if isinstance(input_val, list) and len(input_val) == 2: | |
| if input_val[0] == original_vae_decode_id: | |
| node_data['inputs'][input_name] = [pid_vae_decode_id, 0] | |
| if original_vae_loader_id in assembler.workflow: | |
| del assembler.workflow[original_vae_loader_id] | |
| if original_vae_decode_id in assembler.workflow: | |
| del assembler.workflow[original_vae_decode_id] | |
| print("[PiD Injector] Successfully injected PiD pipeline and replaced VAE decode/loader.") |