File size: 3,094 Bytes
f7edee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from copy import deepcopy

def inject(assembler, chain_definition, chain_items):
    if not chain_items:
        return

    start_node_name = chain_definition.get('start')
    start_node_id = None
    if start_node_name:
        if start_node_name not in assembler.node_map:
            print(f"Warning: Start node '{start_node_name}' for dynamic LoRA chain not found. Skipping chain.")
            return
        start_node_id = assembler.node_map[start_node_name]
    
    output_map = chain_definition.get('output_map', {})
    current_connections = {}
    for key, type_name in output_map.items():
        if ':' in str(key):
            node_name, idx_str = key.split(':')
            if node_name not in assembler.node_map:
                print(f"Warning: Node '{node_name}' in chain's output_map not found. Skipping.")
                continue
            node_id = assembler.node_map[node_name]
            start_output_idx = int(idx_str)
            current_connections[type_name] = [node_id, start_output_idx]
        elif start_node_id:
            start_output_idx = int(key)
            current_connections[type_name] = [start_node_id, start_output_idx]
        else:
            print(f"Warning: LoRA chain has no 'start' node defined, and an output_map key '{key}' is not in 'node:index' format. Skipping this connection.")


    input_map = chain_definition.get('input_map', {})
    chain_output_map = chain_definition.get('template_output_map', { "0": "model", "1": "clip" })

    for item_data in chain_items:
        template_name = chain_definition['template']
        template = assembler._get_node_template(template_name)
        node_data = deepcopy(template)
        
        for param_name, value in item_data.items():
            if param_name in node_data['inputs']:
                node_data['inputs'][param_name] = value
        
        for type_name, input_name in input_map.items():
            if type_name in current_connections:
                node_data['inputs'][input_name] = current_connections[type_name]

        new_node_id = assembler._get_unique_id()
        assembler.workflow[new_node_id] = node_data
        
        for idx_str, type_name in chain_output_map.items():
            current_connections[type_name] = [new_node_id, int(idx_str)]

    end_input_map = chain_definition.get('end_input_map', {})
    for type_name, targets in end_input_map.items():
        if type_name in current_connections:
            if not isinstance(targets, list):
                targets = [targets]
            
            for target_str in targets:
                end_node_name, end_input_name = target_str.split(':')
                if end_node_name in assembler.node_map:
                    end_node_id = assembler.node_map[end_node_name]
                    assembler.workflow[end_node_id]['inputs'][end_input_name] = current_connections[type_name]
                else:
                    print(f"Warning: End node '{end_node_name}' for dynamic chain not found. Skipping connection.")