#================================================================================= # https://huggingface.co/spaces/projectlosangeles/Chords-Progressions-Transformer #================================================================================= print('=' * 70) print('Chords Progressions Transformer Gradio App') print('=' * 70) import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ['USE_FLASH_ATTENTION'] = '1' import time as reqtime from pytz import timezone import torch torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_cudnn_sdp(True) import spaces import gradio as gr from x_transformer_2_3_1 import * import datetime import random import tqdm from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt from huggingface_hub import hf_hub_download # ================================================================================================= print('=' * 70) print('Loading models...') print('=' * 70) print('Loading chords texturing model...') print('=' * 70) SEQ_LEN = 2048 PAD_IDX = 721 DEVICE = 'cuda' tex_model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 12, heads = 16, rotary_pos_emb = True, attn_flash = True ) ) tex_model = AutoregressiveWrapper(tex_model, ignore_index=PAD_IDX) tex_model.to(DEVICE) print('=' * 70) print('Loading model checkpoint...') checkpoint = hf_hub_download( repo_id='asigalov61/Chordified-Piano-Transformer', filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth' ) tex_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True)) tex_model.eval() tex_model = torch.compile(tex_model) print('=' * 70) print('Done!') print('=' * 70) # ================================================================================================= print('Loading chords progressions model...') print('=' * 70) SEQ_LEN = 380 PAD_IDX = 324 DEVICE = 'cuda' prg_model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 6, heads = 16, rotary_pos_emb = True, attn_flash = True ) ) prg_model = AutoregressiveWrapper(prg_model, ignore_index=PAD_IDX) prg_model.to(DEVICE) print('=' * 70) print('Loading model checkpoint...') checkpoint = hf_hub_download( repo_id='asigalov61/Chordified-Piano-Transformer', filename='Chordified_Piano_Transformer_Chords_Progressions_Trained_Model_3569_steps_1.8604_loss_0.4727_acc.pth' ) prg_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True)) prg_model.eval() prg_model = torch.compile(prg_model) print('=' * 70) # ================================================================================================= dtype = torch.bfloat16 ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) print('Done!') print('=' * 70) # ================================================================================================= print('Loading SoundFont...') SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u', repo_type='dataset', filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' ) print('Done!') print('=' * 70) # ================================================================================================= @spaces.GPU def generate_chords(chords, input_temperature, input_top_p_value ): print('*' * 70) print('Generating chords progression...') chords = [321] + chords + [322] x = torch.LongTensor([chords] * 256).to(DEVICE) with ctx: out = prg_model.generate(x, 380-len(chords), temperature=input_temperature, filter_logits_fn=top_p, filter_kwargs={'thres': input_top_p_value}, eos_token=323, return_prime=False, verbose=True ) out = out.tolist() good_outs = [] for o in out: if len(set(o)) >= len(chords): good_outs.append(o) if len(good_outs) > 0: cho_prg = sorted(good_outs, key=lambda x: -len(set(x)))[0] else: cho_prg = sorted(out, key=lambda x: -len(set(x)))[0] cho_prg = [c for c in cho_prg if 0 <= c < 321] ncho = [0, 89, 178, 233, 267, 288, 301, 309, 314, 317, 319, 320] inp_cho_prg = [c+140 if c not in ncho else ncho.index(c)+128 for c in cho_prg] print('Done!') print('*' * 70) print('Number of good chords progressions:', len(good_outs)) print('*' * 70) print('Texturing selected generated chords progression...') x = torch.LongTensor([718] + inp_cho_prg + [719]).to(DEVICE) with ctx: out = tex_model.generate(x, 2048-len(cho_prg)+2, temperature=input_temperature, filter_logits_fn=top_p, filter_kwargs={'thres': input_top_p_value}, eos_token=720, return_prime=False, verbose=True ) score = out.tolist() print('Done!') print('=' * 70) return cho_prg, score # ================================================================================================= def tokens_to_escore_notes(tokens): song_f = [] time = 0 dur = 1 vel = 90 pitch = 60 channel = 0 patch = 0 patches = [0] * 16 for m in tokens: if 0 <= m < 128: time += m elif 461 < m < 589: pitch = (m-461) elif 589 < m < 717: dur = (m-589) song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0]) return song_f # ================================================================================================= def Generate_Chords(input_example, input_chords, input_temperature, input_top_p_value ): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) print('Input example:', input_example) print('Input chords:', input_chords) print('Req model temp:', input_temperature) print('Req top_k value:', input_top_p_value) print('=' * 70) if input_chords is not None: chords = [] for c in input_chords: cho = [int(t) for t in c.split('-')] chords.append(TMIDIX.ALL_CHORDS_SORTED.index(cho)) else: if input_example == 'Blue Bird': chords = blue_bird elif input_example == 'Come To My Window': chords = come_to_my_window else: chords = sharing_the_night_together print('There are', len(chords), 'chords') print('Sample chords:', chords[:5]) print('=' * 70) #=============================================================================== print('Sample chords', chords[:2]) print('=' * 70) print('Generating...') cho_prg, score = generate_chords(chords, input_temperature, input_top_p_value ) final_chords = [TMIDIX.ALL_CHORDS_SORTED[c] for c in cho_prg] final_score = tokens_to_escore_notes(score) final_score = TMIDIX.remove_duplicate_pitches_from_escore_notes(final_score) final_score = TMIDIX.fix_escore_notes_durations(final_score, min_notes_gap=0 ) final_score = TMIDIX.humanize_velocities_in_escore_notes(final_score) #=============================================================================== print('Rendering results...') print('=' * 70) now = datetime.datetime.now(PDT) ms4 = now.strftime("%f")[:4] fn1 = ( 'Chords-Progressions-Transformer-Composition-' + now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}") ) output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(final_score) detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'Chords Progressions Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches, timings_multiplier=32 ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_gen_chords = '\n'.join(str(c) for c in final_chords) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, timings_multiplier=32, plot_title=output_midi, return_plt=True ) print('Output gen chords:', output_gen_chords[:3]) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_gen_chords, output_audio, output_plot, output_midi # ================================================================================================= chords_labels = [] for c in TMIDIX.ALL_CHORDS_SORTED: cho = '-'.join([str(t) for t in c]) chords_labels.append(cho) # ================================================================================================= blue_bird = [89, 301, 314, 317, 114, 280, 110, 318, 194, 221, 320, 187, 270, 191, 303, 162, 298, 178, 181, 308, 267, 283, 94, 104, 96, 272, 215, 288, 296, 105, 102, 144, 91, 284, 273, 227, 316, 164, 189, 147, 281, 268, 179, 186, 213, 159, 165, 92, 188, 150, 218, 112, 309, 285, 302, 217, 290, 306, 148, 1, 310, 4, 289, 0, 307, 119, 212, 117, 233, 254, 34, 3, 180, 319] come_to_my_window = [16, 0, 13, 216, 309, 178, 194, 301, 192, 317, 320, 1, 191, 89, 319, 314, 288, 267, 195, 282, 280, 183, 18, 14, 181, 179, 215, 303, 184, 213, 306, 37, 272, 310, 34, 228, 212, 312, 227, 97, 308] sharing_the_night_together = [267, 270, 281, 314, 148, 280, 146, 102, 233, 316, 283, 320, 10, 104, 235, 117, 89, 91, 256, 285, 21, 159, 112, 264, 301, 238, 303, 0, 263, 94, 144, 110, 153, 103, 106, 170, 95, 171, 119, 268, 90, 108, 114, 317, 306, 136, 134, 254, 307, 302, 31, 284, 319, 258, 243, 272] # ================================================================================================= PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("

Chords Progressions Transformer

") gr.Markdown("

Generate and texture unique chords progressions

") gr.Markdown("## Select example chords progression or create your own") input_example = gr.Dropdown(label="Example chords progressions", choices=['Blue Bird', 'Come To My Window', 'Sharing The Night Together'], value='Blue Bird', info='NOTE: Selecting custom chords below will override example selection' ) input_chords = gr.Dropdown(label="Desired chords to generate", choices=chords_labels, value=None, multiselect=True, info='NOTE: Selected chords will be introduced into generated chords progression in order of selection' ) input_temperature = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Model temperature") input_top_p_value = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top_p value") run_btn = gr.Button("generate", variant="primary") gr.Markdown("## Generation results") output_gen_chords = gr.Textbox(label="Generated chords list", lines=7) output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(Generate_Chords, [input_example, input_chords, input_temperature, input_top_p_value ], [output_gen_chords, output_audio, output_plot, output_midi ]) app.launch(mcp_server=True)