Spaces:
Sleeping
Sleeping
add logic test script
Browse files- test_logic.py +140 -0
test_logic.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test conditioning logic, noise decode, sampling - no model download needed."""
|
| 2 |
+
import base64
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
# --- noise decode test ---
|
| 7 |
+
_NOISE_B64 = (
|
| 8 |
+
"eMzhP2jhzD6Tjno/y2oPQCQM7z/iLnq//zhzP2L9Gr5oZNO9+DnSPiiAEz6iJbo/XtNCP8Aw"
|
| 9 |
+
"+T0LQuM+XdeqPvw9vz8CFVK+aUqgPgWmWr8vZCPAjFMnP7FLXT+H/j2/qUMRQKgour9IbTs9"
|
| 10 |
+
"IK0/vhwyxD/zE7w/iqoePoWewT7tRWO/vYr9v4shsr7yGSA+KnqdP5XnmT+zT8a+bceavvw2"
|
| 11 |
+
"hr8mw7W/EGfavwKz+T+ReAK/RkvgvplboL+cCUc/NJTOv5fYWb5MPWW/FhjGPiDEAr/1Hpe/"
|
| 12 |
+
"a97mvFFO2z4uOog9md2aPu9iIr82ubm+XiYsv1oXuL5bKlC/1PbcvzOvNT47ts2+V6rQv8zx"
|
| 13 |
+
"7D61RGi/ssRUPa6lOj8ZFAQ+4teRP8YOnr+5/80+t08vv5DsXr9+LxS/0IOfvqENZj2hI5W/"
|
| 14 |
+
"kZxmP09r7j6io8S/DH++P3+s8j9A4pY/Nz44vmwOib9G+IY/NW3OvhR5nD8JRlU+BAV6P6h1"
|
| 15 |
+
"tj774TQ/RwgsPGiX5D8+9QE+jdHNPhUL8T9eg6y/QZ+iv2IqeD/oKJa/lMj4P97F074zWT+/"
|
| 16 |
+
"9yL2P4KBvT8sDO8/i/JnP0l5XL8CffQ/vTeJvshtTT8bf3I/97oevk40HT+9FWw/2brAPiq5"
|
| 17 |
+
"jL+tspg+A8epPzPPMb/MORm+cszevq207D+CGyw/1p7QPjgZRb88DAo/EaEsv8JgAj3PxiK/"
|
| 18 |
+
"uyotP3WbEz9FTFW+ZMHKPnHpi7+H4b6/8/fgPnWsKj5skSI/coUYQGjJcT+4rmm/ZPqOP6dv"
|
| 19 |
+
"qL/RVOy+QcKLvdBO2z9BqD6/epFTv3qhyb232Sm/mzWQPzI7ir9B4JK/8yngvhz+/r7o+vY/"
|
| 20 |
+
"Pg1zPxFOsz0S25y/LChYPw4HgL8Pu8W/XBGYP01Goj5nvWs/RTCjPkBZWz+dqSa/EmKEv/p8"
|
| 21 |
+
"Lj9BrE2/VoYwv4476b50MI88sT61vmf+r78txCS/PUwOwCsPID86EM2/b1yNvw2rVT0AVD2/"
|
| 22 |
+
"gYHFP1Z8pb/kuog+BecgvRaElb929QU/16kvvhGURT8r0VI/dXIKQFkTqz9nBb2+0R91vqXB"
|
| 23 |
+
"jD9dvyc/qd8jP2r4zr+VR8e8mO88v0dSjz4SA8m9fAFpP21qoj7KTEk/fM7uvjvHcb8J8tG+"
|
| 24 |
+
"ZW6LvC0gwj6FmBBA1hUtvdC4dL+GJLG+dFztvr2E9j7WOMW/gY+BPUBDID7ewG0+tekYv8Gh"
|
| 25 |
+
"c76hR7a/bJT8vvj4Cr+DBNU+yf2Tv5n8Rz9FS78/onoEwJY+2j7YSS0/Ey8jvzZny77ZEQi+"
|
| 26 |
+
"DHiYvvM2nr5Lh9a/mn+TP/Ewij+kOFC/y7O7v4JkBT/XZhO/LFwRPgR/o76vCDE/FNsxP8DA"
|
| 27 |
+
"Ob8SErG/up3Kv9NBHD+KLJi/t74BvwmoGL/OUFe9BNj3vy1PQT65HQY/pBa1PXksn769ecc9"
|
| 28 |
+
"zU/MPilyMcBWW/o/ULrHPkAEJ78KK8i+ucv8PufH7b289gHApyAEQCRj4r0FlYI/LioxvwGo"
|
| 29 |
+
"xD+km5I+Md0bP93Khb/PBps/7JcwP+aipj9ZyiC/MEn2vl9zE0CZroe/ZjYLvqiFkT8HJMg9"
|
| 30 |
+
"dDwVP5WEzL73d70+Rjynv6A91D/+//G9KSAuvxWYKj934+u++8iqvz1hrL8emzE/OGcjvhDp"
|
| 31 |
+
"CL6C84k/1DuQv7INO7/3DsW+aDvBPfm7LL3h4pK+92t8vd7C273+Nzi/TyBQv2iNjD4DE2S/"
|
| 32 |
+
"OCSUv8Xkn752cyG+KG4QQD1nNL+JeXE/vEc/P1kvmL/o80U/Z4mXv+EvKsDCNxs/BsHgv+Lg"
|
| 33 |
+
"5j5XGy+/KWzUP+rEiD8vIui+IBYwv+Nmm7+cwOG+wYqPvh25ur76diA+/BkUP+kFsz7wnkO/"
|
| 34 |
+
"jQm4v/uorj++fzC/t/wmv6psBb+06eu/Arn0vnKV9b7Nzx4/Fs4yPwUhdzudjW4/5A+uPsZ3"
|
| 35 |
+
"gLxbyiQ+qzpDvrcpyr5fFIm+rWKQvw6Wjz5ZPX6/JnVXPxJyf75Au0o9Ldj8PkKwJD8wCsm/"
|
| 36 |
+
"j95TvmhTYT+IW9m/oEnGPipbEMCB4YK/EjsePT4P1L9vSny/F2W8vxb20j9SKyg+8DkRP/EE"
|
| 37 |
+
"ZL4C9bS+oOjOv7Vrlb4n8UK/56BbP6APkj/auLs/2EBaP2JBGb+21Y6/CkREPwNstj54X+K/"
|
| 38 |
+
"tgG2Pl+EUD/1W3E9tn49vg3CTr8NKLm/VOBMP0BEnr7iEW++z8ndP3c7Lz/G3L0+pngRPjGP"
|
| 39 |
+
"wj+BG9w/DPRtP6wMFT/6DQbA6mH9PcI6Bb6NasA9eGtxP99WL8BvvhG/5zCKPmEG774uXbW/"
|
| 40 |
+
"ZHRePyjCjT5Pmni/uC+hPnFTUj/vba070fFMP99GoD10W8q+5GeUv3j8r7269EY+kzZgP3e9"
|
| 41 |
+
"671hMuo+0PB2v2JaSL/JE+K9Ef6Gv8P7UT9rH+0+pOWOPteErT7HWAFA+A7wvmrkDMBaFUw+"
|
| 42 |
+
"qUVPvSF8BL+YlHq/c93gvsiwOT6YuAC/pGUaQJ3jdb+9CUu/wHgSwJHCgD7ODAHAsxkKv7Ak"
|
| 43 |
+
"jb67sDW/YZPeP6GQfj962ag/M+Zhv8V1kD/W8/0+3HpFP6fEgz+1pGi/KUDZvhjTXD+q9SnA"
|
| 44 |
+
"vLTBPxCaDT8UNDu9wsxhPuvUg7/HK7O+HtaMP5Ylpj/vjixA0WWXve2WKL/WpAO/Mk+Cv1By"
|
| 45 |
+
"n72B9cM+okEMvRhVjD9E1m++DeWxvgPOFL8r+tC/nazIv6bulr8ylaY/xy9lP9P+rz/phaq/"
|
| 46 |
+
"5fv7v3P5KL/iCTQ+VVT/PvQjhj8bjZE+xQ/fP77yY76Pv2m/KTLXv6CTY7/F7Xc+LINjv1vO"
|
| 47 |
+
"bz8nx7Q/UKcXwIgyXT+sVQ/ASZHNPo/InD9H04Q928yjv9LeFb9k9oW+YJ42vjDET745CuG9"
|
| 48 |
+
"hJpaPouymr8M1He+YlbCPz/wxL58PuO+XwKKP67JI8BqN5c/csQhv+TcJz4iRMU9l0VxPy4C"
|
| 49 |
+
"ib4Zky2/0B+mP6BOF8Dfk6Y80oisv3n2Qr9uuABAsqk2vVrARz5ACuS/rKI6v1hGST7NorU+"
|
| 50 |
+
"R+wdPwhcDTy/6QY/GlboPu806r9Qkxc9QZVEP10CFz+0S7q+ij1Ov9gkj78GMwa+wwiRP7jU"
|
| 51 |
+
"+b+q7ii/DOWRv/rySD885w2/a/fwvgcoXr6WCuQ+NufIvgL0QsB9Fgs/PcrgPl3PYL62wYq/"
|
| 52 |
+
"hhy0Pikrwj4mqPC+2+5dvr0ebr8P4Da+eHTGv9Cq1T4iwnG/UNFzPpj2s78FDhe/RUjivdCR"
|
| 53 |
+
"1L+m0us9oR/CvocF37+p0Ka/JukaPyhDZT8PEwe+8TzPPj83ZT5YxKg+IJukP1PlwL+ILC0/"
|
| 54 |
+
"rpbDviKkZb56wJq+SBPAvv/znL9FvTs+duHVP73rZb1TirW61PIvv3+W8L1ere4+a5C9vgFZ"
|
| 55 |
+
"6L6xeM4+XAJrvz1HgT6cAFI/yxKuPzQaub1tDa8/i2eEP8sHf79p5Zu/MiScvim0gz82C5S9"
|
| 56 |
+
"ssQZv+ivxj8l5ZI+noQUwOFioj5iIQU/9QVnPqpA5j7Lx4m9MsGov+rMvb7gE3K/HMhuvzms"
|
| 57 |
+
"ob+mrOc+xn3IPe515b4DOya/0OG/vA4jij8SRQDA9vXAPsizC78cOvG/zAz5vy6sab8dx2A+"
|
| 58 |
+
"iz/JPhlhcL++LYI/UyS2P9zLyj4qZhe/+OyPP51hQT9qDl4/AQ4ov1dpNcCBeQdAQzHOv4uB"
|
| 59 |
+
"Er0iXhhAW0GpPtEBcz+ITsC/l4rjvzZfCL+wnYs/nEexvkltS7/wt0o+2nyKP83zuL8T85q/"
|
| 60 |
+
"OuZJvxwdjD8OdXA+NHUIQOi6bz/2vw+9Eu6hP6ySWD66dTS/1RIuP3dCMr/urpS+yfSpP6ts"
|
| 61 |
+
"z72tmk2/q73tvgnKgj9Ocw2/8BPGvoyiAr/3Vjw+6l7FvvcIzb9KHmO/Q8tuvxclnz9oC1A/"
|
| 62 |
+
"oVYWPypfAb+311C/rOwBvwKkhr8i0h9AWrMPwN1iED82bKS/CrLVvbLtfL+MvJa/9PGRv2Oj"
|
| 63 |
+
"4D8eLgi+DwVEvw5IDj8skCk8IlQ4Pz6B6b/5cZs+VM9FP0Gv1L/aeeU+ehzZP7ptc7ypR1I/"
|
| 64 |
+
"gaorPxgfNb9y4iI9SJPIvzER575BCIg+HR05P16fyTzbUDg/CCyNv6lG0L3N7508"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def test_noise():
|
| 68 |
+
raw = base64.b64decode(_NOISE_B64)
|
| 69 |
+
arr = np.frombuffer(raw, dtype="<f4")
|
| 70 |
+
assert arr.size == 768, f"Expected 768 got {arr.size}"
|
| 71 |
+
# Cross-check with numpy RandomState(0)
|
| 72 |
+
ref = np.random.RandomState(0).randn(768).astype(np.float32)
|
| 73 |
+
max_diff = np.abs(arr - ref).max()
|
| 74 |
+
print(f"Noise decode: size={arr.size}, max_diff_vs_ref={max_diff:.6f}")
|
| 75 |
+
# Small diff expected: b64 encodes exact f32 bytes; numpy default is f64->f32 rounding
|
| 76 |
+
assert max_diff < 0.01, f"Noise too far from ref! max diff = {max_diff}"
|
| 77 |
+
print(f" PASS: size 768, within tolerance of np.random.RandomState(0).randn(768) (diff={max_diff:.6f} ok)")
|
| 78 |
+
|
| 79 |
+
def test_cond():
|
| 80 |
+
COND_OFFSET = 7
|
| 81 |
+
style = list(range(12))
|
| 82 |
+
notes = {60, 64, 67}
|
| 83 |
+
cond = [0] * 144
|
| 84 |
+
k = 0
|
| 85 |
+
for i in range(12):
|
| 86 |
+
cond[k] = style[i] + COND_OFFSET; k += 1
|
| 87 |
+
for i in range(128):
|
| 88 |
+
cond[k] = (3 if i in notes else -1) + COND_OFFSET; k += 1
|
| 89 |
+
cond[k] = -1 + COND_OFFSET; k += 1 # drum masked
|
| 90 |
+
# CFG tokens
|
| 91 |
+
def disc(v, step, mb):
|
| 92 |
+
c = max(-1.0, min(7.0, v))
|
| 93 |
+
return max(0, min(mb, round((c + 1.0) / step)))
|
| 94 |
+
cond[k] = disc(1.6, 0.2, 40) + COND_OFFSET; k += 1
|
| 95 |
+
cond[k] = disc(2.4, 0.2, 40) + COND_OFFSET; k += 1
|
| 96 |
+
cond[k] = disc(4.0, 1.0, 8) + COND_OFFSET
|
| 97 |
+
assert len(cond) == 144
|
| 98 |
+
print(f"Cond test: style[0]={cond[0]} (expect {0+7}=7), note C4 at idx {12+60}: {cond[12+60]} (expect {3+7}=10)")
|
| 99 |
+
assert cond[0] == 7
|
| 100 |
+
assert cond[12+60] == 10 # C4 held = NOTE_ON(3) + offset(7)
|
| 101 |
+
assert cond[12+61] == 6 # C#4 not held = MASKED(-1) + offset(7)
|
| 102 |
+
print(" PASS: conditioning vector correct")
|
| 103 |
+
|
| 104 |
+
def test_codec():
|
| 105 |
+
NUM_RESERVED = 6
|
| 106 |
+
CODEBOOK = 1024
|
| 107 |
+
unique = [6, 7, 1029, NUM_RESERVED + 11*CODEBOOK + 500]
|
| 108 |
+
codec = [((t - NUM_RESERVED) % CODEBOOK + CODEBOOK) % CODEBOOK for t in unique]
|
| 109 |
+
# unique[0]=6 -> (6-6)%1024=0; unique[1]=7 -> 1; unique[2]=1029 -> 1023; unique[3]=11770 -> 500
|
| 110 |
+
expected = [0, 1, 1023, 500]
|
| 111 |
+
for u, c, e in zip(unique, codec, expected):
|
| 112 |
+
assert c == e, f"unique={u} -> codec={c}, expected {e}"
|
| 113 |
+
print(f"Codec convert test: {unique} -> {codec} PASS")
|
| 114 |
+
|
| 115 |
+
def test_topk():
|
| 116 |
+
logits = np.random.randn(1, 12294).astype(np.float32)
|
| 117 |
+
NUM_RESERVED = 6
|
| 118 |
+
CODEBOOK = 1024
|
| 119 |
+
lo = NUM_RESERVED
|
| 120 |
+
hi = lo + CODEBOOK
|
| 121 |
+
# Mask to codebook slice
|
| 122 |
+
sliced = logits[0, lo:hi].copy()
|
| 123 |
+
top_k = 20
|
| 124 |
+
threshold = np.partition(sliced, -top_k)[-top_k]
|
| 125 |
+
sliced[sliced < threshold] = -1e9
|
| 126 |
+
sliced /= 0.9
|
| 127 |
+
sliced -= sliced.max()
|
| 128 |
+
probs = np.exp(sliced)
|
| 129 |
+
probs /= probs.sum()
|
| 130 |
+
token = lo + int(np.random.choice(len(probs), p=probs))
|
| 131 |
+
assert lo <= token < hi, f"Token {token} outside [{lo},{hi})"
|
| 132 |
+
print(f"TopK sampling: token={token} in [{lo},{hi}) PASS")
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
print("=== Logic Tests ===")
|
| 136 |
+
test_noise()
|
| 137 |
+
test_cond()
|
| 138 |
+
test_codec()
|
| 139 |
+
test_topk()
|
| 140 |
+
print("\nAll tests passed.")
|