Spaces:
Sleeping
Sleeping
Add model size dropdown (small/large), lazy-load LLM sessions, infer KV geometry from ONNX shapes
Browse files
app.py
CHANGED
|
@@ -22,10 +22,6 @@ COND_OFFSET = 7 # NUM_RESERVED + 1, added to every conditioning integ
|
|
| 22 |
COND_LEN = 144 # 12 style + 128 notes + 1 drum + 3 cfg
|
| 23 |
VOCAB_SIZE = NUM_RESERVED + NUM_CB * CODEBOOK # 12294
|
| 24 |
|
| 25 |
-
# mrt2_small KV geometry
|
| 26 |
-
T_LAYERS, T_W, T_H, T_HD = 12, 41, 8, 128 # temporal
|
| 27 |
-
D_LAYERS, D_W, D_H, D_HD = 2, 12, 6, 128 # depth
|
| 28 |
-
|
| 29 |
# note states
|
| 30 |
NOTE_MASKED, NOTE_OFF, NOTE_ON = -1, 0, 3
|
| 31 |
DRUM_MASKED = -1
|
|
@@ -117,25 +113,72 @@ def _sess(path: str) -> ort.InferenceSession:
|
|
| 117 |
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 118 |
return ort.InferenceSession(path, opts, providers=["CPUExecutionProvider"])
|
| 119 |
|
|
|
|
| 120 |
text_enc_s = _sess(f"{MODEL_PATH}/musiccoca/text_encoder.onnx")
|
| 121 |
mapper_s = _sess(f"{MODEL_PATH}/musiccoca/mapper.onnx")
|
| 122 |
vq_s = _sess(f"{MODEL_PATH}/musiccoca/pretrained_vector_quantizer.onnx")
|
| 123 |
-
enc_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/encoder.onnx")
|
| 124 |
-
temp_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/temporal_step.onnx")
|
| 125 |
-
depth_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/depth_step.onnx")
|
| 126 |
-
embed_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/embed.onnx")
|
| 127 |
dec_s = _sess(f"{MODEL_PATH}/spectrostream/decoder.onnx")
|
| 128 |
sp = spm_lib.SentencePieceProcessor(model_file=f"{MODEL_PATH}/musiccoca/spm.model")
|
| 129 |
|
| 130 |
-
|
| 131 |
-
for name, s in [("text_enc", text_enc_s), ("mapper", mapper_s), ("vq", vq_s),
|
| 132 |
-
("enc", enc_s), ("temporal", temp_s), ("depth", depth_s),
|
| 133 |
-
("embed", embed_s), ("decoder", dec_s)]:
|
| 134 |
ins = {i.name: i.shape for i in s.get_inputs()}
|
| 135 |
outs = {o.name: o.shape for o in s.get_outputs()}
|
| 136 |
print(f"[{name}] inputs: {ins}")
|
| 137 |
print(f"[{name}] outputs: {outs}")
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# ---------- helper: discretize CFG ----------
|
| 140 |
def _disc_cfg(v: float, step: float, max_bin: int) -> int:
|
| 141 |
c = max(-1.0, min(7.0, v))
|
|
@@ -143,7 +186,7 @@ def _disc_cfg(v: float, step: float, max_bin: int) -> int:
|
|
| 143 |
|
| 144 |
# ---------- conditioning vector ----------
|
| 145 |
def build_cond(style_tokens: list, notes: list, cfg_mcc=CFG_MCC, cfg_notes=CFG_NOTES, cfg_drums=CFG_DRUMS) -> np.ndarray:
|
| 146 |
-
"""Build 144-length cond vector, shifted by COND_OFFSET, shape [1,1,144]
|
| 147 |
out = [0] * COND_LEN
|
| 148 |
k = 0
|
| 149 |
for i in range(NUM_CB):
|
|
@@ -169,7 +212,6 @@ def encode_text(prompt: str) -> list:
|
|
| 169 |
pad_mask = np.ones((1, 128), dtype=np.float32)
|
| 170 |
pad_mask[0, :len(ids_raw)+1] = 0.0
|
| 171 |
|
| 172 |
-
# text encoder - feed by position (name detection fallback)
|
| 173 |
enc_inputs = text_enc_s.get_inputs()
|
| 174 |
feed_enc = {}
|
| 175 |
for inp in enc_inputs:
|
|
@@ -179,7 +221,6 @@ def encode_text(prompt: str) -> list:
|
|
| 179 |
feed_enc[inp.name] = ids
|
| 180 |
emb = text_enc_s.run(None, feed_enc)[0] # [1, 768]
|
| 181 |
|
| 182 |
-
# mapper - feed by position (args_0=emb, args_1=noise)
|
| 183 |
map_inputs = mapper_s.get_inputs()
|
| 184 |
feed_map = {}
|
| 185 |
for inp in map_inputs:
|
|
@@ -189,13 +230,11 @@ def encode_text(prompt: str) -> list:
|
|
| 189 |
feed_map[inp.name] = emb
|
| 190 |
mapped = mapper_s.run(None, feed_map)[0] # [1, 768]
|
| 191 |
|
| 192 |
-
# L2 normalize
|
| 193 |
norm = np.linalg.norm(mapped)
|
| 194 |
if norm > 1e-8:
|
| 195 |
mapped = mapped / norm
|
| 196 |
|
| 197 |
-
|
| 198 |
-
style_tokens = vq_s.run(None, {vq_s.get_inputs()[0].name: mapped})[0] # [1, 12]
|
| 199 |
return style_tokens.reshape(-1).tolist()
|
| 200 |
|
| 201 |
# ---------- sampling ----------
|
|
@@ -225,6 +264,7 @@ def generate(
|
|
| 225 |
n_seconds: float,
|
| 226 |
temperature: float,
|
| 227 |
cfg_mcc: float,
|
|
|
|
| 228 |
progress=gr.Progress(track_tqdm=True),
|
| 229 |
) -> tuple:
|
| 230 |
import json
|
|
@@ -236,17 +276,21 @@ def generate(
|
|
| 236 |
|
| 237 |
n_frames = max(4, int(n_seconds * 25))
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
style_tokens = encode_text(prompt)
|
| 242 |
|
| 243 |
-
# Conditioning vector
|
| 244 |
cond = build_cond(style_tokens, held_notes, cfg_mcc=cfg_mcc)
|
| 245 |
-
|
| 246 |
-
# Encode conditioning -> enc_out [1,1,256]
|
| 247 |
enc_out_arr = enc_s.run(None, {enc_s.get_inputs()[0].name: cond})[0]
|
| 248 |
|
| 249 |
-
# Init temporal KV cache (zeros)
|
| 250 |
psk = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
| 251 |
psv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
| 252 |
pck = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
|
@@ -254,16 +298,14 @@ def generate(
|
|
| 254 |
prev_codes = np.zeros((1, NUM_CB), dtype=np.int64)
|
| 255 |
cache_pos = 0
|
| 256 |
|
| 257 |
-
# Cache output name lists (same every frame)
|
| 258 |
t_out_names = [o.name for o in temp_s.get_outputs()]
|
| 259 |
d_out_names = [o.name for o in depth_s.get_outputs()]
|
| 260 |
|
| 261 |
all_codec_frames = []
|
| 262 |
|
| 263 |
for f in range(n_frames):
|
| 264 |
-
progress((f + 1) / (n_frames + 1), desc=f"Generating frame {f+1}/{n_frames}
|
| 265 |
|
| 266 |
-
# Temporal step
|
| 267 |
feed_t = {
|
| 268 |
"prev_codes": prev_codes,
|
| 269 |
"enc_out": enc_out_arr,
|
|
@@ -276,17 +318,16 @@ def generate(
|
|
| 276 |
feed_t[f"past_cross_v.{i}"] = pcv[i]
|
| 277 |
|
| 278 |
t_out_dict = dict(zip(t_out_names, temp_s.run(None, feed_t)))
|
| 279 |
-
temporal_out = t_out_dict["temporal_out"]
|
| 280 |
for i in range(T_LAYERS):
|
| 281 |
psk[i] = t_out_dict[f"present_self_k.{i}"]
|
| 282 |
psv[i] = t_out_dict[f"present_self_v.{i}"]
|
| 283 |
pck[i] = t_out_dict[f"present_cross_k.{i}"]
|
| 284 |
pcv[i] = t_out_dict[f"present_cross_v.{i}"]
|
| 285 |
|
| 286 |
-
# Depth loop: generate 12 unique-scheme tokens per frame
|
| 287 |
dk = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
|
| 288 |
dv = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
|
| 289 |
-
depth_in = temporal_out
|
| 290 |
unique_codes = []
|
| 291 |
|
| 292 |
for level in range(NUM_CB):
|
|
@@ -299,7 +340,7 @@ def generate(
|
|
| 299 |
feed_d[f"past_v.{i}"] = dv[i]
|
| 300 |
|
| 301 |
d_out_dict = dict(zip(d_out_names, depth_s.run(None, feed_d)))
|
| 302 |
-
logits = d_out_dict["logits"]
|
| 303 |
for i in range(D_LAYERS):
|
| 304 |
dk[i] = d_out_dict[f"present_k.{i}"]
|
| 305 |
dv[i] = d_out_dict[f"present_v.{i}"]
|
|
@@ -311,7 +352,7 @@ def generate(
|
|
| 311 |
|
| 312 |
if level < NUM_CB - 1:
|
| 313 |
e_out = embed_s.run(None, {"token": np.array([token], dtype=np.int64)})
|
| 314 |
-
depth_in = e_out[0]
|
| 315 |
|
| 316 |
codec_frame = to_codec(unique_codes)
|
| 317 |
all_codec_frames.append(codec_frame)
|
|
@@ -321,13 +362,11 @@ def generate(
|
|
| 321 |
if len(all_codec_frames) < 2:
|
| 322 |
return (SAMPLE_RATE, np.zeros((FRAME_SAMPLES * 2, 2), dtype=np.float32))
|
| 323 |
|
| 324 |
-
# SpectroStream batch decode: pass all T frames, get (T-1)*1920 stereo samples
|
| 325 |
progress(0.98, desc="Decoding audio...")
|
| 326 |
codes_arr = np.array(all_codec_frames, dtype=np.int32).reshape(1, len(all_codec_frames), NUM_CB)
|
| 327 |
audio_raw = dec_s.run(None, {"codes": codes_arr})[0] # [1, (T-1)*1920, 2]
|
| 328 |
-
audio = audio_raw.squeeze(0)
|
| 329 |
|
| 330 |
-
# Clamp
|
| 331 |
audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
|
| 332 |
progress(1.0, desc="Done!")
|
| 333 |
return (SAMPLE_RATE, audio)
|
|
@@ -398,8 +437,6 @@ PIANO_HTML = """
|
|
| 398 |
{midi:79,n:'G5'},{midi:81,n:'A5'},{midi:83,n:'B5'},
|
| 399 |
{midi:84,n:'C6'}
|
| 400 |
];
|
| 401 |
-
// Each entry: MIDI -> [white-key-index-of-left-neighbor, 0]
|
| 402 |
-
// C#/Db is right of C (index 0,7,14), D# right of D (1,8,15), etc.
|
| 403 |
const BLACK_POSITIONS = {
|
| 404 |
49:[0,0], 51:[1,0], 54:[3,0], 56:[4,0], 58:[5,0],
|
| 405 |
61:[7,0], 63:[8,0], 66:[10,0], 68:[11,0], 70:[12,0],
|
|
@@ -410,7 +447,6 @@ PIANO_HTML = """
|
|
| 410 |
let held = new Set();
|
| 411 |
const piano = document.getElementById('piano');
|
| 412 |
|
| 413 |
-
// Draw white keys
|
| 414 |
WHITE_NOTES.forEach((wk, idx) => {
|
| 415 |
const el = document.createElement('div');
|
| 416 |
el.className = 'white-key';
|
|
@@ -420,12 +456,9 @@ PIANO_HTML = """
|
|
| 420 |
piano.appendChild(el);
|
| 421 |
});
|
| 422 |
|
| 423 |
-
// Draw black keys
|
| 424 |
const whiteKeys = piano.querySelectorAll('.white-key');
|
| 425 |
Object.entries(BLACK_POSITIONS).forEach(([midi, [wIdx, _]]) => {
|
| 426 |
if (wIdx >= whiteKeys.length) return;
|
| 427 |
-
const ref = whiteKeys[wIdx].getBoundingClientRect
|
| 428 |
-
? whiteKeys[wIdx] : null;
|
| 429 |
const el = document.createElement('div');
|
| 430 |
el.className = 'black-key';
|
| 431 |
el.dataset.midi = midi;
|
|
@@ -476,10 +509,10 @@ PIANO_HTML = """
|
|
| 476 |
</script>
|
| 477 |
"""
|
| 478 |
|
| 479 |
-
def _generate_wrapper(prompt, notes_json, n_seconds, temperature, cfg_mcc, progress=gr.Progress()):
|
| 480 |
if not prompt.strip():
|
| 481 |
prompt = "smooth jazz piano"
|
| 482 |
-
return generate(prompt, notes_json, n_seconds, temperature, cfg_mcc, progress)
|
| 483 |
|
| 484 |
with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
|
| 485 |
gr.HTML("""
|
|
@@ -510,6 +543,12 @@ with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
|
|
| 510 |
value="smooth jazz piano, warm, relaxed",
|
| 511 |
lines=2
|
| 512 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
n_seconds = gr.Slider(1, 20, value=5, step=1, label="Duration (seconds)")
|
| 514 |
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature (creativity)")
|
| 515 |
cfg_mcc = gr.Slider(0.0, 6.0, value=1.6, step=0.1, label="Style guidance strength")
|
|
@@ -522,13 +561,14 @@ with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
|
|
| 522 |
<b style="color:#aaa;">How to use:</b>
|
| 523 |
Click piano keys to hold notes (click again to release) - they steer the melody.
|
| 524 |
Type a style prompt, set duration, then hit Generate.
|
| 525 |
-
<br>CPU generation: ~1-3 min for 5s audio. No MIDI device needed.
|
|
|
|
| 526 |
</div>
|
| 527 |
""")
|
| 528 |
|
| 529 |
gen_btn.click(
|
| 530 |
fn=_generate_wrapper,
|
| 531 |
-
inputs=[prompt_in, notes_state, n_seconds, temperature, cfg_mcc],
|
| 532 |
outputs=[audio_out],
|
| 533 |
)
|
| 534 |
|
|
|
|
| 22 |
COND_LEN = 144 # 12 style + 128 notes + 1 drum + 3 cfg
|
| 23 |
VOCAB_SIZE = NUM_RESERVED + NUM_CB * CODEBOOK # 12294
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# note states
|
| 26 |
NOTE_MASKED, NOTE_OFF, NOTE_ON = -1, 0, 3
|
| 27 |
DRUM_MASKED = -1
|
|
|
|
| 113 |
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 114 |
return ort.InferenceSession(path, opts, providers=["CPUExecutionProvider"])
|
| 115 |
|
| 116 |
+
# Shared sessions (same for both model sizes)
|
| 117 |
text_enc_s = _sess(f"{MODEL_PATH}/musiccoca/text_encoder.onnx")
|
| 118 |
mapper_s = _sess(f"{MODEL_PATH}/musiccoca/mapper.onnx")
|
| 119 |
vq_s = _sess(f"{MODEL_PATH}/musiccoca/pretrained_vector_quantizer.onnx")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
dec_s = _sess(f"{MODEL_PATH}/spectrostream/decoder.onnx")
|
| 121 |
sp = spm_lib.SentencePieceProcessor(model_file=f"{MODEL_PATH}/musiccoca/spm.model")
|
| 122 |
|
| 123 |
+
for name, s in [("text_enc", text_enc_s), ("mapper", mapper_s), ("vq", vq_s), ("decoder", dec_s)]:
|
|
|
|
|
|
|
|
|
|
| 124 |
ins = {i.name: i.shape for i in s.get_inputs()}
|
| 125 |
outs = {o.name: o.shape for o in s.get_outputs()}
|
| 126 |
print(f"[{name}] inputs: {ins}")
|
| 127 |
print(f"[{name}] outputs: {outs}")
|
| 128 |
|
| 129 |
+
# ---------- LLM variant loading (lazy, cached) ----------
|
| 130 |
+
_llm_cache: dict = {}
|
| 131 |
+
|
| 132 |
+
def _infer_geometry(temp_sess, depth_sess) -> tuple:
|
| 133 |
+
"""Read KV cache geometry from ONNX input shapes - works for any model size."""
|
| 134 |
+
t_in = {i.name: i.shape for i in temp_sess.get_inputs()}
|
| 135 |
+
T_LAYERS = sum(1 for k in t_in if k.startswith("past_self_k."))
|
| 136 |
+
s = t_in["past_self_k.0"] # ['B', window, heads, head_dim]
|
| 137 |
+
T_W, T_H, T_HD = int(s[1]), int(s[2]), int(s[3])
|
| 138 |
+
|
| 139 |
+
d_in = {i.name: i.shape for i in depth_sess.get_inputs()}
|
| 140 |
+
D_LAYERS = sum(1 for k in d_in if k.startswith("past_k."))
|
| 141 |
+
s = d_in["past_k.0"] # ['B', window, heads, head_dim]
|
| 142 |
+
D_W, D_H, D_HD = int(s[1]), int(s[2]), int(s[3])
|
| 143 |
+
|
| 144 |
+
return T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD
|
| 145 |
+
|
| 146 |
+
def _load_llm(size: str) -> dict:
|
| 147 |
+
if size in _llm_cache:
|
| 148 |
+
return _llm_cache[size]
|
| 149 |
+
base = f"{MODEL_PATH}/mrt2_{size}/onnx"
|
| 150 |
+
if not os.path.isdir(base):
|
| 151 |
+
raise FileNotFoundError(f"Model variant '{size}' not found at {base}")
|
| 152 |
+
print(f"Loading LLM sessions for mrt2_{size}...")
|
| 153 |
+
enc = _sess(f"{base}/encoder.onnx")
|
| 154 |
+
temp = _sess(f"{base}/temporal_step.onnx")
|
| 155 |
+
depth = _sess(f"{base}/depth_step.onnx")
|
| 156 |
+
embed = _sess(f"{base}/embed.onnx")
|
| 157 |
+
T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD = _infer_geometry(temp, depth)
|
| 158 |
+
print(f" mrt2_{size}: T_LAYERS={T_LAYERS} T_W={T_W} T_H={T_H} T_HD={T_HD} "
|
| 159 |
+
f"D_LAYERS={D_LAYERS} D_W={D_W} D_H={D_H} D_HD={D_HD}")
|
| 160 |
+
for name, s in [("enc", enc), ("temporal", temp), ("depth", depth), ("embed", embed)]:
|
| 161 |
+
ins = {i.name: i.shape for i in s.get_inputs()}
|
| 162 |
+
outs = {o.name: o.shape for o in s.get_outputs()}
|
| 163 |
+
print(f"[{name}] inputs: {ins}")
|
| 164 |
+
print(f"[{name}] outputs: {outs}")
|
| 165 |
+
result = dict(enc=enc, temp=temp, depth=depth, embed=embed,
|
| 166 |
+
T_LAYERS=T_LAYERS, T_W=T_W, T_H=T_H, T_HD=T_HD,
|
| 167 |
+
D_LAYERS=D_LAYERS, D_W=D_W, D_H=D_H, D_HD=D_HD)
|
| 168 |
+
_llm_cache[size] = result
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
# Pre-warm small at startup
|
| 172 |
+
_load_llm("small")
|
| 173 |
+
|
| 174 |
+
# Detect available sizes
|
| 175 |
+
_SIZES_AVAILABLE = ["small"]
|
| 176 |
+
if os.path.isdir(f"{MODEL_PATH}/mrt2_large"):
|
| 177 |
+
_SIZES_AVAILABLE.append("large")
|
| 178 |
+
print("Large model variant detected - will be available in UI")
|
| 179 |
+
else:
|
| 180 |
+
print("No mrt2_large directory found - only small available")
|
| 181 |
+
|
| 182 |
# ---------- helper: discretize CFG ----------
|
| 183 |
def _disc_cfg(v: float, step: float, max_bin: int) -> int:
|
| 184 |
c = max(-1.0, min(7.0, v))
|
|
|
|
| 186 |
|
| 187 |
# ---------- conditioning vector ----------
|
| 188 |
def build_cond(style_tokens: list, notes: list, cfg_mcc=CFG_MCC, cfg_notes=CFG_NOTES, cfg_drums=CFG_DRUMS) -> np.ndarray:
|
| 189 |
+
"""Build 144-length cond vector, shifted by COND_OFFSET, shape [1,1,144] int64."""
|
| 190 |
out = [0] * COND_LEN
|
| 191 |
k = 0
|
| 192 |
for i in range(NUM_CB):
|
|
|
|
| 212 |
pad_mask = np.ones((1, 128), dtype=np.float32)
|
| 213 |
pad_mask[0, :len(ids_raw)+1] = 0.0
|
| 214 |
|
|
|
|
| 215 |
enc_inputs = text_enc_s.get_inputs()
|
| 216 |
feed_enc = {}
|
| 217 |
for inp in enc_inputs:
|
|
|
|
| 221 |
feed_enc[inp.name] = ids
|
| 222 |
emb = text_enc_s.run(None, feed_enc)[0] # [1, 768]
|
| 223 |
|
|
|
|
| 224 |
map_inputs = mapper_s.get_inputs()
|
| 225 |
feed_map = {}
|
| 226 |
for inp in map_inputs:
|
|
|
|
| 230 |
feed_map[inp.name] = emb
|
| 231 |
mapped = mapper_s.run(None, feed_map)[0] # [1, 768]
|
| 232 |
|
|
|
|
| 233 |
norm = np.linalg.norm(mapped)
|
| 234 |
if norm > 1e-8:
|
| 235 |
mapped = mapped / norm
|
| 236 |
|
| 237 |
+
style_tokens = vq_s.run(None, {vq_s.get_inputs()[0].name: mapped})[0]
|
|
|
|
| 238 |
return style_tokens.reshape(-1).tolist()
|
| 239 |
|
| 240 |
# ---------- sampling ----------
|
|
|
|
| 264 |
n_seconds: float,
|
| 265 |
temperature: float,
|
| 266 |
cfg_mcc: float,
|
| 267 |
+
model_size: str,
|
| 268 |
progress=gr.Progress(track_tqdm=True),
|
| 269 |
) -> tuple:
|
| 270 |
import json
|
|
|
|
| 276 |
|
| 277 |
n_frames = max(4, int(n_seconds * 25))
|
| 278 |
|
| 279 |
+
progress(0.0, desc=f"Loading {model_size} model sessions...")
|
| 280 |
+
m = _load_llm(model_size)
|
| 281 |
+
enc_s = m["enc"]
|
| 282 |
+
temp_s = m["temp"]
|
| 283 |
+
depth_s = m["depth"]
|
| 284 |
+
embed_s = m["embed"]
|
| 285 |
+
T_LAYERS, T_W, T_H, T_HD = m["T_LAYERS"], m["T_W"], m["T_H"], m["T_HD"]
|
| 286 |
+
D_LAYERS, D_W, D_H, D_HD = m["D_LAYERS"], m["D_W"], m["D_H"], m["D_HD"]
|
| 287 |
+
|
| 288 |
+
progress(0.02, desc="Encoding prompt...")
|
| 289 |
style_tokens = encode_text(prompt)
|
| 290 |
|
|
|
|
| 291 |
cond = build_cond(style_tokens, held_notes, cfg_mcc=cfg_mcc)
|
|
|
|
|
|
|
| 292 |
enc_out_arr = enc_s.run(None, {enc_s.get_inputs()[0].name: cond})[0]
|
| 293 |
|
|
|
|
| 294 |
psk = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
| 295 |
psv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
| 296 |
pck = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
|
|
|
|
| 298 |
prev_codes = np.zeros((1, NUM_CB), dtype=np.int64)
|
| 299 |
cache_pos = 0
|
| 300 |
|
|
|
|
| 301 |
t_out_names = [o.name for o in temp_s.get_outputs()]
|
| 302 |
d_out_names = [o.name for o in depth_s.get_outputs()]
|
| 303 |
|
| 304 |
all_codec_frames = []
|
| 305 |
|
| 306 |
for f in range(n_frames):
|
| 307 |
+
progress((f + 1) / (n_frames + 1), desc=f"Generating frame {f+1}/{n_frames} [{model_size}]...")
|
| 308 |
|
|
|
|
| 309 |
feed_t = {
|
| 310 |
"prev_codes": prev_codes,
|
| 311 |
"enc_out": enc_out_arr,
|
|
|
|
| 318 |
feed_t[f"past_cross_v.{i}"] = pcv[i]
|
| 319 |
|
| 320 |
t_out_dict = dict(zip(t_out_names, temp_s.run(None, feed_t)))
|
| 321 |
+
temporal_out = t_out_dict["temporal_out"]
|
| 322 |
for i in range(T_LAYERS):
|
| 323 |
psk[i] = t_out_dict[f"present_self_k.{i}"]
|
| 324 |
psv[i] = t_out_dict[f"present_self_v.{i}"]
|
| 325 |
pck[i] = t_out_dict[f"present_cross_k.{i}"]
|
| 326 |
pcv[i] = t_out_dict[f"present_cross_v.{i}"]
|
| 327 |
|
|
|
|
| 328 |
dk = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
|
| 329 |
dv = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
|
| 330 |
+
depth_in = temporal_out
|
| 331 |
unique_codes = []
|
| 332 |
|
| 333 |
for level in range(NUM_CB):
|
|
|
|
| 340 |
feed_d[f"past_v.{i}"] = dv[i]
|
| 341 |
|
| 342 |
d_out_dict = dict(zip(d_out_names, depth_s.run(None, feed_d)))
|
| 343 |
+
logits = d_out_dict["logits"]
|
| 344 |
for i in range(D_LAYERS):
|
| 345 |
dk[i] = d_out_dict[f"present_k.{i}"]
|
| 346 |
dv[i] = d_out_dict[f"present_v.{i}"]
|
|
|
|
| 352 |
|
| 353 |
if level < NUM_CB - 1:
|
| 354 |
e_out = embed_s.run(None, {"token": np.array([token], dtype=np.int64)})
|
| 355 |
+
depth_in = e_out[0]
|
| 356 |
|
| 357 |
codec_frame = to_codec(unique_codes)
|
| 358 |
all_codec_frames.append(codec_frame)
|
|
|
|
| 362 |
if len(all_codec_frames) < 2:
|
| 363 |
return (SAMPLE_RATE, np.zeros((FRAME_SAMPLES * 2, 2), dtype=np.float32))
|
| 364 |
|
|
|
|
| 365 |
progress(0.98, desc="Decoding audio...")
|
| 366 |
codes_arr = np.array(all_codec_frames, dtype=np.int32).reshape(1, len(all_codec_frames), NUM_CB)
|
| 367 |
audio_raw = dec_s.run(None, {"codes": codes_arr})[0] # [1, (T-1)*1920, 2]
|
| 368 |
+
audio = audio_raw.squeeze(0)
|
| 369 |
|
|
|
|
| 370 |
audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
|
| 371 |
progress(1.0, desc="Done!")
|
| 372 |
return (SAMPLE_RATE, audio)
|
|
|
|
| 437 |
{midi:79,n:'G5'},{midi:81,n:'A5'},{midi:83,n:'B5'},
|
| 438 |
{midi:84,n:'C6'}
|
| 439 |
];
|
|
|
|
|
|
|
| 440 |
const BLACK_POSITIONS = {
|
| 441 |
49:[0,0], 51:[1,0], 54:[3,0], 56:[4,0], 58:[5,0],
|
| 442 |
61:[7,0], 63:[8,0], 66:[10,0], 68:[11,0], 70:[12,0],
|
|
|
|
| 447 |
let held = new Set();
|
| 448 |
const piano = document.getElementById('piano');
|
| 449 |
|
|
|
|
| 450 |
WHITE_NOTES.forEach((wk, idx) => {
|
| 451 |
const el = document.createElement('div');
|
| 452 |
el.className = 'white-key';
|
|
|
|
| 456 |
piano.appendChild(el);
|
| 457 |
});
|
| 458 |
|
|
|
|
| 459 |
const whiteKeys = piano.querySelectorAll('.white-key');
|
| 460 |
Object.entries(BLACK_POSITIONS).forEach(([midi, [wIdx, _]]) => {
|
| 461 |
if (wIdx >= whiteKeys.length) return;
|
|
|
|
|
|
|
| 462 |
const el = document.createElement('div');
|
| 463 |
el.className = 'black-key';
|
| 464 |
el.dataset.midi = midi;
|
|
|
|
| 509 |
</script>
|
| 510 |
"""
|
| 511 |
|
| 512 |
+
def _generate_wrapper(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress=gr.Progress()):
|
| 513 |
if not prompt.strip():
|
| 514 |
prompt = "smooth jazz piano"
|
| 515 |
+
return generate(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress)
|
| 516 |
|
| 517 |
with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
|
| 518 |
gr.HTML("""
|
|
|
|
| 543 |
value="smooth jazz piano, warm, relaxed",
|
| 544 |
lines=2
|
| 545 |
)
|
| 546 |
+
model_size_dd = gr.Dropdown(
|
| 547 |
+
choices=_SIZES_AVAILABLE,
|
| 548 |
+
value="small",
|
| 549 |
+
label="Model size (large = slower but higher quality, loads on first use)",
|
| 550 |
+
interactive=len(_SIZES_AVAILABLE) > 1,
|
| 551 |
+
)
|
| 552 |
n_seconds = gr.Slider(1, 20, value=5, step=1, label="Duration (seconds)")
|
| 553 |
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature (creativity)")
|
| 554 |
cfg_mcc = gr.Slider(0.0, 6.0, value=1.6, step=0.1, label="Style guidance strength")
|
|
|
|
| 561 |
<b style="color:#aaa;">How to use:</b>
|
| 562 |
Click piano keys to hold notes (click again to release) - they steer the melody.
|
| 563 |
Type a style prompt, set duration, then hit Generate.
|
| 564 |
+
<br>CPU generation: ~1-3 min for 5s audio (small). No MIDI device needed.
|
| 565 |
+
Large model loads its sessions on first use (extra ~30s), then stays cached.
|
| 566 |
</div>
|
| 567 |
""")
|
| 568 |
|
| 569 |
gen_btn.click(
|
| 570 |
fn=_generate_wrapper,
|
| 571 |
+
inputs=[prompt_in, notes_state, n_seconds, temperature, cfg_mcc, model_size_dd],
|
| 572 |
outputs=[audio_out],
|
| 573 |
)
|
| 574 |
|