Nekochu commited on
Commit
e6d18b9
·
1 Parent(s): 19d857a

add logic test script

Browse files
Files changed (1) hide show
  1. test_logic.py +140 -0
test_logic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test conditioning logic, noise decode, sampling - no model download needed."""
2
+ import base64
3
+ import sys
4
+ import numpy as np
5
+
6
+ # --- noise decode test ---
7
+ _NOISE_B64 = (
8
+ "eMzhP2jhzD6Tjno/y2oPQCQM7z/iLnq//zhzP2L9Gr5oZNO9+DnSPiiAEz6iJbo/XtNCP8Aw"
9
+ "+T0LQuM+XdeqPvw9vz8CFVK+aUqgPgWmWr8vZCPAjFMnP7FLXT+H/j2/qUMRQKgour9IbTs9"
10
+ "IK0/vhwyxD/zE7w/iqoePoWewT7tRWO/vYr9v4shsr7yGSA+KnqdP5XnmT+zT8a+bceavvw2"
11
+ "hr8mw7W/EGfavwKz+T+ReAK/RkvgvplboL+cCUc/NJTOv5fYWb5MPWW/FhjGPiDEAr/1Hpe/"
12
+ "a97mvFFO2z4uOog9md2aPu9iIr82ubm+XiYsv1oXuL5bKlC/1PbcvzOvNT47ts2+V6rQv8zx"
13
+ "7D61RGi/ssRUPa6lOj8ZFAQ+4teRP8YOnr+5/80+t08vv5DsXr9+LxS/0IOfvqENZj2hI5W/"
14
+ "kZxmP09r7j6io8S/DH++P3+s8j9A4pY/Nz44vmwOib9G+IY/NW3OvhR5nD8JRlU+BAV6P6h1"
15
+ "tj774TQ/RwgsPGiX5D8+9QE+jdHNPhUL8T9eg6y/QZ+iv2IqeD/oKJa/lMj4P97F074zWT+/"
16
+ "9yL2P4KBvT8sDO8/i/JnP0l5XL8CffQ/vTeJvshtTT8bf3I/97oevk40HT+9FWw/2brAPiq5"
17
+ "jL+tspg+A8epPzPPMb/MORm+cszevq207D+CGyw/1p7QPjgZRb88DAo/EaEsv8JgAj3PxiK/"
18
+ "uyotP3WbEz9FTFW+ZMHKPnHpi7+H4b6/8/fgPnWsKj5skSI/coUYQGjJcT+4rmm/ZPqOP6dv"
19
+ "qL/RVOy+QcKLvdBO2z9BqD6/epFTv3qhyb232Sm/mzWQPzI7ir9B4JK/8yngvhz+/r7o+vY/"
20
+ "Pg1zPxFOsz0S25y/LChYPw4HgL8Pu8W/XBGYP01Goj5nvWs/RTCjPkBZWz+dqSa/EmKEv/p8"
21
+ "Lj9BrE2/VoYwv4476b50MI88sT61vmf+r78txCS/PUwOwCsPID86EM2/b1yNvw2rVT0AVD2/"
22
+ "gYHFP1Z8pb/kuog+BecgvRaElb929QU/16kvvhGURT8r0VI/dXIKQFkTqz9nBb2+0R91vqXB"
23
+ "jD9dvyc/qd8jP2r4zr+VR8e8mO88v0dSjz4SA8m9fAFpP21qoj7KTEk/fM7uvjvHcb8J8tG+"
24
+ "ZW6LvC0gwj6FmBBA1hUtvdC4dL+GJLG+dFztvr2E9j7WOMW/gY+BPUBDID7ewG0+tekYv8Gh"
25
+ "c76hR7a/bJT8vvj4Cr+DBNU+yf2Tv5n8Rz9FS78/onoEwJY+2j7YSS0/Ey8jvzZny77ZEQi+"
26
+ "DHiYvvM2nr5Lh9a/mn+TP/Ewij+kOFC/y7O7v4JkBT/XZhO/LFwRPgR/o76vCDE/FNsxP8DA"
27
+ "Ob8SErG/up3Kv9NBHD+KLJi/t74BvwmoGL/OUFe9BNj3vy1PQT65HQY/pBa1PXksn769ecc9"
28
+ "zU/MPilyMcBWW/o/ULrHPkAEJ78KK8i+ucv8PufH7b289gHApyAEQCRj4r0FlYI/LioxvwGo"
29
+ "xD+km5I+Md0bP93Khb/PBps/7JcwP+aipj9ZyiC/MEn2vl9zE0CZroe/ZjYLvqiFkT8HJMg9"
30
+ "dDwVP5WEzL73d70+Rjynv6A91D/+//G9KSAuvxWYKj934+u++8iqvz1hrL8emzE/OGcjvhDp"
31
+ "CL6C84k/1DuQv7INO7/3DsW+aDvBPfm7LL3h4pK+92t8vd7C273+Nzi/TyBQv2iNjD4DE2S/"
32
+ "OCSUv8Xkn752cyG+KG4QQD1nNL+JeXE/vEc/P1kvmL/o80U/Z4mXv+EvKsDCNxs/BsHgv+Lg"
33
+ "5j5XGy+/KWzUP+rEiD8vIui+IBYwv+Nmm7+cwOG+wYqPvh25ur76diA+/BkUP+kFsz7wnkO/"
34
+ "jQm4v/uorj++fzC/t/wmv6psBb+06eu/Arn0vnKV9b7Nzx4/Fs4yPwUhdzudjW4/5A+uPsZ3"
35
+ "gLxbyiQ+qzpDvrcpyr5fFIm+rWKQvw6Wjz5ZPX6/JnVXPxJyf75Au0o9Ldj8PkKwJD8wCsm/"
36
+ "j95TvmhTYT+IW9m/oEnGPipbEMCB4YK/EjsePT4P1L9vSny/F2W8vxb20j9SKyg+8DkRP/EE"
37
+ "ZL4C9bS+oOjOv7Vrlb4n8UK/56BbP6APkj/auLs/2EBaP2JBGb+21Y6/CkREPwNstj54X+K/"
38
+ "tgG2Pl+EUD/1W3E9tn49vg3CTr8NKLm/VOBMP0BEnr7iEW++z8ndP3c7Lz/G3L0+pngRPjGP"
39
+ "wj+BG9w/DPRtP6wMFT/6DQbA6mH9PcI6Bb6NasA9eGtxP99WL8BvvhG/5zCKPmEG774uXbW/"
40
+ "ZHRePyjCjT5Pmni/uC+hPnFTUj/vba070fFMP99GoD10W8q+5GeUv3j8r7269EY+kzZgP3e9"
41
+ "671hMuo+0PB2v2JaSL/JE+K9Ef6Gv8P7UT9rH+0+pOWOPteErT7HWAFA+A7wvmrkDMBaFUw+"
42
+ "qUVPvSF8BL+YlHq/c93gvsiwOT6YuAC/pGUaQJ3jdb+9CUu/wHgSwJHCgD7ODAHAsxkKv7Ak"
43
+ "jb67sDW/YZPeP6GQfj962ag/M+Zhv8V1kD/W8/0+3HpFP6fEgz+1pGi/KUDZvhjTXD+q9SnA"
44
+ "vLTBPxCaDT8UNDu9wsxhPuvUg7/HK7O+HtaMP5Ylpj/vjixA0WWXve2WKL/WpAO/Mk+Cv1By"
45
+ "n72B9cM+okEMvRhVjD9E1m++DeWxvgPOFL8r+tC/nazIv6bulr8ylaY/xy9lP9P+rz/phaq/"
46
+ "5fv7v3P5KL/iCTQ+VVT/PvQjhj8bjZE+xQ/fP77yY76Pv2m/KTLXv6CTY7/F7Xc+LINjv1vO"
47
+ "bz8nx7Q/UKcXwIgyXT+sVQ/ASZHNPo/InD9H04Q928yjv9LeFb9k9oW+YJ42vjDET745CuG9"
48
+ "hJpaPouymr8M1He+YlbCPz/wxL58PuO+XwKKP67JI8BqN5c/csQhv+TcJz4iRMU9l0VxPy4C"
49
+ "ib4Zky2/0B+mP6BOF8Dfk6Y80oisv3n2Qr9uuABAsqk2vVrARz5ACuS/rKI6v1hGST7NorU+"
50
+ "R+wdPwhcDTy/6QY/GlboPu806r9Qkxc9QZVEP10CFz+0S7q+ij1Ov9gkj78GMwa+wwiRP7jU"
51
+ "+b+q7ii/DOWRv/rySD885w2/a/fwvgcoXr6WCuQ+NufIvgL0QsB9Fgs/PcrgPl3PYL62wYq/"
52
+ "hhy0Pikrwj4mqPC+2+5dvr0ebr8P4Da+eHTGv9Cq1T4iwnG/UNFzPpj2s78FDhe/RUjivdCR"
53
+ "1L+m0us9oR/CvocF37+p0Ka/JukaPyhDZT8PEwe+8TzPPj83ZT5YxKg+IJukP1PlwL+ILC0/"
54
+ "rpbDviKkZb56wJq+SBPAvv/znL9FvTs+duHVP73rZb1TirW61PIvv3+W8L1ere4+a5C9vgFZ"
55
+ "6L6xeM4+XAJrvz1HgT6cAFI/yxKuPzQaub1tDa8/i2eEP8sHf79p5Zu/MiScvim0gz82C5S9"
56
+ "ssQZv+ivxj8l5ZI+noQUwOFioj5iIQU/9QVnPqpA5j7Lx4m9MsGov+rMvb7gE3K/HMhuvzms"
57
+ "ob+mrOc+xn3IPe515b4DOya/0OG/vA4jij8SRQDA9vXAPsizC78cOvG/zAz5vy6sab8dx2A+"
58
+ "iz/JPhlhcL++LYI/UyS2P9zLyj4qZhe/+OyPP51hQT9qDl4/AQ4ov1dpNcCBeQdAQzHOv4uB"
59
+ "Er0iXhhAW0GpPtEBcz+ITsC/l4rjvzZfCL+wnYs/nEexvkltS7/wt0o+2nyKP83zuL8T85q/"
60
+ "OuZJvxwdjD8OdXA+NHUIQOi6bz/2vw+9Eu6hP6ySWD66dTS/1RIuP3dCMr/urpS+yfSpP6ts"
61
+ "z72tmk2/q73tvgnKgj9Ocw2/8BPGvoyiAr/3Vjw+6l7FvvcIzb9KHmO/Q8tuvxclnz9oC1A/"
62
+ "oVYWPypfAb+311C/rOwBvwKkhr8i0h9AWrMPwN1iED82bKS/CrLVvbLtfL+MvJa/9PGRv2Oj"
63
+ "4D8eLgi+DwVEvw5IDj8skCk8IlQ4Pz6B6b/5cZs+VM9FP0Gv1L/aeeU+ehzZP7ptc7ypR1I/"
64
+ "gaorPxgfNb9y4iI9SJPIvzER575BCIg+HR05P16fyTzbUDg/CCyNv6lG0L3N7508"
65
+ )
66
+
67
+ def test_noise():
68
+ raw = base64.b64decode(_NOISE_B64)
69
+ arr = np.frombuffer(raw, dtype="<f4")
70
+ assert arr.size == 768, f"Expected 768 got {arr.size}"
71
+ # Cross-check with numpy RandomState(0)
72
+ ref = np.random.RandomState(0).randn(768).astype(np.float32)
73
+ max_diff = np.abs(arr - ref).max()
74
+ print(f"Noise decode: size={arr.size}, max_diff_vs_ref={max_diff:.6f}")
75
+ # Small diff expected: b64 encodes exact f32 bytes; numpy default is f64->f32 rounding
76
+ assert max_diff < 0.01, f"Noise too far from ref! max diff = {max_diff}"
77
+ print(f" PASS: size 768, within tolerance of np.random.RandomState(0).randn(768) (diff={max_diff:.6f} ok)")
78
+
79
+ def test_cond():
80
+ COND_OFFSET = 7
81
+ style = list(range(12))
82
+ notes = {60, 64, 67}
83
+ cond = [0] * 144
84
+ k = 0
85
+ for i in range(12):
86
+ cond[k] = style[i] + COND_OFFSET; k += 1
87
+ for i in range(128):
88
+ cond[k] = (3 if i in notes else -1) + COND_OFFSET; k += 1
89
+ cond[k] = -1 + COND_OFFSET; k += 1 # drum masked
90
+ # CFG tokens
91
+ def disc(v, step, mb):
92
+ c = max(-1.0, min(7.0, v))
93
+ return max(0, min(mb, round((c + 1.0) / step)))
94
+ cond[k] = disc(1.6, 0.2, 40) + COND_OFFSET; k += 1
95
+ cond[k] = disc(2.4, 0.2, 40) + COND_OFFSET; k += 1
96
+ cond[k] = disc(4.0, 1.0, 8) + COND_OFFSET
97
+ assert len(cond) == 144
98
+ print(f"Cond test: style[0]={cond[0]} (expect {0+7}=7), note C4 at idx {12+60}: {cond[12+60]} (expect {3+7}=10)")
99
+ assert cond[0] == 7
100
+ assert cond[12+60] == 10 # C4 held = NOTE_ON(3) + offset(7)
101
+ assert cond[12+61] == 6 # C#4 not held = MASKED(-1) + offset(7)
102
+ print(" PASS: conditioning vector correct")
103
+
104
+ def test_codec():
105
+ NUM_RESERVED = 6
106
+ CODEBOOK = 1024
107
+ unique = [6, 7, 1029, NUM_RESERVED + 11*CODEBOOK + 500]
108
+ codec = [((t - NUM_RESERVED) % CODEBOOK + CODEBOOK) % CODEBOOK for t in unique]
109
+ # unique[0]=6 -> (6-6)%1024=0; unique[1]=7 -> 1; unique[2]=1029 -> 1023; unique[3]=11770 -> 500
110
+ expected = [0, 1, 1023, 500]
111
+ for u, c, e in zip(unique, codec, expected):
112
+ assert c == e, f"unique={u} -> codec={c}, expected {e}"
113
+ print(f"Codec convert test: {unique} -> {codec} PASS")
114
+
115
+ def test_topk():
116
+ logits = np.random.randn(1, 12294).astype(np.float32)
117
+ NUM_RESERVED = 6
118
+ CODEBOOK = 1024
119
+ lo = NUM_RESERVED
120
+ hi = lo + CODEBOOK
121
+ # Mask to codebook slice
122
+ sliced = logits[0, lo:hi].copy()
123
+ top_k = 20
124
+ threshold = np.partition(sliced, -top_k)[-top_k]
125
+ sliced[sliced < threshold] = -1e9
126
+ sliced /= 0.9
127
+ sliced -= sliced.max()
128
+ probs = np.exp(sliced)
129
+ probs /= probs.sum()
130
+ token = lo + int(np.random.choice(len(probs), p=probs))
131
+ assert lo <= token < hi, f"Token {token} outside [{lo},{hi})"
132
+ print(f"TopK sampling: token={token} in [{lo},{hi}) PASS")
133
+
134
+ if __name__ == "__main__":
135
+ print("=== Logic Tests ===")
136
+ test_noise()
137
+ test_cond()
138
+ test_codec()
139
+ test_topk()
140
+ print("\nAll tests passed.")