#==================================================================== # https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer #==================================================================== """ Orpheus Music Transformer Gradio App - Single Model, Simplified Version SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs Using one large optimized model which was trained for 4 full epochs" """ import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import time as reqtime import datetime from pytz import timezone import torch import matplotlib.pyplot as plt import gradio as gr import spaces from huggingface_hub import hf_hub_download import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder, top_p import random # ----------------------------- # CONFIGURATION & GLOBALS # ----------------------------- SEP = '=' * 70 PDT = timezone('US/Pacific') MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Large_Trained_Model_31087_steps_0.6878_loss_0.7889_acc.pth' SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' NUM_OUT_BATCHES = 10 PREVIEW_LENGTH = 120 # in tokens # ----------------------------- # PRINT START-UP INFO # ----------------------------- def print_sep(): print(SEP) print_sep() print("Orpheus Music Transformer Gradio App") print_sep() print("Loading modules...") # ----------------------------- # ENVIRONMENT & PyTorch Settings # ----------------------------- os.environ['USE_FLASH_ATTENTION'] = '1' torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_cudnn_sdp(True) print_sep() print("PyTorch version:", torch.__version__) print("Done loading modules!") print_sep() # ----------------------------- # MODEL INITIALIZATION # ----------------------------- print_sep() print("Instantiating model...") device_type = 'cuda' dtype = 'bfloat16' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) SEQ_LEN = 8192 PAD_IDX = 18819 model = TransformerWrapper( num_tokens=PAD_IDX + 1, max_seq_len=SEQ_LEN, attn_layers=Decoder( dim=2048, depth=16, heads=16, rotary_pos_emb=True, attn_flash=True ) ) model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) print_sep() print("Loading model checkpoint...") checkpoint = hf_hub_download( repo_id='asigalov61/Orpheus-Music-Transformer', filename=MODEL_CHECKPOINT ) model.cuda() model = torch.compile(model) model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True)) model.eval() print_sep() print("Done!") print("Model will use", dtype, "precision...") print_sep() # ----------------------------- # HELPER FUNCTIONS # ----------------------------- def render_midi_output(final_composition): """Generate MIDI score, plot, and audio from final composition.""" fname, midi_score = save_midi(final_composition) time_val = midi_score[-1][1] / 1000 # seconds marker from last note midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=[], return_plt=True ) midi_audio = midi_to_colab_audio( fname + '.mid', soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True ) return (16000, midi_audio), midi_plot, fname + '.mid', time_val # ----------------------------- # MIDI PROCESSING FUNCTIONS # ----------------------------- def load_midi(input_midi, apply_sustains=True, remove_duplicate_pitches=True, remove_overlapping_durations=True ): """Process the input MIDI file and create a token sequence.""" raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=apply_sustains ) if escore_notes: escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True ) if remove_duplicate_pitches: escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes) if remove_overlapping_durations: escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, min_notes_gap=0 ) dscore = TMIDIX.delta_score_notes(escore_notes) dcscore = TMIDIX.chordify_score([d[1:] for d in dscore]) melody_chords = [18816] #======================================================= # MAIN PROCESSING CYCLE #======================================================= for i, c in enumerate(dcscore): delta_time = c[0][0] melody_chords.append(delta_time) for e in c: #======================================================= # Durations dur = max(1, min(255, e[1])) # Patches pat = max(0, min(128, e[5])) # Pitches ptc = max(1, min(127, e[3])) # Velocities # Calculating octo-velocity vel = max(8, min(127, e[4])) velocity = round(vel / 15)-1 #======================================================= # FINAL NOTE SEQ #======================================================= # Writing final note pat_ptc = (128 * pat) + ptc dur_vel = (8 * dur) + velocity melody_chords.extend([pat_ptc+256, dur_vel+16768]) return melody_chords else: return [18816] def save_midi(tokens): """Convert token sequence back to a MIDI score and write it using TMIDIX. """ time = 0 dur = 1 vel = 90 pitch = 60 channel = 0 patch = 0 patches = [-1] * 16 channels = [0] * 16 channels[9] = 1 song_f = [] for ss in tokens: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 16768: patch = (ss-256) // 128 if patch < 128: if patch not in patches: if 0 in channels: cha = channels.index(0) channels[cha] = 1 else: cha = 15 patches[cha] = patch channel = patches.index(patch) else: channel = patches.index(patch) if patch == 128: channel = 9 pitch = (ss-256) % 128 if 16768 <= ss < 18816: dur = ((ss-16768) // 8) * 16 vel = (((ss-16768) % 8)+1) * 15 song_f.append(['note', time, dur, channel, pitch, vel, patch]) song_f = TMIDIX.remove_duplicate_pitches_from_escore_notes(song_f) song_f = TMIDIX.fix_escore_notes_durations(song_f, min_notes_gap=0 ) output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) fname = f"Orpheus-Music-Transformer-Composition" TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter( output_score, output_signature='Orpheus Music Transformer', output_file_name=fname, track_name='Project Los Angeles', list_of_MIDI_patches=patches, verbose=False ) return fname, output_score # ----------------------------- # MUSIC GENERATION FUNCTION (Combined) # ----------------------------- @spaces.GPU def generate_music(prime, num_gen_tokens, num_gen_batches, model_temperature, model_top_p): """Generate music tokens given prime tokens and parameters.""" if len(prime) >= 6656: prime = [18816] + prime[-6656:] inputs = prime print("Generating...") inp = torch.LongTensor([inputs] * num_gen_batches).cuda() if model_top_p < 1: with ctx: out = model.generate( inp, num_gen_tokens, filter_logits_fn=top_p, filter_kwargs={'thres': model_top_p}, temperature=model_temperature, eos_token=18818, return_prime=False, verbose=False ) else: with ctx: out = model.generate( inp, num_gen_tokens, temperature=model_temperature, eos_token=18818, return_prime=False, verbose=False ) print("Done!") print_sep() return out.tolist() def generate_music_and_state(input_midi, apply_sustains, remove_duplicate_pitches, remove_overlapping_durations, prime_instruments, num_prime_tokens, num_gen_tokens, model_temperature, model_top_p, add_drums, add_outro, final_composition, generated_batches, block_lines ): """ Generate tokens using the model, update the composition state, and prepare outputs. This function combines seed loading, token generation, and UI output packaging. """ print_sep() print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) start_time = reqtime.time() print_sep() if input_midi is not None: fn = os.path.basename(input_midi.name) fn1 = fn.split('.')[0] print('Input file name:', fn) print('Apply sustains:', apply_sustains) print('Remove duplicate pitches:', remove_duplicate_pitches) print('Remove overlapping duriations', remove_overlapping_durations) print('Prime instruments:', prime_instruments) print('Num prime tokens:', num_prime_tokens) print('Num gen tokens:', num_gen_tokens) print('Model temp:', model_temperature) print('Model top p:', model_top_p) print('Add drums:', add_drums) print('Add outro:', add_outro) print_sep() # Load seed from MIDI if there is no existing composition. if not final_composition and input_midi is not None: final_composition = load_midi(input_midi, apply_sustains=apply_sustains, remove_duplicate_pitches=remove_duplicate_pitches, remove_overlapping_durations=remove_overlapping_durations ) if num_prime_tokens < 6656: final_composition = final_composition[:num_prime_tokens] midi_fname, midi_score = save_midi(final_composition) # Use the last note's time as a marker. block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0) if not final_composition and input_midi is None: final_composition = [18816, 0] for i, instr in enumerate(prime_instruments): instr_num = patch2number[instr] final_composition.append((128*instr_num)+(72-(i*12))+256) final_composition.append((8*16)+5+16768) if final_composition: if add_outro: final_composition.append(18817) # Outro token if add_drums: drum_pitches = random.sample([35, 36, 41, 43, 45], k=1) for dp in drum_pitches: final_composition.extend([(128*128)+dp+256]) # Drum patch/pitch token print_sep() print('Composition has', len(final_composition), 'tokens') print_sep() batched_gen_tokens = generate_music(final_composition, num_gen_tokens, NUM_OUT_BATCHES, model_temperature, model_top_p) output_batches = [] for i, tokens in enumerate(batched_gen_tokens): preview_tokens = final_composition[-PREVIEW_LENGTH:] midi_fname, midi_score = save_midi(preview_tokens + tokens) plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True} if len(final_composition) > PREVIEW_LENGTH: plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if 256 <= t < 16768]) midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True) output_batches.append([(16000, midi_audio), midi_plot, tokens]) # Update generated_batches (for use by add/remove functions) generated_batches = batched_gen_tokens # Flatten outputs: states then audio and plots for each batch. outputs_flat = [] for batch in output_batches: outputs_flat.extend([batch[0], batch[1]]) print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) print_sep() end_time = reqtime.time() execution_time = end_time - start_time print(f"Request execution time: {execution_time} seconds") print_sep() return [final_composition, generated_batches, block_lines] + outputs_flat # ----------------------------- # BATCH HANDLING FUNCTIONS # ----------------------------- def add_batch(batch_number, final_composition, generated_batches, block_lines): """Add tokens from the specified batch to the final composition and update outputs.""" if generated_batches: final_composition.extend(generated_batches[batch_number]) midi_fname, midi_score = save_midi(final_composition) block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0) midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=block_lines[:-1], return_plt=True ) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True) print("Added batch #", batch_number) print_sep() return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines else: return None, None, None, [], [], [] def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines): """Remove tokens from the final composition and update outputs.""" if final_composition and len(final_composition) > num_tokens: final_composition = final_composition[:-num_tokens] if block_lines: block_lines.pop() midi_fname, midi_score = save_midi(final_composition) midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=block_lines[:-1], return_plt=True ) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True) print("Removed batch #", batch_number) print_sep() return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines else: return None, None, None, [], [], [] def clear(): """Clear outputs and reset state.""" print_sep() print('Clear batch...') print_sep() return None, None, None, [], [] def reset(final_composition=[], generated_batches=[], block_lines=[]): """Reset composition state.""" print_sep() print('Reset composition...') print_sep() return [], [], [] patch2number = {v: k for k, v in TMIDIX.Number2patch.items()} patch2number['Drums'] = 128 # ----------------------------- # GRADIO INTERFACE SETUP # ----------------------------- with gr.Blocks() as demo: gr.Markdown("

