asigalov61 commited on
Commit
3302b00
·
verified ·
1 Parent(s): f7bfafe

Upload 2 files

Browse files
Files changed (2) hide show
  1. TMIDIX.py +386 -1
  2. x_transformer_2_3_1.py +1204 -2
TMIDIX.py CHANGED
@@ -51,7 +51,7 @@ r'''############################################################################
51
 
52
  ###################################################################################
53
 
54
- __version__ = "26.2.16"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
@@ -16094,6 +16094,391 @@ def squash_monophonic_escore_notes_pitches(escore_notes,
16094
 
16095
  ###################################################################################
16096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16097
  print('Module loaded!')
16098
  print('=' * 70)
16099
  print('Enjoy! :)')
 
51
 
52
  ###################################################################################
53
 
54
+ __version__ = "26.2.17"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
 
16094
 
16095
  ###################################################################################
16096
 
16097
+ def humanize_velocities_in_escore_notes(escore_notes):
16098
+
16099
+ if not escore_notes:
16100
+ return []
16101
+
16102
+ notes = [list(n) for n in escore_notes]
16103
+ notes.sort(key=lambda x: (x[1], x[4])) # Sort by time, then pitch
16104
+
16105
+ # -----------------------------------------------------------
16106
+ # 1. GLOBAL ANALYSIS (Grid, Phrases, Arcs)
16107
+ # -----------------------------------------------------------
16108
+
16109
+ onset_times = sorted(list(set([n[1] for n in notes])))
16110
+
16111
+ # --- Grid Estimation ---
16112
+ estimated_grid = 120
16113
+ if len(onset_times) > 1:
16114
+ intervals = [onset_times[i+1] - onset_times[i] for i in range(len(onset_times)-1)]
16115
+ active_intervals = [i for i in intervals if i > 0 and i < 960]
16116
+ if active_intervals:
16117
+ rounded_intervals = [round(i / 10) * 10 for i in active_intervals]
16118
+ valid_intervals = [r for r in rounded_intervals if r > 0]
16119
+ if valid_intervals:
16120
+ estimated_grid = max(set(valid_intervals), key=valid_intervals.count)
16121
+
16122
+ # --- Beat Grid Calculation ---
16123
+ if estimated_grid <= 120:
16124
+ beat_grid = estimated_grid * 4
16125
+ elif estimated_grid <= 240:
16126
+ beat_grid = estimated_grid * 2
16127
+ else:
16128
+ beat_grid = estimated_grid
16129
+
16130
+ # --- Phrase Detection ---
16131
+ melodic_onsets = sorted(list(set([n[1] for n in notes if n[3] != 9])))
16132
+ phrase_gap_threshold = beat_grid * 3
16133
+
16134
+ phrases = []
16135
+ current_phrase = []
16136
+
16137
+ if not melodic_onsets: melodic_onsets = onset_times
16138
+
16139
+ for i, t in enumerate(melodic_onsets):
16140
+ if not current_phrase:
16141
+ current_phrase.append(t)
16142
+ else:
16143
+ prev_t = melodic_onsets[i-1]
16144
+ if t - prev_t > phrase_gap_threshold:
16145
+ phrases.append(current_phrase)
16146
+ current_phrase = [t]
16147
+ else:
16148
+ current_phrase.append(t)
16149
+ if current_phrase:
16150
+ phrases.append(current_phrase)
16151
+
16152
+ # --- Global Arc Calculation ---
16153
+ total_duration = onset_times[-1] - onset_times[0] if onset_times else 0
16154
+ global_progress_map = {}
16155
+ for t in onset_times:
16156
+ progress = (t - onset_times[0]) / total_duration if total_duration > 0 else 0
16157
+ global_arc = math.cos((progress - 0.5) * math.pi) * 6 # Range -6 to +6
16158
+ global_progress_map[t] = global_arc
16159
+
16160
+ # -----------------------------------------------------------
16161
+ # 2. INSTRUMENT SEPARATION & PROCESSING
16162
+ # -----------------------------------------------------------
16163
+
16164
+ instrument_tracks = defaultdict(list)
16165
+ for n in notes:
16166
+ instrument_tracks[(n[3], n[6])].append(n)
16167
+
16168
+ # Main Loop
16169
+ for key, track_notes in instrument_tracks.items():
16170
+ channel, inst_num = key
16171
+
16172
+ # --- Instrument Classification ---
16173
+ is_drum = (channel == 9) or (inst_num == 128)
16174
+ is_bass = (32 <= inst_num <= 39)
16175
+ is_keys = (inst_num <= 7) or (16 <= inst_num <= 23)
16176
+ is_plucked = (24 <= inst_num <= 31)
16177
+ is_solo_string = (40 <= inst_num <= 47)
16178
+ is_ensemble = (48 <= inst_num <= 55)
16179
+ is_synth_lead = (80 <= inst_num <= 87)
16180
+ is_synth_pad = (88 <= inst_num <= 95)
16181
+
16182
+ track_notes.sort(key=lambda x: (x[1], x[4]))
16183
+
16184
+ # ==========================================================
16185
+ # LOGIC A: DRUMS (UNIFORM & SEPARATED)
16186
+ # ==========================================================
16187
+ if is_drum:
16188
+ for n in track_notes:
16189
+ t, pitch = n[1], n[4]
16190
+
16191
+ is_kick = pitch in [35, 36]
16192
+ is_snare = pitch in [38, 40]
16193
+ is_hat = pitch in [42, 44, 46]
16194
+ is_tom = 41 <= pitch <= 50 and not is_hat and not is_snare
16195
+
16196
+ if is_kick:
16197
+ target_vel = 115
16198
+ elif is_snare:
16199
+ target_vel = 112
16200
+ elif is_hat:
16201
+ target_vel = 90
16202
+ elif is_tom:
16203
+ target_vel = 105
16204
+ else:
16205
+ target_vel = 100
16206
+
16207
+ position_in_beat = t % beat_grid
16208
+ if position_in_beat == 0:
16209
+ metric_mod = 0
16210
+ elif position_in_beat == (beat_grid // 2):
16211
+ metric_mod = -2
16212
+ else:
16213
+ metric_mod = -6
16214
+
16215
+ closest_t = min(onset_times, key=lambda x: abs(x - t))
16216
+ global_mod = global_progress_map.get(closest_t, 0)
16217
+
16218
+ final_vel = target_vel + metric_mod + global_mod + random.gauss(0, 2.0)
16219
+ n[5] = max(30, min(127, int(final_vel)))
16220
+
16221
+ continue
16222
+
16223
+ # ==========================================================
16224
+ # LOGIC B: BASS
16225
+ # ==========================================================
16226
+ elif is_bass:
16227
+ last_vel = 100
16228
+ for n in track_notes:
16229
+ t = n[1]
16230
+
16231
+ position_in_grid = t % beat_grid
16232
+ metric_strength = 4 if position_in_grid == 0 else (1 if position_in_grid == (beat_grid // 2) else -2)
16233
+
16234
+ target = 100 + metric_strength
16235
+
16236
+ smoothed = (target * 0.3) + (last_vel * 0.7)
16237
+ last_vel = smoothed
16238
+
16239
+ final_vel = smoothed + random.gauss(0, 2)
16240
+ n[5] = max(30, min(115, int(final_vel)))
16241
+ continue
16242
+
16243
+ # ==========================================================
16244
+ # LOGIC C: MELODIC INSTRUMENTS
16245
+ # ==========================================================
16246
+ else:
16247
+ # --- Role Detection (Base Velocity Target) ---
16248
+ # Determine the 'floor' velocity for this track
16249
+
16250
+ base_vel_target = 90 # Default
16251
+
16252
+ if is_solo_string or is_synth_lead:
16253
+ base_vel_target = 105 # Lead
16254
+ elif is_ensemble or is_synth_pad:
16255
+ base_vel_target = 80 # Pad/Backing
16256
+ elif is_plucked:
16257
+ # Check density for Guitar Solo vs Rhythm
16258
+ unique_onsets = set(nn[1] for nn in track_notes)
16259
+ avg_poly = len(track_notes) / len(unique_onsets) if unique_onsets else 1
16260
+ base_vel_target = 105 if avg_poly < 1.3 else 95
16261
+ elif is_keys:
16262
+ # Check if it looks like a Solo Piano piece or just a Piano track
16263
+ # If the whole score is basically this track -> Solo
16264
+ if len(track_notes) > (len(notes) * 0.8):
16265
+ base_vel_target = 100
16266
+ else:
16267
+ base_vel_target = 90
16268
+
16269
+ # --- Dispatch to specific logic ---
16270
+
16271
+ # CASE 1: PIANO / KEYS (RESTORED ORIGINAL LOGIC)
16272
+ if is_keys:
16273
+ last_vel_lh = base_vel_target
16274
+ last_vel_rh = base_vel_target
16275
+ prev_melody_pitch = None
16276
+
16277
+ time_slices = defaultdict(list)
16278
+ for n in track_notes:
16279
+ time_slices[n[1]].append(n)
16280
+ sorted_times = sorted(time_slices.keys())
16281
+
16282
+ for t_idx, t in enumerate(sorted_times):
16283
+ slice_notes = time_slices[t]
16284
+ slice_notes.sort(key=lambda x: x[4])
16285
+
16286
+ # Context Mods
16287
+ phrase_arc_mod = 0
16288
+ for p_times in phrases:
16289
+ if t in p_times:
16290
+ p_len = p_times[-1] - p_times[0]
16291
+ if p_len > 0:
16292
+ progress = (t - p_times[0]) / p_len
16293
+ phrase_arc_mod = math.cos((1.0 - progress) * math.pi) * 0.5 + 0.5
16294
+ phrase_arc_mod = (phrase_arc_mod - 0.5) * 10
16295
+ break
16296
+
16297
+ global_arc = global_progress_map.get(t, 0)
16298
+
16299
+ position_in_grid = t % beat_grid
16300
+ if position_in_grid == 0: metric_strength = 6
16301
+ elif position_in_grid == (beat_grid // 2): metric_strength = 2
16302
+ else: metric_strength = -3
16303
+
16304
+ # Hand Splitting Logic
16305
+ pitches = [n[4] for n in slice_notes]
16306
+ split_idx = 0
16307
+ max_gap = 0
16308
+
16309
+ if len(pitches) > 1:
16310
+ for i in range(len(pitches)-1):
16311
+ gap = pitches[i+1] - pitches[i]
16312
+ if gap > max_gap:
16313
+ max_gap = gap
16314
+ split_idx = i + 1
16315
+
16316
+ if max_gap < 12:
16317
+ avg_pitch = sum(pitches) / len(pitches)
16318
+ split_idx = 0 if avg_pitch >= 60 else len(pitches)
16319
+
16320
+ lh_notes = slice_notes[:split_idx]
16321
+ rh_notes = slice_notes[split_idx:]
16322
+
16323
+ # LH Processing
16324
+ if lh_notes:
16325
+ target_lh = base_vel_target + metric_strength + global_arc + phrase_arc_mod - 3
16326
+ smoothed_lh = (target_lh * 0.25) + (last_vel_lh * 0.75)
16327
+ last_vel_lh = smoothed_lh
16328
+
16329
+ for i, n in enumerate(lh_notes):
16330
+ offset = 0 if i == 0 else -5
16331
+ if n[2] < estimated_grid / 2: offset += 2
16332
+ vel = smoothed_lh + offset + random.gauss(0, 2)
16333
+ n[5] = max(30, min(115, int(vel)))
16334
+
16335
+ # RH Processing
16336
+ if rh_notes:
16337
+ target_rh = base_vel_target + metric_strength + global_arc + phrase_arc_mod + 3
16338
+ smoothed_rh = (target_rh * 0.35) + (last_vel_rh * 0.65)
16339
+ last_vel_rh = smoothed_rh
16340
+
16341
+ current_top_pitch = rh_notes[-1][4]
16342
+ num_rh = len(rh_notes)
16343
+ is_dense = num_rh >= 3
16344
+
16345
+ for i, n in enumerate(rh_notes):
16346
+ pitch = n[4]
16347
+ offset = 0
16348
+
16349
+ if i == num_rh - 1: # Top note melody
16350
+ offset = 12 if is_dense else 5
16351
+ if prev_melody_pitch is not None:
16352
+ diff = pitch - prev_melody_pitch
16353
+ if diff > 4: offset += 6
16354
+ elif diff > 0: offset += 3
16355
+ elif diff < -4: offset -= 4
16356
+ if pitch > 72: offset += (pitch - 72) * 0.2
16357
+ prev_melody_pitch = pitch
16358
+ elif i == 0:
16359
+ offset = -1
16360
+ else:
16361
+ offset = -8
16362
+
16363
+ if n[2] < estimated_grid / 2: offset += 2
16364
+ vel = smoothed_rh + offset + random.gauss(0, 2)
16365
+ n[5] = max(35, min(120, int(vel)))
16366
+
16367
+ # CASE 2: OTHER MELODIC (Generic Role-Based Logic)
16368
+ else:
16369
+ last_vel = base_vel_target
16370
+ prev_melody_pitch = None
16371
+
16372
+ time_slices = defaultdict(list)
16373
+ for n in track_notes:
16374
+ time_slices[n[1]].append(n)
16375
+ sorted_times = sorted(time_slices.keys())
16376
+
16377
+ for t_idx, t in enumerate(sorted_times):
16378
+ slice_notes = time_slices[t]
16379
+ slice_notes.sort(key=lambda x: x[4])
16380
+
16381
+ # Context
16382
+ phrase_arc_mod = 0
16383
+ for p_times in phrases:
16384
+ if t in p_times:
16385
+ p_len = p_times[-1] - p_times[0]
16386
+ if p_len > 0:
16387
+ progress = (t - p_times[0]) / p_len
16388
+ phrase_arc_mod = math.sin(progress * math.pi) * 6
16389
+ break
16390
+
16391
+ global_arc = global_progress_map.get(t, 0)
16392
+
16393
+ position_in_grid = t % beat_grid
16394
+ if position_in_grid == 0: metric_strength = 6
16395
+ elif position_in_grid == (beat_grid // 2): metric_strength = 2
16396
+ else: metric_strength = -3
16397
+
16398
+ core_vel = base_vel_target + metric_strength + global_arc + phrase_arc_mod
16399
+ smoothed = (core_vel * 0.35) + (last_vel * 0.65)
16400
+ last_vel = smoothed
16401
+
16402
+ num_notes = len(slice_notes)
16403
+
16404
+ for i, n in enumerate(slice_notes):
16405
+ pitch = n[4]
16406
+ offset = 0
16407
+
16408
+ if i == num_notes - 1:
16409
+ offset = 10
16410
+ if prev_melody_pitch is not None:
16411
+ diff = pitch - prev_melody_pitch
16412
+ if diff > 4: offset += 5
16413
+ prev_melody_pitch = pitch
16414
+ elif i == 0:
16415
+ offset = 0
16416
+ else:
16417
+ offset = -6
16418
+
16419
+ final_vel = smoothed + offset + random.gauss(0, 3)
16420
+ n[5] = max(30, min(127, int(final_vel)))
16421
+
16422
+ # -----------------------------------------------------------
16423
+ # 3. FINAL EXPRESSIVE SCALING
16424
+ # -----------------------------------------------------------
16425
+ for n in notes:
16426
+ if n[3] == 9:
16427
+ continue
16428
+
16429
+ v = n[5]
16430
+ center = 95
16431
+
16432
+ deviation = v - center
16433
+ final_v = center + (deviation * 1.1)
16434
+ final_v += random.randint(-1, 1)
16435
+
16436
+ n[5] = max(20, min(127, int(final_v)))
16437
+
16438
+ return notes
16439
+
16440
+ ###################################################################################
16441
+
16442
+ def most_common_ordered_set(values, top_k):
16443
+
16444
+ freq = Counter(values)
16445
+
16446
+ top_vals = {v for v, _ in freq.most_common(top_k)}
16447
+
16448
+ result = []
16449
+ seen = set()
16450
+
16451
+ for v in values:
16452
+ if v in top_vals and v not in seen:
16453
+ result.append(v)
16454
+ seen.add(v)
16455
+
16456
+ return result
16457
+
16458
+ ###################################################################################
16459
+
16460
+ def escore_notes_velocities(escore_notes, chan_idx=3, vels_idx=5):
16461
+
16462
+ output_list = []
16463
+
16464
+ all_vels = [e[vels_idx] for e in escore_notes]
16465
+ avg_vel = sum(all_vels) / len(all_vels)
16466
+ vels_span = max(all_vels) - min(all_vels)
16467
+
16468
+ output_list.append([-1, min(all_vels), avg_vel, max(all_vels), vels_span])
16469
+
16470
+ chan_groups = groupby(sorted(escore_notes, key=lambda x: x[chan_idx]), key=lambda x: x[chan_idx])
16471
+
16472
+ for cha, group in chan_groups:
16473
+ all_vels = [e[vels_idx] for e in list(group)]
16474
+ avg_vel = sum(all_vels) / len(all_vels)
16475
+ vels_span = max(all_vels) - min(all_vels)
16476
+ output_list.append([cha, min(all_vels), avg_vel, max(all_vels), vels_span])
16477
+
16478
+ return output_list
16479
+
16480
+ ###################################################################################
16481
+
16482
  print('Module loaded!')
16483
  print('=' * 70)
16484
  print('Enjoy! :)')
x_transformer_2_3_1.py CHANGED
@@ -4,7 +4,7 @@
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
- # Version 7.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
@@ -5190,7 +5190,9 @@ def build_cls_model(num_tokens=18819,
5190
  squeeze_out_last_dim = squeeze_out_last_dim,
5191
  attn_layers=Encoder(dim=dim,
5192
  depth=depth,
5193
- heads=heads
 
 
5194
  )
5195
  )
5196
 
@@ -6291,6 +6293,1206 @@ def calculate_training_run_eta(
6291
 
6292
  return eta
6293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6294
  #=================================================================================================================================
6295
  # This is the end of x_transformer_2_3_1 Python module
6296
  #=================================================================================================================================
 
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
+ # Version 8.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
 
5190
  squeeze_out_last_dim = squeeze_out_last_dim,
5191
  attn_layers=Encoder(dim=dim,
5192
  depth=depth,
5193
+ heads=heads,
5194
+ attn_flash=True,
5195
+ rotary_pos_emb=True
5196
  )
5197
  )
5198
 
 
6293
 
6294
  return eta
6295
 
6296
+ #=================================================================================================================================
6297
+ # Autoregressive embeddings retrieval functions
6298
+ #=================================================================================================================================
6299
+
6300
+ import torch
6301
+ import torch.nn.functional as F
6302
+ from typing import Optional, Dict, List, Union, Set
6303
+ import numpy as np
6304
+ from tqdm import tqdm
6305
+ from contextlib import nullcontext
6306
+
6307
+ #===================================================================================================================
6308
+ # Advanced Embeddings Retrieval Function for Autoregressive X-Transformers
6309
+ #===================================================================================================================
6310
+
6311
+ def get_embeddings(
6312
+ model,
6313
+ inputs: torch.Tensor,
6314
+ pooling: str = 'mean',
6315
+ mask: Optional[torch.Tensor] = None,
6316
+ token_ids: Optional[List[int]] = None,
6317
+ token_weights: Optional[Dict[int, float]] = None,
6318
+ layer_index: int = -1,
6319
+ normalize: bool = False,
6320
+ device: Optional[torch.device] = None,
6321
+ dtype: torch.dtype = torch.bfloat16,
6322
+ pad_idx: int = 18819,
6323
+ use_amp: bool = True,
6324
+ verbose: bool = True,
6325
+ _max_concat_tokens: Optional[int] = None,
6326
+ ) -> np.ndarray:
6327
+
6328
+ """
6329
+ Get embeddings for a single batch of inputs.
6330
+
6331
+ Parameters
6332
+ ----------
6333
+ model : AutoregressiveWrapper
6334
+ Your trained transformer model
6335
+ inputs : torch.Tensor
6336
+ Input token sequences of shape (batch, seq_len)
6337
+ pooling : str
6338
+ Pooling strategy: 'mean' or 'concat'
6339
+ mask : Optional[torch.Tensor]
6340
+ Boolean mask, True for valid tokens. Auto-generated if None
6341
+ token_ids : Optional[List[int]]
6342
+ Token IDs to include. Works independently or with token_weights.
6343
+ token_weights : Optional[Dict[int, float]]
6344
+ Token ID to weight/priority mapping:
6345
+ - 'mean': weights for weighted average
6346
+ - 'concat': priority scores for selection when limiting count
6347
+ - If provided WITHOUT token_ids: keys become the filter
6348
+ - If provided WITH token_ids: only tokens in BOTH are used (intersection)
6349
+ layer_index : int
6350
+ Which layer's hidden states to use (-1 for last)
6351
+ normalize : bool
6352
+ L2-normalize output embeddings
6353
+ device : Optional[torch.device]
6354
+ Device for inference
6355
+ dtype : torch.dtype
6356
+ Dtype for autocast
6357
+ pad_idx : int
6358
+ Padding token index
6359
+ use_amp : bool
6360
+ Use automatic mixed precision
6361
+ verbose : bool
6362
+ Print warnings and info
6363
+ _max_concat_tokens : Optional[int]
6364
+ Internal: pre-computed max tokens for concat mode
6365
+
6366
+ Returns
6367
+ -------
6368
+ np.ndarray
6369
+ Embeddings array:
6370
+ - 'mean': (batch, dim)
6371
+ - 'concat': (batch, max_tokens * dim)
6372
+ """
6373
+
6374
+ model.eval()
6375
+
6376
+ if device is None:
6377
+ device = next(model.parameters()).device
6378
+
6379
+ inputs = inputs.to(device)
6380
+
6381
+ if inputs.ndim == 1:
6382
+ inputs = inputs.unsqueeze(0)
6383
+
6384
+ batch_size, seq_len = inputs.shape
6385
+
6386
+ if mask is None:
6387
+ mask = (inputs != pad_idx)
6388
+ else:
6389
+ mask = mask.to(device)
6390
+
6391
+ if mask.dtype != torch.bool:
6392
+ mask = mask.bool()
6393
+
6394
+ if hasattr(model, 'net'):
6395
+ net_model = model.net
6396
+ else:
6397
+ net_model = model
6398
+
6399
+ if use_amp and device.type == 'cuda':
6400
+ ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
6401
+ else:
6402
+ ctx = nullcontext()
6403
+
6404
+ try:
6405
+ with torch.no_grad():
6406
+ with ctx if use_amp else nullcontext():
6407
+ output = net_model(
6408
+ inputs,
6409
+ mask=mask if mask.ndim == 2 else mask.squeeze(),
6410
+ return_intermediates=True,
6411
+ )
6412
+
6413
+ if isinstance(output, tuple) and len(output) == 2:
6414
+ _, intermediates = output
6415
+ else:
6416
+ intermediates = None
6417
+
6418
+ hidden = _extract_hidden_states(intermediates, layer_index, verbose=verbose)
6419
+
6420
+ if hidden is None:
6421
+ raise ValueError("Could not extract hidden states")
6422
+
6423
+ except Exception as e:
6424
+ if verbose:
6425
+ print(f"Warning: Could not extract hidden states, using token embeddings. Error: {e}")
6426
+ hidden = _get_token_embeddings(net_model, inputs)
6427
+
6428
+ seq_mask = (inputs != pad_idx)
6429
+ seq_mask_expanded = seq_mask.unsqueeze(-1)
6430
+ hidden = hidden * seq_mask_expanded.float()
6431
+
6432
+ # Compute effective token IDs with INTUITIVE logic
6433
+ effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
6434
+
6435
+ if pooling == 'mean':
6436
+ emb = _mean_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, verbose=verbose)
6437
+ elif pooling == 'concat':
6438
+ emb = _concat_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights,
6439
+ max_tokens=_max_concat_tokens, verbose=verbose)
6440
+ else:
6441
+ raise ValueError(f"Unknown pooling strategy: {pooling}. Use 'mean' or 'concat'")
6442
+
6443
+ if normalize:
6444
+ emb = F.normalize(emb, p=2, dim=-1)
6445
+
6446
+ return emb.cpu().detach().numpy()
6447
+
6448
+
6449
+ #===================================================================================================================
6450
+ # Batched Processing Function
6451
+ #===================================================================================================================
6452
+
6453
+ def get_embeddings_batched(
6454
+ model,
6455
+ sequences: List[List[int]],
6456
+ pooling: str = 'mean',
6457
+ token_ids: Optional[List[int]] = None,
6458
+ token_weights: Optional[Dict[int, float]] = None,
6459
+ max_seq_len: int = 8192,
6460
+ pad_idx: int = 18819,
6461
+ batch_size: int = 8,
6462
+ use_amp: bool = True,
6463
+ dtype: torch.dtype = torch.bfloat16,
6464
+ verbose: bool = True,
6465
+ show_progress: bool = True,
6466
+ normalize: bool = False,
6467
+ ) -> np.ndarray:
6468
+
6469
+ """
6470
+ Process multiple sequences in TRUE batches for memory efficiency.
6471
+
6472
+ Parameters
6473
+ ----------
6474
+ model : AutoregressiveWrapper
6475
+ Your trained transformer model
6476
+ sequences : List[List[int]]
6477
+ List of token sequences (list of lists)
6478
+ pooling : str
6479
+ Pooling strategy: 'mean' or 'concat'
6480
+ token_ids : Optional[List[int]]
6481
+ Token IDs to include
6482
+ token_weights : Optional[Dict[int, float]]
6483
+ Token ID to weight/priority mapping
6484
+ max_seq_len : int
6485
+ Maximum sequence length
6486
+ pad_idx : int
6487
+ Padding token index
6488
+ batch_size : int
6489
+ Batch size for processing
6490
+ use_amp : bool
6491
+ Use automatic mixed precision
6492
+ dtype : torch.dtype
6493
+ Dtype for autocast
6494
+ verbose : bool
6495
+ Print messages
6496
+ show_progress : bool
6497
+ Show tqdm progress bar
6498
+ normalize : bool
6499
+ L2-normalize output embeddings
6500
+
6501
+ Returns
6502
+ -------
6503
+ np.ndarray
6504
+ Embeddings array with consistent dimensions
6505
+ """
6506
+
6507
+ model.eval()
6508
+
6509
+ num_sequences = len(sequences)
6510
+
6511
+ if verbose:
6512
+ print(f"Processing {num_sequences} sequences in batches of {batch_size}...")
6513
+
6514
+ # For concat mode: pre-scan to find max matching tokens across ALL sequences
6515
+ max_concat_tokens = None
6516
+ if pooling == 'concat':
6517
+ effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
6518
+ max_concat_tokens = _scan_max_matching_tokens(sequences, effective_token_ids, pad_idx, max_seq_len)
6519
+ if verbose and max_concat_tokens is not None:
6520
+ print(f"Auto-detected max matching tokens: {max_concat_tokens}")
6521
+ elif verbose and max_concat_tokens == 0:
6522
+ print("Warning: No sequences contain matching token IDs, using 1 token placeholder")
6523
+ max_concat_tokens = 1
6524
+
6525
+ all_embeddings = []
6526
+ num_batches = (num_sequences + batch_size - 1) // batch_size
6527
+
6528
+ batch_iterator = tqdm(range(num_batches), desc="Extracting embeddings", disable=not (show_progress and verbose)) if show_progress and verbose else range(num_batches)
6529
+
6530
+ for batch_idx in batch_iterator:
6531
+ start_idx = batch_idx * batch_size
6532
+ end_idx = min((batch_idx + 1) * batch_size, num_sequences)
6533
+
6534
+ batch_sequences = sequences[start_idx:end_idx]
6535
+ max_len_in_batch = min(max_seq_len, max(len(seq) for seq in batch_sequences))
6536
+
6537
+ padded_batch = []
6538
+ for seq in batch_sequences:
6539
+ if len(seq) > max_len_in_batch:
6540
+ seq = seq[:max_len_in_batch]
6541
+ else:
6542
+ seq = seq + [pad_idx] * (max_len_in_batch - len(seq))
6543
+ padded_batch.append(seq)
6544
+
6545
+ batch_inputs = torch.tensor(padded_batch, dtype=torch.long)
6546
+
6547
+ batch_embeddings = get_embeddings(
6548
+ model,
6549
+ batch_inputs,
6550
+ pooling=pooling,
6551
+ token_ids=token_ids,
6552
+ token_weights=token_weights,
6553
+ pad_idx=pad_idx,
6554
+ use_amp=use_amp,
6555
+ dtype=dtype,
6556
+ verbose=verbose and batch_idx == 0,
6557
+ normalize=normalize,
6558
+ _max_concat_tokens=max_concat_tokens,
6559
+ )
6560
+
6561
+ all_embeddings.append(batch_embeddings)
6562
+
6563
+ final_embeddings = np.concatenate(all_embeddings, axis=0)
6564
+
6565
+ if verbose:
6566
+ print(f"Final embeddings shape: {final_embeddings.shape}")
6567
+
6568
+ return final_embeddings
6569
+
6570
+ #===================================================================================================================
6571
+ # Helper Functions
6572
+ #===================================================================================================================
6573
+
6574
+ def _compute_effective_token_ids(token_ids: Optional[List[int]], token_weights: Optional[Dict[int, float]]) -> Optional[Set[int]]:
6575
+ """
6576
+ Compute effective token IDs with INTUITIVE logic:
6577
+
6578
+ - token_ids=None, token_weights=None → None (all valid tokens)
6579
+ - token_ids=[...], token_weights=None → token_ids
6580
+ - token_ids=None, token_weights={...} → keys from token_weights
6581
+ - token_ids=[...], token_weights={...} → INTERSECTION (only tokens in BOTH)
6582
+
6583
+ This ensures token_weights acts as a filter when provided, not just weights.
6584
+ """
6585
+ if token_ids is None and token_weights is None:
6586
+ return None
6587
+
6588
+ token_ids_set = set(token_ids) if token_ids is not None else None
6589
+ weights_keys_set = set(token_weights.keys()) if token_weights is not None else None
6590
+
6591
+ if token_ids_set is None and weights_keys_set is not None:
6592
+ # Only token_weights provided: use its keys as filter
6593
+ return weights_keys_set
6594
+ elif token_ids_set is not None and weights_keys_set is None:
6595
+ # Only token_ids provided: use token_ids as filter
6596
+ return token_ids_set
6597
+ elif token_ids_set is not None and weights_keys_set is not None:
6598
+ # Both provided: INTERSECTION (only tokens in BOTH lists)
6599
+ # This is the key fix for intuitive behavior
6600
+ intersection = token_ids_set & weights_keys_set
6601
+ if len(intersection) == 0:
6602
+ # Warn but fall back to token_ids (more permissive)
6603
+ print(f"Warning: token_ids and token_weights have no overlap. Using token_ids only.")
6604
+ return token_ids_set
6605
+ return intersection
6606
+ else:
6607
+ return None
6608
+
6609
+ def _scan_max_matching_tokens(sequences: List[List[int]],
6610
+ token_ids: Optional[Set[int]],
6611
+ pad_idx: int,
6612
+ max_seq_len: int) -> int:
6613
+ """
6614
+ Scan all sequences to find maximum number of tokens matching token_ids.
6615
+ """
6616
+ if token_ids is None:
6617
+ return max(min(len(seq), max_seq_len) for seq in sequences) if sequences else 0
6618
+
6619
+ max_count = 0
6620
+ for seq in sequences:
6621
+ truncated = seq[:max_seq_len]
6622
+ count = sum(1 for tok in truncated if tok in token_ids and tok != pad_idx)
6623
+ max_count = max(max_count, count)
6624
+
6625
+ return max_count
6626
+
6627
+ def _extract_hidden_states(intermediates, layer_index: int = -1, verbose: bool = True):
6628
+ """Extract hidden states from LayerIntermediates object."""
6629
+ if intermediates is None:
6630
+ if verbose:
6631
+ print("Warning: intermediates is None")
6632
+ return None
6633
+
6634
+ if hasattr(intermediates, 'layer_hiddens') and intermediates.layer_hiddens is not None:
6635
+ if len(intermediates.layer_hiddens) > 0:
6636
+ return intermediates.layer_hiddens[layer_index]
6637
+
6638
+ if hasattr(intermediates, 'hiddens') and intermediates.hiddens is not None:
6639
+ if len(intermediates.hiddens) > 0:
6640
+ return intermediates.hiddens[layer_index]
6641
+
6642
+ if hasattr(intermediates, 'attn_intermediates') and intermediates.attn_intermediates is not None:
6643
+ if len(intermediates.attn_intermediates) > 0:
6644
+ attn_int = intermediates.attn_intermediates[layer_index]
6645
+ if hasattr(attn_int, 'values') and attn_int.values is not None:
6646
+ return attn_int.values
6647
+
6648
+ if verbose:
6649
+ print("Warning: Could not find layer_hiddens in intermediates")
6650
+
6651
+ return None
6652
+
6653
+
6654
+ def _get_token_embeddings(net_model, inputs: torch.Tensor):
6655
+ """Get token embeddings directly from embedding layer."""
6656
+ if hasattr(net_model, 'token_emb'):
6657
+ if hasattr(net_model.token_emb, 'emb'):
6658
+ return net_model.token_emb.emb(inputs)
6659
+ else:
6660
+ return net_model.token_emb(inputs)
6661
+ elif hasattr(net_model, 'emb'):
6662
+ return net_model.emb(inputs)
6663
+ else:
6664
+ raise ValueError("Could not find embedding layer in model")
6665
+
6666
+ def _mean_pooling(
6667
+ hidden: torch.Tensor,
6668
+ inputs: torch.Tensor,
6669
+ mask: torch.Tensor,
6670
+ token_ids: Optional[Set[int]],
6671
+ token_weights: Optional[Dict[int, float]],
6672
+ verbose: bool = True
6673
+ ) -> torch.Tensor:
6674
+ """
6675
+ Mean pooling with token ID filtering and weighted averaging.
6676
+ """
6677
+ batch_size, seq_len, dim = hidden.shape
6678
+ device = hidden.device
6679
+
6680
+ if mask.ndim > 2:
6681
+ mask = mask.squeeze()
6682
+
6683
+ effective_mask = mask.clone()
6684
+
6685
+ if token_ids is not None:
6686
+ token_mask = torch.zeros_like(mask, dtype=torch.bool, device=device)
6687
+ for tid in token_ids:
6688
+ token_mask = token_mask | (inputs == tid)
6689
+ effective_mask = effective_mask & token_mask
6690
+
6691
+ if verbose and effective_mask.sum() == 0:
6692
+ print(f"Warning: No tokens match filter, falling back to all valid tokens")
6693
+ effective_mask = mask
6694
+
6695
+ if token_weights is not None:
6696
+ weights = torch.zeros_like(effective_mask, dtype=torch.float32, device=device)
6697
+
6698
+ for token_id, weight in token_weights.items():
6699
+ id_mask = (inputs == token_id) & effective_mask
6700
+ weights = weights.masked_fill(id_mask, float(weight))
6701
+
6702
+ weights = weights.masked_fill(effective_mask & (weights == 0), 1.0)
6703
+
6704
+ weighted_hidden = hidden * weights.unsqueeze(-1)
6705
+ sum_weighted = weighted_hidden.sum(dim=1)
6706
+ sum_weights = weights.sum(dim=1, keepdim=True).clamp(min=1e-9)
6707
+ return sum_weighted / sum_weights
6708
+ else:
6709
+ masked_hidden = hidden * effective_mask.unsqueeze(-1).float()
6710
+ sum_hidden = masked_hidden.sum(dim=1)
6711
+ count = effective_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
6712
+ return sum_hidden / count
6713
+
6714
+ def _concat_pooling(
6715
+ hidden: torch.Tensor,
6716
+ inputs: torch.Tensor,
6717
+ mask: torch.Tensor,
6718
+ token_ids: Optional[Set[int]],
6719
+ token_weights: Optional[Dict[int, float]],
6720
+ max_tokens: Optional[int],
6721
+ verbose: bool = True
6722
+ ) -> torch.Tensor:
6723
+ """
6724
+ Concat pooling with token ID filtering and weight-based priority selection.
6725
+ """
6726
+ batch_size, seq_len, dim = hidden.shape
6727
+ device = hidden.device
6728
+
6729
+ if max_tokens is None:
6730
+ max_tokens = 1
6731
+
6732
+ output_dim = max_tokens * dim
6733
+
6734
+ all_token_embs = []
6735
+
6736
+ for i in range(batch_size):
6737
+ seq_mask = mask[i]
6738
+ seq_inputs = inputs[i]
6739
+
6740
+ if token_ids is not None:
6741
+ matching_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
6742
+ for tid in token_ids:
6743
+ matching_mask = matching_mask | ((seq_inputs == tid) & seq_mask)
6744
+ valid_indices = matching_mask.nonzero(as_tuple=True)[0]
6745
+ else:
6746
+ valid_indices = seq_mask.nonzero(as_tuple=True)[0]
6747
+
6748
+ if len(valid_indices) == 0:
6749
+ emb = torch.zeros(dim, device=device)
6750
+ emb = F.pad(emb, (0, output_dim - dim))
6751
+ all_token_embs.append(emb)
6752
+ continue
6753
+
6754
+ matching_embs = hidden[i, valid_indices, :]
6755
+
6756
+ if token_weights is not None and len(valid_indices) > max_tokens:
6757
+ weights_list = []
6758
+ for idx in valid_indices:
6759
+ tok_id = seq_inputs[idx].item()
6760
+ weights_list.append(token_weights.get(tok_id, 1.0))
6761
+
6762
+ sorted_pairs = sorted(zip(range(len(valid_indices)), weights_list),
6763
+ key=lambda x: x[1], reverse=True)
6764
+ top_indices = [valid_indices[p[0]] for p in sorted_pairs[:max_tokens]]
6765
+ matching_embs = hidden[i, torch.tensor(top_indices, device=device), :]
6766
+ elif len(valid_indices) > max_tokens:
6767
+ matching_embs = matching_embs[:max_tokens]
6768
+
6769
+ if len(valid_indices) < max_tokens:
6770
+ padding_needed = max_tokens - len(valid_indices)
6771
+ padding = torch.zeros(padding_needed, dim, device=device)
6772
+ matching_embs = torch.cat([matching_embs, padding], dim=0)
6773
+
6774
+ emb = matching_embs.reshape(-1)
6775
+ all_token_embs.append(emb)
6776
+
6777
+ return torch.stack(all_token_embs, dim=0)
6778
+
6779
+ #=================================================================================================================================
6780
+ # Non-Autoregressive Encoder Embeddings Retrieval Functions
6781
+ #=================================================================================================================================
6782
+
6783
+ def get_enc_embeddings(
6784
+ model,
6785
+ sequences: List[List[int]],
6786
+ seq_len: Optional[int] = 3072,
6787
+ seq_pad_idx: int = 385,
6788
+ batch_size: int = 64,
6789
+ save_every_num_batches: int = -1,
6790
+ save_file_path: str = "saved_embeddings.npy",
6791
+ device: Optional[torch.device] = None,
6792
+ normalize: bool = False,
6793
+ pooling: str = "auto", # "auto" | "mean" | "weighted_mean"
6794
+ token_type_weights: Optional[Tuple[float, float, float]] = None, # (onset_w, duration_w, pitch_w)
6795
+ use_bfloat16: bool = True, # enable bfloat16 autocast when possible
6796
+ return_dtype: str = "float32", # "float32" or "float16" for returned embeddings
6797
+ return_numpy: bool = False,
6798
+ verbose: bool = True,
6799
+ show_progress_bar: bool = True
6800
+ ) -> Union[Tensor, np.ndarray]:
6801
+
6802
+ """
6803
+ Compute embeddings for a list of token sequences using a PyTorch model with optional bfloat16/autocast,
6804
+ pooling, normalization, and periodic saving.
6805
+
6806
+ This function batches input token id sequences, pads/truncates them to a fixed length, runs the model
6807
+ in evaluation mode under `torch.no_grad()` and optional mixed-precision autocast, and returns a single
6808
+ tensor (or NumPy array) containing per-sequence embeddings. The model is expected to accept a LongTensor
6809
+ of token ids `x` and a boolean mask `mask` and to return either:
6810
+ - a 2-D tensor `(B, D)` of already-pooled embeddings, or
6811
+ - a 3-D tensor `(B, L, D)` of per-token embeddings (which will be pooled according to `pooling`).
6812
+
6813
+ Key behaviors:
6814
+ - Sequences are padded with `seq_pad_idx` and masked so padding does not affect pooling.
6815
+ - If `seq_len` is provided, sequences longer than `seq_len` are truncated; otherwise the batch max length is used.
6816
+ - Mixed-precision autocast is used when `use_bfloat16` is True and supported by the device; the function
6817
+ falls back to the default autocast or no autocast if unavailable.
6818
+ - Supports three pooling modes for per-token embeddings:
6819
+ - `"auto"` or `"mean"`: simple masked mean pooling across tokens.
6820
+ - `"weighted_mean"`: weighted mean pooling by token type (onset/duration/pitch) inferred from token ids;
6821
+ weights are provided via `token_type_weights` and padding tokens are ignored.
6822
+ - Optionally L2-normalizes embeddings (in float32) when `normalize=True`.
6823
+ - Returned embeddings can be cast to `float16` for storage/transfer via `return_dtype`.
6824
+ - Embeddings are collected on CPU; intermediate results can be periodically saved to `save_file_path`.
6825
+ - If `return_numpy=True`, a NumPy array is returned; otherwise a CPU `torch.Tensor` is returned.
6826
+
6827
+ Args:
6828
+ model (torch.nn.Module):
6829
+ PyTorch model used to compute embeddings. The model will be moved to `device` (or its current
6830
+ parameter device if `device` is None) and set to `eval()` for inference. The forward call must
6831
+ accept `x` (LongTensor) and `mask` (BoolTensor) and return embeddings when called with
6832
+ `return_embeddings=True`.
6833
+ sequences (List[List[int]]):
6834
+ Batch of token id sequences (each sequence is a list of ints). Can be empty; an empty result
6835
+ with shape `(0, 0)` will be returned in that case.
6836
+ seq_len (Optional[int], default=3072):
6837
+ Target sequence length for truncation/padding. If None, the maximum sequence length in the
6838
+ current batch is used.
6839
+ seq_pad_idx (int, default=385):
6840
+ Token id used for padding positions.
6841
+ batch_size (int, default=64):
6842
+ Number of sequences processed per forward pass.
6843
+ save_every_num_batches (int, default=-1):
6844
+ If > 0, the function will save accumulated embeddings to `save_file_path` every
6845
+ `save_every_num_batches` batches. A non-positive value disables periodic saving.
6846
+ save_file_path (str, default="saved_embeddings.npy"):
6847
+ File path used by `np.save` when periodic saving is enabled.
6848
+ device (Optional[torch.device], default=None):
6849
+ Device to run the model and tensors on. If None, the device of the model parameters is used.
6850
+ normalize (bool, default=False):
6851
+ If True, L2-normalize each embedding vector (done in float32 for numerical stability).
6852
+ pooling (str, default="auto"):
6853
+ Pooling strategy applied when model returns per-token embeddings:
6854
+ - "auto" or "mean": masked mean pooling.
6855
+ - "weighted_mean": weighted mean pooling by token type using `token_type_weights`.
6856
+ Any other value raises `ValueError`.
6857
+ token_type_weights (Optional[Tuple[float, float, float]], default=None):
6858
+ Per-token-type weights `(onset_w, duration_w, pitch_w)` used when `pooling="weighted_mean"`.
6859
+ If None, defaults to `(1.0, 1.0, 1.0)`. Token type ranges are inferred as:
6860
+ onset: token_id in [0, 127]
6861
+ duration:token_id in [128, 255]
6862
+ pitch: token_id in [256, 383]
6863
+ use_bfloat16 (bool, default=True):
6864
+ If True, attempts to use `torch.bfloat16` autocast for the device; falls back gracefully if not supported.
6865
+ return_dtype (str, default="float32"):
6866
+ Data type for returned embeddings: `"float32"` or `"float16"`. Internally embeddings are normalized
6867
+ in float32; casting to float16 happens just before collecting results if requested.
6868
+ return_numpy (bool, default=False):
6869
+ If True, the final result is returned as a NumPy array; otherwise a CPU `torch.Tensor` is returned.
6870
+ verbose (bool, default=True):
6871
+ If True, prints progress and short diagnostic messages via `tqdm`.
6872
+ show_progress_bar (bool, default=True)
6873
+ If True, displays tqdm progress bar.
6874
+
6875
+ Returns:
6876
+ Union[torch.Tensor, numpy.ndarray]:
6877
+ - If `return_numpy` is False: a CPU `torch.Tensor` of shape `(N, D)` and dtype `torch.float32`
6878
+ or `torch.float16` depending on `return_dtype`.
6879
+ - If `return_numpy` is True: a NumPy array of shape `(N, D)` and dtype `np.float32` or `np.float16`.
6880
+ `N` is the total number of input sequences and `D` is the embedding dimensionality produced by the model.
6881
+
6882
+ Raises:
6883
+ AssertionError:
6884
+ If `return_dtype` is not one of `"float32"` or `"float16"`.
6885
+ RuntimeError:
6886
+ If the model returns `None` for embeddings (indicates incorrect forward flags or model behavior).
6887
+ ValueError:
6888
+ If the model returns an embedding tensor with unexpected dimensionality or if `pooling` is unsupported.
6889
+
6890
+ Notes:
6891
+ - The function uses `pad_and_mask` to produce `x` (LongTensor) and `mask` (BoolTensor). Padding tokens
6892
+ are ignored by pooling operations.
6893
+ - When `pooling="weighted_mean"`, if `token_ids` are not available or the model returns a 2-D tensor,
6894
+ the function falls back to masked mean pooling.
6895
+ - Periodic saving concatenates all embeddings collected so far and writes them with `np.save`. Save
6896
+ failures are caught and reported when `verbose=True` but do not abort processing.
6897
+ - The function runs the model under `torch.no_grad()` and sets `model.eval()`; it will move the model
6898
+ to `device` if provided.
6899
+ - For reproducible numeric behavior across devices, ensure the model and device support the requested
6900
+ autocast dtype (bfloat16) and that any randomness is controlled externally.
6901
+
6902
+ Example:
6903
+ >>> # simple usage
6904
+ >>> embs = get_embeddings_bf16(model, sequences, seq_len=1024, batch_size=32, pooling="mean",
6905
+ ... normalize=True, return_dtype="float32", return_numpy=False)
6906
+ """
6907
+
6908
+ assert return_dtype in ("float32", "float16"), "return_dtype must be 'float32' or 'float16'"
6909
+
6910
+ model_device = next(model.parameters()).device if device is None else device
6911
+ model.to(model_device)
6912
+ model.eval()
6913
+
6914
+ all_embs: List[Tensor] = []
6915
+ total_batches = math.ceil(len(sequences) / batch_size) if batch_size > 0 else 0
6916
+
6917
+ if verbose:
6918
+ tqdm.write(
6919
+ f"[get_embeddings_bf16] sequences={len(sequences)}, batch_size={batch_size}, "
6920
+ f"batches={total_batches}, device={model_device}, seq_len={seq_len}, pooling={pooling}"
6921
+ )
6922
+
6923
+ # Prepare autocast context using torch.amp.autocast
6924
+ autocast_ctx = None
6925
+ if use_bfloat16:
6926
+ try:
6927
+ autocast_ctx = torch.amp.autocast(device_type=model_device.type, dtype=torch.bfloat16)
6928
+ except Exception:
6929
+ try:
6930
+ autocast_ctx = torch.amp.autocast(device_type=model_device.type)
6931
+ except Exception:
6932
+ autocast_ctx = None
6933
+ else:
6934
+ try:
6935
+ autocast_ctx = torch.amp.autocast(device_type=model_device.type)
6936
+ except Exception:
6937
+ autocast_ctx = None
6938
+
6939
+ with torch.inference_mode():
6940
+ batch_iter = range(0, len(sequences), batch_size)
6941
+ pbar = tqdm(batch_iter, disable=not show_progress_bar, total=total_batches, desc="Embedding batches")
6942
+ for batch_idx, i in enumerate(pbar):
6943
+ batch_seqs = sequences[i : i + batch_size]
6944
+ x, mask = pad_and_mask(batch_seqs, pad_idx=seq_pad_idx, seq_len=seq_len, device=model_device, verbose=verbose)
6945
+ # x: (B, L) LongTensor token ids, mask: (B, L) boolean
6946
+
6947
+ # Run forward under autocast if available
6948
+ if autocast_ctx is not None:
6949
+ with autocast_ctx:
6950
+ out = model(x, return_embeddings=True, mask=mask)
6951
+ else:
6952
+ out = model(x, return_embeddings=True, mask=mask)
6953
+
6954
+ if out is None:
6955
+ raise RuntimeError("model returned None for embeddings. Check forward flags.")
6956
+
6957
+ # Handle shapes
6958
+ if out.dim() == 2:
6959
+ # already pooled: (B, D)
6960
+ emb = out
6961
+ elif out.dim() == 3:
6962
+ # per-token embeddings: (B, L, D)
6963
+ if pooling in ("mean", "auto"):
6964
+ emb = masked_mean_pool(out, mask, dim=1, verbose=verbose)
6965
+ elif pooling == "weighted_mean":
6966
+ # Use token ids to compute per-token weights; fallback to mean if token ids missing
6967
+ emb = masked_weighted_mean_pool(out, mask, token_ids=x, token_type_weights=token_type_weights, dim=1, verbose=verbose)
6968
+ else:
6969
+ raise ValueError(f"unsupported pooling: {pooling}")
6970
+ else:
6971
+ raise ValueError(f"unexpected embedding tensor shape: {out.shape}")
6972
+
6973
+ # Ensure embeddings are float32 for stable normalization/indexing
6974
+ if emb.dtype != torch.float32:
6975
+ emb = emb.float()
6976
+
6977
+ # L2 normalize in float32
6978
+ if normalize:
6979
+ emb = F.normalize(emb, p=2, dim=-1)
6980
+
6981
+ # Optionally cast to float16 for return/storage
6982
+ if return_dtype == "float16":
6983
+ emb = emb.half()
6984
+
6985
+ all_embs.append(emb.cpu())
6986
+
6987
+ # Update progress bar postfix with shapes and dtype
6988
+ if verbose:
6989
+ pbar.set_postfix({"batch": batch_idx + 1, "emb_shape": f"{emb.shape}", "dtype": str(emb.dtype)})
6990
+
6991
+ # Save intermediate results periodically
6992
+ if save_every_num_batches > 0:
6993
+ # compute 0-based batch number
6994
+ bnum = batch_idx
6995
+ if (bnum + 1) % save_every_num_batches == 0:
6996
+ try:
6997
+ concatenated = torch.cat(all_embs, dim=0).numpy()
6998
+ np.save(save_file_path, concatenated)
6999
+ if verbose:
7000
+ tqdm.write(f"[get_embeddings_bf16] saved {concatenated.shape[0]} embeddings to {save_file_path}")
7001
+ except Exception as e:
7002
+ # Do not crash the whole run for a save failure; report if verbose
7003
+ if verbose:
7004
+ tqdm.write(f"[get_embeddings_bf16] warning: failed to save embeddings: {e}")
7005
+
7006
+ if len(all_embs) == 0:
7007
+ # return empty tensor/array with shape (0, 0)
7008
+ empty = torch.empty((0, 0), dtype=(torch.float16 if return_dtype == "float16" else torch.float32))
7009
+ if verbose:
7010
+ tqdm.write("[get_embeddings_bf16] no embeddings were produced; returning empty tensor")
7011
+ return empty.numpy() if return_numpy else empty
7012
+
7013
+ result = torch.cat(all_embs, dim=0) # (N, D) on CPU
7014
+
7015
+ if verbose:
7016
+ tqdm.write(f"[get_embeddings_bf16] finished: total_embeddings={result.shape[0]}, dim={result.shape[1]}, dtype={result.dtype}")
7017
+
7018
+ if return_numpy:
7019
+ return result.numpy()
7020
+
7021
+ return result
7022
+
7023
+ ###################################################################################
7024
+
7025
+ def masked_mean_pool(
7026
+ token_embeddings: Tensor,
7027
+ mask: Tensor,
7028
+ dim: int = 1,
7029
+ eps: float = 1e-9,
7030
+ verbose: bool = True,
7031
+ ) -> Tensor:
7032
+
7033
+ """
7034
+ Compute a masked mean pooling over a specified dimension.
7035
+
7036
+ This function computes the mean of `token_embeddings` along `dim`, ignoring
7037
+ positions where `mask` is False. The mask is cast to the same dtype as the
7038
+ embeddings to allow safe multiplication. A small epsilon is used to avoid
7039
+ division by zero for sequences that are entirely masked out.
7040
+
7041
+ Args:
7042
+ token_embeddings: Tensor of shape (B, L, D) or similar where `dim` indexes
7043
+ the sequence length. Embeddings dtype can be float16/float32/bfloat16.
7044
+ mask: Boolean tensor of shape broadcastable to the sequence dimension
7045
+ (e.g., (B, L)). True indicates valid tokens; False indicates padding.
7046
+ dim: Dimension along which to pool (default: 1, the sequence length).
7047
+ eps: Small value to avoid division by zero when a row has zero valid tokens.
7048
+ verbose: If True, prints a short summary about the pooling operation.
7049
+
7050
+ Returns:
7051
+ Tensor of pooled embeddings with the sequence dimension removed, typically
7052
+ shape (B, D). The returned dtype matches `token_embeddings.dtype`.
7053
+ """
7054
+
7055
+ mask_f = mask.to(token_embeddings.dtype) # (B, L)
7056
+ summed = (token_embeddings * mask_f.unsqueeze(-1)).sum(dim=dim) # (B, D)
7057
+ counts = mask_f.sum(dim=dim).clamp_min(eps).unsqueeze(-1) # (B, 1)
7058
+ pooled = summed / counts # (B, D)
7059
+
7060
+ if verbose:
7061
+ # Use tqdm.write so it doesn't interfere with progress bars
7062
+ valid_counts = counts.squeeze(-1)
7063
+ tqdm.write(
7064
+ f"[masked_mean_pool] pooled shape={pooled.shape}, "
7065
+ f"counts min={valid_counts.min().item():.3f}, max={valid_counts.max().item():.3f}"
7066
+ )
7067
+
7068
+ return pooled
7069
+
7070
+ ###################################################################################
7071
+
7072
+ def masked_weighted_mean_pool(
7073
+ token_embs: Tensor,
7074
+ valid_mask: Tensor,
7075
+ token_ids: Optional[Tensor] = None,
7076
+ token_type_weights: Optional[Tuple[float, float, float]] = None,
7077
+ dim: int = 1,
7078
+ verbose: bool = False,
7079
+ ) -> Tensor:
7080
+
7081
+ """
7082
+ Weighted mean pooling across tokens. If token_ids is provided, token types are
7083
+ inferred using the same ranges as the reference code:
7084
+ - onset: token_id in [0, 127]
7085
+ - duration:token_id in [128, 255]
7086
+ - pitch: token_id in [256, 383]
7087
+ token_type_weights: (onset_w, duration_w, pitch_w). If None, defaults to (1.0,1.0,1.0)
7088
+ The function multiplies each token embedding by its scalar weight and divides
7089
+ by the sum of weights for valid tokens per sequence.
7090
+ """
7091
+
7092
+ B, L, D = token_embs.shape
7093
+ device = token_embs.device
7094
+ dtype = token_embs.dtype
7095
+
7096
+ if token_ids is None:
7097
+ # No token-level ids available: fallback to simple masked mean
7098
+ if verbose:
7099
+ tqdm.write("[masked_weighted_mean_pool] token_ids is None, falling back to masked_mean_pool")
7100
+ return masked_mean_pool(token_embs, valid_mask, dim=dim, verbose=verbose)
7101
+
7102
+ # Default weights
7103
+ if token_type_weights is None:
7104
+ onset_w, duration_w, pitch_w = 1.0, 1.0, 1.0
7105
+ else:
7106
+ onset_w, duration_w, pitch_w = token_type_weights
7107
+
7108
+ # Build per-type boolean masks based on token id values (same ranges as reference)
7109
+ onset_mask = (token_ids >= 0) & (token_ids < 128)
7110
+ duration_mask = (token_ids >= 128) & (token_ids < 256)
7111
+ pitch_mask = (token_ids >= 256) & (token_ids < 384)
7112
+
7113
+ # Combine with valid_mask to ignore padding positions
7114
+ onset_mask = onset_mask & valid_mask
7115
+ duration_mask = duration_mask & valid_mask
7116
+ pitch_mask = pitch_mask & valid_mask
7117
+
7118
+ # Build per-token scalar weight tensor (B, L)
7119
+ w = torch.ones((B, L), device=device, dtype=dtype)
7120
+ if onset_w != 1.0:
7121
+ w = torch.where(onset_mask, torch.tensor(onset_w, device=device, dtype=dtype), w)
7122
+ if duration_w != 1.0:
7123
+ w = torch.where(duration_mask, torch.tensor(duration_w, device=device, dtype=dtype), w)
7124
+ if pitch_w != 1.0:
7125
+ w = torch.where(pitch_mask, torch.tensor(pitch_w, device=device, dtype=dtype), w)
7126
+
7127
+ # Zero out weights for padding positions
7128
+ valid_mask_f = valid_mask.to(dtype) # (B, L)
7129
+ w = w * valid_mask_f # (B, L)
7130
+
7131
+ # Weighted sum and normalization
7132
+ denom = w.sum(dim=1, keepdim=True).clamp(min=1e-6) # (B, 1)
7133
+ w_exp = w.unsqueeze(-1) # (B, L, 1)
7134
+ summed = (token_embs * w_exp).sum(dim=dim) # (B, D)
7135
+ pooled = summed / denom # (B, D)
7136
+
7137
+ return pooled
7138
+
7139
+ ###################################################################################
7140
+
7141
+ def pad_and_mask(
7142
+ sequences: List[List[int]],
7143
+ pad_idx: int = 385,
7144
+ seq_len: Optional[int] = None,
7145
+ device: Optional[torch.device] = None,
7146
+ verbose: bool = False,
7147
+ ) -> Tuple[Tensor, Tensor]:
7148
+
7149
+ """
7150
+ Pad and create a boolean mask for a batch of integer token sequences.
7151
+
7152
+ This utility converts a list of variable-length integer sequences into a
7153
+ padded LongTensor and a corresponding boolean mask indicating valid token
7154
+ positions. Sequences longer than `seq_len` are truncated. If `seq_len` is
7155
+ None, the function uses the maximum sequence length in the batch.
7156
+
7157
+ Args:
7158
+ sequences: List of token id sequences (each a list of ints).
7159
+ pad_idx: Integer token id used for padding positions (default: 385).
7160
+ seq_len: Optional target sequence length. If provided, sequences are
7161
+ truncated or padded to this length. If None, the maximum length in
7162
+ `sequences` is used.
7163
+ device: Optional torch.device where the returned tensors will be placed.
7164
+ If None, tensors are created on the default device.
7165
+ verbose: If True, shows a small progress bar while processing sequences
7166
+ and prints a summary.
7167
+
7168
+ Returns:
7169
+ A tuple (x, mask):
7170
+ - x: LongTensor of shape (B, T) containing padded token ids.
7171
+ - mask: BoolTensor of shape (B, T) where True indicates a real token.
7172
+ """
7173
+
7174
+ # Fast path for empty batch
7175
+ if not sequences:
7176
+ empty = torch.empty((0, 0), dtype=torch.long, device=device)
7177
+ empty_mask = torch.empty((0, 0), dtype=torch.bool, device=device)
7178
+ return empty, empty_mask
7179
+
7180
+ # Compute lengths and the batch maximum length
7181
+ lengths = [len(s) for s in sequences]
7182
+ batch_max = max(lengths)
7183
+
7184
+ # If seq_len is given, only use it to cap lengths; but if the batch max is smaller,
7185
+ # use the smaller value to avoid extra allocation/work.
7186
+ if seq_len is None:
7187
+ target_len = batch_max
7188
+ else:
7189
+ target_len = min(seq_len, batch_max)
7190
+
7191
+ b = len(sequences)
7192
+ if target_len == 0:
7193
+ x = torch.full((b, 0), pad_idx, dtype=torch.long, device=device)
7194
+ mask = torch.zeros((b, 0), dtype=torch.bool, device=device)
7195
+ return x, mask
7196
+
7197
+ x = torch.full((b, target_len), pad_idx, dtype=torch.long, device=device)
7198
+ mask = torch.zeros((b, target_len), dtype=torch.bool, device=device)
7199
+
7200
+ # iterate with optional progress display
7201
+ iterator = enumerate(sequences)
7202
+ if verbose:
7203
+ iterator = enumerate(tqdm(sequences, disable=not verbose, desc="Pad & mask"))
7204
+
7205
+ for i, seq in iterator:
7206
+ if not seq:
7207
+ continue
7208
+ # Only truncate if seq is longer than the chosen target_len
7209
+ L = len(seq)
7210
+ if L > target_len:
7211
+ L = target_len
7212
+ # slice once to avoid creating a larger tensor then slicing
7213
+ seq_slice = seq[:L]
7214
+ seq_tensor = torch.tensor(seq_slice, dtype=torch.long, device=device)
7215
+ else:
7216
+ seq_tensor = torch.tensor(seq, dtype=torch.long, device=device)
7217
+
7218
+ x[i, :L] = seq_tensor[:L]
7219
+ mask[i, :L] = True
7220
+
7221
+ if verbose:
7222
+ tqdm.write(
7223
+ f"[pad_and_mask] batch_size={b}, target_len={target_len}, "
7224
+ f"min_len={min(lengths)}, max_len={max(lengths)}"
7225
+ )
7226
+
7227
+ return x, mask
7228
+
7229
+ #=================================================================================================================================
7230
+ # Embeddings similarity comparison functions
7231
+ #=================================================================================================================================
7232
+
7233
+ import torch
7234
+ import torch.nn.functional as F
7235
+ from tqdm import tqdm
7236
+ from typing import Optional, Union, Tuple
7237
+
7238
+ def topk_cosine_neighbors(embeddings: torch.Tensor,
7239
+ k: int = 10,
7240
+ key_embeddings: Optional[torch.Tensor] = None,
7241
+ row_batch: Optional[int] = None,
7242
+ col_batch: Optional[int] = None,
7243
+ device: Optional[Union[str, torch.device]] = None,
7244
+ normalize: bool = True,
7245
+ dtype: Optional[torch.dtype] = None,
7246
+ show_progress: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
7247
+
7248
+ """
7249
+ For each query embedding, find the indices and similarities of its top-k neighbors
7250
+ from a set of key embeddings, sorted by descending similarity.
7251
+
7252
+ Supports both self-similarity (single array, excludes self) and pairwise
7253
+ retrieval (two arrays, no exclusion).
7254
+
7255
+ Optimized for maximum speed and memory efficiency across CPU, CUDA, and MPS.
7256
+ Uses a streaming batched approach to handle datasets larger than GPU memory.
7257
+
7258
+ Args:
7259
+ embeddings (torch.Tensor): Query embeddings, shape (N_q, D).
7260
+ k (int): How many neighbors to return.
7261
+ key_embeddings (torch.Tensor, optional): Database/Key embeddings, shape (N_k, D).
7262
+ If None, defaults to 'embeddings' (self-search).
7263
+ row_batch (int, optional): Number of query rows to process at once. Auto-tuned if None.
7264
+ col_batch (int, optional): Number of key columns to process at once. Auto-tuned if None.
7265
+ device (str or torch.device, optional): Target device. If None, uses embeddings.device.
7266
+ normalize (bool): If True, L2-normalize embeddings. Skip if already normalized.
7267
+ dtype (torch.dtype, optional): Compute dtype (e.g., torch.float16, torch.bfloat16).
7268
+ If None, uses embeddings.dtype.
7269
+ show_progress (bool): Show tqdm progress bar.
7270
+
7271
+ Returns:
7272
+ top_idx (torch.Tensor): shape (N_q, k), int32 indices of nearest neighbors (indices into key_embeddings).
7273
+ top_sim (torch.Tensor): shape (N_q, k), float32 cosine similarities.
7274
+ """
7275
+
7276
+ # 1. Determine Search Mode (Self vs. Pairwise)
7277
+ is_self_search = (key_embeddings is None)
7278
+ if is_self_search:
7279
+ key_embeddings = embeddings
7280
+
7281
+ # 2. Device & Dtype Setup
7282
+ if device is None:
7283
+ device = embeddings.device
7284
+ else:
7285
+ device = torch.device(device)
7286
+
7287
+ # Determine compute dtype
7288
+ if dtype is None:
7289
+ dtype = embeddings.dtype
7290
+ else:
7291
+ assert dtype.is_floating_point, "dtype must be a floating point type"
7292
+
7293
+ # Move and cast embeddings
7294
+ # Ensure contiguous for efficient matmul
7295
+ query_embeddings = embeddings.to(device=device, dtype=dtype).contiguous()
7296
+ key_embeddings = key_embeddings.to(device=device, dtype=dtype).contiguous()
7297
+
7298
+ N_q, D = query_embeddings.shape
7299
+ N_k, D_k = key_embeddings.shape
7300
+
7301
+ if D != D_k:
7302
+ raise ValueError(f"Query and Key embeddings must have same dimension. Got {D} and {D_k}")
7303
+
7304
+ # Validation
7305
+ if k < 1:
7306
+ raise ValueError(f"k must be >= 1; got {k}")
7307
+
7308
+ if is_self_search:
7309
+ if k >= N_q:
7310
+ raise ValueError(f"For self-search, k must be < N (to exclude self). Got N={N_q}, k={k}")
7311
+ else:
7312
+ if k > N_k:
7313
+ raise ValueError(f"For pairwise search, k must be <= N_k. Got N_k={N_k}, k={k}")
7314
+
7315
+ # 3. Auto-tune batch sizes based on device and memory
7316
+ # Heuristics adjusted for potentially different N_q and N_k
7317
+ if row_batch is None:
7318
+ if device.type == 'cuda':
7319
+ row_batch = 16384
7320
+ elif device.type == 'mps':
7321
+ row_batch = 8192
7322
+ else:
7323
+ row_batch = 4096 # CPU
7324
+
7325
+ if col_batch is None:
7326
+ if device.type == 'cuda':
7327
+ col_batch = 16384
7328
+ elif device.type == 'mps':
7329
+ col_batch = 8192
7330
+ else:
7331
+ col_batch = 4096 # CPU
7332
+
7333
+ # Clamp batch sizes to actual dimensions
7334
+ row_batch = min(row_batch, N_q)
7335
+ col_batch = min(col_batch, N_k)
7336
+
7337
+ # 4. Optional Normalization
7338
+ if normalize:
7339
+ # Normalize in-place if possible, or reassign
7340
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
7341
+ # Only normalize keys if they are distinct from queries to avoid redundant work
7342
+ # in self-search case (already normalized above)
7343
+ if not is_self_search:
7344
+ key_embeddings = F.normalize(key_embeddings, p=2, dim=1)
7345
+
7346
+ # 5. Initialize Result Tensors (always float32 for precision in output)
7347
+ top_sim = torch.empty((N_q, k), dtype=torch.float32, device=device)
7348
+ top_idx = torch.empty((N_q, k), dtype=torch.int, device=device)
7349
+
7350
+ # Pre-allocate reusable buffers for inner loop (memory efficiency)
7351
+ # Buffers for top-k merge (size 2k)
7352
+ merge_sim_buffer = torch.empty((row_batch, 2 * k), dtype=dtype, device=device)
7353
+ merge_idx_buffer = torch.empty((row_batch, 2 * k), dtype=torch.int, device=device)
7354
+
7355
+ # Buffer for column batch similarities
7356
+ sim_buffer = torch.empty((row_batch, col_batch), dtype=dtype, device=device)
7357
+
7358
+ # Value for masking (minimum possible float for the dtype)
7359
+ min_val = -torch.finfo(dtype).max
7360
+
7361
+ # 6. Inference Context
7362
+ with torch.no_grad():
7363
+ iterator = range(0, N_q, row_batch)
7364
+ if show_progress:
7365
+ desc = "Query Batches" if not is_self_search else "Row Batches"
7366
+ iterator = tqdm(iterator, desc=desc, leave=True)
7367
+
7368
+ for i in iterator:
7369
+ i_end = min(i + row_batch, N_q)
7370
+ rb = i_end - i
7371
+
7372
+ rows = query_embeddings[i:i_end] # (rb, D)
7373
+
7374
+ # Initialize current batch top-k
7375
+ # Use a tensor that persists across column batches for the current row batch
7376
+ curr_sim = torch.full((rb, k), min_val, dtype=dtype, device=device)
7377
+ curr_idx = torch.full((rb, k), -1, dtype=torch.int, device=device)
7378
+
7379
+ for j in range(0, N_k, col_batch):
7380
+ j_end = min(j + col_batch, N_k)
7381
+ cb = j_end - j
7382
+
7383
+ cols = key_embeddings[j:j_end] # (cb, D)
7384
+
7385
+ # Compute similarities in-place into buffer
7386
+ # sim_block shape: (rb, cb)
7387
+ sim_block = sim_buffer[:rb, :cb]
7388
+ torch.matmul(rows, cols.T, out=sim_block)
7389
+
7390
+ # Mask self-similarity ONLY if self-search
7391
+ if is_self_search:
7392
+ offset = i - j
7393
+ r_start = max(0, -offset)
7394
+ r_end = min(rb, cb - offset)
7395
+
7396
+ if r_start < r_end:
7397
+ # Vectorized masking of the diagonal
7398
+ r_range = torch.arange(r_start, r_end, dtype=torch.long, device=device)
7399
+ c_range = r_range + offset
7400
+ sim_block[r_range, c_range] = min_val
7401
+
7402
+ # Top-k in block
7403
+ if cb >= k:
7404
+ blk_s, blk_p = torch.topk(sim_block, k, dim=1, largest=True, sorted=True)
7405
+ blk_i = blk_p + j
7406
+ else:
7407
+ # Pad block to k if remaining keys are fewer than k
7408
+ pad_size = k - cb
7409
+ pad_vals = torch.full((rb, pad_size), min_val, dtype=dtype, device=device)
7410
+ sims_padded = torch.cat([sim_block, pad_vals], dim=1)
7411
+ blk_s, blk_p = torch.topk(sims_padded, k, dim=1, largest=True, sorted=True)
7412
+ blk_i = blk_p + j
7413
+ # Invalidate padded indices
7414
+ blk_i[blk_s == min_val] = -1
7415
+
7416
+ # Merge with current best
7417
+ # Layout: [curr_sim (k), blk_s (k)] -> topk(2k) -> keep k
7418
+ merge_sim_buffer[:rb, :k] = curr_sim
7419
+ merge_sim_buffer[:rb, k:2*k] = blk_s
7420
+ merge_idx_buffer[:rb, :k] = curr_idx
7421
+ merge_idx_buffer[:rb, k:2*k] = blk_i
7422
+
7423
+ curr_sim, top_p = torch.topk(merge_sim_buffer[:rb, :2*k], k, dim=1, largest=True, sorted=True)
7424
+ curr_idx = torch.gather(merge_idx_buffer[:rb, :2*k], dim=1, index=top_p)
7425
+
7426
+ # Write results (convert to float32 for consistency)
7427
+ top_sim[i:i_end] = curr_sim.to(torch.float32)
7428
+ top_idx[i:i_end] = curr_idx
7429
+
7430
+ # 7. Post-processing return format
7431
+ if k == 1:
7432
+ return top_idx.view(-1), top_sim.view(-1)
7433
+
7434
+ return top_idx, top_sim
7435
+
7436
+ #=================================================================================================================================
7437
+ # Embeddings visualization functions
7438
+ #=================================================================================================================================
7439
+
7440
+ try:
7441
+ import numpy as np
7442
+ import matplotlib.pyplot as plt
7443
+ from sklearn.metrics import pairwise_distances
7444
+
7445
+ except:
7446
+ pass
7447
+
7448
+ def plot_emb_cosine_similarity(embeddings,
7449
+ clip=2.0,
7450
+ gamma=0.55,
7451
+ cmap="inferno",
7452
+ figsize=(20, 20),
7453
+ dpi=300,
7454
+ output_fname='embeddings_similarity_plot.png',
7455
+ return_sims=False
7456
+ ):
7457
+
7458
+ """
7459
+ Produces a crisp, high-contrast cosine similarity heatmap.
7460
+ - clip: percentile clipping (1–5 recommended)
7461
+ - gamma: nonlinear contrast (0.4–0.8 recommended)
7462
+
7463
+ -----------
7464
+ Use Example
7465
+ -----------
7466
+
7467
+ tok_emb = model.net.token_emb.emb.weight.detach().cpu()
7468
+
7469
+ plot_cosine_similarity(tok_emb)
7470
+ """
7471
+
7472
+ # 1. Compute cosine similarity (not distance)
7473
+ cos_dist = pairwise_distances(embeddings, metric="cosine")
7474
+ cos_sim = 1 - cos_dist
7475
+
7476
+ # 2. Gamma correction for contrast
7477
+ sim = np.sign(cos_sim) * (np.abs(cos_sim) ** gamma)
7478
+
7479
+ # 3. Percentile clipping to remove flat tails
7480
+ vmin, vmax = np.percentile(sim, [clip, 100 - clip])
7481
+
7482
+ # 4. Plot
7483
+ plt.figure(figsize=figsize, dpi=dpi)
7484
+ plt.imshow(sim, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
7485
+ plt.colorbar(fraction=0.046, pad=0.04)
7486
+ plt.title("Embeddings Pairwise Cosine Similarity")
7487
+ plt.xlabel("Embedding Index")
7488
+ plt.ylabel("Embeddings Index")
7489
+ plt.tight_layout()
7490
+ plt.savefig(output_fname)
7491
+ plt.show()
7492
+
7493
+ if return_sims:
7494
+ return sim
7495
+
7496
  #=================================================================================================================================
7497
  # This is the end of x_transformer_2_3_1 Python module
7498
  #=================================================================================================================================