ImageGen5 / core /pipelines /sd_image_pipeline.py
RioShiina's picture
Upload folder using huggingface_hub
5484594 verified
Raw
History Blame
12.8 kB
import os
import random
import shutil
import torch
import gradio as gr
from PIL import Image
from typing import List, Dict, Any
from .base_pipeline import BasePipeline
from core.settings import *
from utils.app_utils import sanitize_prompt
from core.workflow_assembler import WorkflowAssembler
from .workflow_executor import WorkflowExecutor
from .pipeline_input_processor import process_pipeline_inputs
class SdImagePipeline(BasePipeline):
def get_required_models(self, model_display_name: str, **kwargs) -> List[str]:
model_info = ALL_MODEL_MAP.get(model_display_name)
if not model_info:
return [model_display_name]
path_or_components = model_info[1]
if isinstance(path_or_components, dict):
return [v for v in path_or_components.values() if v and v != "pixel_space"]
else:
return [model_display_name]
def _gpu_logic(self, ui_inputs: Dict, loras_string: str, workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)):
model_display_name = ui_inputs['model_display_name']
progress(0.4, desc="Executing workflow...")
initial_objects = {}
decoded_images_tensor = WorkflowExecutor.execute_workflow(workflow, initial_objects=initial_objects)
output_images = []
start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1)
for i in range(decoded_images_tensor.shape[0]):
img_tensor = decoded_images_tensor[i]
pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
current_seed = start_seed + i
width_for_meta = ui_inputs.get('width', 'N/A')
height_for_meta = ui_inputs.get('height', 'N/A')
params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n"
params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}"
if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}"
if ui_inputs.get('clip_skip') and ui_inputs['clip_skip'] != 1: params_string += f", Clip skip: {abs(ui_inputs['clip_skip'])}"
if loras_string: params_string += f", {loras_string}"
pil_image.info = {'parameters': params_string.strip()}
output_images.append(pil_image)
return output_images
def run(self, ui_inputs: Dict, progress):
progress(0, desc="Preparing models...")
task_type = ui_inputs['task_type']
model_display_name = ui_inputs['model_display_name']
model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl')
architectures_dict = ARCHITECTURES_CONFIG.get('architectures', {})
workflow_model_type = architectures_dict.get(model_type, {}).get("model_type", model_type.lower().replace(" ", "").replace(".", ""))
ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', ''))
ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', ''))
if 'clip_skip' in ui_inputs and ui_inputs['clip_skip'] is not None:
ui_inputs['clip_skip'] = -int(ui_inputs['clip_skip'])
else:
ui_inputs['clip_skip'] = -1
required_models = self.get_required_models(model_display_name=model_display_name)
is_pid_enabled = (ui_inputs.get('pid_settings', 'OFF') == 'ON' and task_type == 'txt2img')
if is_pid_enabled:
import yaml
pid_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'yaml', 'pid.yaml')
pid_unet_name = "pid_flux1_1024_to_4096_4step_mxfp8.safetensors"
try:
with open(pid_config_path, 'r', encoding='utf-8') as f:
pid_config = yaml.safe_load(f) or {}
pid_items = pid_config.get("PiD", [])
for item in pid_items:
archs = item.get("architectures", [])
if workflow_model_type in archs:
pid_unet_name = item.get("filepath")
break
except Exception as e:
print(f"Error loading PiD config for download: {e}")
if pid_unet_name not in required_models:
required_models.append(pid_unet_name)
if "gemma_2_2b_it_elm_fp8_scaled.safetensors" not in required_models:
required_models.append("gemma_2_2b_it_elm_fp8_scaled.safetensors")
self.model_manager.ensure_models_downloaded(required_models, progress=progress)
temp_files_to_clean = []
try:
processed = process_pipeline_inputs(ui_inputs, progress, workflow_model_type)
temp_files_to_clean.extend(processed["temp_files_to_clean"])
active_loras_for_gpu = processed["active_loras_for_gpu"]
active_loras_for_meta = processed["active_loras_for_meta"]
active_controlnets = processed["active_controlnets"]
active_anima_controlnets = processed["active_anima_controlnets"]
active_diffsynth_controlnets = processed["active_diffsynth_controlnets"]
active_ipadapters = processed["active_ipadapters"]
active_flux1_ipadapters = processed["active_flux1_ipadapters"]
active_sd3_ipadapters = processed["active_sd3_ipadapters"]
active_styles = processed["active_styles"]
active_reference_latents = processed["active_reference_latents"]
active_hidream_o1_reference = processed["active_hidream_o1_reference"]
active_conditioning = processed["active_conditioning"]
loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""
progress(0.8, desc="Assembling workflow...")
if ui_inputs.get('seed') == -1:
ui_inputs['seed'] = random.randint(0, 2**32 - 1)
model_info = ALL_MODEL_MAP[model_display_name]
path_or_components = model_info[1]
latent_type = model_info[3] if len(model_info) > 3 and model_info[3] else 'latent'
latent_generator_template = "EmptyLatentImage"
if latent_type == 'sd3_latent':
latent_generator_template = "EmptySD3LatentImage"
elif latent_type == 'chroma_radiance_latent':
latent_generator_template = "EmptyChromaRadianceLatentImage"
elif latent_type == 'hunyuan_latent':
latent_generator_template = "EmptyHunyuanImageLatent"
dynamic_values = {
'task_type': ui_inputs['task_type'],
'model_type': workflow_model_type,
'latent_type': latent_type,
'latent_generator_template': latent_generator_template
}
recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml")
assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values)
hidream_o1_smoothing_data = []
if workflow_model_type == 'hidream-o1' and model_display_name == "HiDream-O1-Image":
hidream_o1_smoothing_data.append({})
workflow_inputs = {
**ui_inputs,
"positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'],
"seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'],
"sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'],
"batch_size": ui_inputs['batch_size'],
"clip_skip": ui_inputs['clip_skip'],
"denoise": ui_inputs['denoise'],
"vae_name": ui_inputs.get('vae_name'),
"guidance": ui_inputs.get('guidance', 3.5),
"lora_chain": active_loras_for_gpu,
"controlnet_chain": active_controlnets if not active_anima_controlnets else [],
"anima_controlnet_lllite_chain": active_anima_controlnets,
"diffsynth_controlnet_chain": active_diffsynth_controlnets,
"ipadapter_chain": active_ipadapters,
"flux1_ipadapter_chain": active_flux1_ipadapters,
"sd3_ipadapter_chain": active_sd3_ipadapters,
"style_chain": active_styles,
"conditioning_chain": active_conditioning,
"reference_latent_chain": active_reference_latents,
"hidream_o1_reference_chain": active_hidream_o1_reference,
"vae_chain": [ui_inputs.get('vae_name')] if ui_inputs.get('vae_name') else [],
"hidream_o1_smoothing_chain": hidream_o1_smoothing_data,
"pid_chain": [ui_inputs.get('pid_settings', 'OFF')] if is_pid_enabled else [],
"scheduler_width": ui_inputs.get('width', 1024),
"scheduler_height": ui_inputs.get('height', 1024),
}
if isinstance(path_or_components, dict):
workflow_inputs.update({
'unet_name': path_or_components.get('unet'),
'unet_uncond_name': path_or_components.get('unet_uncond'),
'vae_name': ui_inputs.get('vae_name') or path_or_components.get('vae'),
'clip_name': path_or_components.get('clip'),
'clip1_name': path_or_components.get('clip1'),
'clip2_name': path_or_components.get('clip2'),
'clip3_name': path_or_components.get('clip3'),
'clip4_name': path_or_components.get('clip4'),
'lora_name': path_or_components.get('lora'),
})
else:
workflow_inputs['model_name'] = path_or_components
if task_type == 'txt2img':
workflow_inputs['width'] = ui_inputs['width']
workflow_inputs['height'] = ui_inputs['height']
workflow = assembler.assemble(workflow_inputs)
progress(1.0, desc="All models ready. Requesting GPU for generation...")
results = self._execute_gpu_logic(
self._gpu_logic,
duration=ui_inputs['zero_gpu_duration'],
default_duration=60,
task_name=f"ImageGen ({task_type})",
ui_inputs=ui_inputs,
loras_string=loras_string,
workflow=workflow,
assembler=assembler,
progress=progress
)
import json
import glob
from PIL import PngImagePlugin
prompt_json = json.dumps(workflow)
out_dir = os.path.abspath(OUTPUT_DIR)
os.makedirs(out_dir, exist_ok=True)
try:
existing_files = glob.glob(os.path.join(out_dir, "gen_*.png"))
existing_files.sort(key=os.path.getmtime)
while len(existing_files) > 50:
os.remove(existing_files.pop(0))
except Exception as e:
print(f"Warning: Failed to cleanup output dir: {e}")
final_results = []
for img in results:
if not isinstance(img, Image.Image):
final_results.append(img)
continue
metadata = PngImagePlugin.PngInfo()
params_string = img.info.get("parameters", "")
if params_string:
metadata.add_text("parameters", params_string)
metadata.add_text("prompt", prompt_json)
filename = f"gen_{random.randint(1000000, 9999999)}.png"
filepath = os.path.join(out_dir, filename)
img.save(filepath, "PNG", pnginfo=metadata)
final_results.append(filepath)
results = final_results
finally:
for temp_file in temp_files_to_clean:
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
print(f"✅ Cleaned up temp file: {temp_file}")
return results