Orpheus Music Transformer

") gr.Markdown("

SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs

") gr.Markdown("

๐Ÿ”ฅ[2026]๐Ÿ”ฅ Now featuring large optimized model!

") gr.HTML(""" Check out Godzilla MIDI Dataset on Hugging Face

Duplicate in Hugging Face

for faster execution and endless generation! """) gr.HTML("""
Project Los Angeles ยท Orpheus Music Transformer
""") gr.Markdown("## Key Features") gr.Markdown(""" - **Efficient Architecture with RoPE**: Large optimized 748M full attention autoregressive transformer with RoPE. - **Extended Sequence Length**: 8k tokens that comfortably fit most music compositions and facilitate long-term music structure generation. - **Premium Training Data**: Trained solely on the highest-quality MIDIs from the Godzilla MIDI dataset. - **Optimized MIDI Encoding**: Extremely efficient MIDI representation using only 3 tokens per note and 7 tokens per tri-chord. - **Distinct Encoding Order**: Features a unique duration/velocity last MIDI encoding order for refined musical expression. - **Full-Range Instrumental Learning**: True full-range MIDI instruments encoding enabling the model to learn each instrument separately. - **Natural Composition Endings**: Outro tokens that help generate smooth and natural musical conclusions. """) gr.Markdown( """ ## If you enjoyed Orpheus Music Transformer, please star and duplicate. It helps a lot! ๐Ÿค— ### [โญ Star this Space](https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer) ### [๐Ÿ” Duplicate this Space](https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer?duplicate=true) ### [โญ Star models repo](https://huggingface.co/asigalov61/Orpheus-Music-Transformer) """ ) # Global state variables for composition final_composition = gr.State([]) generated_batches = gr.State([]) block_lines = gr.State([]) gr.Markdown("## Upload seed MIDI or click 'Generate' for random output") gr.Markdown("### PLEASE NOTE:") gr.Markdown("* Orpheus Music Transformer is a primarily music continuation/co-composition model!") gr.Markdown("* The model works best if given some music context to work with") gr.Markdown("* Random generation from SOS token/embeddings may not always produce good results") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) input_midi.upload(reset, [final_composition, generated_batches, block_lines], [final_composition, generated_batches, block_lines]) apply_sustains = gr.Checkbox(value=True, label="Apply sustains (if present)") remove_duplicate_pitches = gr.Checkbox(value=True, label="Remove duplicate pitches (if present)") remove_overlapping_durations = gr.Checkbox(value=True, label="Trim overlapping durations (if present)") gr.Markdown("## Generation options") prime_instruments = gr.Dropdown(label="Prime instruments (select up to 5)", choices=list(patch2number.keys()), multiselect=True, max_choices=5, type="value", info="Instruments are asigned from top to bottom in order of selection. Custom MIDI overrides prime instruments." ) prime_instruments.input(reset, [final_composition, generated_batches, block_lines], [final_composition, generated_batches, block_lines]) num_prime_tokens = gr.Slider(16, 6656, value=6656, step=1, label="Number of prime tokens") num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate") model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature") model_top_p = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top p value") add_drums = gr.Checkbox(value=False, label="Add drums") add_outro = gr.Checkbox(value=False, label="Add an outro") generate_btn = gr.Button("Generate", variant="primary") gr.Markdown("## Batch Previews") outputs = [final_composition, generated_batches, block_lines] # Two outputs (audio and plot) for each batch for i in range(NUM_OUT_BATCHES): with gr.Tab(f"Batch # {i}"): audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3") plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot") outputs.extend([audio_output, plot_output]) generate_btn.click( generate_music_and_state, [input_midi, apply_sustains, remove_duplicate_pitches, remove_overlapping_durations, prime_instruments, num_prime_tokens, num_gen_tokens, model_temperature, model_top_p, add_drums, add_outro, final_composition, generated_batches, block_lines ], outputs ) gr.Markdown("## Add/Remove Batch") batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove") add_btn = gr.Button("Add batch", variant="primary") remove_btn = gr.Button("Remove batch", variant="stop") clear_btn = gr.ClearButton() final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3") final_plot_output = gr.Plot(label="Final MIDI plot") final_file_output = gr.File(label="Final MIDI file") add_btn.click( add_batch, [batch_number, final_composition, generated_batches, block_lines], [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines] ) remove_btn.click( remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines], [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines] ) clear_btn.click(clear, inputs=None, outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines]) demo.launch(mcp_server=True)