ImageGen5 / ui /events /run_handlers.py
RioShiina's picture
Upload folder using huggingface_hub
7298549 verified
Raw
History Blame
6.71 kB
import gradio as gr
from core.generation_logic import generate_image_wrapper
def create_run_event(prefix: str, task_type: str, ui_components: dict):
run_inputs_map = {
'model_display_name': ui_components[f'base_model_{prefix}'],
'positive_prompt': ui_components.get(f'prompt_{prefix}') or ui_components.get(f'{prefix}_positive_prompt'),
'negative_prompt': ui_components.get(f'neg_prompt_{prefix}') or ui_components.get(f'{prefix}_negative_prompt'),
'seed': ui_components.get(f'seed_{prefix}') or ui_components.get(f'{prefix}_seed'),
'batch_size': ui_components.get(f'batch_size_{prefix}') or ui_components.get(f'{prefix}_batch_size'),
'guidance_scale': ui_components.get(f'cfg_{prefix}') or ui_components.get(f'{prefix}_cfg'),
'num_inference_steps': ui_components.get(f'steps_{prefix}') or ui_components.get(f'{prefix}_steps'),
'sampler': ui_components.get(f'sampler_{prefix}') or ui_components.get(f'{prefix}_sampler_name'),
'scheduler': ui_components.get(f'scheduler_{prefix}') or ui_components.get(f'{prefix}_scheduler'),
'zero_gpu_duration': ui_components.get(f'zero_gpu_{prefix}'),
'clip_skip': ui_components.get(f'clip_skip_{prefix}'),
'guidance': ui_components.get(f'guidance_{prefix}'),
'task_type': gr.State(task_type)
}
if ui_components.get(f'pid_settings_{prefix}'):
run_inputs_map['pid_settings'] = ui_components[f'pid_settings_{prefix}']
if task_type not in ['img2img', 'inpaint']:
run_inputs_map.update({
'width': ui_components.get(f'width_{prefix}') or ui_components.get(f'{prefix}_width'),
'height': ui_components.get(f'height_{prefix}') or ui_components.get(f'{prefix}_height')
})
task_specific_map = {
'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'},
'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}', 'grow_mask_by': f'grow_mask_by_{prefix}', 'inpaint_denoise': f'denoise_{prefix}'},
'outpaint': {'outpaint_image': f'input_image_{prefix}', 'left': f'left_{prefix}', 'top': f'top_{prefix}', 'right': f'right_{prefix}', 'bottom': f'bottom_{prefix}', 'feathering': f'feathering_{prefix}'},
'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'}
}
if task_type in task_specific_map:
for key, comp_name in task_specific_map[task_type].items():
if comp_name in ui_components:
run_inputs_map[key] = ui_components[comp_name]
lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', [])
controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', [])
anima_controlnet_lllite_data_components = ui_components.get(f'all_anima_controlnet_lllite_components_flat_{prefix}', [])
diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', [])
ipadapter_data_components = ui_components.get(f'all_ipadapter_components_flat_{prefix}', [])
sd3_ipadapter_data_components = ui_components.get(f'all_sd3_ipadapter_components_flat_{prefix}', [])
flux1_ipadapter_data_components = ui_components.get(f'all_flux1_ipadapter_components_flat_{prefix}', [])
style_data_components = ui_components.get(f'all_style_components_flat_{prefix}', [])
embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', [])
conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', [])
reference_latent_data_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', [])
hidream_o1_reference_data_components = ui_components.get(f'all_hidream_o1_reference_components_flat_{prefix}', [])
run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}')
run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}')
run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}')
input_keys = list(run_inputs_map.keys())
input_list_flat = [v for v in run_inputs_map.values() if v is not None]
all_chains = [
lora_data_components, controlnet_data_components, anima_controlnet_lllite_data_components, diffsynth_controlnet_data_components, ipadapter_data_components,
sd3_ipadapter_data_components, flux1_ipadapter_data_components, style_data_components,
embedding_data_components, conditioning_data_components, reference_latent_data_components, hidream_o1_reference_data_components
]
for chain in all_chains:
if chain:
input_list_flat.extend(chain)
def create_ui_inputs_dict(*args):
valid_keys = [k for k in input_keys if run_inputs_map[k] is not None]
ui_dict = dict(zip(valid_keys, args[:len(valid_keys)]))
arg_idx = len(valid_keys)
def assign_chain_data(chain_key, components_list):
nonlocal arg_idx
if components_list:
ui_dict[chain_key] = list(args[arg_idx : arg_idx + len(components_list)])
arg_idx += len(components_list)
assign_chain_data('lora_data', lora_data_components)
assign_chain_data('controlnet_data', controlnet_data_components)
assign_chain_data('anima_controlnet_lllite_data', anima_controlnet_lllite_data_components)
assign_chain_data('diffsynth_controlnet_data', diffsynth_controlnet_data_components)
assign_chain_data('ipadapter_data', ipadapter_data_components)
assign_chain_data('sd3_ipadapter_chain', sd3_ipadapter_data_components)
assign_chain_data('flux1_ipadapter_data', flux1_ipadapter_data_components)
assign_chain_data('style_data', style_data_components)
assign_chain_data('embedding_data', embedding_data_components)
assign_chain_data('conditioning_data', conditioning_data_components)
assign_chain_data('reference_latent_data', reference_latent_data_components)
assign_chain_data('hidream_o1_reference_data', hidream_o1_reference_data_components)
return ui_dict
run_btn = ui_components.get(f'run_{prefix}') or ui_components.get(f'{prefix}_run_button')
res_gal = ui_components.get(f'result_{prefix}') or ui_components.get(f'{prefix}_output_gallery')
if run_btn and res_gal:
run_btn.click(
fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress),
inputs=input_list_flat,
outputs=[res_gal]
)