Nekochu commited on
Commit
69afa51
·
1 Parent(s): ec30eba

Add model size dropdown (small/large), lazy-load LLM sessions, infer KV geometry from ONNX shapes

Browse files
Files changed (1) hide show
  1. app.py +85 -45
app.py CHANGED
@@ -22,10 +22,6 @@ COND_OFFSET = 7 # NUM_RESERVED + 1, added to every conditioning integ
22
  COND_LEN = 144 # 12 style + 128 notes + 1 drum + 3 cfg
23
  VOCAB_SIZE = NUM_RESERVED + NUM_CB * CODEBOOK # 12294
24
 
25
- # mrt2_small KV geometry
26
- T_LAYERS, T_W, T_H, T_HD = 12, 41, 8, 128 # temporal
27
- D_LAYERS, D_W, D_H, D_HD = 2, 12, 6, 128 # depth
28
-
29
  # note states
30
  NOTE_MASKED, NOTE_OFF, NOTE_ON = -1, 0, 3
31
  DRUM_MASKED = -1
@@ -117,25 +113,72 @@ def _sess(path: str) -> ort.InferenceSession:
117
  opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
118
  return ort.InferenceSession(path, opts, providers=["CPUExecutionProvider"])
119
 
 
120
  text_enc_s = _sess(f"{MODEL_PATH}/musiccoca/text_encoder.onnx")
121
  mapper_s = _sess(f"{MODEL_PATH}/musiccoca/mapper.onnx")
122
  vq_s = _sess(f"{MODEL_PATH}/musiccoca/pretrained_vector_quantizer.onnx")
123
- enc_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/encoder.onnx")
124
- temp_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/temporal_step.onnx")
125
- depth_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/depth_step.onnx")
126
- embed_s = _sess(f"{MODEL_PATH}/mrt2_small/onnx/embed.onnx")
127
  dec_s = _sess(f"{MODEL_PATH}/spectrostream/decoder.onnx")
128
  sp = spm_lib.SentencePieceProcessor(model_file=f"{MODEL_PATH}/musiccoca/spm.model")
129
 
130
- # Log I/O names at startup for debugging
131
- for name, s in [("text_enc", text_enc_s), ("mapper", mapper_s), ("vq", vq_s),
132
- ("enc", enc_s), ("temporal", temp_s), ("depth", depth_s),
133
- ("embed", embed_s), ("decoder", dec_s)]:
134
  ins = {i.name: i.shape for i in s.get_inputs()}
135
  outs = {o.name: o.shape for o in s.get_outputs()}
136
  print(f"[{name}] inputs: {ins}")
137
  print(f"[{name}] outputs: {outs}")
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # ---------- helper: discretize CFG ----------
140
  def _disc_cfg(v: float, step: float, max_bin: int) -> int:
141
  c = max(-1.0, min(7.0, v))
@@ -143,7 +186,7 @@ def _disc_cfg(v: float, step: float, max_bin: int) -> int:
143
 
144
  # ---------- conditioning vector ----------
145
  def build_cond(style_tokens: list, notes: list, cfg_mcc=CFG_MCC, cfg_notes=CFG_NOTES, cfg_drums=CFG_DRUMS) -> np.ndarray:
146
- """Build 144-length cond vector, shifted by COND_OFFSET, shape [1,1,144] int32."""
147
  out = [0] * COND_LEN
148
  k = 0
149
  for i in range(NUM_CB):
@@ -169,7 +212,6 @@ def encode_text(prompt: str) -> list:
169
  pad_mask = np.ones((1, 128), dtype=np.float32)
170
  pad_mask[0, :len(ids_raw)+1] = 0.0
171
 
172
- # text encoder - feed by position (name detection fallback)
173
  enc_inputs = text_enc_s.get_inputs()
174
  feed_enc = {}
175
  for inp in enc_inputs:
@@ -179,7 +221,6 @@ def encode_text(prompt: str) -> list:
179
  feed_enc[inp.name] = ids
180
  emb = text_enc_s.run(None, feed_enc)[0] # [1, 768]
181
 
182
- # mapper - feed by position (args_0=emb, args_1=noise)
183
  map_inputs = mapper_s.get_inputs()
184
  feed_map = {}
185
  for inp in map_inputs:
@@ -189,13 +230,11 @@ def encode_text(prompt: str) -> list:
189
  feed_map[inp.name] = emb
