ImageGen5 / chain_injectors /lora_injector.py
RioShiina's picture
Upload folder using huggingface_hub
3a02128 verified
Raw
History Blame Contribute Delete
3.09 kB
from copy import deepcopy
def inject(assembler, chain_definition, chain_items):
if not chain_items:
return
start_node_name = chain_definition.get('start')
start_node_id = None
if start_node_name:
if start_node_name not in assembler.node_map:
print(f"Warning: Start node '{start_node_name}' for dynamic LoRA chain not found. Skipping chain.")
return
start_node_id = assembler.node_map[start_node_name]
output_map = chain_definition.get('output_map', {})
current_connections = {}
for key, type_name in output_map.items():
if ':' in str(key):
node_name, idx_str = key.split(':')
if node_name not in assembler.node_map:
print(f"Warning: Node '{node_name}' in chain's output_map not found. Skipping.")
continue
node_id = assembler.node_map[node_name]
start_output_idx = int(idx_str)
current_connections[type_name] = [node_id, start_output_idx]
elif start_node_id:
start_output_idx = int(key)
current_connections[type_name] = [start_node_id, start_output_idx]
else:
print(f"Warning: LoRA chain has no 'start' node defined, and an output_map key '{key}' is not in 'node:index' format. Skipping this connection.")
input_map = chain_definition.get('input_map', {})
chain_output_map = chain_definition.get('template_output_map', { "0": "model", "1": "clip" })
for item_data in chain_items:
template_name = chain_definition['template']
template = assembler._get_node_template(template_name)
node_data = deepcopy(template)
for param_name, value in item_data.items():
if param_name in node_data['inputs']:
node_data['inputs'][param_name] = value
for type_name, input_name in input_map.items():
if type_name in current_connections:
node_data['inputs'][input_name] = current_connections[type_name]
new_node_id = assembler._get_unique_id()
assembler.workflow[new_node_id] = node_data
for idx_str, type_name in chain_output_map.items():
current_connections[type_name] = [new_node_id, int(idx_str)]
end_input_map = chain_definition.get('end_input_map', {})
for type_name, targets in end_input_map.items():
if type_name in current_connections:
if not isinstance(targets, list):
targets = [targets]
for target_str in targets:
end_node_name, end_input_name = target_str.split(':')
if end_node_name in assembler.node_map:
end_node_id = assembler.node_map[end_node_name]
assembler.workflow[end_node_id]['inputs'][end_input_name] = current_connections[type_name]
else:
print(f"Warning: End node '{end_node_name}' for dynamic chain not found. Skipping connection.")