Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| from core.generation_logic import generate_image_wrapper | |
| def create_run_event(prefix: str, task_type: str, ui_components: dict): | |
| run_inputs_map = { | |
| 'model_display_name': ui_components[f'base_model_{prefix}'], | |
| 'positive_prompt': ui_components.get(f'prompt_{prefix}') or ui_components.get(f'{prefix}_positive_prompt'), | |
| 'negative_prompt': ui_components.get(f'neg_prompt_{prefix}') or ui_components.get(f'{prefix}_negative_prompt'), | |
| 'seed': ui_components.get(f'seed_{prefix}') or ui_components.get(f'{prefix}_seed'), | |
| 'batch_size': ui_components.get(f'batch_size_{prefix}') or ui_components.get(f'{prefix}_batch_size'), | |
| 'guidance_scale': ui_components.get(f'cfg_{prefix}') or ui_components.get(f'{prefix}_cfg'), | |
| 'num_inference_steps': ui_components.get(f'steps_{prefix}') or ui_components.get(f'{prefix}_steps'), | |
| 'sampler': ui_components.get(f'sampler_{prefix}') or ui_components.get(f'{prefix}_sampler_name'), | |
| 'scheduler': ui_components.get(f'scheduler_{prefix}') or ui_components.get(f'{prefix}_scheduler'), | |
| 'zero_gpu_duration': ui_components.get(f'zero_gpu_{prefix}'), | |
| 'clip_skip': ui_components.get(f'clip_skip_{prefix}'), | |
| 'guidance': ui_components.get(f'guidance_{prefix}'), | |
| 'task_type': gr.State(task_type) | |
| } | |
| if task_type not in ['img2img', 'inpaint']: | |
| run_inputs_map.update({ | |
| 'width': ui_components.get(f'width_{prefix}') or ui_components.get(f'{prefix}_width'), | |
| 'height': ui_components.get(f'height_{prefix}') or ui_components.get(f'{prefix}_height') | |
| }) | |
| task_specific_map = { | |
| 'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'}, | |
| 'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}', 'grow_mask_by': f'grow_mask_by_{prefix}'}, | |
| 'outpaint': {'outpaint_image': f'input_image_{prefix}', 'left': f'left_{prefix}', 'top': f'top_{prefix}', 'right': f'right_{prefix}', 'bottom': f'bottom_{prefix}', 'feathering': f'feathering_{prefix}'}, | |
| 'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'} | |
| } | |
| if task_type in task_specific_map: | |
| for key, comp_name in task_specific_map[task_type].items(): | |
| if comp_name in ui_components: | |
| run_inputs_map[key] = ui_components[comp_name] | |
| lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', []) | |
| controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', []) | |
| anima_controlnet_lllite_data_components = ui_components.get(f'all_anima_controlnet_lllite_components_flat_{prefix}', []) | |
| diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', []) | |
| ipadapter_data_components = ui_components.get(f'all_ipadapter_components_flat_{prefix}', []) | |
| sd3_ipadapter_data_components = ui_components.get(f'all_sd3_ipadapter_components_flat_{prefix}', []) | |
| flux1_ipadapter_data_components = ui_components.get(f'all_flux1_ipadapter_components_flat_{prefix}', []) | |
| style_data_components = ui_components.get(f'all_style_components_flat_{prefix}', []) | |
| embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', []) | |
| conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', []) | |
| reference_latent_data_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', []) | |
| hidream_o1_reference_data_components = ui_components.get(f'all_hidream_o1_reference_components_flat_{prefix}', []) | |
| run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}') | |
| run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}') | |
| run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}') | |
| input_keys = list(run_inputs_map.keys()) | |
| input_list_flat = [v for v in run_inputs_map.values() if v is not None] | |
| all_chains = [ | |
| lora_data_components, controlnet_data_components, anima_controlnet_lllite_data_components, diffsynth_controlnet_data_components, ipadapter_data_components, | |
| sd3_ipadapter_data_components, flux1_ipadapter_data_components, style_data_components, | |
| embedding_data_components, conditioning_data_components, reference_latent_data_components, hidream_o1_reference_data_components | |
| ] | |
| for chain in all_chains: | |
| if chain: | |
| input_list_flat.extend(chain) | |
| def create_ui_inputs_dict(*args): | |
| valid_keys = [k for k in input_keys if run_inputs_map[k] is not None] | |
| ui_dict = dict(zip(valid_keys, args[:len(valid_keys)])) | |
| arg_idx = len(valid_keys) | |
| def assign_chain_data(chain_key, components_list): | |
| nonlocal arg_idx | |
| if components_list: | |
| ui_dict[chain_key] = list(args[arg_idx : arg_idx + len(components_list)]) | |
| arg_idx += len(components_list) | |
| assign_chain_data('lora_data', lora_data_components) | |
| assign_chain_data('controlnet_data', controlnet_data_components) | |
| assign_chain_data('anima_controlnet_lllite_data', anima_controlnet_lllite_data_components) | |
| assign_chain_data('diffsynth_controlnet_data', diffsynth_controlnet_data_components) | |
| assign_chain_data('ipadapter_data', ipadapter_data_components) | |
| assign_chain_data('sd3_ipadapter_chain', sd3_ipadapter_data_components) | |
| assign_chain_data('flux1_ipadapter_data', flux1_ipadapter_data_components) | |
| assign_chain_data('style_data', style_data_components) | |
| assign_chain_data('embedding_data', embedding_data_components) | |
| assign_chain_data('conditioning_data', conditioning_data_components) | |
| assign_chain_data('reference_latent_data', reference_latent_data_components) | |
| assign_chain_data('hidream_o1_reference_data', hidream_o1_reference_data_components) | |
| return ui_dict | |
| run_btn = ui_components.get(f'run_{prefix}') or ui_components.get(f'{prefix}_run_button') | |
| res_gal = ui_components.get(f'result_{prefix}') or ui_components.get(f'{prefix}_output_gallery') | |
| if run_btn and res_gal: | |
| run_btn.click( | |
| fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress), | |
| inputs=input_list_flat, | |
| outputs=[res_gal] | |
| ) | |