| |
| |
| |
|
|
| print('=' * 70) |
| print('Chords Progressions Transformer Gradio App') |
| print('=' * 70) |
|
|
| import os |
|
|
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| os.environ['USE_FLASH_ATTENTION'] = '1' |
|
|
| import time as reqtime |
| from pytz import timezone |
|
|
| import torch |
|
|
| torch.set_float32_matmul_precision('high') |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cuda.enable_mem_efficient_sdp(True) |
| torch.backends.cuda.enable_math_sdp(True) |
| torch.backends.cuda.enable_flash_sdp(True) |
| torch.backends.cuda.enable_cudnn_sdp(True) |
|
|
| import spaces |
| import gradio as gr |
|
|
| from x_transformer_2_3_1 import * |
|
|
| import datetime |
| import random |
| import tqdm |
|
|
| from midi_to_colab_audio import midi_to_colab_audio |
| import TMIDIX |
|
|
| import matplotlib.pyplot as plt |
|
|
| from huggingface_hub import hf_hub_download |
| |
| |
|
|
| print('=' * 70) |
| print('Loading models...') |
| print('=' * 70) |
| print('Loading chords texturing model...') |
| print('=' * 70) |
|
|
| SEQ_LEN = 2048 |
| PAD_IDX = 721 |
| DEVICE = 'cuda' |
|
|
| tex_model = TransformerWrapper( |
| num_tokens = PAD_IDX+1, |
| max_seq_len = SEQ_LEN, |
| attn_layers = Decoder(dim = 2048, |
| depth = 12, |
| heads = 16, |
| rotary_pos_emb = True, |
| attn_flash = True |
| ) |
| ) |
|
|
| tex_model = AutoregressiveWrapper(tex_model, ignore_index=PAD_IDX) |
|
|
| tex_model.to(DEVICE) |
| print('=' * 70) |
|
|
| print('Loading model checkpoint...') |
|
|
| checkpoint = hf_hub_download( |
| repo_id='asigalov61/Chordified-Piano-Transformer', |
| filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth' |
| ) |
|
|
| tex_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True)) |
|
|
| tex_model.eval() |
|
|
| tex_model = torch.compile(tex_model) |
|
|
| print('=' * 70) |
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| print('Loading chords progressions model...') |
| print('=' * 70) |
|
|
| SEQ_LEN = 380 |
| PAD_IDX = 324 |
| DEVICE = 'cuda' |
|
|
| prg_model = TransformerWrapper( |
| num_tokens = PAD_IDX+1, |
| max_seq_len = SEQ_LEN, |
| attn_layers = Decoder(dim = 2048, |
| depth = 6, |
| heads = 16, |
| rotary_pos_emb = True, |
| attn_flash = True |
| ) |
| ) |
|
|
| prg_model = AutoregressiveWrapper(prg_model, ignore_index=PAD_IDX) |
|
|
| prg_model.to(DEVICE) |
| print('=' * 70) |
|
|
| print('Loading model checkpoint...') |
|
|
| checkpoint = hf_hub_download( |
| repo_id='asigalov61/Chordified-Piano-Transformer', |
| filename='Chordified_Piano_Transformer_Chords_Progressions_Trained_Model_3569_steps_1.8604_loss_0.4727_acc.pth' |
| ) |
|
|
| prg_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True)) |
|
|
| prg_model.eval() |
|
|
| prg_model = torch.compile(prg_model) |
|
|
| print('=' * 70) |
|
|
| |
|
|
| dtype = torch.bfloat16 |
|
|
| ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) |
|
|
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| print('Loading SoundFont...') |
|
|
| SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u', |
| repo_type='dataset', |
| filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' |
| ) |
|
|
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| @spaces.GPU |
| def generate_chords(chords, |
| input_temperature, |
| input_top_p_value |
| ): |
| |
| print('*' * 70) |
| print('Generating chords progression...') |
|
|
| chords = [321] + chords + [322] |
|
|
| x = torch.LongTensor([chords] * 256).to(DEVICE) |
|
|
| with ctx: |
| out = prg_model.generate(x, |
| 380-len(chords), |
| temperature=input_temperature, |
| filter_logits_fn=top_p, |
| filter_kwargs={'thres': input_top_p_value}, |
| eos_token=323, |
| return_prime=False, |
| verbose=True |
| ) |
| |
| out = out.tolist() |
|
|
| good_outs = [] |
|
|
| for o in out: |
| if len(set(o)) >= len(chords): |
| good_outs.append(o) |
|
|
| if len(good_outs) > 0: |
| cho_prg = sorted(good_outs, key=lambda x: -len(set(x)))[0] |
|
|
| else: |
| cho_prg = sorted(out, key=lambda x: -len(set(x)))[0] |
| |
| cho_prg = [c for c in cho_prg if 0 <= c < 321] |
|
|
| ncho = [0, 89, 178, 233, 267, 288, 301, 309, 314, 317, 319, 320] |
|
|
| inp_cho_prg = [c+140 if c not in ncho else ncho.index(c)+128 for c in cho_prg] |
| |
| print('Done!') |
| print('*' * 70) |
| print('Number of good chords progressions:', len(good_outs)) |
| print('*' * 70) |
| print('Texturing selected generated chords progression...') |
|
|
| x = torch.LongTensor([718] + inp_cho_prg + [719]).to(DEVICE) |
|
|
| with ctx: |
| out = tex_model.generate(x, |
| 2048-len(cho_prg)+2, |
| temperature=input_temperature, |
| filter_logits_fn=top_p, |
| filter_kwargs={'thres': input_top_p_value}, |
| eos_token=720, |
| return_prime=False, |
| verbose=True |
| ) |
| |
| score = out.tolist() |
|
|
| print('Done!') |
| print('=' * 70) |
| |
| return cho_prg, score |
| |
| |
|
|
| def tokens_to_escore_notes(tokens): |
|
|
| song_f = [] |
|
|
| time = 0 |
| dur = 1 |
| vel = 90 |
| pitch = 60 |
| channel = 0 |
| patch = 0 |
|
|
| patches = [0] * 16 |
|
|
| for m in tokens: |
| |
| if 0 <= m < 128: |
| time += m |
| |
| elif 461 < m < 589: |
| pitch = (m-461) |
| |
| elif 589 < m < 717: |
| dur = (m-589) |
| song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0]) |
|
|
| return song_f |
|
|
| |
|
|
| def Generate_Chords(input_example, |
| input_chords, |
| input_temperature, |
| input_top_p_value |
| ): |
| |
| print('=' * 70) |
| print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
| start_time = reqtime.time() |
|
|
| print('=' * 70) |
| print('Input example:', input_example) |
| print('Input chords:', input_chords) |
| print('Req model temp:', input_temperature) |
| print('Req top_k value:', input_top_p_value) |
| print('=' * 70) |
|
|
| if input_chords is not None: |
| chords = [] |
| |
| for c in input_chords: |
| cho = [int(t) for t in c.split('-')] |
| chords.append(TMIDIX.ALL_CHORDS_SORTED.index(cho)) |
|
|
| else: |
| if input_example == 'Blue Bird': |
| chords = blue_bird |
|
|
| elif input_example == 'Come To My Window': |
| chords = come_to_my_window |
|
|
| else: |
| chords = sharing_the_night_together |
|
|
| print('There are', len(chords), 'chords') |
| print('Sample chords:', chords[:5]) |
| print('=' * 70) |
|
|
| |
| |
| print('Sample chords', chords[:2]) |
| print('=' * 70) |
| print('Generating...') |
|
|
| cho_prg, score = generate_chords(chords, |
| input_temperature, |
| input_top_p_value |
| ) |
|
|
| final_chords = [TMIDIX.ALL_CHORDS_SORTED[c] for c in cho_prg] |
| final_score = tokens_to_escore_notes(score) |
|
|
| final_score = TMIDIX.remove_duplicate_pitches_from_escore_notes(final_score) |
|
|
| final_score = TMIDIX.fix_escore_notes_durations(final_score, |
| min_notes_gap=0 |
| ) |
|
|
| final_score = TMIDIX.humanize_velocities_in_escore_notes(final_score) |
| |
| |
| print('Rendering results...') |
| print('=' * 70) |
|
|
| now = datetime.datetime.now(PDT) |
| ms4 = now.strftime("%f")[:4] |
|
|
| fn1 = ( |
| 'Chords-Progressions-Transformer-Composition-' |
| + now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}") |
| ) |
| |
| output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(final_score) |
| |
| detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, |
| output_signature = 'Chords Progressions Transformer', |
| output_file_name = fn1, |
| track_name='Project Los Angeles', |
| list_of_MIDI_patches=patches, |
| timings_multiplier=32 |
| ) |
| |
| new_fn = fn1+'.mid' |
| |
| |
| audio = midi_to_colab_audio(new_fn, |
| soundfont_path=SOUNDFONT_PATH, |
| sample_rate=16000, |
| output_for_gradio=True |
| ) |
| |
| print('Done!') |
| print('=' * 70) |
|
|
| |
|
|
| output_gen_chords = '\n'.join(str(c) for c in final_chords) |
| output_midi = str(new_fn) |
| output_audio = (16000, audio) |
| |
| output_plot = TMIDIX.plot_ms_SONG(output_score, |
| timings_multiplier=32, |
| plot_title=output_midi, |
| return_plt=True |
| ) |
|
|
| print('Output gen chords:', output_gen_chords[:3]) |
| print('=' * 70) |
| |
|
|
| |
| |
| print('-' * 70) |
| print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
| print('-' * 70) |
| print('Req execution time:', (reqtime.time() - start_time), 'sec') |
|
|
| return output_gen_chords, output_audio, output_plot, output_midi |
| |
| |
|
|
| chords_labels = [] |
|
|
| for c in TMIDIX.ALL_CHORDS_SORTED: |
| |
| cho = '-'.join([str(t) for t in c]) |
|
|
| chords_labels.append(cho) |
|
|
| |
|
|
| blue_bird = [89, 301, 314, 317, 114, 280, 110, 318, 194, 221, 320, 187, 270, 191, 303, 162, |
| 298, 178, 181, 308, 267, 283, 94, 104, 96, 272, 215, 288, 296, 105, 102, 144, |
| 91, 284, 273, 227, 316, 164, 189, 147, 281, 268, 179, 186, 213, 159, 165, 92, |
| 188, 150, 218, 112, 309, 285, 302, 217, 290, 306, 148, 1, 310, 4, 289, 0, 307, |
| 119, 212, 117, 233, 254, 34, 3, 180, 319] |
|
|
| come_to_my_window = [16, 0, 13, 216, 309, 178, 194, 301, 192, 317, 320, 1, 191, 89, 319, 314, 288, |
| 267, 195, 282, 280, 183, 18, 14, 181, 179, 215, 303, 184, 213, 306, 37, 272, |
| 310, 34, 228, 212, 312, 227, 97, 308] |
|
|
| sharing_the_night_together = [267, 270, 281, 314, 148, 280, 146, 102, 233, 316, 283, 320, 10, 104, 235, 117, |
| 89, 91, 256, 285, 21, 159, 112, 264, 301, 238, 303, 0, 263, 94, 144, 110, 153, |
| 103, 106, 170, 95, 171, 119, 268, 90, 108, 114, 317, 306, 136, 134, 254, 307, |
| 302, 31, 284, 319, 258, 243, 272] |
|
|
| |
| |
| PDT = timezone('US/Pacific') |
|
|
| print('=' * 70) |
| print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
| print('=' * 70) |
|
|
| soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" |
|
|
| app = gr.Blocks() |
| with app: |
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Chords Progressions Transformer</h1>") |
| gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Generate and texture unique chords progressions</h1>") |
|
|
| gr.Markdown("## Select example chords progression or create your own") |
|
|
| input_example = gr.Dropdown(label="Example chords progressions", |
| choices=['Blue Bird', 'Come To My Window', 'Sharing The Night Together'], |
| value='Blue Bird', |
| info='NOTE: Selecting custom chords below will override example selection' |
| ) |
| |
| input_chords = gr.Dropdown(label="Desired chords to generate", |
| choices=chords_labels, |
| value=None, |
| multiselect=True, |
| info='NOTE: Selected chords will be introduced into generated chords progression in order of selection' |
| ) |
| input_temperature = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Model temperature") |
| input_top_p_value = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top_p value") |
| |
| run_btn = gr.Button("generate", variant="primary") |
|
|
| gr.Markdown("## Generation results") |
|
|
| output_gen_chords = gr.Textbox(label="Generated chords list", lines=7) |
| output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio") |
| output_plot = gr.Plot(label="Output MIDI score plot") |
| output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) |
|
|
| run_event = run_btn.click(Generate_Chords, |
| [input_example, |
| input_chords, |
| input_temperature, |
| input_top_p_value |
| ], |
| [output_gen_chords, |
| output_audio, |
| output_plot, |
| output_midi |
| ]) |
| |
| app.launch(mcp_server=True) |