190
  mapped = mapper_s.run(None, feed_map)[0] # [1, 768]
191
 
192
- # L2 normalize
193
  norm = np.linalg.norm(mapped)
194
  if norm > 1e-8:
195
  mapped = mapped / norm
196
 
197
- # VQ quantizer
198
- style_tokens = vq_s.run(None, {vq_s.get_inputs()[0].name: mapped})[0] # [1, 12]
199
  return style_tokens.reshape(-1).tolist()
200
 
201
  # ---------- sampling ----------
@@ -225,6 +264,7 @@ def generate(
225
  n_seconds: float,
226
  temperature: float,
227
  cfg_mcc: float,
 
228
  progress=gr.Progress(track_tqdm=True),
229
  ) -> tuple:
230
  import json
@@ -236,17 +276,21 @@ def generate(
236
 
237
  n_frames = max(4, int(n_seconds * 25))
238
 
239
- # Style tokens from text
240
- progress(0.0, desc="Encoding prompt...")
 
 
 
 
 
 
 
 
241
  style_tokens = encode_text(prompt)
242
 
243
- # Conditioning vector
244
  cond = build_cond(style_tokens, held_notes, cfg_mcc=cfg_mcc)
245
-
246
- # Encode conditioning -> enc_out [1,1,256]
247
  enc_out_arr = enc_s.run(None, {enc_s.get_inputs()[0].name: cond})[0]
248
 
249
- # Init temporal KV cache (zeros)
250
  psk = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
251
  psv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
252
  pck = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
@@ -254,16 +298,14 @@ def generate(
254
  prev_codes = np.zeros((1, NUM_CB), dtype=np.int64)
255
  cache_pos = 0
256
 
257
- # Cache output name lists (same every frame)
258
  t_out_names = [o.name for o in temp_s.get_outputs()]
259
  d_out_names = [o.name for o in depth_s.get_outputs()]
260
 
261
  all_codec_frames = []
262
 
263
  for f in range(n_frames):
264
- progress((f + 1) / (n_frames + 1), desc=f"Generating frame {f+1}/{n_frames} (this takes a while on CPU...)")
265
 
266
- # Temporal step
267
  feed_t = {
268
  "prev_codes": prev_codes,
269
  "enc_out": enc_out_arr,
@@ -276,17 +318,16 @@ def generate(
276
  feed_t[f"past_cross_v.{i}"] = pcv[i]
277
 
278
  t_out_dict = dict(zip(t_out_names, temp_s.run(None, feed_t)))
279
- temporal_out = t_out_dict["temporal_out"] # [1,1,1024]
280
  for i in range(T_LAYERS):
281
  psk[i] = t_out_dict[f"present_self_k.{i}"]
282
  psv[i] = t_out_dict[f"present_self_v.{i}"]
283
  pck[i] = t_out_dict[f"present_cross_k.{i}"]
284
  pcv[i] = t_out_dict[f"present_cross_v.{i}"]
285
 
286
- # Depth loop: generate 12 unique-scheme tokens per frame
287
  dk = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
288
  dv = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
289
- depth_in = temporal_out # [1,1,1024]
290
  unique_codes = []
291
 
292
  for level in range(NUM_CB):
@@ -299,7 +340,7 @@ def generate(
299
  feed_d[f"past_v.{i}"] = dv[i]
300
 
301
  d_out_dict = dict(zip(d_out_names, depth_s.run(None, feed_d)))
302
- logits = d_out_dict["logits"] # [1, VOCAB_SIZE]
303
  for i in range(D_LAYERS):
304
  dk[i] = d_out_dict[f"present_k.{i}"]
305
  dv[i] = d_out_dict[f"present_v.{i}"]
@@ -311,7 +352,7 @@ def generate(
311
 
312
  if level < NUM_CB - 1:
313
  e_out = embed_s.run(None, {"token": np.array([token], dtype=np.int64)})
314
- depth_in = e_out[0] # [1,1,1024]
315
 
316
  codec_frame = to_codec(unique_codes)
317
  all_codec_frames.append(codec_frame)
@@ -321,13 +362,11 @@ def generate(
321
  if len(all_codec_frames) < 2:
322
  return (SAMPLE_RATE, np.zeros((FRAME_SAMPLES * 2, 2), dtype=np.float32))
323
 
324
- # SpectroStream batch decode: pass all T frames, get (T-1)*1920 stereo samples
325
  progress(0.98, desc="Decoding audio...")
326
  codes_arr = np.array(all_codec_frames, dtype=np.int32).reshape(1, len(all_codec_frames), NUM_CB)
327
  audio_raw = dec_s.run(None, {"codes": codes_arr})[0] # [1, (T-1)*1920, 2]
328
- audio = audio_raw.squeeze(0) # [(T-1)*1920, 2]
329
 
330
- # Clamp
331
  audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
332
  progress(1.0, desc="Done!")
333
  return (SAMPLE_RATE, audio)
@@ -398,8 +437,6 @@ PIANO_HTML = """
398
  {midi:79,n:'G5'},{midi:81,n:'A5'},{midi:83,n:'B5'},
399
  {midi:84,n:'C6'}
400
  ];
401
- // Each entry: MIDI -> [white-key-index-of-left-neighbor, 0]
402
- // C#/Db is right of C (index 0,7,14), D# right of D (1,8,15), etc.
403
  const BLACK_POSITIONS = {
404
  49:[0,0], 51:[1,0], 54:[3,0], 56:[4,0], 58:[5,0],
405
  61:[7,0], 63:[8,0], 66:[10,0], 68:[11,0], 70:[12,0],
@@ -410,7 +447,6 @@ PIANO_HTML = """
410
  let held = new Set();
411
  const piano = document.getElementById('piano');
412
 
413
- // Draw white keys
414
  WHITE_NOTES.forEach((wk, idx) => {
415
  const el = document.createElement('div');
416
  el.className = 'white-key';
@@ -420,12 +456,9 @@ PIANO_HTML = """
420
  piano.appendChild(el);
421
  });
422
 
423
- // Draw black keys
424
  const whiteKeys = piano.querySelectorAll('.white-key');
425
  Object.entries(BLACK_POSITIONS).forEach(([midi, [wIdx, _]]) => {
426
  if (wIdx >= whiteKeys.length) return;
427
- const ref = whiteKeys[wIdx].getBoundingClientRect
428
- ? whiteKeys[wIdx] : null;
429
  const el = document.createElement('div');
430
  el.className = 'black-key';
431
  el.dataset.midi = midi;
@@ -476,10 +509,10 @@ PIANO_HTML = """
476
  </script>
477
  """
478
 
479
- def _generate_wrapper(prompt, notes_json, n_seconds, temperature, cfg_mcc, progress=gr.Progress()):
480
  if not prompt.strip():
481
  prompt = "smooth jazz piano"
482
- return generate(prompt, notes_json, n_seconds, temperature, cfg_mcc, progress)
483
 
484
  with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
485
  gr.HTML("""
@@ -510,6 +543,12 @@ with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
510
  value="smooth jazz piano, warm, relaxed",
511
  lines=2
512
  )
 
 
 
 
 
 
513
  n_seconds = gr.Slider(1, 20, value=5, step=1, label="Duration (seconds)")
514
  temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature (creativity)")
515
  cfg_mcc = gr.Slider(0.0, 6.0, value=1.6, step=0.1, label="Style guidance strength")
@@ -522,13 +561,14 @@ with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
522
  <b style="color:#aaa;">How to use:</b>
523
  Click piano keys to hold notes (click again to release) - they steer the melody.
524
  Type a style prompt, set duration, then hit Generate.
525
- <br>CPU generation: ~1-3 min for 5s audio. No MIDI device needed.
 
526
  </div>
527
  """)
528
 
529
  gen_btn.click(
530
  fn=_generate_wrapper,
531
- inputs=[prompt_in, notes_state, n_seconds, temperature, cfg_mcc],
532
  outputs=[audio_out],
533
  )
534
 
 
22
  COND_LEN = 144 # 12 style + 128 notes + 1 drum + 3 cfg
23
  VOCAB_SIZE = NUM_RESERVED + NUM_CB * CODEBOOK # 12294
24
 
 
 
 
 
25
  # note states
26
  NOTE_MASKED, NOTE_OFF, NOTE_ON = -1, 0, 3
27
  DRUM_MASKED = -1
 
113
  opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
114
  return ort.InferenceSession(path, opts, providers=["CPUExecutionProvider"])
115
 
116
+ # Shared sessions (same for both model sizes)
117
  text_enc_s = _sess(f"{MODEL_PATH}/musiccoca/text_encoder.onnx")
118
  mapper_s = _sess(f"{MODEL_PATH}/musiccoca/mapper.onnx")
119
  vq_s = _sess(f"{MODEL_PATH}/musiccoca/pretrained_vector_quantizer.onnx")
 
 
 
 
120
  dec_s = _sess(f"{MODEL_PATH}/spectrostream/decoder.onnx")
121
  sp = spm_lib.SentencePieceProcessor(model_file=f"{MODEL_PATH}/musiccoca/spm.model")
122
 
123
+ for name, s in [("text_enc", text_enc_s), ("mapper", mapper_s), ("vq", vq_s), ("decoder", dec_s)]:
 
 
 
124
  ins = {i.name: i.shape for i in s.get_inputs()}
125
  outs = {o.name: o.shape for o in s.get_outputs()}
126
  print(f"[{name}] inputs: {ins}")
127
  print(f"[{name}] outputs: {outs}")
128
 
129
+ # ---------- LLM variant loading (lazy, cached) ----------
130
+ _llm_cache: dict = {}
131
+
132
+ def _infer_geometry(temp_sess, depth_sess) -> tuple:
133
+ """Read KV cache geometry from ONNX input shapes - works for any model size."""
134
+ t_in = {i.name: i.shape for i in temp_sess.get_inputs()}
135
+ T_LAYERS = sum(1 for k in t_in if k.startswith("past_self_k."))
136
+ s = t_in["past_self_k.0"] # ['B', window, heads, head_dim]
137
+ T_W, T_H, T_HD = int(s[1]), int(s[2]), int(s[3])
138
+
139
+ d_in = {i.name: i.shape for i in depth_sess.get_inputs()}
140
+ D_LAYERS = sum(1 for k in d_in if k.startswith("past_k."))
141
+ s = d_in["past_k.0"] # ['B', window, heads, head_dim]
142
+ D_W, D_H, D_HD = int(s[1]), int(s[2]), int(s[3])
143
+
144
+ return T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD
145
+
146
+ def _load_llm(size: str) -> dict:
147
+ if size in _llm_cache:
148
+ return _llm_cache[size]
149
+ base = f"{MODEL_PATH}/mrt2_{size}/onnx"
150
+ if not os.path.isdir(base):
151
+ raise FileNotFoundError(f"Model variant '{size}' not found at {base}")
152
+ print(f"Loading LLM sessions for mrt2_{size}...")
153
+ enc = _sess(f"{base}/encoder.onnx")
154
+ temp = _sess(f"{base}/temporal_step.onnx")
155
+ depth = _sess(f"{base}/depth_step.onnx")
156
+ embed = _sess(f"{base}/embed.onnx")
157
+ T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD = _infer_geometry(temp, depth)
158
+ print(f" mrt2_{size}: T_LAYERS={T_LAYERS} T_W={T_W} T_H={T_H} T_HD={T_HD} "
159
+ f"D_LAYERS={D_LAYERS} D_W={D_W} D_H={D_H} D_HD={D_HD}")
160
+ for name, s in [("enc", enc), ("temporal", temp), ("depth", depth), ("embed", embed)]:
161
+ ins = {i.name: i.shape for i in s.get_inputs()}
162
+ outs = {o.name: o.shape for o in s.get_outputs()}
163
+ print(f"[{name}] inputs: {ins}")
164
+ print(f"[{name}] outputs: {outs}")
165
+ result = dict(enc=enc, temp=temp, depth=depth, embed=embed,
166
+ T_LAYERS=T_LAYERS, T_W=T_W, T_H=T_H, T_HD=T_HD,
167
+ D_LAYERS=D_LAYERS, D_W=D_W, D_H=D_H, D_HD=D_HD)
168
+ _llm_cache[size] = result
169
+ return result
170
+
171
+ # Pre-warm small at startup
172
+ _load_llm("small")
173
+
174
+ # Detect available sizes
175
+ _SIZES_AVAILABLE = ["small"]
176
+ if os.path.isdir(f"{MODEL_PATH}/mrt2_large"):
177
+ _SIZES_AVAILABLE.append("large")
178
+ print("Large model variant detected - will be available in UI")
179
+ else:
180
+ print("No mrt2_large directory found - only small available")
181
+
182
  # ---------- helper: discretize CFG ----------
183
  def _disc_cfg(v: float, step: float, max_bin: int) -> int:
184
  c = max(-1.0, min(7.0, v))
 
186
 
187
  # ---------- conditioning vector ----------
188
  def build_cond(style_tokens: list, notes: list, cfg_mcc=CFG_MCC, cfg_notes=CFG_NOTES, cfg_drums=CFG_DRUMS) -> np.ndarray:
189
+ """Build 144-length cond vector, shifted by COND_OFFSET, shape [1,1,144] int64."""
190
  out = [0] * COND_LEN
191
  k = 0
192
  for i in range(NUM_CB):
 
212
  pad_mask = np.ones((1, 128), dtype=np.float32)
213
  pad_mask[0, :len(ids_raw)+1] = 0.0
214
 
 
215
  enc_inputs = text_enc_s.get_inputs()
216
  feed_enc = {}
217
  for inp in enc_inputs:
 
221
  feed_enc[inp.name] = ids
222
  emb = text_enc_s.run(None, feed_enc)[0] # [1, 768]
223
 
 
224
  map_inputs = mapper_s.get_inputs()
225
  feed_map = {}
226
  for inp in map_inputs:
 
230
  feed_map[inp.name] = emb
231
  mapped = mapper_s.run(None, feed_map)[0] # [1, 768]
232
 
 
233
  norm = np.linalg.norm(mapped)
234
  if norm > 1e-8:
235
  mapped = mapped / norm
236
 
237
+ style_tokens = vq_s.run(None, {vq_s.get_inputs()[0].name: mapped})[0]
 
238
  return style_tokens.reshape(-1).tolist()
239
 
240
  # ---------- sampling ----------
 
264
  n_seconds: float,
265
  temperature: float,
266
  cfg_mcc: float,
267
+ model_size: str,
268
  progress=gr.Progress(track_tqdm=True),
269
  ) -> tuple:
270
  import json
 
276
 
277
  n_frames = max(4, int(n_seconds * 25))
278
 
279
+ progress(0.0, desc=f"Loading {model_size} model sessions...")
280
+ m = _load_llm(model_size)
281
+ enc_s = m["enc"]
282
+ temp_s = m["temp"]
283
+ depth_s = m["depth"]
284
+ embed_s = m["embed"]
285
+ T_LAYERS, T_W, T_H, T_HD = m["T_LAYERS"], m["T_W"], m["T_H"], m["T_HD"]
286
+ D_LAYERS, D_W, D_H, D_HD = m["D_LAYERS"], m["D_W"], m["D_H"], m["D_HD"]
287
+
288
+ progress(0.02, desc="Encoding prompt...")
289
  style_tokens = encode_text(prompt)
290
 
 
291
  cond = build_cond(style_tokens, held_notes, cfg_mcc=cfg_mcc)
 
 
292
  enc_out_arr = enc_s.run(None, {enc_s.get_inputs()[0].name: cond})[0]
293
 
 
294
  psk = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
295
  psv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
296
  pck = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
 
298
  prev_codes = np.zeros((1, NUM_CB), dtype=np.int64)
299
  cache_pos = 0
300
 
 
301
  t_out_names = [o.name for o in temp_s.get_outputs()]
302
  d_out_names = [o.name for o in depth_s.get_outputs()]
303
 
304
  all_codec_frames = []
305
 
306
  for f in range(n_frames):
307
+ progress((f + 1) / (n_frames + 1), desc=f"Generating frame {f+1}/{n_frames} [{model_size}]...")
308
 
 
309
  feed_t = {
310
  "prev_codes": prev_codes,
311
  "enc_out": enc_out_arr,
 
318
  feed_t[f"past_cross_v.{i}"] = pcv[i]
319
 
320
  t_out_dict = dict(zip(t_out_names, temp_s.run(None, feed_t)))
321
+ temporal_out = t_out_dict["temporal_out"]
322
  for i in range(T_LAYERS):
323
  psk[i] = t_out_dict[f"present_self_k.{i}"]
324
  psv[i] = t_out_dict[f"present_self_v.{i}"]
325
  pck[i] = t_out_dict[f"present_cross_k.{i}"]
326
  pcv[i] = t_out_dict[f"present_cross_v.{i}"]
327
 
 
328
  dk = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
329
  dv = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
330
+ depth_in = temporal_out
331
  unique_codes = []
332
 
333
  for level in range(NUM_CB):
 
340
  feed_d[f"past_v.{i}"] = dv[i]
341
 
342
  d_out_dict = dict(zip(d_out_names, depth_s.run(None, feed_d)))
343
+ logits = d_out_dict["logits"]
344
  for i in range(D_LAYERS):
345
  dk[i] = d_out_dict[f"present_k.{i}"]
346
  dv[i] = d_out_dict[f"present_v.{i}"]
 
352
 
353
  if level < NUM_CB - 1:
354
  e_out = embed_s.run(None, {"token": np.array([token], dtype=np.int64)})
355
+ depth_in = e_out[0]
356
 
357
  codec_frame = to_codec(unique_codes)
358
  all_codec_frames.append(codec_frame)
 
362
  if len(all_codec_frames) < 2:
363
  return (SAMPLE_RATE, np.zeros((FRAME_SAMPLES * 2, 2), dtype=np.float32))
364
 
 
365
  progress(0.98, desc="Decoding audio...")
366
  codes_arr = np.array(all_codec_frames, dtype=np.int32).reshape(1, len(all_codec_frames), NUM_CB)
367
  audio_raw = dec_s.run(None, {"codes": codes_arr})[0] # [1, (T-1)*1920, 2]
368
+ audio = audio_raw.squeeze(0)
369
 
 
370
  audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
371
  progress(1.0, desc="Done!")
372
  return (SAMPLE_RATE, audio)
 
437
  {midi:79,n:'G5'},{midi:81,n:'A5'},{midi:83,n:'B5'},
438
  {midi:84,n:'C6'}
439
  ];
 
 
440
  const BLACK_POSITIONS = {
441
  49:[0,0], 51:[1,0], 54:[3,0], 56:[4,0], 58:[5,0],
442
  61:[7,0], 63:[8,0], 66:[10,0], 68:[11,0], 70:[12,0],
 
447
  let held = new Set();
448
  const piano = document.getElementById('piano');
449
 
 
450
  WHITE_NOTES.forEach((wk, idx) => {
451
  const el = document.createElement('div');
452
  el.className = 'white-key';
 
456
  piano.appendChild(el);
457
  });
458
 
 
459
  const whiteKeys = piano.querySelectorAll('.white-key');
460
  Object.entries(BLACK_POSITIONS).forEach(([midi, [wIdx, _]]) => {
461
  if (wIdx >= whiteKeys.length) return;
 
 
462
  const el = document.createElement('div');
463
  el.className = 'black-key';
464
  el.dataset.midi = midi;
 
509
  </script>
510
  """
511
 
512
+ def _generate_wrapper(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress=gr.Progress()):
513
  if not prompt.strip():
514
  prompt = "smooth jazz piano"
515
+ return generate(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress)
516
 
517
  with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
518
  gr.HTML("""
 
543
  value="smooth jazz piano, warm, relaxed",
544
  lines=2
545
  )
546
+ model_size_dd = gr.Dropdown(
547
+ choices=_SIZES_AVAILABLE,
548
+ value="small",
549
+ label="Model size (large = slower but higher quality, loads on first use)",
550
+ interactive=len(_SIZES_AVAILABLE) > 1,
551
+ )
552
  n_seconds = gr.Slider(1, 20, value=5, step=1, label="Duration (seconds)")
553
  temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature (creativity)")
554
  cfg_mcc = gr.Slider(0.0, 6.0, value=1.6, step=0.1, label="Style guidance strength")
 
561
  <b style="color:#aaa;">How to use:</b>
562
  Click piano keys to hold notes (click again to release) - they steer the melody.
563
  Type a style prompt, set duration, then hit Generate.
564
+ <br>CPU generation: ~1-3 min for 5s audio (small). No MIDI device needed.
565
+ Large model loads its sessions on first use (extra ~30s), then stays cached.
566
  </div>
567
  """)
568
 
569
  gen_btn.click(
570
  fn=_generate_wrapper,
571
+ inputs=[prompt_in, notes_state, n_seconds, temperature, cfg_mcc, model_size_dd],
572
  outputs=[audio_out],
573
  )
574