| import gradio as gr |
| from core.generation_logic import generate_image_wrapper |
|
|
| def create_run_event(prefix: str, task_type: str, ui_components: dict): |
| run_inputs_map = { |
| 'model_display_name': ui_components[f'base_model_{prefix}'], |
| 'positive_prompt': ui_components.get(f'prompt_{prefix}') or ui_components.get(f'{prefix}_positive_prompt'), |
| 'negative_prompt': ui_components.get(f'neg_prompt_{prefix}') or ui_components.get(f'{prefix}_negative_prompt'), |
| 'seed': ui_components.get(f'seed_{prefix}') or ui_components.get(f'{prefix}_seed'), |
| 'batch_size': ui_components.get(f'batch_size_{prefix}') or ui_components.get(f'{prefix}_batch_size'), |
| 'guidance_scale': ui_components.get(f'cfg_{prefix}') or ui_components.get(f'{prefix}_cfg'), |
| 'num_inference_steps': ui_components.get(f'steps_{prefix}') or ui_components.get(f'{prefix}_steps'), |
| 'sampler': ui_components.get(f'sampler_{prefix}') or ui_components.get(f'{prefix}_sampler_name'), |
| 'scheduler': ui_components.get(f'scheduler_{prefix}') or ui_components.get(f'{prefix}_scheduler'), |
| 'zero_gpu_duration': ui_components.get(f'zero_gpu_{prefix}'), |
|
|
| 'clip_skip': ui_components.get(f'clip_skip_{prefix}'), |
| 'guidance': ui_components.get(f'guidance_{prefix}'), |
| 'task_type': gr.State(task_type) |
| } |
| |
| if ui_components.get(f'pid_settings_{prefix}'): |
| run_inputs_map['pid_settings'] = ui_components[f'pid_settings_{prefix}'] |
| |
| if task_type not in ['img2img', 'inpaint']: |
| run_inputs_map.update({ |
| 'width': ui_components.get(f'width_{prefix}') or ui_components.get(f'{prefix}_width'), |
| 'height': ui_components.get(f'height_{prefix}') or ui_components.get(f'{prefix}_height') |
| }) |
| |
| task_specific_map = { |
| 'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'}, |
| 'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}', 'grow_mask_by': f'grow_mask_by_{prefix}', 'inpaint_denoise': f'denoise_{prefix}'}, |
| 'outpaint': {'outpaint_image': f'input_image_{prefix}', 'left': f'left_{prefix}', 'top': f'top_{prefix}', 'right': f'right_{prefix}', 'bottom': f'bottom_{prefix}', 'feathering': f'feathering_{prefix}'}, |
| 'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'} |
| } |
| if task_type in task_specific_map: |
| for key, comp_name in task_specific_map[task_type].items(): |
| if comp_name in ui_components: |
| run_inputs_map[key] = ui_components[comp_name] |
| |
| lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', []) |
| controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', []) |
| anima_controlnet_lllite_data_components = ui_components.get(f'all_anima_controlnet_lllite_components_flat_{prefix}', []) |
| diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', []) |
| ipadapter_data_components = ui_components.get(f'all_ipadapter_components_flat_{prefix}', []) |
| sd3_ipadapter_data_components = ui_components.get(f'all_sd3_ipadapter_components_flat_{prefix}', []) |
| flux1_ipadapter_data_components = ui_components.get(f'all_flux1_ipadapter_components_flat_{prefix}', []) |
| style_data_components = ui_components.get(f'all_style_components_flat_{prefix}', []) |
| embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', []) |
| conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', []) |
| reference_latent_data_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', []) |
| hidream_o1_reference_data_components = ui_components.get(f'all_hidream_o1_reference_components_flat_{prefix}', []) |
| |
| run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}') |
| run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}') |
| run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}') |
|
|
| input_keys = list(run_inputs_map.keys()) |
| input_list_flat = [v for v in run_inputs_map.values() if v is not None] |
| all_chains = [ |
| lora_data_components, controlnet_data_components, anima_controlnet_lllite_data_components, diffsynth_controlnet_data_components, ipadapter_data_components, |
| sd3_ipadapter_data_components, flux1_ipadapter_data_components, style_data_components, |
| embedding_data_components, conditioning_data_components, reference_latent_data_components, hidream_o1_reference_data_components |
| ] |
| for chain in all_chains: |
| if chain: |
| input_list_flat.extend(chain) |
|
|
| def create_ui_inputs_dict(*args): |
| valid_keys = [k for k in input_keys if run_inputs_map[k] is not None] |
| ui_dict = dict(zip(valid_keys, args[:len(valid_keys)])) |
| arg_idx = len(valid_keys) |
|
|
| def assign_chain_data(chain_key, components_list): |
| nonlocal arg_idx |
| if components_list: |
| ui_dict[chain_key] = list(args[arg_idx : arg_idx + len(components_list)]) |
| arg_idx += len(components_list) |
|
|
| assign_chain_data('lora_data', lora_data_components) |
| assign_chain_data('controlnet_data', controlnet_data_components) |
| assign_chain_data('anima_controlnet_lllite_data', anima_controlnet_lllite_data_components) |
| assign_chain_data('diffsynth_controlnet_data', diffsynth_controlnet_data_components) |
| assign_chain_data('ipadapter_data', ipadapter_data_components) |
| assign_chain_data('sd3_ipadapter_chain', sd3_ipadapter_data_components) |
| assign_chain_data('flux1_ipadapter_data', flux1_ipadapter_data_components) |
| assign_chain_data('style_data', style_data_components) |
| assign_chain_data('embedding_data', embedding_data_components) |
| assign_chain_data('conditioning_data', conditioning_data_components) |
| assign_chain_data('reference_latent_data', reference_latent_data_components) |
| assign_chain_data('hidream_o1_reference_data', hidream_o1_reference_data_components) |
|
|
| return ui_dict |
|
|
| run_btn = ui_components.get(f'run_{prefix}') or ui_components.get(f'{prefix}_run_button') |
| res_gal = ui_components.get(f'result_{prefix}') or ui_components.get(f'{prefix}_output_gallery') |
| if run_btn and res_gal: |
| run_btn.click( |
| fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress), |
| inputs=input_list_flat, |
| outputs=[res_gal] |
| ) |