Nekochu's picture
Add model size dropdown (small/large), lazy-load LLM sessions, infer KV geometry from ONNX shapes
69afa51
Raw
History Blame Contribute Delete
24.3 kB
import base64
import io
import os
import numpy as np
import onnxruntime as ort
import sentencepiece as spm_lib
import soundfile as sf
import gradio as gr
from huggingface_hub import snapshot_download
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ---------- constants (from blanchon/magenta-realtime-2-demo/src/lib/mrt2/constants.ts) ----------
SAMPLE_RATE = 48000
FRAME_SAMPLES = 1920 # 40ms @ 48kHz
NUM_CB = 12 # RVQ codebooks per frame
CODEBOOK = 1024 # per-codebook vocab
NUM_RESERVED = 6 # reserved token ids before the codebook tokens
COND_OFFSET = 7 # NUM_RESERVED + 1, added to every conditioning integer
COND_LEN = 144 # 12 style + 128 notes + 1 drum + 3 cfg
VOCAB_SIZE = NUM_RESERVED + NUM_CB * CODEBOOK # 12294
# note states
NOTE_MASKED, NOTE_OFF, NOTE_ON = -1, 0, 3
DRUM_MASKED = -1
# CFG defaults
CFG_MCC, CFG_NOTES, CFG_DRUMS = 1.6, 2.4, 4.0
# fixed mapper noise - np.random.RandomState(0).randn(768), little-endian f32, base64
_NOISE_B64 = (
"eMzhP2jhzD6Tjno/y2oPQCQM7z/iLnq//zhzP2L9Gr5oZNO9+DnSPiiAEz6iJbo/XtNCP8Aw"
"+T0LQuM+XdeqPvw9vz8CFVK+aUqgPgWmWr8vZCPAjFMnP7FLXT+H/j2/qUMRQKgour9IbTs9"
"IK0/vhwyxD/zE7w/iqoePoWewT7tRWO/vYr9v4shsr7yGSA+KnqdP5XnmT+zT8a+bceavvw2"
"hr8mw7W/EGfavwKz+T+ReAK/RkvgvplboL+cCUc/NJTOv5fYWb5MPWW/FhjGPiDEAr/1Hpe/"
"a97mvFFO2z4uOog9md2aPu9iIr82ubm+XiYsv1oXuL5bKlC/1PbcvzOvNT47ts2+V6rQv8zx"
"7D61RGi/ssRUPa6lOj8ZFAQ+4teRP8YOnr+5/80+t08vv5DsXr9+LxS/0IOfvqENZj2hI5W/"
"kZxmP09r7j6io8S/DH++P3+s8j9A4pY/Nz44vmwOib9G+IY/NW3OvhR5nD8JRlU+BAV6P6h1"
"tj774TQ/RwgsPGiX5D8+9QE+jdHNPhUL8T9eg6y/QZ+iv2IqeD/oKJa/lMj4P97F074zWT+/"
"9yL2P4KBvT8sDO8/i/JnP0l5XL8CffQ/vTeJvshtTT8bf3I/97oevk40HT+9FWw/2brAPiq5"
"jL+tspg+A8epPzPPMb/MORm+cszevq207D+CGyw/1p7QPjgZRb88DAo/EaEsv8JgAj3PxiK/"
"uyotP3WbEz9FTFW+ZMHKPnHpi7+H4b6/8/fgPnWsKj5skSI/coUYQGjJcT+4rmm/ZPqOP6dv"
"qL/RVOy+QcKLvdBO2z9BqD6/epFTv3qhyb232Sm/mzWQPzI7ir9B4JK/8yngvhz+/r7o+vY/"
"Pg1zPxFOsz0S25y/LChYPw4HgL8Pu8W/XBGYP01Goj5nvWs/RTCjPkBZWz+dqSa/EmKEv/p8"
"Lj9BrE2/VoYwv4476b50MI88sT61vmf+r78txCS/PUwOwCsPID86EM2/b1yNvw2rVT0AVD2/"
"gYHFP1Z8pb/kuog+BecgvRaElb929QU/16kvvhGURT8r0VI/dXIKQFkTqz9nBb2+0R91vqXB"
"jD9dvyc/qd8jP2r4zr+VR8e8mO88v0dSjz4SA8m9fAFpP21qoj7KTEk/fM7uvjvHcb8J8tG+"
"ZW6LvC0gwj6FmBBA1hUtvdC4dL+GJLG+dFztvr2E9j7WOMW/gY+BPUBDID7ewG0+tekYv8Gh"
"c76hR7a/bJT8vvj4Cr+DBNU+yf2Tv5n8Rz9FS78/onoEwJY+2j7YSS0/Ey8jvzZny77ZEQi+"
"DHiYvvM2nr5Lh9a/mn+TP/Ewij+kOFC/y7O7v4JkBT/XZhO/LFwRPgR/o76vCDE/FNsxP8DA"
"Ob8SErG/up3Kv9NBHD+KLJi/t74BvwmoGL/OUFe9BNj3vy1PQT65HQY/pBa1PXksn769ecc9"
"zU/MPilyMcBWW/o/ULrHPkAEJ78KK8i+ucv8PufH7b289gHApyAEQCRj4r0FlYI/LioxvwGo"
"xD+km5I+Md0bP93Khb/PBps/7JcwP+aipj9ZyiC/MEn2vl9zE0CZroe/ZjYLvqiFkT8HJMg9"
"dDwVP5WEzL73d70+Rjynv6A91D/+//G9KSAuvxWYKj934+u++8iqvz1hrL8emzE/OGcjvhDp"
"CL6C84k/1DuQv7INO7/3DsW+aDvBPfm7LL3h4pK+92t8vd7C273+Nzi/TyBQv2iNjD4DE2S/"
"OCSUv8Xkn752cyG+KG4QQD1nNL+JeXE/vEc/P1kvmL/o80U/Z4mXv+EvKsDCNxs/BsHgv+Lg"
"5j5XGy+/KWzUP+rEiD8vIui+IBYwv+Nmm7+cwOG+wYqPvh25ur76diA+/BkUP+kFsz7wnkO/"
"jQm4v/uorj++fzC/t/wmv6psBb+06eu/Arn0vnKV9b7Nzx4/Fs4yPwUhdzudjW4/5A+uPsZ3"
"gLxbyiQ+qzpDvrcpyr5fFIm+rWKQvw6Wjz5ZPX6/JnVXPxJyf75Au0o9Ldj8PkKwJD8wCsm/"
"j95TvmhTYT+IW9m/oEnGPipbEMCB4YK/EjsePT4P1L9vSny/F2W8vxb20j9SKyg+8DkRP/EE"
"ZL4C9bS+oOjOv7Vrlb4n8UK/56BbP6APkj/auLs/2EBaP2JBGb+21Y6/CkREPwNstj54X+K/"
"tgG2Pl+EUD/1W3E9tn49vg3CTr8NKLm/VOBMP0BEnr7iEW++z8ndP3c7Lz/G3L0+pngRPjGP"
"wj+BG9w/DPRtP6wMFT/6DQbA6mH9PcI6Bb6NasA9eGtxP99WL8BvvhG/5zCKPmEG774uXbW/"
"ZHRePyjCjT5Pmni/uC+hPnFTUj/vba070fFMP99GoD10W8q+5GeUv3j8r7269EY+kzZgP3e9"
"671hMuo+0PB2v2JaSL/JE+K9Ef6Gv8P7UT9rH+0+pOWOPteErT7HWAFA+A7wvmrkDMBaFUw+"
"qUVPvSF8BL+YlHq/c93gvsiwOT6YuAC/pGUaQJ3jdb+9CUu/wHgSwJHCgD7ODAHAsxkKv7Ak"
"jb67sDW/YZPeP6GQfj962ag/M+Zhv8V1kD/W8/0+3HpFP6fEgz+1pGi/KUDZvhjTXD+q9SnA"
"vLTBPxCaDT8UNDu9wsxhPuvUg7/HK7O+HtaMP5Ylpj/vjixA0WWXve2WKL/WpAO/Mk+Cv1By"
"n72B9cM+okEMvRhVjD9E1m++DeWxvgPOFL8r+tC/nazIv6bulr8ylaY/xy9lP9P+rz/phaq/"
"5fv7v3P5KL/iCTQ+VVT/PvQjhj8bjZE+xQ/fP77yY76Pv2m/KTLXv6CTY7/F7Xc+LINjv1vO"
"bz8nx7Q/UKcXwIgyXT+sVQ/ASZHNPo/InD9H04Q928yjv9LeFb9k9oW+YJ42vjDET745CuG9"
"hJpaPouymr8M1He+YlbCPz/wxL58PuO+XwKKP67JI8BqN5c/csQhv+TcJz4iRMU9l0VxPy4C"
"ib4Zky2/0B+mP6BOF8Dfk6Y80oisv3n2Qr9uuABAsqk2vVrARz5ACuS/rKI6v1hGST7NorU+"
"R+wdPwhcDTy/6QY/GlboPu806r9Qkxc9QZVEP10CFz+0S7q+ij1Ov9gkj78GMwa+wwiRP7jU"
"+b+q7ii/DOWRv/rySD885w2/a/fwvgcoXr6WCuQ+NufIvgL0QsB9Fgs/PcrgPl3PYL62wYq/"
"hhy0Pikrwj4mqPC+2+5dvr0ebr8P4Da+eHTGv9Cq1T4iwnG/UNFzPpj2s78FDhe/RUjivdCR"
"1L+m0us9oR/CvocF37+p0Ka/JukaPyhDZT8PEwe+8TzPPj83ZT5YxKg+IJukP1PlwL+ILC0/"
"rpbDviKkZb56wJq+SBPAvv/znL9FvTs+duHVP73rZb1TirW61PIvv3+W8L1ere4+a5C9vgFZ"
"6L6xeM4+XAJrvz1HgT6cAFI/yxKuPzQaub1tDa8/i2eEP8sHf79p5Zu/MiScvim0gz82C5S9"
"ssQZv+ivxj8l5ZI+noQUwOFioj5iIQU/9QVnPqpA5j7Lx4m9MsGov+rMvb7gE3K/HMhuvzms"
"ob+mrOc+xn3IPe515b4DOya/0OG/vA4jij8SRQDA9vXAPsizC78cOvG/zAz5vy6sab8dx2A+"
"iz/JPhlhcL++LYI/UyS2P9zLyj4qZhe/+OyPP51hQT9qDl4/AQ4ov1dpNcCBeQdAQzHOv4uB"
"Er0iXhhAW0GpPtEBcz+ITsC/l4rjvzZfCL+wnYs/nEexvkltS7/wt0o+2nyKP83zuL8T85q/"
"OuZJvxwdjD8OdXA+NHUIQOi6bz/2vw+9Eu6hP6ySWD66dTS/1RIuP3dCMr/urpS+yfSpP6ts"
"z72tmk2/q73tvgnKgj9Ocw2/8BPGvoyiAr/3Vjw+6l7FvvcIzb9KHmO/Q8tuvxclnz9oC1A/"
"oVYWPypfAb+311C/rOwBvwKkhr8i0h9AWrMPwN1iED82bKS/CrLVvbLtfL+MvJa/9PGRv2Oj"
"4D8eLgi+DwVEvw5IDj8skCk8IlQ4Pz6B6b/5cZs+VM9FP0Gv1L/aeeU+ehzZP7ptc7ypR1I/"
"gaorPxgfNb9y4iI9SJPIvzER575BCIg+HR05P16fyTzbUDg/CCyNv6lG0L3N7508"
)
def _decode_mapper_noise() -> np.ndarray:
raw = base64.b64decode(_NOISE_B64)
arr = np.frombuffer(raw, dtype="<f4")
assert arr.size == 768, f"noise size mismatch: {arr.size}"
return arr.reshape(1, 768).astype(np.float32)
MAPPER_NOISE = _decode_mapper_noise()
# ---------- model loading ----------
MODEL_ID = "blanchon/magenta-realtime-2-onnx"
CACHE_DIR = "/tmp/mrt2_onnx"
print("Downloading model (first run ~2.7 GB)...")
MODEL_PATH = snapshot_download(MODEL_ID, cache_dir=CACHE_DIR)
print(f"Model path: {MODEL_PATH}")
def _sess(path: str) -> ort.InferenceSession:
opts = ort.SessionOptions()
opts.intra_op_num_threads = 2
opts.inter_op_num_threads = 2
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
return ort.InferenceSession(path, opts, providers=["CPUExecutionProvider"])
# Shared sessions (same for both model sizes)
text_enc_s = _sess(f"{MODEL_PATH}/musiccoca/text_encoder.onnx")
mapper_s = _sess(f"{MODEL_PATH}/musiccoca/mapper.onnx")
vq_s = _sess(f"{MODEL_PATH}/musiccoca/pretrained_vector_quantizer.onnx")
dec_s = _sess(f"{MODEL_PATH}/spectrostream/decoder.onnx")
sp = spm_lib.SentencePieceProcessor(model_file=f"{MODEL_PATH}/musiccoca/spm.model")
for name, s in [("text_enc", text_enc_s), ("mapper", mapper_s), ("vq", vq_s), ("decoder", dec_s)]:
ins = {i.name: i.shape for i in s.get_inputs()}
outs = {o.name: o.shape for o in s.get_outputs()}
print(f"[{name}] inputs: {ins}")
print(f"[{name}] outputs: {outs}")
# ---------- LLM variant loading (lazy, cached) ----------
_llm_cache: dict = {}
def _infer_geometry(temp_sess, depth_sess) -> tuple:
"""Read KV cache geometry from ONNX input shapes - works for any model size."""
t_in = {i.name: i.shape for i in temp_sess.get_inputs()}
T_LAYERS = sum(1 for k in t_in if k.startswith("past_self_k."))
s = t_in["past_self_k.0"] # ['B', window, heads, head_dim]
T_W, T_H, T_HD = int(s[1]), int(s[2]), int(s[3])
d_in = {i.name: i.shape for i in depth_sess.get_inputs()}
D_LAYERS = sum(1 for k in d_in if k.startswith("past_k."))
s = d_in["past_k.0"] # ['B', window, heads, head_dim]
D_W, D_H, D_HD = int(s[1]), int(s[2]), int(s[3])
return T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD
def _load_llm(size: str) -> dict:
if size in _llm_cache:
return _llm_cache[size]
base = f"{MODEL_PATH}/mrt2_{size}/onnx"
if not os.path.isdir(base):
raise FileNotFoundError(f"Model variant '{size}' not found at {base}")
print(f"Loading LLM sessions for mrt2_{size}...")
enc = _sess(f"{base}/encoder.onnx")
temp = _sess(f"{base}/temporal_step.onnx")
depth = _sess(f"{base}/depth_step.onnx")
embed = _sess(f"{base}/embed.onnx")
T_LAYERS, T_W, T_H, T_HD, D_LAYERS, D_W, D_H, D_HD = _infer_geometry(temp, depth)
print(f" mrt2_{size}: T_LAYERS={T_LAYERS} T_W={T_W} T_H={T_H} T_HD={T_HD} "
f"D_LAYERS={D_LAYERS} D_W={D_W} D_H={D_H} D_HD={D_HD}")
for name, s in [("enc", enc), ("temporal", temp), ("depth", depth), ("embed", embed)]:
ins = {i.name: i.shape for i in s.get_inputs()}
outs = {o.name: o.shape for o in s.get_outputs()}
print(f"[{name}] inputs: {ins}")
print(f"[{name}] outputs: {outs}")
result = dict(enc=enc, temp=temp, depth=depth, embed=embed,
T_LAYERS=T_LAYERS, T_W=T_W, T_H=T_H, T_HD=T_HD,
D_LAYERS=D_LAYERS, D_W=D_W, D_H=D_H, D_HD=D_HD)
_llm_cache[size] = result
return result
# Pre-warm small at startup
_load_llm("small")
# Detect available sizes
_SIZES_AVAILABLE = ["small"]
if os.path.isdir(f"{MODEL_PATH}/mrt2_large"):
_SIZES_AVAILABLE.append("large")
print("Large model variant detected - will be available in UI")
else:
print("No mrt2_large directory found - only small available")
# ---------- helper: discretize CFG ----------
def _disc_cfg(v: float, step: float, max_bin: int) -> int:
c = max(-1.0, min(7.0, v))
return max(0, min(max_bin, round((c + 1.0) / step)))
# ---------- conditioning vector ----------
def build_cond(style_tokens: list, notes: list, cfg_mcc=CFG_MCC, cfg_notes=CFG_NOTES, cfg_drums=CFG_DRUMS) -> np.ndarray:
"""Build 144-length cond vector, shifted by COND_OFFSET, shape [1,1,144] int64."""
out = [0] * COND_LEN
k = 0
for i in range(NUM_CB):
out[k] = style_tokens[i] + COND_OFFSET if i < len(style_tokens) else NOTE_MASKED + COND_OFFSET
k += 1
for i in range(128):
state = NOTE_ON if i in notes else NOTE_MASKED
out[k] = state + COND_OFFSET
k += 1
out[k] = DRUM_MASKED + COND_OFFSET; k += 1 # drums
out[k] = _disc_cfg(cfg_mcc, 0.2, 40) + COND_OFFSET; k += 1
out[k] = _disc_cfg(cfg_notes, 0.2, 40) + COND_OFFSET; k += 1
out[k] = _disc_cfg(cfg_drums, 1.0, 8) + COND_OFFSET
return np.array(out, dtype=np.int64).reshape(1, 1, COND_LEN)
# ---------- MusicCoCa pipeline ----------
def encode_text(prompt: str) -> list:
"""Text -> list of 12 int style tokens."""
ids_raw = sp.encode(prompt.lower(), out_type=int)[:127]
ids = np.zeros((1, 128), dtype=np.int32)
ids[0, 0] = 1 # BOS
ids[0, 1:len(ids_raw)+1] = ids_raw
pad_mask = np.ones((1, 128), dtype=np.float32)
pad_mask[0, :len(ids_raw)+1] = 0.0
enc_inputs = text_enc_s.get_inputs()
feed_enc = {}
for inp in enc_inputs:
if "padding" in inp.name.lower():
feed_enc[inp.name] = pad_mask
else:
feed_enc[inp.name] = ids
emb = text_enc_s.run(None, feed_enc)[0] # [1, 768]
map_inputs = mapper_s.get_inputs()
feed_map = {}
for inp in map_inputs:
if "1" in inp.name:
feed_map[inp.name] = MAPPER_NOISE
else:
feed_map[inp.name] = emb
mapped = mapper_s.run(None, feed_map)[0] # [1, 768]
norm = np.linalg.norm(mapped)
if norm > 1e-8:
mapped = mapped / norm
style_tokens = vq_s.run(None, {vq_s.get_inputs()[0].name: mapped})[0]
return style_tokens.reshape(-1).tolist()
# ---------- sampling ----------
def _topk_sample(logits: np.ndarray, lo: int, hi: int, temperature: float, top_k: int = 20) -> int:
"""Sample from logits restricted to codebook slice [lo, hi) with top-k + temperature.
Handles shapes [B, T, vocab] or [B, vocab] - always takes last time step."""
v = logits.reshape(-1, logits.shape[-1])[-1] # [vocab_size]
sliced = v[lo:hi].copy()
if top_k > 0 and top_k < len(sliced):
threshold = np.partition(sliced, -top_k)[-top_k]
sliced[sliced < threshold] = -1e9
sliced = sliced / max(temperature, 1e-6)
sliced -= sliced.max()
probs = np.exp(sliced)
probs /= probs.sum()
return lo + int(np.random.choice(len(probs), p=probs))
# ---------- unique-scheme -> codec codes ----------
def to_codec(unique_codes: list) -> list:
"""Convert unique-scheme token ids to per-codebook SpectroStream codes (0-1023)."""
return [((t - NUM_RESERVED) % CODEBOOK + CODEBOOK) % CODEBOOK for t in unique_codes]
# ---------- main generation ----------
def generate(
prompt: str,
notes_json: str,
n_seconds: float,
temperature: float,
cfg_mcc: float,
model_size: str,
progress=gr.Progress(track_tqdm=True),
) -> tuple:
import json
held_notes = []
try:
held_notes = [int(x) for x in json.loads(notes_json or "[]")]
except Exception:
pass
n_frames = max(4, int(n_seconds * 25))
progress(0.0, desc=f"Loading {model_size} model sessions...")
m = _load_llm(model_size)
enc_s = m["enc"]
temp_s = m["temp"]
depth_s = m["depth"]
embed_s = m["embed"]
T_LAYERS, T_W, T_H, T_HD = m["T_LAYERS"], m["T_W"], m["T_H"], m["T_HD"]
D_LAYERS, D_W, D_H, D_HD = m["D_LAYERS"], m["D_W"], m["D_H"], m["D_HD"]
progress(0.02, desc="Encoding prompt...")
style_tokens = encode_text(prompt)
cond = build_cond(style_tokens, held_notes, cfg_mcc=cfg_mcc)
enc_out_arr = enc_s.run(None, {enc_s.get_inputs()[0].name: cond})[0]
psk = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
psv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
pck = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
pcv = [np.zeros((1, T_W, T_H, T_HD), np.float32) for _ in range(T_LAYERS)]
prev_codes = np.zeros((1, NUM_CB), dtype=np.int64)
cache_pos = 0
t_out_names = [o.name for o in temp_s.get_outputs()]
d_out_names = [o.name for o in depth_s.get_outputs()]
all_codec_frames = []
for f in range(n_frames):
progress((f + 1) / (n_frames + 1), desc=f"Generating frame {f+1}/{n_frames} [{model_size}]...")
feed_t = {
"prev_codes": prev_codes,
"enc_out": enc_out_arr,
"cache_pos": np.array([cache_pos], dtype=np.int64),
}
for i in range(T_LAYERS):
feed_t[f"past_self_k.{i}"] = psk[i]
feed_t[f"past_self_v.{i}"] = psv[i]
feed_t[f"past_cross_k.{i}"] = pck[i]
feed_t[f"past_cross_v.{i}"] = pcv[i]
t_out_dict = dict(zip(t_out_names, temp_s.run(None, feed_t)))
temporal_out = t_out_dict["temporal_out"]
for i in range(T_LAYERS):
psk[i] = t_out_dict[f"present_self_k.{i}"]
psv[i] = t_out_dict[f"present_self_v.{i}"]
pck[i] = t_out_dict[f"present_cross_k.{i}"]
pcv[i] = t_out_dict[f"present_cross_v.{i}"]
dk = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
dv = [np.zeros((1, D_W, D_H, D_HD), np.float32) for _ in range(D_LAYERS)]
depth_in = temporal_out
unique_codes = []
for level in range(NUM_CB):
feed_d = {
"depth_in": depth_in,
"level": np.array([level], dtype=np.int64),
}
for i in range(D_LAYERS):
feed_d[f"past_k.{i}"] = dk[i]
feed_d[f"past_v.{i}"] = dv[i]
d_out_dict = dict(zip(d_out_names, depth_s.run(None, feed_d)))
logits = d_out_dict["logits"]
for i in range(D_LAYERS):
dk[i] = d_out_dict[f"present_k.{i}"]
dv[i] = d_out_dict[f"present_v.{i}"]
lo = NUM_RESERVED + level * CODEBOOK
hi = lo + CODEBOOK
token = _topk_sample(logits, lo, hi, temperature)
unique_codes.append(token)
if level < NUM_CB - 1:
e_out = embed_s.run(None, {"token": np.array([token], dtype=np.int64)})
depth_in = e_out[0]
codec_frame = to_codec(unique_codes)
all_codec_frames.append(codec_frame)
prev_codes = np.array([unique_codes], dtype=np.int64)
cache_pos += 1
if len(all_codec_frames) < 2:
return (SAMPLE_RATE, np.zeros((FRAME_SAMPLES * 2, 2), dtype=np.float32))
progress(0.98, desc="Decoding audio...")
codes_arr = np.array(all_codec_frames, dtype=np.int32).reshape(1, len(all_codec_frames), NUM_CB)
audio_raw = dec_s.run(None, {"codes": codes_arr})[0] # [1, (T-1)*1920, 2]
audio = audio_raw.squeeze(0)
audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
progress(1.0, desc="Done!")
return (SAMPLE_RATE, audio)
# ---------- Gradio UI ----------
PIANO_HTML = """
<style>
#piano-wrap {
display: flex; flex-direction: column; align-items: center;
background: #111; border-radius: 12px; padding: 20px; gap: 12px;
font-family: 'Segoe UI', sans-serif; color: #ddd;
}
#piano-label { font-size: 14px; color: #aaa; }
#piano {
display: flex; position: relative;
height: 120px; gap: 2px;
}
.white-key {
width: 34px; height: 120px;
background: linear-gradient(180deg, #e8e8e8 0%, #fff 100%);
border: 1px solid #555; border-radius: 0 0 6px 6px;
cursor: pointer; position: relative; flex-shrink: 0;
transition: background 0.08s;
box-shadow: 0 4px 6px rgba(0,0,0,0.5);
}
.white-key.active {
background: linear-gradient(180deg, #a0c4ff 0%, #7eb8ff 100%);
box-shadow: 0 2px 4px rgba(0,0,0,0.5);
}
.white-key .note-name {
position: absolute; bottom: 6px; left: 50%; transform: translateX(-50%);
font-size: 9px; color: #888; user-select: none;
}
.black-key {
width: 22px; height: 75px;
background: linear-gradient(180deg, #222 0%, #000 100%);
border: 1px solid #000; border-radius: 0 0 4px 4px;
cursor: pointer; position: absolute; z-index: 2;
transition: background 0.08s;
box-shadow: 0 4px 6px rgba(0,0,0,0.8);
}
.black-key.active {
background: linear-gradient(180deg, #4a90e2 0%, #2d6abf 100%);
}
#held-display {
font-size: 12px; color: #7eb8ff; min-height: 18px;
}
#clear-btn {
padding: 5px 14px; background: #333; color: #ccc; border: 1px solid #555;
border-radius: 6px; cursor: pointer; font-size: 12px;
}
#clear-btn:hover { background: #444; }
</style>
<div id="piano-wrap">
<div id="piano-label">Click keys to hold notes - click again to release</div>
<div id="piano" tabindex="0"></div>
<div id="held-display">No notes held</div>
<button id="clear-btn" onclick="clearNotes()">Clear Notes</button>
</div>
<script>
(function() {
const WHITE_NOTES = [
{midi:48,n:'C3'},{midi:50,n:'D3'},{midi:52,n:'E3'},{midi:53,n:'F3'},
{midi:55,n:'G3'},{midi:57,n:'A3'},{midi:59,n:'B3'},
{midi:60,n:'C4'},{midi:62,n:'D4'},{midi:64,n:'E4'},{midi:65,n:'F4'},
{midi:67,n:'G4'},{midi:69,n:'A4'},{midi:71,n:'B4'},
{midi:72,n:'C5'},{midi:74,n:'D5'},{midi:76,n:'E5'},{midi:77,n:'F5'},
{midi:79,n:'G5'},{midi:81,n:'A5'},{midi:83,n:'B5'},
{midi:84,n:'C6'}
];
const BLACK_POSITIONS = {
49:[0,0], 51:[1,0], 54:[3,0], 56:[4,0], 58:[5,0],
61:[7,0], 63:[8,0], 66:[10,0], 68:[11,0], 70:[12,0],
73:[14,0], 75:[15,0], 78:[17,0], 80:[18,0], 82:[19,0]
};
const KEY_WIDTH = 36;
const KEY_STEP = 38; // KEY_WIDTH + gap(2)
let held = new Set();
const piano = document.getElementById('piano');
WHITE_NOTES.forEach((wk, idx) => {
const el = document.createElement('div');
el.className = 'white-key';
el.dataset.midi = wk.midi;
el.innerHTML = `<span class="note-name">${wk.n}</span>`;
el.addEventListener('mousedown', (e) => { e.preventDefault(); toggleNote(wk.midi); });
piano.appendChild(el);
});
const whiteKeys = piano.querySelectorAll('.white-key');
Object.entries(BLACK_POSITIONS).forEach(([midi, [wIdx, _]]) => {
if (wIdx >= whiteKeys.length) return;
const el = document.createElement('div');
el.className = 'black-key';
el.dataset.midi = midi;
el.style.left = (wIdx * KEY_STEP + KEY_WIDTH * 0.65) + 'px';
el.addEventListener('mousedown', (e) => { e.preventDefault(); toggleNote(parseInt(midi)); });
piano.appendChild(el);
});
function toggleNote(midi) {
if (held.has(midi)) { held.delete(midi); } else { held.add(midi); }
updateUI();
pushNotes();
}
window.clearNotes = function() {
held.clear(); updateUI(); pushNotes();
};
function updateUI() {
piano.querySelectorAll('.white-key,.black-key').forEach(el => {
const m = parseInt(el.dataset.midi);
el.classList.toggle('active', held.has(m));
});
const disp = document.getElementById('held-display');
if (held.size === 0) {
disp.textContent = 'No notes held';
} else {
const names = [...held].sort((a,b)=>a-b).map(midiName).join(', ');
disp.textContent = 'Held: ' + names;
}
}
function midiName(m) {
const N = ['C','C#','D','D#','E','F','F#','G','G#','A','A#','B'];
return N[m % 12] + Math.floor(m/12 - 1);
}
function pushNotes() {
const json = JSON.stringify([...held]);
const tb = document.querySelector('#notes-state textarea');
if (tb) {
tb.value = json;
tb.dispatchEvent(new Event('input', { bubbles: true }));
tb.dispatchEvent(new Event('change', { bubbles: true }));
}
}
})();
</script>
"""
def _generate_wrapper(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress=gr.Progress()):
if not prompt.strip():
prompt = "smooth jazz piano"
return generate(prompt, notes_json, n_seconds, temperature, cfg_mcc, model_size, progress)
with gr.Blocks(title="Magenta RT2 - Piano (CPU)") as demo:
gr.HTML("""
<div style="text-align:center; padding: 16px 0 8px 0; background:#0a0a0a;">
<h1 style="color:#7eb8ff; font-size:1.8em; margin:0;">
Magenta RealTime 2 - Piano (CPU)
</h1>
<p style="color:#888; margin:6px 0 0 0; font-size:0.9em;">
Real-time music generation steered by text + piano notes - running on CPU via ONNX
(first generation takes a few minutes)
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
gr.HTML(PIANO_HTML)
notes_state = gr.Textbox(
value="[]",
elem_id="notes-state",
visible=False,
label="held notes json"
)
with gr.Column(scale=1):
prompt_in = gr.Textbox(
label="Music Prompt",
value="smooth jazz piano, warm, relaxed",
lines=2
)
model_size_dd = gr.Dropdown(
choices=_SIZES_AVAILABLE,
value="small",
label="Model size (large = slower but higher quality, loads on first use)",
interactive=len(_SIZES_AVAILABLE) > 1,
)
n_seconds = gr.Slider(1, 20, value=5, step=1, label="Duration (seconds)")
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature (creativity)")
cfg_mcc = gr.Slider(0.0, 6.0, value=1.6, step=0.1, label="Style guidance strength")
gen_btn = gr.Button("Generate Music", variant="primary")
audio_out = gr.Audio(label="Generated Audio", autoplay=False)
gr.HTML("""
<div style="background:#111; border-radius:8px; padding:12px; margin-top:8px; color:#888; font-size:0.82em;">
<b style="color:#aaa;">How to use:</b>
Click piano keys to hold notes (click again to release) - they steer the melody.
Type a style prompt, set duration, then hit Generate.
<br>CPU generation: ~1-3 min for 5s audio (small). No MIDI device needed.
Large model loads its sessions on first use (extra ~30s), then stays cached.
</div>
""")
gen_btn.click(
fn=_generate_wrapper,
inputs=[prompt_in, notes_state, n_seconds, temperature, cfg_mcc, model_size_dd],
outputs=[audio_out],
)
if __name__ == "__main__":
demo.launch()