signsur4739379373 commited on
Commit
717607c
·
verified ·
1 Parent(s): 8ed76e5
Files changed (1) hide show
  1. custom_nodes/bernini_chunk/nodes.py +135 -0
custom_nodes/bernini_chunk/nodes.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bernini-R chunk nodes — state caching for chunked ZeroGPU generation.
3
+ Dropped into custom_nodes/bernini_chunk/ of the Bernini-R-Lightning space.
4
+
5
+ Works with KSamplerAdvanced (Kijai) instead of SamplerCustomAdvanced+Rudra fork.
6
+ No relay mask — Bernini-R doesn't need it. Pure save/load + sigma step control.
7
+ """
8
+ from __future__ import annotations
9
+ import os
10
+ import torch
11
+
12
+ _TMP = "/tmp"
13
+
14
+
15
+ def _lat_path(session_id: str) -> str:
16
+ return os.path.join(_TMP, f"bernini_lat_{session_id}.pt")
17
+
18
+
19
+ def _cond_path(session_id: str) -> str:
20
+ return os.path.join(_TMP, f"bernini_cond_{session_id}.pt")
21
+
22
+
23
+ class BerniniChunkSaveLatent:
24
+ """Save KSamplerAdvanced output latent to /tmp for the next chunk."""
25
+
26
+ @classmethod
27
+ def INPUT_TYPES(cls):
28
+ return {"required": {
29
+ "samples": ("LATENT",),
30
+ "session_id": ("STRING", {"default": ""}),
31
+ }}
32
+
33
+ RETURN_TYPES = ("LATENT",)
34
+ RETURN_NAMES = ("samples",)
35
+ FUNCTION = "save"
36
+ CATEGORY = "Bernini/Chunk"
37
+ OUTPUT_NODE = True
38
+
39
+ def save(self, samples, session_id):
40
+ if session_id:
41
+ path = _lat_path(session_id)
42
+ torch.save({"samples": samples["samples"].cpu()}, path)
43
+ print(f"[BerniniChunk] saved latent -> {path}", flush=True)
44
+ return (samples,)
45
+
46
+
47
+ class BerniniChunkLoadLatent:
48
+ """Load previously saved latent. Falls back to a provided default if none exists."""
49
+
50
+ @classmethod
51
+ def INPUT_TYPES(cls):
52
+ return {"required": {
53
+ "fallback": ("LATENT",),
54
+ "session_id": ("STRING", {"default": ""}),
55
+ }}
56
+
57
+ RETURN_TYPES = ("LATENT",)
58
+ RETURN_NAMES = ("samples",)
59
+ FUNCTION = "load"
60
+ CATEGORY = "Bernini/Chunk"
61
+
62
+ def load(self, fallback, session_id):
63
+ if session_id:
64
+ path = _lat_path(session_id)
65
+ if os.path.exists(path):
66
+ data = torch.load(path, map_location="cpu", weights_only=False)
67
+ print(f"[BerniniChunk] loaded latent <- {path}", flush=True)
68
+ return ({"samples": data["samples"]},)
69
+ return (fallback,)
70
+
71
+
72
+ class BerniniChunkSaveCond:
73
+ """Save positive/negative conditioning after step 0 (planning/text-encode)."""
74
+
75
+ @classmethod
76
+ def INPUT_TYPES(cls):
77
+ return {"required": {
78
+ "positive": ("CONDITIONING",),
79
+ "negative": ("CONDITIONING",),
80
+ "session_id": ("STRING", {"default": ""}),
81
+ }}
82
+
83
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
84
+ RETURN_NAMES = ("positive", "negative")
85
+ FUNCTION = "save"
86
+ CATEGORY = "Bernini/Chunk"
87
+ OUTPUT_NODE = True
88
+
89
+ def save(self, positive, negative, session_id):
90
+ if session_id:
91
+ path = _cond_path(session_id)
92
+ torch.save({"positive": positive, "negative": negative}, path)
93
+ print(f"[BerniniChunk] saved conditioning -> {path}", flush=True)
94
+ return (positive, negative)
95
+
96
+
97
+ class BerniniChunkLoadCond:
98
+ """Load saved conditioning. Falls back to fresh if no saved file exists."""
99
+
100
+ @classmethod
101
+ def INPUT_TYPES(cls):
102
+ return {"required": {
103
+ "fallback_pos": ("CONDITIONING",),
104
+ "fallback_neg": ("CONDITIONING",),
105
+ "session_id": ("STRING", {"default": ""}),
106
+ }}
107
+
108
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
109
+ RETURN_NAMES = ("positive", "negative")
110
+ FUNCTION = "load"
111
+ CATEGORY = "Bernini/Chunk"
112
+
113
+ def load(self, fallback_pos, fallback_neg, session_id):
114
+ if session_id:
115
+ path = _cond_path(session_id)
116
+ if os.path.exists(path):
117
+ data = torch.load(path, map_location="cpu", weights_only=False)
118
+ print(f"[BerniniChunk] loaded conditioning <- {path}", flush=True)
119
+ return (data["positive"], data["negative"])
120
+ return (fallback_pos, fallback_neg)
121
+
122
+
123
+ NODE_CLASS_MAPPINGS = {
124
+ "BerniniChunkSaveLatent": BerniniChunkSaveLatent,
125
+ "BerniniChunkLoadLatent": BerniniChunkLoadLatent,
126
+ "BerniniChunkSaveCond": BerniniChunkSaveCond,
127
+ "BerniniChunkLoadCond": BerniniChunkLoadCond,
128
+ }
129
+
130
+ NODE_DISPLAY_NAME_MAPPINGS = {
131
+ "BerniniChunkSaveLatent": "Bernini · Chunk Save Latent",
132
+ "BerniniChunkLoadLatent": "Bernini · Chunk Load Latent",
133
+ "BerniniChunkSaveCond": "Bernini · Chunk Save Conditioning",
134
+ "BerniniChunkLoadCond": "Bernini · Chunk Load Conditioning",
135
+ }