File size: 14,757 Bytes
706c1fb
 
 
 
1a51bfe
 
 
 
706c1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31492b9
 
706c1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1423669
706c1fb
 
dfb15c8
 
1423669
706c1fb
1423669
706c1fb
 
 
dfb15c8
706c1fb
 
 
 
 
 
1423669
706c1fb
1423669
706c1fb
 
 
 
 
 
dfb15c8
706c1fb
 
1423669
706c1fb
1423669
 
 
706c1fb
 
 
 
 
 
 
1423669
 
 
 
 
6cb77e7
1423669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cb77e7
1423669
 
 
 
 
 
 
 
 
 
 
 
 
706c1fb
 
 
 
 
 
 
 
76f055e
 
 
 
 
 
 
 
 
 
 
 
706c1fb
8dbc50f
 
 
 
fef24f4
 
6596f05
706c1fb
b62a82c
 
163205e
8dbc50f
 
6596f05
 
 
 
 
 
 
 
 
706c1fb
fef24f4
 
 
 
 
eb8537f
8e4f3e6
6596f05
fef24f4
1ea1cdf
fef24f4
 
1a51bfe
fd2b6ab
38ad335
8e6911a
fd2b6ab
 
 
fef24f4
6596f05
b62a82c
fef24f4
6596f05
b62a82c
fef24f4
4aaab5c
b62a82c
 
d8c7429
b62a82c
 
 
 
 
 
 
 
 
 
 
 
 
fef24f4
b62a82c
706c1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
8dbc50f
706c1fb
8dbc50f
 
706c1fb
cc932e3
706c1fb
8dbc50f
 
 
 
cc932e3
8dbc50f
706c1fb
 
 
 
 
515978a
 
8dbc50f
 
 
706c1fb
 
 
 
 
 
515978a
8dbc50f
706c1fb
 
 
 
515978a
 
 
 
 
 
 
 
 
 
706c1fb
515978a
 
 
 
 
b62a82c
302653e
 
 
 
 
 
 
 
 
 
b62a82c
 
 
 
302653e
b62a82c
 
3833312
 
 
 
 
 
 
 
706c1fb
302653e
 
 
44cdadb
76f055e
 
 
 
 
 
 
44cdadb
302653e
 
 
 
 
 
 
 
 
 
 
706c1fb
 
302653e
76f055e
302653e
 
 
 
 
 
 
 
 
9459d9a
302653e
 
 
 
 
 
 
 
 
9459d9a
302653e
 
 
 
 
 
 
 
 
 
9459d9a
706c1fb
 
 
302653e
706c1fb
302653e
 
6596f05
302653e
 
706c1fb
 
515978a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706c1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
515978a
706c1fb
515978a
 
d763387
50a57cf
515978a
 
302653e
 
3507848
8e6911a
 
302653e
39d051c
 
706c1fb
 
 
 
 
302653e
706c1fb
 
 
 
302653e
515978a
 
706c1fb
 
 
302653e
706c1fb
9459d9a
 
706c1fb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
#=================================================================================
# https://huggingface.co/spaces/projectlosangeles/Chords-Progressions-Transformer
#=================================================================================

print('=' * 70)
print('Chords Progressions Transformer Gradio App')
print('=' * 70)

import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ['USE_FLASH_ATTENTION'] = '1'

import time as reqtime
from pytz import timezone

import torch

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)

import spaces
import gradio as gr

from x_transformer_2_3_1 import *

import datetime
import random
import tqdm

from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX

import matplotlib.pyplot as plt

from huggingface_hub import hf_hub_download
         
# =================================================================================================

print('=' * 70)
print('Loading models...')
print('=' * 70)
print('Loading chords texturing model...')
print('=' * 70)

SEQ_LEN = 2048
PAD_IDX = 721
DEVICE = 'cuda'

tex_model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 2048,
                          depth = 12,
                          heads = 16,
                          rotary_pos_emb = True,
                          attn_flash = True
                         )
    )

tex_model = AutoregressiveWrapper(tex_model, ignore_index=PAD_IDX)

tex_model.to(DEVICE)
print('=' * 70)

print('Loading model checkpoint...')

checkpoint = hf_hub_download(
    repo_id='asigalov61/Chordified-Piano-Transformer',
    filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth'
)

tex_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))

tex_model.eval()

tex_model = torch.compile(tex_model)

print('=' * 70)
print('Done!')
print('=' * 70)

# =================================================================================================

print('Loading chords progressions model...')
print('=' * 70)

SEQ_LEN = 380
PAD_IDX = 324
DEVICE = 'cuda'

prg_model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 2048,
                          depth = 6,
                          heads = 16,
                          rotary_pos_emb = True,
                          attn_flash = True
                         )
    )

