projectlosangeles's picture
Update app.py
eb8537f verified
#=================================================================================
# https://huggingface.co/spaces/projectlosangeles/Chords-Progressions-Transformer
#=================================================================================
print('=' * 70)
print('Chords Progressions Transformer Gradio App')
print('=' * 70)
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ['USE_FLASH_ATTENTION'] = '1'
import time as reqtime
from pytz import timezone
import torch
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)
import spaces
import gradio as gr
from x_transformer_2_3_1 import *
import datetime
import random
import tqdm
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
# =================================================================================================
print('=' * 70)
print('Loading models...')
print('=' * 70)
print('Loading chords texturing model...')
print('=' * 70)
SEQ_LEN = 2048
PAD_IDX = 721
DEVICE = 'cuda'
tex_model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 2048,
depth = 12,
heads = 16,
rotary_pos_emb = True,
attn_flash = True
)
)
tex_model = AutoregressiveWrapper(tex_model, ignore_index=PAD_IDX)
tex_model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
checkpoint = hf_hub_download(
repo_id='asigalov61/Chordified-Piano-Transformer',
filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth'
)
tex_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
tex_model.eval()
tex_model = torch.compile(tex_model)
print('=' * 70)
print('Done!')
print('=' * 70)
# =================================================================================================
print('Loading chords progressions model...')
print('=' * 70)
SEQ_LEN = 380
PAD_IDX = 324
DEVICE = 'cuda'
prg_model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 2048,
depth = 6,
heads = 16,
rotary_pos_emb = True,
attn_flash = True
)
)
prg_model = AutoregressiveWrapper(prg_model, ignore_index=PAD_IDX)
prg_model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
checkpoint = hf_hub_download(
repo_id='asigalov61/Chordified-Piano-Transformer',
filename='Chordified_Piano_Transformer_Chords_Progressions_Trained_Model_3569_steps_1.8604_loss_0.4727_acc.pth'
)
prg_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
prg_model.eval()
prg_model = torch.compile(prg_model)
print('=' * 70)
# =================================================================================================
dtype = torch.bfloat16
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
print('Done!')
print('=' * 70)
# =================================================================================================
print('Loading SoundFont...')
SOUNDFONT_PATH = hf_hub_download(repo_id='projectlosangeles/soundfonts4u',
repo_type='dataset',
filename='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
)
print('Done!')
print('=' * 70)
# =================================================================================================
@spaces.GPU
def generate_chords(chords,
input_temperature,
input_top_p_value
):
print('*' * 70)
print('Generating chords progression...')
chords = [321] + chords + [322]
x = torch.LongTensor([chords] * 256).to(DEVICE)
with ctx:
out = prg_model.generate(x,
380-len(chords),
temperature=input_temperature,
filter_logits_fn=top_p,
filter_kwargs={'thres': input_top_p_value},
eos_token=323,
return_prime=False,
verbose=True
)
out = out.tolist()
good_outs = []
for o in out:
if len(set(o)) >= len(chords):
good_outs.append(o)
if len(good_outs) > 0:
cho_prg = sorted(good_outs, key=lambda x: -len(set(x)))[0]
else:
cho_prg = sorted(out, key=lambda x: -len(set(x)))[0]
cho_prg = [c for c in cho_prg if 0 <= c < 321]
ncho = [0, 89, 178, 233, 267, 288, 301, 309, 314, 317, 319, 320]
inp_cho_prg = [c+140 if c not in ncho else ncho.index(c)+128 for c in cho_prg]
print('Done!')
print('*' * 70)
print('Number of good chords progressions:', len(good_outs))
print('*' * 70)
print('Texturing selected generated chords progression...')
x = torch.LongTensor([718] + inp_cho_prg + [719]).to(DEVICE)
with ctx:
out = tex_model.generate(x,
2048-len(cho_prg)+2,
temperature=input_temperature,
filter_logits_fn=top_p,
filter_kwargs={'thres': input_top_p_value},
eos_token=720,
return_prime=False,
verbose=True
)
score = out.tolist()
print('Done!')
print('=' * 70)
return cho_prg, score
# =================================================================================================
def tokens_to_escore_notes(tokens):
song_f = []
time = 0
dur = 1
vel = 90
pitch = 60
channel = 0
patch = 0
patches = [0] * 16
for m in tokens:
if 0 <= m < 128:
time += m
elif 461 < m < 589:
pitch = (m-461)
elif 589 < m < 717:
dur = (m-589)
song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0])
return song_f
# =================================================================================================
def Generate_Chords(input_example,
input_chords,
input_temperature,
input_top_p_value
):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
print('Input example:', input_example)
print('Input chords:', input_chords)
print('Req model temp:', input_temperature)
print('Req top_k value:', input_top_p_value)
print('=' * 70)
if input_chords is not None:
chords = []
for c in input_chords:
cho = [int(t) for t in c.split('-')]
chords.append(TMIDIX.ALL_CHORDS_SORTED.index(cho))
else:
if input_example == 'Blue Bird':
chords = blue_bird
elif input_example == 'Come To My Window':
chords = come_to_my_window
else:
chords = sharing_the_night_together
print('There are', len(chords), 'chords')
print('Sample chords:', chords[:5])
print('=' * 70)
#===============================================================================
print('Sample chords', chords[:2])
print('=' * 70)
print('Generating...')
cho_prg, score = generate_chords(chords,
input_temperature,
input_top_p_value
)
final_chords = [TMIDIX.ALL_CHORDS_SORTED[c] for c in cho_prg]
final_score = tokens_to_escore_notes(score)
final_score = TMIDIX.remove_duplicate_pitches_from_escore_notes(final_score)
final_score = TMIDIX.fix_escore_notes_durations(final_score,
min_notes_gap=0
)
final_score = TMIDIX.humanize_velocities_in_escore_notes(final_score)
#===============================================================================
print('Rendering results...')
print('=' * 70)
now = datetime.datetime.now(PDT)
ms4 = now.strftime("%f")[:4]
fn1 = (
'Chords-Progressions-Transformer-Composition-'
+ now.strftime(f"%Y-%m-%d-%H-%M-%S-{ms4}")
)
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(final_score)
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature = 'Chords Progressions Transformer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches,
timings_multiplier=32
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=SOUNDFONT_PATH,
sample_rate=16000,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_gen_chords = '\n'.join(str(c) for c in final_chords)
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score,
timings_multiplier=32,
plot_title=output_midi,
return_plt=True
)
print('Output gen chords:', output_gen_chords[:3])
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_gen_chords, output_audio, output_plot, output_midi
# =================================================================================================
chords_labels = []
for c in TMIDIX.ALL_CHORDS_SORTED:
cho = '-'.join([str(t) for t in c])
chords_labels.append(cho)
# =================================================================================================
blue_bird = [89, 301, 314, 317, 114, 280, 110, 318, 194, 221, 320, 187, 270, 191, 303, 162,
298, 178, 181, 308, 267, 283, 94, 104, 96, 272, 215, 288, 296, 105, 102, 144,
91, 284, 273, 227, 316, 164, 189, 147, 281, 268, 179, 186, 213, 159, 165, 92,
188, 150, 218, 112, 309, 285, 302, 217, 290, 306, 148, 1, 310, 4, 289, 0, 307,
119, 212, 117, 233, 254, 34, 3, 180, 319]
come_to_my_window = [16, 0, 13, 216, 309, 178, 194, 301, 192, 317, 320, 1, 191, 89, 319, 314, 288,
267, 195, 282, 280, 183, 18, 14, 181, 179, 215, 303, 184, 213, 306, 37, 272,
310, 34, 228, 212, 312, 227, 97, 308]
sharing_the_night_together = [267, 270, 281, 314, 148, 280, 146, 102, 233, 316, 283, 320, 10, 104, 235, 117,
89, 91, 256, 285, 21, 159, 112, 264, 301, 238, 303, 0, 263, 94, 144, 110, 153,
103, 106, 170, 95, 171, 119, 268, 90, 108, 114, 317, 306, 136, 134, 254, 307,
302, 31, 284, 319, 258, 243, 272]
# =================================================================================================
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Chords Progressions Transformer</h1>")
gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Generate and texture unique chords progressions</h1>")
gr.Markdown("## Select example chords progression or create your own")
input_example = gr.Dropdown(label="Example chords progressions",
choices=['Blue Bird', 'Come To My Window', 'Sharing The Night Together'],
value='Blue Bird',
info='NOTE: Selecting custom chords below will override example selection'
)
input_chords = gr.Dropdown(label="Desired chords to generate",
choices=chords_labels,
value=None,
multiselect=True,
info='NOTE: Selected chords will be introduced into generated chords progression in order of selection'
)
input_temperature = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Model temperature")
input_top_p_value = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top_p value")
run_btn = gr.Button("generate", variant="primary")
gr.Markdown("## Generation results")
output_gen_chords = gr.Textbox(label="Generated chords list", lines=7)
output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(Generate_Chords,
[input_example,
input_chords,
input_temperature,
input_top_p_value
],
[output_gen_chords,
output_audio,
output_plot,
output_midi
])
app.launch(mcp_server=True)