import gradio as gr from core.generation_logic import generate_image_wrapper def create_run_event(prefix: str, task_type: str, ui_components: dict): run_inputs_map = { 'model_display_name': ui_components[f'base_model_{prefix}'], 'positive_prompt': ui_components.get(f'prompt_{prefix}') or ui_components.get(f'{prefix}_positive_prompt'), 'negative_prompt': ui_components.get(f'neg_prompt_{prefix}') or ui_components.get(f'{prefix}_negative_prompt'), 'seed': ui_components.get(f'seed_{prefix}') or ui_components.get(f'{prefix}_seed'), 'batch_size': ui_components.get(f'batch_size_{prefix}') or ui_components.get(f'{prefix}_batch_size'), 'guidance_scale': ui_components.get(f'cfg_{prefix}') or ui_components.get(f'{prefix}_cfg'), 'num_inference_steps': ui_components.get(f'steps_{prefix}') or ui_components.get(f'{prefix}_steps'), 'sampler': ui_components.get(f'sampler_{prefix}') or ui_components.get(f'{prefix}_sampler_name'), 'scheduler': ui_components.get(f'scheduler_{prefix}') or ui_components.get(f'{prefix}_scheduler'), 'zero_gpu_duration': ui_components.get(f'zero_gpu_{prefix}'), 'clip_skip': ui_components.get(f'clip_skip_{prefix}'), 'guidance': ui_components.get(f'guidance_{prefix}'), 'task_type': gr.State(task_type) } if task_type not in ['img2img', 'inpaint']: run_inputs_map.update({ 'width': ui_components.get(f'width_{prefix}') or ui_components.get(f'{prefix}_width'), 'height': ui_components.get(f'height_{prefix}') or ui_components.get(f'{prefix}_height') }) task_specific_map = { 'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'}, 'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}', 'grow_mask_by': f'grow_mask_by_{prefix}'}, 'outpaint': {'outpaint_image': f'input_image_{prefix}', 'left': f'left_{prefix}', 'top': f'top_{prefix}', 'right': f'right_{prefix}', 'bottom': f'bottom_{prefix}', 'feathering': f'feathering_{prefix}'}, 'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'} } if task_type in task_specific_map: for key, comp_name in task_specific_map[task_type].items(): if comp_name in ui_components: run_inputs_map[key] = ui_components[comp_name] lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', []) controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', []) anima_controlnet_lllite_data_components = ui_components.get(f'all_anima_controlnet_lllite_components_flat_{prefix}', []) diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', []) ipadapter_data_components = ui_components.get(f'all_ipadapter_components_flat_{prefix}', []) sd3_ipadapter_data_components = ui_components.get(f'all_sd3_ipadapter_components_flat_{prefix}', []) flux1_ipadapter_data_components = ui_components.get(f'all_flux1_ipadapter_components_flat_{prefix}', []) style_data_components = ui_components.get(f'all_style_components_flat_{prefix}', []) embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', []) conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', []) reference_latent_data_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', []) hidream_o1_reference_data_components = ui_components.get(f'all_hidream_o1_reference_components_flat_{prefix}', []) run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}') run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}') run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}') input_keys = list(run_inputs_map.keys()) input_list_flat = [v for v in run_inputs_map.values() if v is not None] all_chains = [ lora_data_components, controlnet_data_components, anima_controlnet_lllite_data_components, diffsynth_controlnet_data_components, ipadapter_data_components, sd3_ipadapter_data_components, flux1_ipadapter_data_components, style_data_components, embedding_data_components, conditioning_data_components, reference_latent_data_components, hidream_o1_reference_data_components ] for chain in all_chains: if chain: input_list_flat.extend(chain) def create_ui_inputs_dict(*args): valid_keys = [k for k in input_keys if run_inputs_map[k] is not None] ui_dict = dict(zip(valid_keys, args[:len(valid_keys)])) arg_idx = len(valid_keys) def assign_chain_data(chain_key, components_list): nonlocal arg_idx if components_list: ui_dict[chain_key] = list(args[arg_idx : arg_idx + len(components_list)]) arg_idx += len(components_list) assign_chain_data('lora_data', lora_data_components) assign_chain_data('controlnet_data', controlnet_data_components) assign_chain_data('anima_controlnet_lllite_data', anima_controlnet_lllite_data_components) assign_chain_data('diffsynth_controlnet_data', diffsynth_controlnet_data_components) assign_chain_data('ipadapter_data', ipadapter_data_components) assign_chain_data('sd3_ipadapter_chain', sd3_ipadapter_data_components) assign_chain_data('flux1_ipadapter_data', flux1_ipadapter_data_components) assign_chain_data('style_data', style_data_components) assign_chain_data('embedding_data', embedding_data_components) assign_chain_data('conditioning_data', conditioning_data_components) assign_chain_data('reference_latent_data', reference_latent_data_components) assign_chain_data('hidream_o1_reference_data', hidream_o1_reference_data_components) return ui_dict run_btn = ui_components.get(f'run_{prefix}') or ui_components.get(f'{prefix}_run_button') res_gal = ui_components.get(f'result_{prefix}') or ui_components.get(f'{prefix}_output_gallery') if run_btn and res_gal: run_btn.click( fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress), inputs=input_list_flat, outputs=[res_gal] )