prg_model = AutoregressiveWrapper(prg_model, ignore_index=PAD_IDX)

prg_model.to(DEVICE)
print('=' * 70)

print('Loading model checkpoint...')

checkpoint = hf_hub_download(
    repo_id='asigalov61/Chordified-Piano-Transformer',
    filename='Chordified_Piano_Transformer_Chords_Progressions_Trained_Model_3569_steps_1.8604_loss_0.4727_acc.pth'
)

prg_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))

prg_model.eval()

prg_model = torch.compile(prg_model)

print('=' * 70)

# =================================================================================================

dtype = torch.bfloat16

ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)

print('Done!')
print('=' * 70)

# =================================================================================================

print('Loading SoundFont...')

SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u',
                                 repo_type='dataset',
                                 filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
                                )

print('Done!')
print('=' * 70)

# =================================================================================================

@spaces.GPU
def generate_chords(chords,
                    input_temperature,
                    input_top_p_value
                   ):
    
    print('*' * 70)
    print('Generating chords progression...')

    chords = [321] + chords + [322]

    x = torch.LongTensor([chords] * 256).to(DEVICE)

    with ctx:
        out = prg_model.generate(x,
                                 380-len(chords),
                                 temperature=input_temperature,
                                 filter_logits_fn=top_p,
                                 filter_kwargs={'thres': input_top_p_value},
                                 eos_token=323,
                                 return_prime=False,
                                 verbose=True
                                )
                
        out = out.tolist()

    good_outs = []

    for o in out:
        if len(set(o)) >= len(chords):
            good_outs.append(o)

    if len(good_outs) > 0:
        cho_prg = sorted(good_outs, key=lambda x: -len(set(x)))[0]

    else:
        cho_prg = sorted(out, key=lambda x: -len(set(x)))[0]
    
    cho_prg = [c for c in cho_prg if 0 <= c < 321]

    ncho = [0, 89, 178, 233, 267, 288, 301, 309, 314, 317, 319, 320]

    inp_cho_prg = [c+140 if c not in ncho else ncho.index(c)+128 for c in cho_prg]
    
    print('Done!')
    print('*' * 70)
    print('Number of good chords progressions:', len(good_outs))
    print('*' * 70)
    print('Texturing selected generated chords progression...')

    x = torch.LongTensor([718] + inp_cho_prg + [719]).to(DEVICE)

    with ctx:
        out = tex_model.generate(x,
                                 2048-len(cho_prg)+2,
                                 temperature=input_temperature,
                                 filter_logits_fn=top_p,
                                 filter_kwargs={'thres': input_top_p_value},
                                 eos_token=720,
                                 return_prime=False,
                                 verbose=True
                                )
                
        score = out.tolist()

    print('Done!')
    print('=' * 70)
    
    return cho_prg, score
                       
# =================================================================================================

def tokens_to_escore_notes(tokens):

    song_f = []

    time = 0
    dur = 1
    vel = 90
    pitch = 60
    channel = 0
    patch = 0

    patches = [0] * 16

    for m in tokens:
    
        if 0 <= m < 128:
            time += m
            
        elif 461 < m < 589:
            pitch = (m-461)
            
        elif 589 < m < 717:
            dur = (m-589)
            song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0])

    return song_f

# =================================================================================================

