Upload 2 files
Browse files- TMIDIX.py +386 -1
- x_transformer_2_3_1.py +1204 -2
TMIDIX.py
CHANGED
|
@@ -51,7 +51,7 @@ r'''############################################################################
|
|
| 51 |
|
| 52 |
###################################################################################
|
| 53 |
|
| 54 |
-
__version__ = "26.2.
|
| 55 |
|
| 56 |
print('=' * 70)
|
| 57 |
print('TMIDIX Python module')
|
|
@@ -16094,6 +16094,391 @@ def squash_monophonic_escore_notes_pitches(escore_notes,
|
|
| 16094 |
|
| 16095 |
###################################################################################
|
| 16096 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16097 |
print('Module loaded!')
|
| 16098 |
print('=' * 70)
|
| 16099 |
print('Enjoy! :)')
|
|
|
|
| 51 |
|
| 52 |
###################################################################################
|
| 53 |
|
| 54 |
+
__version__ = "26.2.17"
|
| 55 |
|
| 56 |
print('=' * 70)
|
| 57 |
print('TMIDIX Python module')
|
|
|
|
| 16094 |
|
| 16095 |
###################################################################################
|
| 16096 |
|
| 16097 |
+
def humanize_velocities_in_escore_notes(escore_notes):
|
| 16098 |
+
|
| 16099 |
+
if not escore_notes:
|
| 16100 |
+
return []
|
| 16101 |
+
|
| 16102 |
+
notes = [list(n) for n in escore_notes]
|
| 16103 |
+
notes.sort(key=lambda x: (x[1], x[4])) # Sort by time, then pitch
|
| 16104 |
+
|
| 16105 |
+
# -----------------------------------------------------------
|
| 16106 |
+
# 1. GLOBAL ANALYSIS (Grid, Phrases, Arcs)
|
| 16107 |
+
# -----------------------------------------------------------
|
| 16108 |
+
|
| 16109 |
+
onset_times = sorted(list(set([n[1] for n in notes])))
|
| 16110 |
+
|
| 16111 |
+
# --- Grid Estimation ---
|
| 16112 |
+
estimated_grid = 120
|
| 16113 |
+
if len(onset_times) > 1:
|
| 16114 |
+
intervals = [onset_times[i+1] - onset_times[i] for i in range(len(onset_times)-1)]
|
| 16115 |
+
active_intervals = [i for i in intervals if i > 0 and i < 960]
|
| 16116 |
+
if active_intervals:
|
| 16117 |
+
rounded_intervals = [round(i / 10) * 10 for i in active_intervals]
|
| 16118 |
+
valid_intervals = [r for r in rounded_intervals if r > 0]
|
| 16119 |
+
if valid_intervals:
|
| 16120 |
+
estimated_grid = max(set(valid_intervals), key=valid_intervals.count)
|
| 16121 |
+
|
| 16122 |
+
# --- Beat Grid Calculation ---
|
| 16123 |
+
if estimated_grid <= 120:
|
| 16124 |
+
beat_grid = estimated_grid * 4
|
| 16125 |
+
elif estimated_grid <= 240:
|
| 16126 |
+
beat_grid = estimated_grid * 2
|
| 16127 |
+
else:
|
| 16128 |
+
beat_grid = estimated_grid
|
| 16129 |
+
|
| 16130 |
+
# --- Phrase Detection ---
|
| 16131 |
+
melodic_onsets = sorted(list(set([n[1] for n in notes if n[3] != 9])))
|
| 16132 |
+
phrase_gap_threshold = beat_grid * 3
|
| 16133 |
+
|
| 16134 |
+
phrases = []
|
| 16135 |
+
current_phrase = []
|
| 16136 |
+
|
| 16137 |
+
if not melodic_onsets: melodic_onsets = onset_times
|
| 16138 |
+
|
| 16139 |
+
for i, t in enumerate(melodic_onsets):
|
| 16140 |
+
if not current_phrase:
|
| 16141 |
+
current_phrase.append(t)
|
| 16142 |
+
else:
|
| 16143 |
+
prev_t = melodic_onsets[i-1]
|
| 16144 |
+
if t - prev_t > phrase_gap_threshold:
|
| 16145 |
+
phrases.append(current_phrase)
|
| 16146 |
+
current_phrase = [t]
|
| 16147 |
+
else:
|
| 16148 |
+
current_phrase.append(t)
|
| 16149 |
+
if current_phrase:
|
| 16150 |
+
phrases.append(current_phrase)
|
| 16151 |
+
|
| 16152 |
+
# --- Global Arc Calculation ---
|
| 16153 |
+
total_duration = onset_times[-1] - onset_times[0] if onset_times else 0
|
| 16154 |
+
global_progress_map = {}
|
| 16155 |
+
for t in onset_times:
|
| 16156 |
+
progress = (t - onset_times[0]) / total_duration if total_duration > 0 else 0
|
| 16157 |
+
global_arc = math.cos((progress - 0.5) * math.pi) * 6 # Range -6 to +6
|
| 16158 |
+
global_progress_map[t] = global_arc
|
| 16159 |
+
|
| 16160 |
+
# -----------------------------------------------------------
|
| 16161 |
+
# 2. INSTRUMENT SEPARATION & PROCESSING
|
| 16162 |
+
# -----------------------------------------------------------
|
| 16163 |
+
|
| 16164 |
+
instrument_tracks = defaultdict(list)
|
| 16165 |
+
for n in notes:
|
| 16166 |
+
instrument_tracks[(n[3], n[6])].append(n)
|
| 16167 |
+
|
| 16168 |
+
# Main Loop
|
| 16169 |
+
for key, track_notes in instrument_tracks.items():
|
| 16170 |
+
channel, inst_num = key
|
| 16171 |
+
|
| 16172 |
+
# --- Instrument Classification ---
|
| 16173 |
+
is_drum = (channel == 9) or (inst_num == 128)
|
| 16174 |
+
is_bass = (32 <= inst_num <= 39)
|
| 16175 |
+
is_keys = (inst_num <= 7) or (16 <= inst_num <= 23)
|
| 16176 |
+
is_plucked = (24 <= inst_num <= 31)
|
| 16177 |
+
is_solo_string = (40 <= inst_num <= 47)
|
| 16178 |
+
is_ensemble = (48 <= inst_num <= 55)
|
| 16179 |
+
is_synth_lead = (80 <= inst_num <= 87)
|
| 16180 |
+
is_synth_pad = (88 <= inst_num <= 95)
|
| 16181 |
+
|
| 16182 |
+
track_notes.sort(key=lambda x: (x[1], x[4]))
|
| 16183 |
+
|
| 16184 |
+
# ==========================================================
|
| 16185 |
+
# LOGIC A: DRUMS (UNIFORM & SEPARATED)
|
| 16186 |
+
# ==========================================================
|
| 16187 |
+
if is_drum:
|
| 16188 |
+
for n in track_notes:
|
| 16189 |
+
t, pitch = n[1], n[4]
|
| 16190 |
+
|
| 16191 |
+
is_kick = pitch in [35, 36]
|
| 16192 |
+
is_snare = pitch in [38, 40]
|
| 16193 |
+
is_hat = pitch in [42, 44, 46]
|
| 16194 |
+
is_tom = 41 <= pitch <= 50 and not is_hat and not is_snare
|
| 16195 |
+
|
| 16196 |
+
if is_kick:
|
| 16197 |
+
target_vel = 115
|
| 16198 |
+
elif is_snare:
|
| 16199 |
+
target_vel = 112
|
| 16200 |
+
elif is_hat:
|
| 16201 |
+
target_vel = 90
|
| 16202 |
+
elif is_tom:
|
| 16203 |
+
target_vel = 105
|
| 16204 |
+
else:
|
| 16205 |
+
target_vel = 100
|
| 16206 |
+
|
| 16207 |
+
position_in_beat = t % beat_grid
|
| 16208 |
+
if position_in_beat == 0:
|
| 16209 |
+
metric_mod = 0
|
| 16210 |
+
elif position_in_beat == (beat_grid // 2):
|
| 16211 |
+
metric_mod = -2
|
| 16212 |
+
else:
|
| 16213 |
+
metric_mod = -6
|
| 16214 |
+
|
| 16215 |
+
closest_t = min(onset_times, key=lambda x: abs(x - t))
|
| 16216 |
+
global_mod = global_progress_map.get(closest_t, 0)
|
| 16217 |
+
|
| 16218 |
+
final_vel = target_vel + metric_mod + global_mod + random.gauss(0, 2.0)
|
| 16219 |
+
n[5] = max(30, min(127, int(final_vel)))
|
| 16220 |
+
|
| 16221 |
+
continue
|
| 16222 |
+
|
| 16223 |
+
# ==========================================================
|
| 16224 |
+
# LOGIC B: BASS
|
| 16225 |
+
# ==========================================================
|
| 16226 |
+
elif is_bass:
|
| 16227 |
+
last_vel = 100
|
| 16228 |
+
for n in track_notes:
|
| 16229 |
+
t = n[1]
|
| 16230 |
+
|
| 16231 |
+
position_in_grid = t % beat_grid
|
| 16232 |
+
metric_strength = 4 if position_in_grid == 0 else (1 if position_in_grid == (beat_grid // 2) else -2)
|
| 16233 |
+
|
| 16234 |
+
target = 100 + metric_strength
|
| 16235 |
+
|
| 16236 |
+
smoothed = (target * 0.3) + (last_vel * 0.7)
|
| 16237 |
+
last_vel = smoothed
|
| 16238 |
+
|
| 16239 |
+
final_vel = smoothed + random.gauss(0, 2)
|
| 16240 |
+
n[5] = max(30, min(115, int(final_vel)))
|
| 16241 |
+
continue
|
| 16242 |
+
|
| 16243 |
+
# ==========================================================
|
| 16244 |
+
# LOGIC C: MELODIC INSTRUMENTS
|
| 16245 |
+
# ==========================================================
|
| 16246 |
+
else:
|
| 16247 |
+
# --- Role Detection (Base Velocity Target) ---
|
| 16248 |
+
# Determine the 'floor' velocity for this track
|
| 16249 |
+
|
| 16250 |
+
base_vel_target = 90 # Default
|
| 16251 |
+
|
| 16252 |
+
if is_solo_string or is_synth_lead:
|
| 16253 |
+
base_vel_target = 105 # Lead
|
| 16254 |
+
elif is_ensemble or is_synth_pad:
|
| 16255 |
+
base_vel_target = 80 # Pad/Backing
|
| 16256 |
+
elif is_plucked:
|
| 16257 |
+
# Check density for Guitar Solo vs Rhythm
|
| 16258 |
+
unique_onsets = set(nn[1] for nn in track_notes)
|
| 16259 |
+
avg_poly = len(track_notes) / len(unique_onsets) if unique_onsets else 1
|
| 16260 |
+
base_vel_target = 105 if avg_poly < 1.3 else 95
|
| 16261 |
+
elif is_keys:
|
| 16262 |
+
# Check if it looks like a Solo Piano piece or just a Piano track
|
| 16263 |
+
# If the whole score is basically this track -> Solo
|
| 16264 |
+
if len(track_notes) > (len(notes) * 0.8):
|
| 16265 |
+
base_vel_target = 100
|
| 16266 |
+
else:
|
| 16267 |
+
base_vel_target = 90
|
| 16268 |
+
|
| 16269 |
+
# --- Dispatch to specific logic ---
|
| 16270 |
+
|
| 16271 |
+
# CASE 1: PIANO / KEYS (RESTORED ORIGINAL LOGIC)
|
| 16272 |
+
if is_keys:
|
| 16273 |
+
last_vel_lh = base_vel_target
|
| 16274 |
+
last_vel_rh = base_vel_target
|
| 16275 |
+
prev_melody_pitch = None
|
| 16276 |
+
|
| 16277 |
+
time_slices = defaultdict(list)
|
| 16278 |
+
for n in track_notes:
|
| 16279 |
+
time_slices[n[1]].append(n)
|
| 16280 |
+
sorted_times = sorted(time_slices.keys())
|
| 16281 |
+
|
| 16282 |
+
for t_idx, t in enumerate(sorted_times):
|
| 16283 |
+
slice_notes = time_slices[t]
|
| 16284 |
+
slice_notes.sort(key=lambda x: x[4])
|
| 16285 |
+
|
| 16286 |
+
# Context Mods
|
| 16287 |
+
phrase_arc_mod = 0
|
| 16288 |
+
for p_times in phrases:
|
| 16289 |
+
if t in p_times:
|
| 16290 |
+
p_len = p_times[-1] - p_times[0]
|
| 16291 |
+
if p_len > 0:
|
| 16292 |
+
progress = (t - p_times[0]) / p_len
|
| 16293 |
+
phrase_arc_mod = math.cos((1.0 - progress) * math.pi) * 0.5 + 0.5
|
| 16294 |
+
phrase_arc_mod = (phrase_arc_mod - 0.5) * 10
|
| 16295 |
+
break
|
| 16296 |
+
|
| 16297 |
+
global_arc = global_progress_map.get(t, 0)
|
| 16298 |
+
|
| 16299 |
+
position_in_grid = t % beat_grid
|
| 16300 |
+
if position_in_grid == 0: metric_strength = 6
|
| 16301 |
+
elif position_in_grid == (beat_grid // 2): metric_strength = 2
|
| 16302 |
+
else: metric_strength = -3
|
| 16303 |
+
|
| 16304 |
+
# Hand Splitting Logic
|
| 16305 |
+
pitches = [n[4] for n in slice_notes]
|
| 16306 |
+
split_idx = 0
|
| 16307 |
+
max_gap = 0
|
| 16308 |
+
|
| 16309 |
+
if len(pitches) > 1:
|
| 16310 |
+
for i in range(len(pitches)-1):
|
| 16311 |
+
gap = pitches[i+1] - pitches[i]
|
| 16312 |
+
if gap > max_gap:
|
| 16313 |
+
max_gap = gap
|
| 16314 |
+
split_idx = i + 1
|
| 16315 |
+
|
| 16316 |
+
if max_gap < 12:
|
| 16317 |
+
avg_pitch = sum(pitches) / len(pitches)
|
| 16318 |
+
split_idx = 0 if avg_pitch >= 60 else len(pitches)
|
| 16319 |
+
|
| 16320 |
+
lh_notes = slice_notes[:split_idx]
|
| 16321 |
+
rh_notes = slice_notes[split_idx:]
|
| 16322 |
+
|
| 16323 |
+
# LH Processing
|
| 16324 |
+
if lh_notes:
|
| 16325 |
+
target_lh = base_vel_target + metric_strength + global_arc + phrase_arc_mod - 3
|
| 16326 |
+
smoothed_lh = (target_lh * 0.25) + (last_vel_lh * 0.75)
|
| 16327 |
+
last_vel_lh = smoothed_lh
|
| 16328 |
+
|
| 16329 |
+
for i, n in enumerate(lh_notes):
|
| 16330 |
+
offset = 0 if i == 0 else -5
|
| 16331 |
+
if n[2] < estimated_grid / 2: offset += 2
|
| 16332 |
+
vel = smoothed_lh + offset + random.gauss(0, 2)
|
| 16333 |
+
n[5] = max(30, min(115, int(vel)))
|
| 16334 |
+
|
| 16335 |
+
# RH Processing
|
| 16336 |
+
if rh_notes:
|
| 16337 |
+
target_rh = base_vel_target + metric_strength + global_arc + phrase_arc_mod + 3
|
| 16338 |
+
smoothed_rh = (target_rh * 0.35) + (last_vel_rh * 0.65)
|
| 16339 |
+
last_vel_rh = smoothed_rh
|
| 16340 |
+
|
| 16341 |
+
current_top_pitch = rh_notes[-1][4]
|
| 16342 |
+
num_rh = len(rh_notes)
|
| 16343 |
+
is_dense = num_rh >= 3
|
| 16344 |
+
|
| 16345 |
+
for i, n in enumerate(rh_notes):
|
| 16346 |
+
pitch = n[4]
|
| 16347 |
+
offset = 0
|
| 16348 |
+
|
| 16349 |
+
if i == num_rh - 1: # Top note melody
|
| 16350 |
+
offset = 12 if is_dense else 5
|
| 16351 |
+
if prev_melody_pitch is not None:
|
| 16352 |
+
diff = pitch - prev_melody_pitch
|
| 16353 |
+
if diff > 4: offset += 6
|
| 16354 |
+
elif diff > 0: offset += 3
|
| 16355 |
+
elif diff < -4: offset -= 4
|
| 16356 |
+
if pitch > 72: offset += (pitch - 72) * 0.2
|
| 16357 |
+
prev_melody_pitch = pitch
|
| 16358 |
+
elif i == 0:
|
| 16359 |
+
offset = -1
|
| 16360 |
+
else:
|
| 16361 |
+
offset = -8
|
| 16362 |
+
|
| 16363 |
+
if n[2] < estimated_grid / 2: offset += 2
|
| 16364 |
+
vel = smoothed_rh + offset + random.gauss(0, 2)
|
| 16365 |
+
n[5] = max(35, min(120, int(vel)))
|
| 16366 |
+
|
| 16367 |
+
# CASE 2: OTHER MELODIC (Generic Role-Based Logic)
|
| 16368 |
+
else:
|
| 16369 |
+
last_vel = base_vel_target
|
| 16370 |
+
prev_melody_pitch = None
|
| 16371 |
+
|
| 16372 |
+
time_slices = defaultdict(list)
|
| 16373 |
+
for n in track_notes:
|
| 16374 |
+
time_slices[n[1]].append(n)
|
| 16375 |
+
sorted_times = sorted(time_slices.keys())
|
| 16376 |
+
|
| 16377 |
+
for t_idx, t in enumerate(sorted_times):
|
| 16378 |
+
slice_notes = time_slices[t]
|
| 16379 |
+
slice_notes.sort(key=lambda x: x[4])
|
| 16380 |
+
|
| 16381 |
+
# Context
|
| 16382 |
+
phrase_arc_mod = 0
|
| 16383 |
+
for p_times in phrases:
|
| 16384 |
+
if t in p_times:
|
| 16385 |
+
p_len = p_times[-1] - p_times[0]
|
| 16386 |
+
if p_len > 0:
|
| 16387 |
+
progress = (t - p_times[0]) / p_len
|
| 16388 |
+
phrase_arc_mod = math.sin(progress * math.pi) * 6
|
| 16389 |
+
break
|
| 16390 |
+
|
| 16391 |
+
global_arc = global_progress_map.get(t, 0)
|
| 16392 |
+
|
| 16393 |
+
position_in_grid = t % beat_grid
|
| 16394 |
+
if position_in_grid == 0: metric_strength = 6
|
| 16395 |
+
elif position_in_grid == (beat_grid // 2): metric_strength = 2
|
| 16396 |
+
else: metric_strength = -3
|
| 16397 |
+
|
| 16398 |
+
core_vel = base_vel_target + metric_strength + global_arc + phrase_arc_mod
|
| 16399 |
+
smoothed = (core_vel * 0.35) + (last_vel * 0.65)
|
| 16400 |
+
last_vel = smoothed
|
| 16401 |
+
|
| 16402 |
+
num_notes = len(slice_notes)
|
| 16403 |
+
|
| 16404 |
+
for i, n in enumerate(slice_notes):
|
| 16405 |
+
pitch = n[4]
|
| 16406 |
+
offset = 0
|
| 16407 |
+
|
| 16408 |
+
if i == num_notes - 1:
|
| 16409 |
+
offset = 10
|
| 16410 |
+
if prev_melody_pitch is not None:
|
| 16411 |
+
diff = pitch - prev_melody_pitch
|
| 16412 |
+
if diff > 4: offset += 5
|
| 16413 |
+
prev_melody_pitch = pitch
|
| 16414 |
+
elif i == 0:
|
| 16415 |
+
offset = 0
|
| 16416 |
+
else:
|
| 16417 |
+
offset = -6
|
| 16418 |
+
|
| 16419 |
+
final_vel = smoothed + offset + random.gauss(0, 3)
|
| 16420 |
+
n[5] = max(30, min(127, int(final_vel)))
|
| 16421 |
+
|
| 16422 |
+
# -----------------------------------------------------------
|
| 16423 |
+
# 3. FINAL EXPRESSIVE SCALING
|
| 16424 |
+
# -----------------------------------------------------------
|
| 16425 |
+
for n in notes:
|
| 16426 |
+
if n[3] == 9:
|
| 16427 |
+
continue
|
| 16428 |
+
|
| 16429 |
+
v = n[5]
|
| 16430 |
+
center = 95
|
| 16431 |
+
|
| 16432 |
+
deviation = v - center
|
| 16433 |
+
final_v = center + (deviation * 1.1)
|
| 16434 |
+
final_v += random.randint(-1, 1)
|
| 16435 |
+
|
| 16436 |
+
n[5] = max(20, min(127, int(final_v)))
|
| 16437 |
+
|
| 16438 |
+
return notes
|
| 16439 |
+
|
| 16440 |
+
###################################################################################
|
| 16441 |
+
|
| 16442 |
+
def most_common_ordered_set(values, top_k):
|
| 16443 |
+
|
| 16444 |
+
freq = Counter(values)
|
| 16445 |
+
|
| 16446 |
+
top_vals = {v for v, _ in freq.most_common(top_k)}
|
| 16447 |
+
|
| 16448 |
+
result = []
|
| 16449 |
+
seen = set()
|
| 16450 |
+
|
| 16451 |
+
for v in values:
|
| 16452 |
+
if v in top_vals and v not in seen:
|
| 16453 |
+
result.append(v)
|
| 16454 |
+
seen.add(v)
|
| 16455 |
+
|
| 16456 |
+
return result
|
| 16457 |
+
|
| 16458 |
+
###################################################################################
|
| 16459 |
+
|
| 16460 |
+
def escore_notes_velocities(escore_notes, chan_idx=3, vels_idx=5):
|
| 16461 |
+
|
| 16462 |
+
output_list = []
|
| 16463 |
+
|
| 16464 |
+
all_vels = [e[vels_idx] for e in escore_notes]
|
| 16465 |
+
avg_vel = sum(all_vels) / len(all_vels)
|
| 16466 |
+
vels_span = max(all_vels) - min(all_vels)
|
| 16467 |
+
|
| 16468 |
+
output_list.append([-1, min(all_vels), avg_vel, max(all_vels), vels_span])
|
| 16469 |
+
|
| 16470 |
+
chan_groups = groupby(sorted(escore_notes, key=lambda x: x[chan_idx]), key=lambda x: x[chan_idx])
|
| 16471 |
+
|
| 16472 |
+
for cha, group in chan_groups:
|
| 16473 |
+
all_vels = [e[vels_idx] for e in list(group)]
|
| 16474 |
+
avg_vel = sum(all_vels) / len(all_vels)
|
| 16475 |
+
vels_span = max(all_vels) - min(all_vels)
|
| 16476 |
+
output_list.append([cha, min(all_vels), avg_vel, max(all_vels), vels_span])
|
| 16477 |
+
|
| 16478 |
+
return output_list
|
| 16479 |
+
|
| 16480 |
+
###################################################################################
|
| 16481 |
+
|
| 16482 |
print('Module loaded!')
|
| 16483 |
print('=' * 70)
|
| 16484 |
print('Enjoy! :)')
|
x_transformer_2_3_1.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
#
|
| 5 |
# Partial x-transformers code With useful modifications as a stand-alone Python module
|
| 6 |
#
|
| 7 |
-
# Version
|
| 8 |
#
|
| 9 |
# Original source code courtesy of lucidrains
|
| 10 |
# https://github.com/lucidrains/x-transformers
|
|
@@ -5190,7 +5190,9 @@ def build_cls_model(num_tokens=18819,
|
|
| 5190 |
squeeze_out_last_dim = squeeze_out_last_dim,
|
| 5191 |
attn_layers=Encoder(dim=dim,
|
| 5192 |
depth=depth,
|
| 5193 |
-
heads=heads
|
|
|
|
|
|
|
| 5194 |
)
|
| 5195 |
)
|
| 5196 |
|
|
@@ -6291,6 +6293,1206 @@ def calculate_training_run_eta(
|
|
| 6291 |
|
| 6292 |
return eta
|
| 6293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6294 |
#=================================================================================================================================
|
| 6295 |
# This is the end of x_transformer_2_3_1 Python module
|
| 6296 |
#=================================================================================================================================
|
|
|
|
| 4 |
#
|
| 5 |
# Partial x-transformers code With useful modifications as a stand-alone Python module
|
| 6 |
#
|
| 7 |
+
# Version 8.0
|
| 8 |
#
|
| 9 |
# Original source code courtesy of lucidrains
|
| 10 |
# https://github.com/lucidrains/x-transformers
|
|
|
|
| 5190 |
squeeze_out_last_dim = squeeze_out_last_dim,
|
| 5191 |
attn_layers=Encoder(dim=dim,
|
| 5192 |
depth=depth,
|
| 5193 |
+
heads=heads,
|
| 5194 |
+
attn_flash=True,
|
| 5195 |
+
rotary_pos_emb=True
|
| 5196 |
)
|
| 5197 |
)
|
| 5198 |
|
|
|
|
| 6293 |
|
| 6294 |
return eta
|
| 6295 |
|
| 6296 |
+
#=================================================================================================================================
|
| 6297 |
+
# Autoregressive embeddings retrieval functions
|
| 6298 |
+
#=================================================================================================================================
|
| 6299 |
+
|
| 6300 |
+
import torch
|
| 6301 |
+
import torch.nn.functional as F
|
| 6302 |
+
from typing import Optional, Dict, List, Union, Set
|
| 6303 |
+
import numpy as np
|
| 6304 |
+
from tqdm import tqdm
|
| 6305 |
+
from contextlib import nullcontext
|
| 6306 |
+
|
| 6307 |
+
#===================================================================================================================
|
| 6308 |
+
# Advanced Embeddings Retrieval Function for Autoregressive X-Transformers
|
| 6309 |
+
#===================================================================================================================
|
| 6310 |
+
|
| 6311 |
+
def get_embeddings(
|
| 6312 |
+
model,
|
| 6313 |
+
inputs: torch.Tensor,
|
| 6314 |
+
pooling: str = 'mean',
|
| 6315 |
+
mask: Optional[torch.Tensor] = None,
|
| 6316 |
+
token_ids: Optional[List[int]] = None,
|
| 6317 |
+
token_weights: Optional[Dict[int, float]] = None,
|
| 6318 |
+
layer_index: int = -1,
|
| 6319 |
+
normalize: bool = False,
|
| 6320 |
+
device: Optional[torch.device] = None,
|
| 6321 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 6322 |
+
pad_idx: int = 18819,
|
| 6323 |
+
use_amp: bool = True,
|
| 6324 |
+
verbose: bool = True,
|
| 6325 |
+
_max_concat_tokens: Optional[int] = None,
|
| 6326 |
+
) -> np.ndarray:
|
| 6327 |
+
|
| 6328 |
+
"""
|
| 6329 |
+
Get embeddings for a single batch of inputs.
|
| 6330 |
+
|
| 6331 |
+
Parameters
|
| 6332 |
+
----------
|
| 6333 |
+
model : AutoregressiveWrapper
|
| 6334 |
+
Your trained transformer model
|
| 6335 |
+
inputs : torch.Tensor
|
| 6336 |
+
Input token sequences of shape (batch, seq_len)
|
| 6337 |
+
pooling : str
|
| 6338 |
+
Pooling strategy: 'mean' or 'concat'
|
| 6339 |
+
mask : Optional[torch.Tensor]
|
| 6340 |
+
Boolean mask, True for valid tokens. Auto-generated if None
|
| 6341 |
+
token_ids : Optional[List[int]]
|
| 6342 |
+
Token IDs to include. Works independently or with token_weights.
|
| 6343 |
+
token_weights : Optional[Dict[int, float]]
|
| 6344 |
+
Token ID to weight/priority mapping:
|
| 6345 |
+
- 'mean': weights for weighted average
|
| 6346 |
+
- 'concat': priority scores for selection when limiting count
|
| 6347 |
+
- If provided WITHOUT token_ids: keys become the filter
|
| 6348 |
+
- If provided WITH token_ids: only tokens in BOTH are used (intersection)
|
| 6349 |
+
layer_index : int
|
| 6350 |
+
Which layer's hidden states to use (-1 for last)
|
| 6351 |
+
normalize : bool
|
| 6352 |
+
L2-normalize output embeddings
|
| 6353 |
+
device : Optional[torch.device]
|
| 6354 |
+
Device for inference
|
| 6355 |
+
dtype : torch.dtype
|
| 6356 |
+
Dtype for autocast
|
| 6357 |
+
pad_idx : int
|
| 6358 |
+
Padding token index
|
| 6359 |
+
use_amp : bool
|
| 6360 |
+
Use automatic mixed precision
|
| 6361 |
+
verbose : bool
|
| 6362 |
+
Print warnings and info
|
| 6363 |
+
_max_concat_tokens : Optional[int]
|
| 6364 |
+
Internal: pre-computed max tokens for concat mode
|
| 6365 |
+
|
| 6366 |
+
Returns
|
| 6367 |
+
-------
|
| 6368 |
+
np.ndarray
|
| 6369 |
+
Embeddings array:
|
| 6370 |
+
- 'mean': (batch, dim)
|
| 6371 |
+
- 'concat': (batch, max_tokens * dim)
|
| 6372 |
+
"""
|
| 6373 |
+
|
| 6374 |
+
model.eval()
|
| 6375 |
+
|
| 6376 |
+
if device is None:
|
| 6377 |
+
device = next(model.parameters()).device
|
| 6378 |
+
|
| 6379 |
+
inputs = inputs.to(device)
|
| 6380 |
+
|
| 6381 |
+
if inputs.ndim == 1:
|
| 6382 |
+
inputs = inputs.unsqueeze(0)
|
| 6383 |
+
|
| 6384 |
+
batch_size, seq_len = inputs.shape
|
| 6385 |
+
|
| 6386 |
+
if mask is None:
|
| 6387 |
+
mask = (inputs != pad_idx)
|
| 6388 |
+
else:
|
| 6389 |
+
mask = mask.to(device)
|
| 6390 |
+
|
| 6391 |
+
if mask.dtype != torch.bool:
|
| 6392 |
+
mask = mask.bool()
|
| 6393 |
+
|
| 6394 |
+
if hasattr(model, 'net'):
|
| 6395 |
+
net_model = model.net
|
| 6396 |
+
else:
|
| 6397 |
+
net_model = model
|
| 6398 |
+
|
| 6399 |
+
if use_amp and device.type == 'cuda':
|
| 6400 |
+
ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
|
| 6401 |
+
else:
|
| 6402 |
+
ctx = nullcontext()
|
| 6403 |
+
|
| 6404 |
+
try:
|
| 6405 |
+
with torch.no_grad():
|
| 6406 |
+
with ctx if use_amp else nullcontext():
|
| 6407 |
+
output = net_model(
|
| 6408 |
+
inputs,
|
| 6409 |
+
mask=mask if mask.ndim == 2 else mask.squeeze(),
|
| 6410 |
+
return_intermediates=True,
|
| 6411 |
+
)
|
| 6412 |
+
|
| 6413 |
+
if isinstance(output, tuple) and len(output) == 2:
|
| 6414 |
+
_, intermediates = output
|
| 6415 |
+
else:
|
| 6416 |
+
intermediates = None
|
| 6417 |
+
|
| 6418 |
+
hidden = _extract_hidden_states(intermediates, layer_index, verbose=verbose)
|
| 6419 |
+
|
| 6420 |
+
if hidden is None:
|
| 6421 |
+
raise ValueError("Could not extract hidden states")
|
| 6422 |
+
|
| 6423 |
+
except Exception as e:
|
| 6424 |
+
if verbose:
|
| 6425 |
+
print(f"Warning: Could not extract hidden states, using token embeddings. Error: {e}")
|
| 6426 |
+
hidden = _get_token_embeddings(net_model, inputs)
|
| 6427 |
+
|
| 6428 |
+
seq_mask = (inputs != pad_idx)
|
| 6429 |
+
seq_mask_expanded = seq_mask.unsqueeze(-1)
|
| 6430 |
+
hidden = hidden * seq_mask_expanded.float()
|
| 6431 |
+
|
| 6432 |
+
# Compute effective token IDs with INTUITIVE logic
|
| 6433 |
+
effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
|
| 6434 |
+
|
| 6435 |
+
if pooling == 'mean':
|
| 6436 |
+
emb = _mean_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, verbose=verbose)
|
| 6437 |
+
elif pooling == 'concat':
|
| 6438 |
+
emb = _concat_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights,
|
| 6439 |
+
max_tokens=_max_concat_tokens, verbose=verbose)
|
| 6440 |
+
else:
|
| 6441 |
+
raise ValueError(f"Unknown pooling strategy: {pooling}. Use 'mean' or 'concat'")
|
| 6442 |
+
|
| 6443 |
+
if normalize:
|
| 6444 |
+
emb = F.normalize(emb, p=2, dim=-1)
|
| 6445 |
+
|
| 6446 |
+
return emb.cpu().detach().numpy()
|
| 6447 |
+
|
| 6448 |
+
|
| 6449 |
+
#===================================================================================================================
|
| 6450 |
+
# Batched Processing Function
|
| 6451 |
+
#===================================================================================================================
|
| 6452 |
+
|
| 6453 |
+
def get_embeddings_batched(
|
| 6454 |
+
model,
|
| 6455 |
+
sequences: List[List[int]],
|
| 6456 |
+
pooling: str = 'mean',
|
| 6457 |
+
token_ids: Optional[List[int]] = None,
|
| 6458 |
+
token_weights: Optional[Dict[int, float]] = None,
|
| 6459 |
+
max_seq_len: int = 8192,
|
| 6460 |
+
pad_idx: int = 18819,
|
| 6461 |
+
batch_size: int = 8,
|
| 6462 |
+
use_amp: bool = True,
|
| 6463 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 6464 |
+
verbose: bool = True,
|
| 6465 |
+
show_progress: bool = True,
|
| 6466 |
+
normalize: bool = False,
|
| 6467 |
+
) -> np.ndarray:
|
| 6468 |
+
|
| 6469 |
+
"""
|
| 6470 |
+
Process multiple sequences in TRUE batches for memory efficiency.
|
| 6471 |
+
|
| 6472 |
+
Parameters
|
| 6473 |
+
----------
|
| 6474 |
+
model : AutoregressiveWrapper
|
| 6475 |
+
Your trained transformer model
|
| 6476 |
+
sequences : List[List[int]]
|
| 6477 |
+
List of token sequences (list of lists)
|
| 6478 |
+
pooling : str
|
| 6479 |
+
Pooling strategy: 'mean' or 'concat'
|
| 6480 |
+
token_ids : Optional[List[int]]
|
| 6481 |
+
Token IDs to include
|
| 6482 |
+
token_weights : Optional[Dict[int, float]]
|
| 6483 |
+
Token ID to weight/priority mapping
|
| 6484 |
+
max_seq_len : int
|
| 6485 |
+
Maximum sequence length
|
| 6486 |
+
pad_idx : int
|
| 6487 |
+
Padding token index
|
| 6488 |
+
batch_size : int
|
| 6489 |
+
Batch size for processing
|
| 6490 |
+
use_amp : bool
|
| 6491 |
+
Use automatic mixed precision
|
| 6492 |
+
dtype : torch.dtype
|
| 6493 |
+
Dtype for autocast
|
| 6494 |
+
verbose : bool
|
| 6495 |
+
Print messages
|
| 6496 |
+
show_progress : bool
|
| 6497 |
+
Show tqdm progress bar
|
| 6498 |
+
normalize : bool
|
| 6499 |
+
L2-normalize output embeddings
|
| 6500 |
+
|
| 6501 |
+
Returns
|
| 6502 |
+
-------
|
| 6503 |
+
np.ndarray
|
| 6504 |
+
Embeddings array with consistent dimensions
|
| 6505 |
+
"""
|
| 6506 |
+
|
| 6507 |
+
model.eval()
|
| 6508 |
+
|
| 6509 |
+
num_sequences = len(sequences)
|
| 6510 |
+
|
| 6511 |
+
if verbose:
|
| 6512 |
+
print(f"Processing {num_sequences} sequences in batches of {batch_size}...")
|
| 6513 |
+
|
| 6514 |
+
# For concat mode: pre-scan to find max matching tokens across ALL sequences
|
| 6515 |
+
max_concat_tokens = None
|
| 6516 |
+
if pooling == 'concat':
|
| 6517 |
+
effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
|
| 6518 |
+
max_concat_tokens = _scan_max_matching_tokens(sequences, effective_token_ids, pad_idx, max_seq_len)
|
| 6519 |
+
if verbose and max_concat_tokens is not None:
|
| 6520 |
+
print(f"Auto-detected max matching tokens: {max_concat_tokens}")
|
| 6521 |
+
elif verbose and max_concat_tokens == 0:
|
| 6522 |
+
print("Warning: No sequences contain matching token IDs, using 1 token placeholder")
|
| 6523 |
+
max_concat_tokens = 1
|
| 6524 |
+
|
| 6525 |
+
all_embeddings = []
|
| 6526 |
+
num_batches = (num_sequences + batch_size - 1) // batch_size
|
| 6527 |
+
|
| 6528 |
+
batch_iterator = tqdm(range(num_batches), desc="Extracting embeddings", disable=not (show_progress and verbose)) if show_progress and verbose else range(num_batches)
|
| 6529 |
+
|
| 6530 |
+
for batch_idx in batch_iterator:
|
| 6531 |
+
start_idx = batch_idx * batch_size
|
| 6532 |
+
end_idx = min((batch_idx + 1) * batch_size, num_sequences)
|
| 6533 |
+
|
| 6534 |
+
batch_sequences = sequences[start_idx:end_idx]
|
| 6535 |
+
max_len_in_batch = min(max_seq_len, max(len(seq) for seq in batch_sequences))
|
| 6536 |
+
|
| 6537 |
+
padded_batch = []
|
| 6538 |
+
for seq in batch_sequences:
|
| 6539 |
+
if len(seq) > max_len_in_batch:
|
| 6540 |
+
seq = seq[:max_len_in_batch]
|
| 6541 |
+
else:
|
| 6542 |
+
seq = seq + [pad_idx] * (max_len_in_batch - len(seq))
|
| 6543 |
+
padded_batch.append(seq)
|
| 6544 |
+
|
| 6545 |
+
batch_inputs = torch.tensor(padded_batch, dtype=torch.long)
|
| 6546 |
+
|
| 6547 |
+
batch_embeddings = get_embeddings(
|
| 6548 |
+
model,
|
| 6549 |
+
batch_inputs,
|
| 6550 |
+
pooling=pooling,
|
| 6551 |
+
token_ids=token_ids,
|
| 6552 |
+
token_weights=token_weights,
|
| 6553 |
+
pad_idx=pad_idx,
|
| 6554 |
+
use_amp=use_amp,
|
| 6555 |
+
dtype=dtype,
|
| 6556 |
+
verbose=verbose and batch_idx == 0,
|
| 6557 |
+
normalize=normalize,
|
| 6558 |
+
_max_concat_tokens=max_concat_tokens,
|
| 6559 |
+
)
|
| 6560 |
+
|
| 6561 |
+
all_embeddings.append(batch_embeddings)
|
| 6562 |
+
|
| 6563 |
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
| 6564 |
+
|
| 6565 |
+
if verbose:
|
| 6566 |
+
print(f"Final embeddings shape: {final_embeddings.shape}")
|
| 6567 |
+
|
| 6568 |
+
return final_embeddings
|
| 6569 |
+
|
| 6570 |
+
#===================================================================================================================
|
| 6571 |
+
# Helper Functions
|
| 6572 |
+
#===================================================================================================================
|
| 6573 |
+
|
| 6574 |
+
def _compute_effective_token_ids(token_ids: Optional[List[int]], token_weights: Optional[Dict[int, float]]) -> Optional[Set[int]]:
|
| 6575 |
+
"""
|
| 6576 |
+
Compute effective token IDs with INTUITIVE logic:
|
| 6577 |
+
|
| 6578 |
+
- token_ids=None, token_weights=None → None (all valid tokens)
|
| 6579 |
+
- token_ids=[...], token_weights=None → token_ids
|
| 6580 |
+
- token_ids=None, token_weights={...} → keys from token_weights
|
| 6581 |
+
- token_ids=[...], token_weights={...} → INTERSECTION (only tokens in BOTH)
|
| 6582 |
+
|
| 6583 |
+
This ensures token_weights acts as a filter when provided, not just weights.
|
| 6584 |
+
"""
|
| 6585 |
+
if token_ids is None and token_weights is None:
|
| 6586 |
+
return None
|
| 6587 |
+
|
| 6588 |
+
token_ids_set = set(token_ids) if token_ids is not None else None
|
| 6589 |
+
weights_keys_set = set(token_weights.keys()) if token_weights is not None else None
|
| 6590 |
+
|
| 6591 |
+
if token_ids_set is None and weights_keys_set is not None:
|
| 6592 |
+
# Only token_weights provided: use its keys as filter
|
| 6593 |
+
return weights_keys_set
|
| 6594 |
+
elif token_ids_set is not None and weights_keys_set is None:
|
| 6595 |
+
# Only token_ids provided: use token_ids as filter
|
| 6596 |
+
return token_ids_set
|
| 6597 |
+
elif token_ids_set is not None and weights_keys_set is not None:
|
| 6598 |
+
# Both provided: INTERSECTION (only tokens in BOTH lists)
|
| 6599 |
+
# This is the key fix for intuitive behavior
|
| 6600 |
+
intersection = token_ids_set & weights_keys_set
|
| 6601 |
+
if len(intersection) == 0:
|
| 6602 |
+
# Warn but fall back to token_ids (more permissive)
|
| 6603 |
+
print(f"Warning: token_ids and token_weights have no overlap. Using token_ids only.")
|
| 6604 |
+
return token_ids_set
|
| 6605 |
+
return intersection
|
| 6606 |
+
else:
|
| 6607 |
+
return None
|
| 6608 |
+
|
| 6609 |
+
def _scan_max_matching_tokens(sequences: List[List[int]],
|
| 6610 |
+
token_ids: Optional[Set[int]],
|
| 6611 |
+
pad_idx: int,
|
| 6612 |
+
max_seq_len: int) -> int:
|
| 6613 |
+
"""
|
| 6614 |
+
Scan all sequences to find maximum number of tokens matching token_ids.
|
| 6615 |
+
"""
|
| 6616 |
+
if token_ids is None:
|
| 6617 |
+
return max(min(len(seq), max_seq_len) for seq in sequences) if sequences else 0
|
| 6618 |
+
|
| 6619 |
+
max_count = 0
|
| 6620 |
+
for seq in sequences:
|
| 6621 |
+
truncated = seq[:max_seq_len]
|
| 6622 |
+
count = sum(1 for tok in truncated if tok in token_ids and tok != pad_idx)
|
| 6623 |
+
max_count = max(max_count, count)
|
| 6624 |
+
|
| 6625 |
+
return max_count
|
| 6626 |
+
|
| 6627 |
+
def _extract_hidden_states(intermediates, layer_index: int = -1, verbose: bool = True):
|
| 6628 |
+
"""Extract hidden states from LayerIntermediates object."""
|
| 6629 |
+
if intermediates is None:
|
| 6630 |
+
if verbose:
|
| 6631 |
+
print("Warning: intermediates is None")
|
| 6632 |
+
return None
|
| 6633 |
+
|
| 6634 |
+
if hasattr(intermediates, 'layer_hiddens') and intermediates.layer_hiddens is not None:
|
| 6635 |
+
if len(intermediates.layer_hiddens) > 0:
|
| 6636 |
+
return intermediates.layer_hiddens[layer_index]
|
| 6637 |
+
|
| 6638 |
+
if hasattr(intermediates, 'hiddens') and intermediates.hiddens is not None:
|
| 6639 |
+
if len(intermediates.hiddens) > 0:
|
| 6640 |
+
return intermediates.hiddens[layer_index]
|
| 6641 |
+
|
| 6642 |
+
if hasattr(intermediates, 'attn_intermediates') and intermediates.attn_intermediates is not None:
|
| 6643 |
+
if len(intermediates.attn_intermediates) > 0:
|
| 6644 |
+
attn_int = intermediates.attn_intermediates[layer_index]
|
| 6645 |
+
if hasattr(attn_int, 'values') and attn_int.values is not None:
|
| 6646 |
+
return attn_int.values
|
| 6647 |
+
|
| 6648 |
+
if verbose:
|
| 6649 |
+
print("Warning: Could not find layer_hiddens in intermediates")
|
| 6650 |
+
|
| 6651 |
+
return None
|
| 6652 |
+
|
| 6653 |
+
|
| 6654 |
+
def _get_token_embeddings(net_model, inputs: torch.Tensor):
|
| 6655 |
+
"""Get token embeddings directly from embedding layer."""
|
| 6656 |
+
if hasattr(net_model, 'token_emb'):
|
| 6657 |
+
if hasattr(net_model.token_emb, 'emb'):
|
| 6658 |
+
return net_model.token_emb.emb(inputs)
|
| 6659 |
+
else:
|
| 6660 |
+
return net_model.token_emb(inputs)
|
| 6661 |
+
elif hasattr(net_model, 'emb'):
|
| 6662 |
+
return net_model.emb(inputs)
|
| 6663 |
+
else:
|
| 6664 |
+
raise ValueError("Could not find embedding layer in model")
|
| 6665 |
+
|
| 6666 |
+
def _mean_pooling(
|
| 6667 |
+
hidden: torch.Tensor,
|
| 6668 |
+
inputs: torch.Tensor,
|
| 6669 |
+
mask: torch.Tensor,
|
| 6670 |
+
token_ids: Optional[Set[int]],
|
| 6671 |
+
token_weights: Optional[Dict[int, float]],
|
| 6672 |
+
verbose: bool = True
|
| 6673 |
+
) -> torch.Tensor:
|
| 6674 |
+
"""
|
| 6675 |
+
Mean pooling with token ID filtering and weighted averaging.
|
| 6676 |
+
"""
|
| 6677 |
+
batch_size, seq_len, dim = hidden.shape
|
| 6678 |
+
device = hidden.device
|
| 6679 |
+
|
| 6680 |
+
if mask.ndim > 2:
|
| 6681 |
+
mask = mask.squeeze()
|
| 6682 |
+
|
| 6683 |
+
effective_mask = mask.clone()
|
| 6684 |
+
|
| 6685 |
+
if token_ids is not None:
|
| 6686 |
+
token_mask = torch.zeros_like(mask, dtype=torch.bool, device=device)
|
| 6687 |
+
for tid in token_ids:
|
| 6688 |
+
token_mask = token_mask | (inputs == tid)
|
| 6689 |
+
effective_mask = effective_mask & token_mask
|
| 6690 |
+
|
| 6691 |
+
if verbose and effective_mask.sum() == 0:
|
| 6692 |
+
print(f"Warning: No tokens match filter, falling back to all valid tokens")
|
| 6693 |
+
effective_mask = mask
|
| 6694 |
+
|
| 6695 |
+
if token_weights is not None:
|
| 6696 |
+
weights = torch.zeros_like(effective_mask, dtype=torch.float32, device=device)
|
| 6697 |
+
|
| 6698 |
+
for token_id, weight in token_weights.items():
|
| 6699 |
+
id_mask = (inputs == token_id) & effective_mask
|
| 6700 |
+
weights = weights.masked_fill(id_mask, float(weight))
|
| 6701 |
+
|
| 6702 |
+
weights = weights.masked_fill(effective_mask & (weights == 0), 1.0)
|
| 6703 |
+
|
| 6704 |
+
weighted_hidden = hidden * weights.unsqueeze(-1)
|
| 6705 |
+
sum_weighted = weighted_hidden.sum(dim=1)
|
| 6706 |
+
sum_weights = weights.sum(dim=1, keepdim=True).clamp(min=1e-9)
|
| 6707 |
+
return sum_weighted / sum_weights
|
| 6708 |
+
else:
|
| 6709 |
+
masked_hidden = hidden * effective_mask.unsqueeze(-1).float()
|
| 6710 |
+
sum_hidden = masked_hidden.sum(dim=1)
|
| 6711 |
+
count = effective_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
|
| 6712 |
+
return sum_hidden / count
|
| 6713 |
+
|
| 6714 |
+
def _concat_pooling(
|
| 6715 |
+
hidden: torch.Tensor,
|
| 6716 |
+
inputs: torch.Tensor,
|
| 6717 |
+
mask: torch.Tensor,
|
| 6718 |
+
token_ids: Optional[Set[int]],
|
| 6719 |
+
token_weights: Optional[Dict[int, float]],
|
| 6720 |
+
max_tokens: Optional[int],
|
| 6721 |
+
verbose: bool = True
|
| 6722 |
+
) -> torch.Tensor:
|
| 6723 |
+
"""
|
| 6724 |
+
Concat pooling with token ID filtering and weight-based priority selection.
|
| 6725 |
+
"""
|
| 6726 |
+
batch_size, seq_len, dim = hidden.shape
|
| 6727 |
+
device = hidden.device
|
| 6728 |
+
|
| 6729 |
+
if max_tokens is None:
|
| 6730 |
+
max_tokens = 1
|
| 6731 |
+
|
| 6732 |
+
output_dim = max_tokens * dim
|
| 6733 |
+
|
| 6734 |
+
all_token_embs = []
|
| 6735 |
+
|
| 6736 |
+
for i in range(batch_size):
|
| 6737 |
+
seq_mask = mask[i]
|
| 6738 |
+
seq_inputs = inputs[i]
|
| 6739 |
+
|
| 6740 |
+
if token_ids is not None:
|
| 6741 |
+
matching_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
|
| 6742 |
+
for tid in token_ids:
|
| 6743 |
+
matching_mask = matching_mask | ((seq_inputs == tid) & seq_mask)
|
| 6744 |
+
valid_indices = matching_mask.nonzero(as_tuple=True)[0]
|
| 6745 |
+
else:
|
| 6746 |
+
valid_indices = seq_mask.nonzero(as_tuple=True)[0]
|
| 6747 |
+
|
| 6748 |
+
if len(valid_indices) == 0:
|
| 6749 |
+
emb = torch.zeros(dim, device=device)
|
| 6750 |
+
emb = F.pad(emb, (0, output_dim - dim))
|
| 6751 |
+
all_token_embs.append(emb)
|
| 6752 |
+
continue
|
| 6753 |
+
|
| 6754 |
+
matching_embs = hidden[i, valid_indices, :]
|
| 6755 |
+
|
| 6756 |
+
if token_weights is not None and len(valid_indices) > max_tokens:
|
| 6757 |
+
weights_list = []
|
| 6758 |
+
for idx in valid_indices:
|
| 6759 |
+
tok_id = seq_inputs[idx].item()
|
| 6760 |
+
weights_list.append(token_weights.get(tok_id, 1.0))
|
| 6761 |
+
|
| 6762 |
+
sorted_pairs = sorted(zip(range(len(valid_indices)), weights_list),
|
| 6763 |
+
key=lambda x: x[1], reverse=True)
|
| 6764 |
+
top_indices = [valid_indices[p[0]] for p in sorted_pairs[:max_tokens]]
|
| 6765 |
+
matching_embs = hidden[i, torch.tensor(top_indices, device=device), :]
|
| 6766 |
+
elif len(valid_indices) > max_tokens:
|
| 6767 |
+
matching_embs = matching_embs[:max_tokens]
|
| 6768 |
+
|
| 6769 |
+
if len(valid_indices) < max_tokens:
|
| 6770 |
+
padding_needed = max_tokens - len(valid_indices)
|
| 6771 |
+
padding = torch.zeros(padding_needed, dim, device=device)
|
| 6772 |
+
matching_embs = torch.cat([matching_embs, padding], dim=0)
|
| 6773 |
+
|
| 6774 |
+
emb = matching_embs.reshape(-1)
|
| 6775 |
+
all_token_embs.append(emb)
|
| 6776 |
+
|
| 6777 |
+
return torch.stack(all_token_embs, dim=0)
|
| 6778 |
+
|
| 6779 |
+
#=================================================================================================================================
|
| 6780 |
+
# Non-Autoregressive Encoder Embeddings Retrieval Functions
|
| 6781 |
+
#=================================================================================================================================
|
| 6782 |
+
|
| 6783 |
+
def get_enc_embeddings(
|
| 6784 |
+
model,
|
| 6785 |
+
sequences: List[List[int]],
|
| 6786 |
+
seq_len: Optional[int] = 3072,
|
| 6787 |
+
seq_pad_idx: int = 385,
|
| 6788 |
+
batch_size: int = 64,
|
| 6789 |
+
save_every_num_batches: int = -1,
|
| 6790 |
+
save_file_path: str = "saved_embeddings.npy",
|
| 6791 |
+
device: Optional[torch.device] = None,
|
| 6792 |
+
normalize: bool = False,
|
| 6793 |
+
pooling: str = "auto", # "auto" | "mean" | "weighted_mean"
|
| 6794 |
+
token_type_weights: Optional[Tuple[float, float, float]] = None, # (onset_w, duration_w, pitch_w)
|
| 6795 |
+
use_bfloat16: bool = True, # enable bfloat16 autocast when possible
|
| 6796 |
+
return_dtype: str = "float32", # "float32" or "float16" for returned embeddings
|
| 6797 |
+
return_numpy: bool = False,
|
| 6798 |
+
verbose: bool = True,
|
| 6799 |
+
show_progress_bar: bool = True
|
| 6800 |
+
) -> Union[Tensor, np.ndarray]:
|
| 6801 |
+
|
| 6802 |
+
"""
|
| 6803 |
+
Compute embeddings for a list of token sequences using a PyTorch model with optional bfloat16/autocast,
|
| 6804 |
+
pooling, normalization, and periodic saving.
|
| 6805 |
+
|
| 6806 |
+
This function batches input token id sequences, pads/truncates them to a fixed length, runs the model
|
| 6807 |
+
in evaluation mode under `torch.no_grad()` and optional mixed-precision autocast, and returns a single
|
| 6808 |
+
tensor (or NumPy array) containing per-sequence embeddings. The model is expected to accept a LongTensor
|
| 6809 |
+
of token ids `x` and a boolean mask `mask` and to return either:
|
| 6810 |
+
- a 2-D tensor `(B, D)` of already-pooled embeddings, or
|
| 6811 |
+
- a 3-D tensor `(B, L, D)` of per-token embeddings (which will be pooled according to `pooling`).
|
| 6812 |
+
|
| 6813 |
+
Key behaviors:
|
| 6814 |
+
- Sequences are padded with `seq_pad_idx` and masked so padding does not affect pooling.
|
| 6815 |
+
- If `seq_len` is provided, sequences longer than `seq_len` are truncated; otherwise the batch max length is used.
|
| 6816 |
+
- Mixed-precision autocast is used when `use_bfloat16` is True and supported by the device; the function
|
| 6817 |
+
falls back to the default autocast or no autocast if unavailable.
|
| 6818 |
+
- Supports three pooling modes for per-token embeddings:
|
| 6819 |
+
- `"auto"` or `"mean"`: simple masked mean pooling across tokens.
|
| 6820 |
+
- `"weighted_mean"`: weighted mean pooling by token type (onset/duration/pitch) inferred from token ids;
|
| 6821 |
+
weights are provided via `token_type_weights` and padding tokens are ignored.
|
| 6822 |
+
- Optionally L2-normalizes embeddings (in float32) when `normalize=True`.
|
| 6823 |
+
- Returned embeddings can be cast to `float16` for storage/transfer via `return_dtype`.
|
| 6824 |
+
- Embeddings are collected on CPU; intermediate results can be periodically saved to `save_file_path`.
|
| 6825 |
+
- If `return_numpy=True`, a NumPy array is returned; otherwise a CPU `torch.Tensor` is returned.
|
| 6826 |
+
|
| 6827 |
+
Args:
|
| 6828 |
+
model (torch.nn.Module):
|
| 6829 |
+
PyTorch model used to compute embeddings. The model will be moved to `device` (or its current
|
| 6830 |
+
parameter device if `device` is None) and set to `eval()` for inference. The forward call must
|
| 6831 |
+
accept `x` (LongTensor) and `mask` (BoolTensor) and return embeddings when called with
|
| 6832 |
+
`return_embeddings=True`.
|
| 6833 |
+
sequences (List[List[int]]):
|
| 6834 |
+
Batch of token id sequences (each sequence is a list of ints). Can be empty; an empty result
|
| 6835 |
+
with shape `(0, 0)` will be returned in that case.
|
| 6836 |
+
seq_len (Optional[int], default=3072):
|
| 6837 |
+
Target sequence length for truncation/padding. If None, the maximum sequence length in the
|
| 6838 |
+
current batch is used.
|
| 6839 |
+
seq_pad_idx (int, default=385):
|
| 6840 |
+
Token id used for padding positions.
|
| 6841 |
+
batch_size (int, default=64):
|
| 6842 |
+
Number of sequences processed per forward pass.
|
| 6843 |
+
save_every_num_batches (int, default=-1):
|
| 6844 |
+
If > 0, the function will save accumulated embeddings to `save_file_path` every
|
| 6845 |
+
`save_every_num_batches` batches. A non-positive value disables periodic saving.
|
| 6846 |
+
save_file_path (str, default="saved_embeddings.npy"):
|
| 6847 |
+
File path used by `np.save` when periodic saving is enabled.
|
| 6848 |
+
device (Optional[torch.device], default=None):
|
| 6849 |
+
Device to run the model and tensors on. If None, the device of the model parameters is used.
|
| 6850 |
+
normalize (bool, default=False):
|
| 6851 |
+
If True, L2-normalize each embedding vector (done in float32 for numerical stability).
|
| 6852 |
+
pooling (str, default="auto"):
|
| 6853 |
+
Pooling strategy applied when model returns per-token embeddings:
|
| 6854 |
+
- "auto" or "mean": masked mean pooling.
|
| 6855 |
+
- "weighted_mean": weighted mean pooling by token type using `token_type_weights`.
|
| 6856 |
+
Any other value raises `ValueError`.
|
| 6857 |
+
token_type_weights (Optional[Tuple[float, float, float]], default=None):
|
| 6858 |
+
Per-token-type weights `(onset_w, duration_w, pitch_w)` used when `pooling="weighted_mean"`.
|
| 6859 |
+
If None, defaults to `(1.0, 1.0, 1.0)`. Token type ranges are inferred as:
|
| 6860 |
+
onset: token_id in [0, 127]
|
| 6861 |
+
duration:token_id in [128, 255]
|
| 6862 |
+
pitch: token_id in [256, 383]
|
| 6863 |
+
use_bfloat16 (bool, default=True):
|
| 6864 |
+
If True, attempts to use `torch.bfloat16` autocast for the device; falls back gracefully if not supported.
|
| 6865 |
+
return_dtype (str, default="float32"):
|
| 6866 |
+
Data type for returned embeddings: `"float32"` or `"float16"`. Internally embeddings are normalized
|
| 6867 |
+
in float32; casting to float16 happens just before collecting results if requested.
|
| 6868 |
+
return_numpy (bool, default=False):
|
| 6869 |
+
If True, the final result is returned as a NumPy array; otherwise a CPU `torch.Tensor` is returned.
|
| 6870 |
+
verbose (bool, default=True):
|
| 6871 |
+
If True, prints progress and short diagnostic messages via `tqdm`.
|
| 6872 |
+
show_progress_bar (bool, default=True)
|
| 6873 |
+
If True, displays tqdm progress bar.
|
| 6874 |
+
|
| 6875 |
+
Returns:
|
| 6876 |
+
Union[torch.Tensor, numpy.ndarray]:
|
| 6877 |
+
- If `return_numpy` is False: a CPU `torch.Tensor` of shape `(N, D)` and dtype `torch.float32`
|
| 6878 |
+
or `torch.float16` depending on `return_dtype`.
|
| 6879 |
+
- If `return_numpy` is True: a NumPy array of shape `(N, D)` and dtype `np.float32` or `np.float16`.
|
| 6880 |
+
`N` is the total number of input sequences and `D` is the embedding dimensionality produced by the model.
|
| 6881 |
+
|
| 6882 |
+
Raises:
|
| 6883 |
+
AssertionError:
|
| 6884 |
+
If `return_dtype` is not one of `"float32"` or `"float16"`.
|
| 6885 |
+
RuntimeError:
|
| 6886 |
+
If the model returns `None` for embeddings (indicates incorrect forward flags or model behavior).
|
| 6887 |
+
ValueError:
|
| 6888 |
+
If the model returns an embedding tensor with unexpected dimensionality or if `pooling` is unsupported.
|
| 6889 |
+
|
| 6890 |
+
Notes:
|
| 6891 |
+
- The function uses `pad_and_mask` to produce `x` (LongTensor) and `mask` (BoolTensor). Padding tokens
|
| 6892 |
+
are ignored by pooling operations.
|
| 6893 |
+
- When `pooling="weighted_mean"`, if `token_ids` are not available or the model returns a 2-D tensor,
|
| 6894 |
+
the function falls back to masked mean pooling.
|
| 6895 |
+
- Periodic saving concatenates all embeddings collected so far and writes them with `np.save`. Save
|
| 6896 |
+
failures are caught and reported when `verbose=True` but do not abort processing.
|
| 6897 |
+
- The function runs the model under `torch.no_grad()` and sets `model.eval()`; it will move the model
|
| 6898 |
+
to `device` if provided.
|
| 6899 |
+
- For reproducible numeric behavior across devices, ensure the model and device support the requested
|
| 6900 |
+
autocast dtype (bfloat16) and that any randomness is controlled externally.
|
| 6901 |
+
|
| 6902 |
+
Example:
|
| 6903 |
+
>>> # simple usage
|
| 6904 |
+
>>> embs = get_embeddings_bf16(model, sequences, seq_len=1024, batch_size=32, pooling="mean",
|
| 6905 |
+
... normalize=True, return_dtype="float32", return_numpy=False)
|
| 6906 |
+
"""
|
| 6907 |
+
|
| 6908 |
+
assert return_dtype in ("float32", "float16"), "return_dtype must be 'float32' or 'float16'"
|
| 6909 |
+
|
| 6910 |
+
model_device = next(model.parameters()).device if device is None else device
|
| 6911 |
+
model.to(model_device)
|
| 6912 |
+
model.eval()
|
| 6913 |
+
|
| 6914 |
+
all_embs: List[Tensor] = []
|
| 6915 |
+
total_batches = math.ceil(len(sequences) / batch_size) if batch_size > 0 else 0
|
| 6916 |
+
|
| 6917 |
+
if verbose:
|
| 6918 |
+
tqdm.write(
|
| 6919 |
+
f"[get_embeddings_bf16] sequences={len(sequences)}, batch_size={batch_size}, "
|
| 6920 |
+
f"batches={total_batches}, device={model_device}, seq_len={seq_len}, pooling={pooling}"
|
| 6921 |
+
)
|
| 6922 |
+
|
| 6923 |
+
# Prepare autocast context using torch.amp.autocast
|
| 6924 |
+
autocast_ctx = None
|
| 6925 |
+
if use_bfloat16:
|
| 6926 |
+
try:
|
| 6927 |
+
autocast_ctx = torch.amp.autocast(device_type=model_device.type, dtype=torch.bfloat16)
|
| 6928 |
+
except Exception:
|
| 6929 |
+
try:
|
| 6930 |
+
autocast_ctx = torch.amp.autocast(device_type=model_device.type)
|
| 6931 |
+
except Exception:
|
| 6932 |
+
autocast_ctx = None
|
| 6933 |
+
else:
|
| 6934 |
+
try:
|
| 6935 |
+
autocast_ctx = torch.amp.autocast(device_type=model_device.type)
|
| 6936 |
+
except Exception:
|
| 6937 |
+
autocast_ctx = None
|
| 6938 |
+
|
| 6939 |
+
with torch.inference_mode():
|
| 6940 |
+
batch_iter = range(0, len(sequences), batch_size)
|
| 6941 |
+
pbar = tqdm(batch_iter, disable=not show_progress_bar, total=total_batches, desc="Embedding batches")
|
| 6942 |
+
for batch_idx, i in enumerate(pbar):
|
| 6943 |
+
batch_seqs = sequences[i : i + batch_size]
|
| 6944 |
+
x, mask = pad_and_mask(batch_seqs, pad_idx=seq_pad_idx, seq_len=seq_len, device=model_device, verbose=verbose)
|
| 6945 |
+
# x: (B, L) LongTensor token ids, mask: (B, L) boolean
|
| 6946 |
+
|
| 6947 |
+
# Run forward under autocast if available
|
| 6948 |
+
if autocast_ctx is not None:
|
| 6949 |
+
with autocast_ctx:
|
| 6950 |
+
out = model(x, return_embeddings=True, mask=mask)
|
| 6951 |
+
else:
|
| 6952 |
+
out = model(x, return_embeddings=True, mask=mask)
|
| 6953 |
+
|
| 6954 |
+
if out is None:
|
| 6955 |
+
raise RuntimeError("model returned None for embeddings. Check forward flags.")
|
| 6956 |
+
|
| 6957 |
+
# Handle shapes
|
| 6958 |
+
if out.dim() == 2:
|
| 6959 |
+
# already pooled: (B, D)
|
| 6960 |
+
emb = out
|
| 6961 |
+
elif out.dim() == 3:
|
| 6962 |
+
# per-token embeddings: (B, L, D)
|
| 6963 |
+
if pooling in ("mean", "auto"):
|
| 6964 |
+
emb = masked_mean_pool(out, mask, dim=1, verbose=verbose)
|
| 6965 |
+
elif pooling == "weighted_mean":
|
| 6966 |
+
# Use token ids to compute per-token weights; fallback to mean if token ids missing
|
| 6967 |
+
emb = masked_weighted_mean_pool(out, mask, token_ids=x, token_type_weights=token_type_weights, dim=1, verbose=verbose)
|
| 6968 |
+
else:
|
| 6969 |
+
raise ValueError(f"unsupported pooling: {pooling}")
|
| 6970 |
+
else:
|
| 6971 |
+
raise ValueError(f"unexpected embedding tensor shape: {out.shape}")
|
| 6972 |
+
|
| 6973 |
+
# Ensure embeddings are float32 for stable normalization/indexing
|
| 6974 |
+
if emb.dtype != torch.float32:
|
| 6975 |
+
emb = emb.float()
|
| 6976 |
+
|
| 6977 |
+
# L2 normalize in float32
|
| 6978 |
+
if normalize:
|
| 6979 |
+
emb = F.normalize(emb, p=2, dim=-1)
|
| 6980 |
+
|
| 6981 |
+
# Optionally cast to float16 for return/storage
|
| 6982 |
+
if return_dtype == "float16":
|
| 6983 |
+
emb = emb.half()
|
| 6984 |
+
|
| 6985 |
+
all_embs.append(emb.cpu())
|
| 6986 |
+
|
| 6987 |
+
# Update progress bar postfix with shapes and dtype
|
| 6988 |
+
if verbose:
|
| 6989 |
+
pbar.set_postfix({"batch": batch_idx + 1, "emb_shape": f"{emb.shape}", "dtype": str(emb.dtype)})
|
| 6990 |
+
|
| 6991 |
+
# Save intermediate results periodically
|
| 6992 |
+
if save_every_num_batches > 0:
|
| 6993 |
+
# compute 0-based batch number
|
| 6994 |
+
bnum = batch_idx
|
| 6995 |
+
if (bnum + 1) % save_every_num_batches == 0:
|
| 6996 |
+
try:
|
| 6997 |
+
concatenated = torch.cat(all_embs, dim=0).numpy()
|
| 6998 |
+
np.save(save_file_path, concatenated)
|
| 6999 |
+
if verbose:
|
| 7000 |
+
tqdm.write(f"[get_embeddings_bf16] saved {concatenated.shape[0]} embeddings to {save_file_path}")
|
| 7001 |
+
except Exception as e:
|
| 7002 |
+
# Do not crash the whole run for a save failure; report if verbose
|
| 7003 |
+
if verbose:
|
| 7004 |
+
tqdm.write(f"[get_embeddings_bf16] warning: failed to save embeddings: {e}")
|
| 7005 |
+
|
| 7006 |
+
if len(all_embs) == 0:
|
| 7007 |
+
# return empty tensor/array with shape (0, 0)
|
| 7008 |
+
empty = torch.empty((0, 0), dtype=(torch.float16 if return_dtype == "float16" else torch.float32))
|
| 7009 |
+
if verbose:
|
| 7010 |
+
tqdm.write("[get_embeddings_bf16] no embeddings were produced; returning empty tensor")
|
| 7011 |
+
return empty.numpy() if return_numpy else empty
|
| 7012 |
+
|
| 7013 |
+
result = torch.cat(all_embs, dim=0) # (N, D) on CPU
|
| 7014 |
+
|
| 7015 |
+
if verbose:
|
| 7016 |
+
tqdm.write(f"[get_embeddings_bf16] finished: total_embeddings={result.shape[0]}, dim={result.shape[1]}, dtype={result.dtype}")
|
| 7017 |
+
|
| 7018 |
+
if return_numpy:
|
| 7019 |
+
return result.numpy()
|
| 7020 |
+
|
| 7021 |
+
return result
|
| 7022 |
+
|
| 7023 |
+
###################################################################################
|
| 7024 |
+
|
| 7025 |
+
def masked_mean_pool(
|
| 7026 |
+
token_embeddings: Tensor,
|
| 7027 |
+
mask: Tensor,
|
| 7028 |
+
dim: int = 1,
|
| 7029 |
+
eps: float = 1e-9,
|
| 7030 |
+
verbose: bool = True,
|
| 7031 |
+
) -> Tensor:
|
| 7032 |
+
|
| 7033 |
+
"""
|
| 7034 |
+
Compute a masked mean pooling over a specified dimension.
|
| 7035 |
+
|
| 7036 |
+
This function computes the mean of `token_embeddings` along `dim`, ignoring
|
| 7037 |
+
positions where `mask` is False. The mask is cast to the same dtype as the
|
| 7038 |
+
embeddings to allow safe multiplication. A small epsilon is used to avoid
|
| 7039 |
+
division by zero for sequences that are entirely masked out.
|
| 7040 |
+
|
| 7041 |
+
Args:
|
| 7042 |
+
token_embeddings: Tensor of shape (B, L, D) or similar where `dim` indexes
|
| 7043 |
+
the sequence length. Embeddings dtype can be float16/float32/bfloat16.
|
| 7044 |
+
mask: Boolean tensor of shape broadcastable to the sequence dimension
|
| 7045 |
+
(e.g., (B, L)). True indicates valid tokens; False indicates padding.
|
| 7046 |
+
dim: Dimension along which to pool (default: 1, the sequence length).
|
| 7047 |
+
eps: Small value to avoid division by zero when a row has zero valid tokens.
|
| 7048 |
+
verbose: If True, prints a short summary about the pooling operation.
|
| 7049 |
+
|
| 7050 |
+
Returns:
|
| 7051 |
+
Tensor of pooled embeddings with the sequence dimension removed, typically
|
| 7052 |
+
shape (B, D). The returned dtype matches `token_embeddings.dtype`.
|
| 7053 |
+
"""
|
| 7054 |
+
|
| 7055 |
+
mask_f = mask.to(token_embeddings.dtype) # (B, L)
|
| 7056 |
+
summed = (token_embeddings * mask_f.unsqueeze(-1)).sum(dim=dim) # (B, D)
|
| 7057 |
+
counts = mask_f.sum(dim=dim).clamp_min(eps).unsqueeze(-1) # (B, 1)
|
| 7058 |
+
pooled = summed / counts # (B, D)
|
| 7059 |
+
|
| 7060 |
+
if verbose:
|
| 7061 |
+
# Use tqdm.write so it doesn't interfere with progress bars
|
| 7062 |
+
valid_counts = counts.squeeze(-1)
|
| 7063 |
+
tqdm.write(
|
| 7064 |
+
f"[masked_mean_pool] pooled shape={pooled.shape}, "
|
| 7065 |
+
f"counts min={valid_counts.min().item():.3f}, max={valid_counts.max().item():.3f}"
|
| 7066 |
+
)
|
| 7067 |
+
|
| 7068 |
+
return pooled
|
| 7069 |
+
|
| 7070 |
+
###################################################################################
|
| 7071 |
+
|
| 7072 |
+
def masked_weighted_mean_pool(
|
| 7073 |
+
token_embs: Tensor,
|
| 7074 |
+
valid_mask: Tensor,
|
| 7075 |
+
token_ids: Optional[Tensor] = None,
|
| 7076 |
+
token_type_weights: Optional[Tuple[float, float, float]] = None,
|
| 7077 |
+
dim: int = 1,
|
| 7078 |
+
verbose: bool = False,
|
| 7079 |
+
) -> Tensor:
|
| 7080 |
+
|
| 7081 |
+
"""
|
| 7082 |
+
Weighted mean pooling across tokens. If token_ids is provided, token types are
|
| 7083 |
+
inferred using the same ranges as the reference code:
|
| 7084 |
+
- onset: token_id in [0, 127]
|
| 7085 |
+
- duration:token_id in [128, 255]
|
| 7086 |
+
- pitch: token_id in [256, 383]
|
| 7087 |
+
token_type_weights: (onset_w, duration_w, pitch_w). If None, defaults to (1.0,1.0,1.0)
|
| 7088 |
+
The function multiplies each token embedding by its scalar weight and divides
|
| 7089 |
+
by the sum of weights for valid tokens per sequence.
|
| 7090 |
+
"""
|
| 7091 |
+
|
| 7092 |
+
B, L, D = token_embs.shape
|
| 7093 |
+
device = token_embs.device
|
| 7094 |
+
dtype = token_embs.dtype
|
| 7095 |
+
|
| 7096 |
+
if token_ids is None:
|
| 7097 |
+
# No token-level ids available: fallback to simple masked mean
|
| 7098 |
+
if verbose:
|
| 7099 |
+
tqdm.write("[masked_weighted_mean_pool] token_ids is None, falling back to masked_mean_pool")
|
| 7100 |
+
return masked_mean_pool(token_embs, valid_mask, dim=dim, verbose=verbose)
|
| 7101 |
+
|
| 7102 |
+
# Default weights
|
| 7103 |
+
if token_type_weights is None:
|
| 7104 |
+
onset_w, duration_w, pitch_w = 1.0, 1.0, 1.0
|
| 7105 |
+
else:
|
| 7106 |
+
onset_w, duration_w, pitch_w = token_type_weights
|
| 7107 |
+
|
| 7108 |
+
# Build per-type boolean masks based on token id values (same ranges as reference)
|
| 7109 |
+
onset_mask = (token_ids >= 0) & (token_ids < 128)
|
| 7110 |
+
duration_mask = (token_ids >= 128) & (token_ids < 256)
|
| 7111 |
+
pitch_mask = (token_ids >= 256) & (token_ids < 384)
|
| 7112 |
+
|
| 7113 |
+
# Combine with valid_mask to ignore padding positions
|
| 7114 |
+
onset_mask = onset_mask & valid_mask
|
| 7115 |
+
duration_mask = duration_mask & valid_mask
|
| 7116 |
+
pitch_mask = pitch_mask & valid_mask
|
| 7117 |
+
|
| 7118 |
+
# Build per-token scalar weight tensor (B, L)
|
| 7119 |
+
w = torch.ones((B, L), device=device, dtype=dtype)
|
| 7120 |
+
if onset_w != 1.0:
|
| 7121 |
+
w = torch.where(onset_mask, torch.tensor(onset_w, device=device, dtype=dtype), w)
|
| 7122 |
+
if duration_w != 1.0:
|
| 7123 |
+
w = torch.where(duration_mask, torch.tensor(duration_w, device=device, dtype=dtype), w)
|
| 7124 |
+
if pitch_w != 1.0:
|
| 7125 |
+
w = torch.where(pitch_mask, torch.tensor(pitch_w, device=device, dtype=dtype), w)
|
| 7126 |
+
|
| 7127 |
+
# Zero out weights for padding positions
|
| 7128 |
+
valid_mask_f = valid_mask.to(dtype) # (B, L)
|
| 7129 |
+
w = w * valid_mask_f # (B, L)
|
| 7130 |
+
|
| 7131 |
+
# Weighted sum and normalization
|
| 7132 |
+
denom = w.sum(dim=1, keepdim=True).clamp(min=1e-6) # (B, 1)
|
| 7133 |
+
w_exp = w.unsqueeze(-1) # (B, L, 1)
|
| 7134 |
+
summed = (token_embs * w_exp).sum(dim=dim) # (B, D)
|
| 7135 |
+
pooled = summed / denom # (B, D)
|
| 7136 |
+
|
| 7137 |
+
return pooled
|
| 7138 |
+
|
| 7139 |
+
###################################################################################
|
| 7140 |
+
|
| 7141 |
+
def pad_and_mask(
|
| 7142 |
+
sequences: List[List[int]],
|
| 7143 |
+
pad_idx: int = 385,
|
| 7144 |
+
seq_len: Optional[int] = None,
|
| 7145 |
+
device: Optional[torch.device] = None,
|
| 7146 |
+
verbose: bool = False,
|
| 7147 |
+
) -> Tuple[Tensor, Tensor]:
|
| 7148 |
+
|
| 7149 |
+
"""
|
| 7150 |
+
Pad and create a boolean mask for a batch of integer token sequences.
|
| 7151 |
+
|
| 7152 |
+
This utility converts a list of variable-length integer sequences into a
|
| 7153 |
+
padded LongTensor and a corresponding boolean mask indicating valid token
|
| 7154 |
+
positions. Sequences longer than `seq_len` are truncated. If `seq_len` is
|
| 7155 |
+
None, the function uses the maximum sequence length in the batch.
|
| 7156 |
+
|
| 7157 |
+
Args:
|
| 7158 |
+
sequences: List of token id sequences (each a list of ints).
|
| 7159 |
+
pad_idx: Integer token id used for padding positions (default: 385).
|
| 7160 |
+
seq_len: Optional target sequence length. If provided, sequences are
|
| 7161 |
+
truncated or padded to this length. If None, the maximum length in
|
| 7162 |
+
`sequences` is used.
|
| 7163 |
+
device: Optional torch.device where the returned tensors will be placed.
|
| 7164 |
+
If None, tensors are created on the default device.
|
| 7165 |
+
verbose: If True, shows a small progress bar while processing sequences
|
| 7166 |
+
and prints a summary.
|
| 7167 |
+
|
| 7168 |
+
Returns:
|
| 7169 |
+
A tuple (x, mask):
|
| 7170 |
+
- x: LongTensor of shape (B, T) containing padded token ids.
|
| 7171 |
+
- mask: BoolTensor of shape (B, T) where True indicates a real token.
|
| 7172 |
+
"""
|
| 7173 |
+
|
| 7174 |
+
# Fast path for empty batch
|
| 7175 |
+
if not sequences:
|
| 7176 |
+
empty = torch.empty((0, 0), dtype=torch.long, device=device)
|
| 7177 |
+
empty_mask = torch.empty((0, 0), dtype=torch.bool, device=device)
|
| 7178 |
+
return empty, empty_mask
|
| 7179 |
+
|
| 7180 |
+
# Compute lengths and the batch maximum length
|
| 7181 |
+
lengths = [len(s) for s in sequences]
|
| 7182 |
+
batch_max = max(lengths)
|
| 7183 |
+
|
| 7184 |
+
# If seq_len is given, only use it to cap lengths; but if the batch max is smaller,
|
| 7185 |
+
# use the smaller value to avoid extra allocation/work.
|
| 7186 |
+
if seq_len is None:
|
| 7187 |
+
target_len = batch_max
|
| 7188 |
+
else:
|
| 7189 |
+
target_len = min(seq_len, batch_max)
|
| 7190 |
+
|
| 7191 |
+
b = len(sequences)
|
| 7192 |
+
if target_len == 0:
|
| 7193 |
+
x = torch.full((b, 0), pad_idx, dtype=torch.long, device=device)
|
| 7194 |
+
mask = torch.zeros((b, 0), dtype=torch.bool, device=device)
|
| 7195 |
+
return x, mask
|
| 7196 |
+
|
| 7197 |
+
x = torch.full((b, target_len), pad_idx, dtype=torch.long, device=device)
|
| 7198 |
+
mask = torch.zeros((b, target_len), dtype=torch.bool, device=device)
|
| 7199 |
+
|
| 7200 |
+
# iterate with optional progress display
|
| 7201 |
+
iterator = enumerate(sequences)
|
| 7202 |
+
if verbose:
|
| 7203 |
+
iterator = enumerate(tqdm(sequences, disable=not verbose, desc="Pad & mask"))
|
| 7204 |
+
|
| 7205 |
+
for i, seq in iterator:
|
| 7206 |
+
if not seq:
|
| 7207 |
+
continue
|
| 7208 |
+
# Only truncate if seq is longer than the chosen target_len
|
| 7209 |
+
L = len(seq)
|
| 7210 |
+
if L > target_len:
|
| 7211 |
+
L = target_len
|
| 7212 |
+
# slice once to avoid creating a larger tensor then slicing
|
| 7213 |
+
seq_slice = seq[:L]
|
| 7214 |
+
seq_tensor = torch.tensor(seq_slice, dtype=torch.long, device=device)
|
| 7215 |
+
else:
|
| 7216 |
+
seq_tensor = torch.tensor(seq, dtype=torch.long, device=device)
|
| 7217 |
+
|
| 7218 |
+
x[i, :L] = seq_tensor[:L]
|
| 7219 |
+
mask[i, :L] = True
|
| 7220 |
+
|
| 7221 |
+
if verbose:
|
| 7222 |
+
tqdm.write(
|
| 7223 |
+
f"[pad_and_mask] batch_size={b}, target_len={target_len}, "
|
| 7224 |
+
f"min_len={min(lengths)}, max_len={max(lengths)}"
|
| 7225 |
+
)
|
| 7226 |
+
|
| 7227 |
+
return x, mask
|
| 7228 |
+
|
| 7229 |
+
#=================================================================================================================================
|
| 7230 |
+
# Embeddings similarity comparison functions
|
| 7231 |
+
#=================================================================================================================================
|
| 7232 |
+
|
| 7233 |
+
import torch
|
| 7234 |
+
import torch.nn.functional as F
|
| 7235 |
+
from tqdm import tqdm
|
| 7236 |
+
from typing import Optional, Union, Tuple
|
| 7237 |
+
|
| 7238 |
+
def topk_cosine_neighbors(embeddings: torch.Tensor,
|
| 7239 |
+
k: int = 10,
|
| 7240 |
+
key_embeddings: Optional[torch.Tensor] = None,
|
| 7241 |
+
row_batch: Optional[int] = None,
|
| 7242 |
+
col_batch: Optional[int] = None,
|
| 7243 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 7244 |
+
normalize: bool = True,
|
| 7245 |
+
dtype: Optional[torch.dtype] = None,
|
| 7246 |
+
show_progress: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 7247 |
+
|
| 7248 |
+
"""
|
| 7249 |
+
For each query embedding, find the indices and similarities of its top-k neighbors
|
| 7250 |
+
from a set of key embeddings, sorted by descending similarity.
|
| 7251 |
+
|
| 7252 |
+
Supports both self-similarity (single array, excludes self) and pairwise
|
| 7253 |
+
retrieval (two arrays, no exclusion).
|
| 7254 |
+
|
| 7255 |
+
Optimized for maximum speed and memory efficiency across CPU, CUDA, and MPS.
|
| 7256 |
+
Uses a streaming batched approach to handle datasets larger than GPU memory.
|
| 7257 |
+
|
| 7258 |
+
Args:
|
| 7259 |
+
embeddings (torch.Tensor): Query embeddings, shape (N_q, D).
|
| 7260 |
+
k (int): How many neighbors to return.
|
| 7261 |
+
key_embeddings (torch.Tensor, optional): Database/Key embeddings, shape (N_k, D).
|
| 7262 |
+
If None, defaults to 'embeddings' (self-search).
|
| 7263 |
+
row_batch (int, optional): Number of query rows to process at once. Auto-tuned if None.
|
| 7264 |
+
col_batch (int, optional): Number of key columns to process at once. Auto-tuned if None.
|
| 7265 |
+
device (str or torch.device, optional): Target device. If None, uses embeddings.device.
|
| 7266 |
+
normalize (bool): If True, L2-normalize embeddings. Skip if already normalized.
|
| 7267 |
+
dtype (torch.dtype, optional): Compute dtype (e.g., torch.float16, torch.bfloat16).
|
| 7268 |
+
If None, uses embeddings.dtype.
|
| 7269 |
+
show_progress (bool): Show tqdm progress bar.
|
| 7270 |
+
|
| 7271 |
+
Returns:
|
| 7272 |
+
top_idx (torch.Tensor): shape (N_q, k), int32 indices of nearest neighbors (indices into key_embeddings).
|
| 7273 |
+
top_sim (torch.Tensor): shape (N_q, k), float32 cosine similarities.
|
| 7274 |
+
"""
|
| 7275 |
+
|
| 7276 |
+
# 1. Determine Search Mode (Self vs. Pairwise)
|
| 7277 |
+
is_self_search = (key_embeddings is None)
|
| 7278 |
+
if is_self_search:
|
| 7279 |
+
key_embeddings = embeddings
|
| 7280 |
+
|
| 7281 |
+
# 2. Device & Dtype Setup
|
| 7282 |
+
if device is None:
|
| 7283 |
+
device = embeddings.device
|
| 7284 |
+
else:
|
| 7285 |
+
device = torch.device(device)
|
| 7286 |
+
|
| 7287 |
+
# Determine compute dtype
|
| 7288 |
+
if dtype is None:
|
| 7289 |
+
dtype = embeddings.dtype
|
| 7290 |
+
else:
|
| 7291 |
+
assert dtype.is_floating_point, "dtype must be a floating point type"
|
| 7292 |
+
|
| 7293 |
+
# Move and cast embeddings
|
| 7294 |
+
# Ensure contiguous for efficient matmul
|
| 7295 |
+
query_embeddings = embeddings.to(device=device, dtype=dtype).contiguous()
|
| 7296 |
+
key_embeddings = key_embeddings.to(device=device, dtype=dtype).contiguous()
|
| 7297 |
+
|
| 7298 |
+
N_q, D = query_embeddings.shape
|
| 7299 |
+
N_k, D_k = key_embeddings.shape
|
| 7300 |
+
|
| 7301 |
+
if D != D_k:
|
| 7302 |
+
raise ValueError(f"Query and Key embeddings must have same dimension. Got {D} and {D_k}")
|
| 7303 |
+
|
| 7304 |
+
# Validation
|
| 7305 |
+
if k < 1:
|
| 7306 |
+
raise ValueError(f"k must be >= 1; got {k}")
|
| 7307 |
+
|
| 7308 |
+
if is_self_search:
|
| 7309 |
+
if k >= N_q:
|
| 7310 |
+
raise ValueError(f"For self-search, k must be < N (to exclude self). Got N={N_q}, k={k}")
|
| 7311 |
+
else:
|
| 7312 |
+
if k > N_k:
|
| 7313 |
+
raise ValueError(f"For pairwise search, k must be <= N_k. Got N_k={N_k}, k={k}")
|
| 7314 |
+
|
| 7315 |
+
# 3. Auto-tune batch sizes based on device and memory
|
| 7316 |
+
# Heuristics adjusted for potentially different N_q and N_k
|
| 7317 |
+
if row_batch is None:
|
| 7318 |
+
if device.type == 'cuda':
|
| 7319 |
+
row_batch = 16384
|
| 7320 |
+
elif device.type == 'mps':
|
| 7321 |
+
row_batch = 8192
|
| 7322 |
+
else:
|
| 7323 |
+
row_batch = 4096 # CPU
|
| 7324 |
+
|
| 7325 |
+
if col_batch is None:
|
| 7326 |
+
if device.type == 'cuda':
|
| 7327 |
+
col_batch = 16384
|
| 7328 |
+
elif device.type == 'mps':
|
| 7329 |
+
col_batch = 8192
|
| 7330 |
+
else:
|
| 7331 |
+
col_batch = 4096 # CPU
|
| 7332 |
+
|
| 7333 |
+
# Clamp batch sizes to actual dimensions
|
| 7334 |
+
row_batch = min(row_batch, N_q)
|
| 7335 |
+
col_batch = min(col_batch, N_k)
|
| 7336 |
+
|
| 7337 |
+
# 4. Optional Normalization
|
| 7338 |
+
if normalize:
|
| 7339 |
+
# Normalize in-place if possible, or reassign
|
| 7340 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
| 7341 |
+
# Only normalize keys if they are distinct from queries to avoid redundant work
|
| 7342 |
+
# in self-search case (already normalized above)
|
| 7343 |
+
if not is_self_search:
|
| 7344 |
+
key_embeddings = F.normalize(key_embeddings, p=2, dim=1)
|
| 7345 |
+
|
| 7346 |
+
# 5. Initialize Result Tensors (always float32 for precision in output)
|
| 7347 |
+
top_sim = torch.empty((N_q, k), dtype=torch.float32, device=device)
|
| 7348 |
+
top_idx = torch.empty((N_q, k), dtype=torch.int, device=device)
|
| 7349 |
+
|
| 7350 |
+
# Pre-allocate reusable buffers for inner loop (memory efficiency)
|
| 7351 |
+
# Buffers for top-k merge (size 2k)
|
| 7352 |
+
merge_sim_buffer = torch.empty((row_batch, 2 * k), dtype=dtype, device=device)
|
| 7353 |
+
merge_idx_buffer = torch.empty((row_batch, 2 * k), dtype=torch.int, device=device)
|
| 7354 |
+
|
| 7355 |
+
# Buffer for column batch similarities
|
| 7356 |
+
sim_buffer = torch.empty((row_batch, col_batch), dtype=dtype, device=device)
|
| 7357 |
+
|
| 7358 |
+
# Value for masking (minimum possible float for the dtype)
|
| 7359 |
+
min_val = -torch.finfo(dtype).max
|
| 7360 |
+
|
| 7361 |
+
# 6. Inference Context
|
| 7362 |
+
with torch.no_grad():
|
| 7363 |
+
iterator = range(0, N_q, row_batch)
|
| 7364 |
+
if show_progress:
|
| 7365 |
+
desc = "Query Batches" if not is_self_search else "Row Batches"
|
| 7366 |
+
iterator = tqdm(iterator, desc=desc, leave=True)
|
| 7367 |
+
|
| 7368 |
+
for i in iterator:
|
| 7369 |
+
i_end = min(i + row_batch, N_q)
|
| 7370 |
+
rb = i_end - i
|
| 7371 |
+
|
| 7372 |
+
rows = query_embeddings[i:i_end] # (rb, D)
|
| 7373 |
+
|
| 7374 |
+
# Initialize current batch top-k
|
| 7375 |
+
# Use a tensor that persists across column batches for the current row batch
|
| 7376 |
+
curr_sim = torch.full((rb, k), min_val, dtype=dtype, device=device)
|
| 7377 |
+
curr_idx = torch.full((rb, k), -1, dtype=torch.int, device=device)
|
| 7378 |
+
|
| 7379 |
+
for j in range(0, N_k, col_batch):
|
| 7380 |
+
j_end = min(j + col_batch, N_k)
|
| 7381 |
+
cb = j_end - j
|
| 7382 |
+
|
| 7383 |
+
cols = key_embeddings[j:j_end] # (cb, D)
|
| 7384 |
+
|
| 7385 |
+
# Compute similarities in-place into buffer
|
| 7386 |
+
# sim_block shape: (rb, cb)
|
| 7387 |
+
sim_block = sim_buffer[:rb, :cb]
|
| 7388 |
+
torch.matmul(rows, cols.T, out=sim_block)
|
| 7389 |
+
|
| 7390 |
+
# Mask self-similarity ONLY if self-search
|
| 7391 |
+
if is_self_search:
|
| 7392 |
+
offset = i - j
|
| 7393 |
+
r_start = max(0, -offset)
|
| 7394 |
+
r_end = min(rb, cb - offset)
|
| 7395 |
+
|
| 7396 |
+
if r_start < r_end:
|
| 7397 |
+
# Vectorized masking of the diagonal
|
| 7398 |
+
r_range = torch.arange(r_start, r_end, dtype=torch.long, device=device)
|
| 7399 |
+
c_range = r_range + offset
|
| 7400 |
+
sim_block[r_range, c_range] = min_val
|
| 7401 |
+
|
| 7402 |
+
# Top-k in block
|
| 7403 |
+
if cb >= k:
|
| 7404 |
+
blk_s, blk_p = torch.topk(sim_block, k, dim=1, largest=True, sorted=True)
|
| 7405 |
+
blk_i = blk_p + j
|
| 7406 |
+
else:
|
| 7407 |
+
# Pad block to k if remaining keys are fewer than k
|
| 7408 |
+
pad_size = k - cb
|
| 7409 |
+
pad_vals = torch.full((rb, pad_size), min_val, dtype=dtype, device=device)
|
| 7410 |
+
sims_padded = torch.cat([sim_block, pad_vals], dim=1)
|
| 7411 |
+
blk_s, blk_p = torch.topk(sims_padded, k, dim=1, largest=True, sorted=True)
|
| 7412 |
+
blk_i = blk_p + j
|
| 7413 |
+
# Invalidate padded indices
|
| 7414 |
+
blk_i[blk_s == min_val] = -1
|
| 7415 |
+
|
| 7416 |
+
# Merge with current best
|
| 7417 |
+
# Layout: [curr_sim (k), blk_s (k)] -> topk(2k) -> keep k
|
| 7418 |
+
merge_sim_buffer[:rb, :k] = curr_sim
|
| 7419 |
+
merge_sim_buffer[:rb, k:2*k] = blk_s
|
| 7420 |
+
merge_idx_buffer[:rb, :k] = curr_idx
|
| 7421 |
+
merge_idx_buffer[:rb, k:2*k] = blk_i
|
| 7422 |
+
|
| 7423 |
+
curr_sim, top_p = torch.topk(merge_sim_buffer[:rb, :2*k], k, dim=1, largest=True, sorted=True)
|
| 7424 |
+
curr_idx = torch.gather(merge_idx_buffer[:rb, :2*k], dim=1, index=top_p)
|
| 7425 |
+
|
| 7426 |
+
# Write results (convert to float32 for consistency)
|
| 7427 |
+
top_sim[i:i_end] = curr_sim.to(torch.float32)
|
| 7428 |
+
top_idx[i:i_end] = curr_idx
|
| 7429 |
+
|
| 7430 |
+
# 7. Post-processing return format
|
| 7431 |
+
if k == 1:
|
| 7432 |
+
return top_idx.view(-1), top_sim.view(-1)
|
| 7433 |
+
|
| 7434 |
+
return top_idx, top_sim
|
| 7435 |
+
|
| 7436 |
+
#=================================================================================================================================
|
| 7437 |
+
# Embeddings visualization functions
|
| 7438 |
+
#=================================================================================================================================
|
| 7439 |
+
|
| 7440 |
+
try:
|
| 7441 |
+
import numpy as np
|
| 7442 |
+
import matplotlib.pyplot as plt
|
| 7443 |
+
from sklearn.metrics import pairwise_distances
|
| 7444 |
+
|
| 7445 |
+
except:
|
| 7446 |
+
pass
|
| 7447 |
+
|
| 7448 |
+
def plot_emb_cosine_similarity(embeddings,
|
| 7449 |
+
clip=2.0,
|
| 7450 |
+
gamma=0.55,
|
| 7451 |
+
cmap="inferno",
|
| 7452 |
+
figsize=(20, 20),
|
| 7453 |
+
dpi=300,
|
| 7454 |
+
output_fname='embeddings_similarity_plot.png',
|
| 7455 |
+
return_sims=False
|
| 7456 |
+
):
|
| 7457 |
+
|
| 7458 |
+
"""
|
| 7459 |
+
Produces a crisp, high-contrast cosine similarity heatmap.
|
| 7460 |
+
- clip: percentile clipping (1–5 recommended)
|
| 7461 |
+
- gamma: nonlinear contrast (0.4–0.8 recommended)
|
| 7462 |
+
|
| 7463 |
+
-----------
|
| 7464 |
+
Use Example
|
| 7465 |
+
-----------
|
| 7466 |
+
|
| 7467 |
+
tok_emb = model.net.token_emb.emb.weight.detach().cpu()
|
| 7468 |
+
|
| 7469 |
+
plot_cosine_similarity(tok_emb)
|
| 7470 |
+
"""
|
| 7471 |
+
|
| 7472 |
+
# 1. Compute cosine similarity (not distance)
|
| 7473 |
+
cos_dist = pairwise_distances(embeddings, metric="cosine")
|
| 7474 |
+
cos_sim = 1 - cos_dist
|
| 7475 |
+
|
| 7476 |
+
# 2. Gamma correction for contrast
|
| 7477 |
+
sim = np.sign(cos_sim) * (np.abs(cos_sim) ** gamma)
|
| 7478 |
+
|
| 7479 |
+
# 3. Percentile clipping to remove flat tails
|
| 7480 |
+
vmin, vmax = np.percentile(sim, [clip, 100 - clip])
|
| 7481 |
+
|
| 7482 |
+
# 4. Plot
|
| 7483 |
+
plt.figure(figsize=figsize, dpi=dpi)
|
| 7484 |
+
plt.imshow(sim, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
| 7485 |
+
plt.colorbar(fraction=0.046, pad=0.04)
|
| 7486 |
+
plt.title("Embeddings Pairwise Cosine Similarity")
|
| 7487 |
+
plt.xlabel("Embedding Index")
|
| 7488 |
+
plt.ylabel("Embeddings Index")
|
| 7489 |
+
plt.tight_layout()
|
| 7490 |
+
plt.savefig(output_fname)
|
| 7491 |
+
plt.show()
|
| 7492 |
+
|
| 7493 |
+
if return_sims:
|
| 7494 |
+
return sim
|
| 7495 |
+
|
| 7496 |
#=================================================================================================================================
|
| 7497 |
# This is the end of x_transformer_2_3_1 Python module
|
| 7498 |
#=================================================================================================================================
|