File size: 6,712 Bytes
e699279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e2913
 
 
e699279
 
 
 
 
 
 
 
7298549
e699279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e2913
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from core.generation_logic import generate_image_wrapper

def create_run_event(prefix: str, task_type: str, ui_components: dict):
    run_inputs_map = {
        'model_display_name': ui_components[f'base_model_{prefix}'],
        'positive_prompt': ui_components.get(f'prompt_{prefix}') or ui_components.get(f'{prefix}_positive_prompt'),
        'negative_prompt': ui_components.get(f'neg_prompt_{prefix}') or ui_components.get(f'{prefix}_negative_prompt'),
        'seed': ui_components.get(f'seed_{prefix}') or ui_components.get(f'{prefix}_seed'),
        'batch_size': ui_components.get(f'batch_size_{prefix}') or ui_components.get(f'{prefix}_batch_size'),
        'guidance_scale': ui_components.get(f'cfg_{prefix}') or ui_components.get(f'{prefix}_cfg'),
        'num_inference_steps': ui_components.get(f'steps_{prefix}') or ui_components.get(f'{prefix}_steps'),
        'sampler': ui_components.get(f'sampler_{prefix}') or ui_components.get(f'{prefix}_sampler_name'),
        'scheduler': ui_components.get(f'scheduler_{prefix}') or ui_components.get(f'{prefix}_scheduler'),
        'zero_gpu_duration': ui_components.get(f'zero_gpu_{prefix}'),

        'clip_skip': ui_components.get(f'clip_skip_{prefix}'),
        'guidance': ui_components.get(f'guidance_{prefix}'),
        'task_type': gr.State(task_type)
    }
    
    if ui_components.get(f'pid_settings_{prefix}'):
        run_inputs_map['pid_settings'] = ui_components[f'pid_settings_{prefix}']
    
    if task_type not in ['img2img', 'inpaint']:
        run_inputs_map.update({
            'width': ui_components.get(f'width_{prefix}') or ui_components.get(f'{prefix}_width'), 
            'height': ui_components.get(f'height_{prefix}') or ui_components.get(f'{prefix}_height')
        })
    
    task_specific_map = {
        'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'},
        'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}', 'grow_mask_by': f'grow_mask_by_{prefix}', 'inpaint_denoise': f'denoise_{prefix}'},
        'outpaint': {'outpaint_image': f'input_image_{prefix}', 'left': f'left_{prefix}', 'top': f'top_{prefix}', 'right': f'right_{prefix}', 'bottom': f'bottom_{prefix}', 'feathering': f'feathering_{prefix}'},
        'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'}
    }
    if task_type in task_specific_map:
        for key, comp_name in task_specific_map[task_type].items():
            if comp_name in ui_components:
                run_inputs_map[key] = ui_components[comp_name]
    
    lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', [])
    controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', [])
    anima_controlnet_lllite_data_components = ui_components.get(f'all_anima_controlnet_lllite_components_flat_{prefix}', [])
    diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', [])
    ipadapter_data_components = ui_components.get(f'all_ipadapter_components_flat_{prefix}', [])
    sd3_ipadapter_data_components = ui_components.get(f'all_sd3_ipadapter_components_flat_{prefix}', [])
    flux1_ipadapter_data_components = ui_components.get(f'all_flux1_ipadapter_components_flat_{prefix}', [])
    style_data_components = ui_components.get(f'all_style_components_flat_{prefix}', [])
    embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', [])
    conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', [])
    reference_latent_data_components = ui_components.get(f'all_reference_latent_components_flat_{prefix}', [])
    hidream_o1_reference_data_components = ui_components.get(f'all_hidream_o1_reference_components_flat_{prefix}', [])
    
    run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}')
    run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}')
    run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}')

    input_keys = list(run_inputs_map.keys())
    input_list_flat = [v for v in run_inputs_map.values() if v is not None]
    all_chains = [
        lora_data_components, controlnet_data_components, anima_controlnet_lllite_data_components, diffsynth_controlnet_data_components, ipadapter_data_components,
        sd3_ipadapter_data_components, flux1_ipadapter_data_components, style_data_components,
        embedding_data_components, conditioning_data_components, reference_latent_data_components, hidream_o1_reference_data_components
    ]
    for chain in all_chains:
        if chain:
            input_list_flat.extend(chain)

    def create_ui_inputs_dict(*args):
        valid_keys = [k for k in input_keys if run_inputs_map[k] is not None]
        ui_dict = dict(zip(valid_keys, args[:len(valid_keys)]))
        arg_idx = len(valid_keys)

        def assign_chain_data(chain_key, components_list):
            nonlocal arg_idx
            if components_list:
                ui_dict[chain_key] = list(args[arg_idx : arg_idx + len(components_list)])
                arg_idx += len(components_list)

        assign_chain_data('lora_data', lora_data_components)
        assign_chain_data('controlnet_data', controlnet_data_components)
        assign_chain_data('anima_controlnet_lllite_data', anima_controlnet_lllite_data_components)
        assign_chain_data('diffsynth_controlnet_data', diffsynth_controlnet_data_components)
        assign_chain_data('ipadapter_data', ipadapter_data_components)
        assign_chain_data('sd3_ipadapter_chain', sd3_ipadapter_data_components)
        assign_chain_data('flux1_ipadapter_data', flux1_ipadapter_data_components)
        assign_chain_data('style_data', style_data_components)
        assign_chain_data('embedding_data', embedding_data_components)
        assign_chain_data('conditioning_data', conditioning_data_components)
        assign_chain_data('reference_latent_data', reference_latent_data_components)
        assign_chain_data('hidream_o1_reference_data', hidream_o1_reference_data_components)

        return ui_dict

    run_btn = ui_components.get(f'run_{prefix}') or ui_components.get(f'{prefix}_run_button')
    res_gal = ui_components.get(f'result_{prefix}') or ui_components.get(f'{prefix}_output_gallery')
    if run_btn and res_gal:
        run_btn.click(
            fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress),
            inputs=input_list_flat,
            outputs=[res_gal]
        )