def Generate_Chords(input_example,
                    input_chords,
                    input_temperature,
                    input_top_p_value
                   ):
    
    print('=' * 70)
    print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    start_time = reqtime.time()

    print('=' * 70)
    print('Input example:', input_example)
    print('Input chords:', input_chords)
    print('Req model temp:', input_temperature)
    print('Req top_k value:', input_top_p_value)
    print('=' * 70)

    if input_chords is not None:
        chords = []
        
        for c in input_chords:
            cho = [int(t) for t in c.split('-')]
            chords.append(TMIDIX.ALL_CHORDS_SORTED.index(cho))

    else:
        if input_example == 'Blue Bird':
            chords = blue_bird

        elif input_example == 'Come To My Window':
            chords = come_to_my_window

        else:
            chords = sharing_the_night_together

    print('There are', len(chords), 'chords')
    print('Sample chords:', chords[:5])
    print('=' * 70)

    #===============================================================================
   
    print('Sample chords', chords[:2])
    print('=' * 70)
    print('Generating...')

    cho_prg, score = generate_chords(chords,
                                     input_temperature,
                                     input_top_p_value
                                    )

    final_chords = [TMIDIX.ALL_CHORDS_SORTED[c] for c in cho_prg]
    final_score = tokens_to_escore_notes(score)

    final_score = TMIDIX.remove_duplicate_pitches_from_escore_notes(final_score)

    final_score = TMIDIX.fix_escore_notes_durations(final_score, 
                                               min_notes_gap=0
                                              )

    final_score = TMIDIX.humanize_velocities_in_escore_notes(final_score)
    
    #===============================================================================
    print('Rendering results...')
    print('=' * 70)

    now = datetime.datetime.now(PDT)
    ms4 = now.strftime("%f")[:4]

    fn1 = (
        'Chords-Progressions-Transformer-Composition-'
        + now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}")
    )
    
    output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(final_score)
    
    detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
                                                              output_signature = 'Chords Progressions Transformer',
                                                              output_file_name = fn1,
                                                              track_name='Project Los Angeles',
                                                              list_of_MIDI_patches=patches,
                                                              timings_multiplier=32
                                                              )
    
    new_fn = fn1+'.mid'
            
    
    audio = midi_to_colab_audio(new_fn, 
                                soundfont_path=SOUNDFONT_PATH,
                                sample_rate=16000,
                                output_for_gradio=True
                                )
    
    print('Done!')
    print('=' * 70)

    #========================================================

    output_gen_chords = '\n'.join(str(c) for c in final_chords)
    output_midi = str(new_fn)
    output_audio = (16000, audio)
    
    output_plot = TMIDIX.plot_ms_SONG(output_score,
                                      timings_multiplier=32,
                                      plot_title=output_midi,
                                      return_plt=True
                                     )

    print('Output gen chords:', output_gen_chords[:3])
    print('=' * 70) 
    

    #========================================================
    
    print('-' * 70)
    print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('-' * 70)
    print('Req execution time:', (reqtime.time() - start_time), 'sec')

    return output_gen_chords, output_audio, output_plot, output_midi
        
# =================================================================================================

chords_labels = []

for c in TMIDIX.ALL_CHORDS_SORTED:
    
    cho = '-'.join([str(t) for t in c])

    chords_labels.append(cho)

# =================================================================================================

blue_bird = [89, 301, 314, 317, 114, 280, 110, 318, 194, 221, 320, 187, 270, 191, 303, 162,
 298, 178, 181, 308, 267, 283, 94, 104, 96, 272, 215, 288, 296, 105, 102, 144,
 91, 284, 273, 227, 316, 164, 189, 147, 281, 268, 179, 186, 213, 159, 165, 92,
 188, 150, 218, 112, 309, 285, 302, 217, 290, 306, 148, 1, 310, 4, 289, 0, 307,
 119, 212, 117, 233, 254, 34, 3, 180, 319]

come_to_my_window = [16, 0, 13, 216, 309, 178, 194, 301, 192, 317, 320, 1, 191, 89, 319, 314, 288,
 267, 195, 282, 280, 183, 18, 14, 181, 179, 215, 303, 184, 213, 306, 37, 272,
 310, 34, 228, 212, 312, 227, 97, 308]

sharing_the_night_together = [267, 270, 281, 314, 148, 280, 146, 102, 233, 316, 283, 320, 10, 104, 235, 117,
 89, 91, 256, 285, 21, 159, 112, 264, 301, 238, 303, 0, 263, 94, 144, 110, 153,
 103, 106, 170, 95, 171, 119, 268, 90, 108, 114, 317, 306, 136, 134, 254, 307,
 302, 31, 284, 319, 258, 243, 272]

# =================================================================================================
  
PDT = timezone('US/Pacific')

print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)

soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"

app = gr.Blocks()
with app:
    gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Chords Progressions Transformer</h1>")
    gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Generate and texture unique chords progressions</h1>")

    gr.Markdown("## Select example chords progression or create your own")

    input_example = gr.Dropdown(label="Example chords progressions",
                                choices=['Blue Bird', 'Come To My Window', 'Sharing The Night Together'],
                                value='Blue Bird',
                                info='NOTE: Selecting custom chords below will override example selection'
                               )
    
    input_chords = gr.Dropdown(label="Desired chords to generate",
                               choices=chords_labels,
                               value=None,
                               multiselect=True,
                               info='NOTE: Selected chords will be introduced into generated chords progression in order of selection'
                              )
    input_temperature = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Model temperature")
    input_top_p_value = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top_p value")
    
    run_btn = gr.Button("generate", variant="primary")

    gr.Markdown("## Generation results")

    output_gen_chords = gr.Textbox(label="Generated chords list", lines=7)
    output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio")
    output_plot = gr.Plot(label="Output MIDI score plot")
    output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])

    run_event = run_btn.click(Generate_Chords,
                              [input_example,
                               input_chords,
                               input_temperature,
                               input_top_p_value
                              ],
                              [output_gen_chords,
                               output_audio,
                               output_plot,
                               output_midi
                              ])
    
    app.launch(mcp_server=True)