| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import torch |
| import soundfile as sf |
| import logging |
| import gradio as gr |
| import platform |
| import numpy as np |
| from pathlib import Path |
| from datetime import datetime |
| import tempfile |
|
|
| |
| from transformers import AutoProcessor, AutoModel |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2" |
| cache_dir = "model_cache" |
|
|
| |
| |
| LEVELS_MAP_UI = { |
| 1: "very_low", |
| 2: "low", |
| 3: "moderate", |
| 4: "high", |
| 5: "very_high" |
| } |
|
|
| |
| def load_model_and_processor(model_id, cache_dir): |
| """Loads the Processor and Model using Transformers.""" |
| logging.info(f"Loading processor from: {model_id}") |
| try: |
| processor = AutoProcessor.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| |
| cache_dir=cache_dir |
| ) |
| logging.info("Processor loaded successfully.") |
| except Exception as e: |
| logging.error(f"Error loading processor: {e}") |
| raise |
|
|
| logging.info(f"Loading model from: {model_id}") |
| try: |
| model = AutoModel.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| cache_dir=cache_dir, |
| |
| ) |
| model.eval() |
| logging.info("Model loaded successfully.") |
| except Exception as e: |
| logging.error(f"Error loading model: {e}") |
| raise |
|
|
| |
| |
| processor.model = model |
| logging.info("Model reference set in processor.") |
|
|
| |
| if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate: |
| logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.") |
| processor.sampling_rate = model.config.sample_rate |
|
|
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif platform.system() == "Darwin" and torch.backends.mps.is_available(): |
| |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
|
|
| logging.info(f"Selected device: {device}") |
| model.to(device) |
| logging.info(f"Model moved to device: {device}") |
|
|
| return processor, model, device |
|
|
| |
| try: |
| processor, model, device = load_model_and_processor(model_id, cache_dir) |
| MODEL_LOADED = True |
| except Exception as e: |
| MODEL_LOADED = False |
| logging.error(f"Failed to load model/processor: {e}") |
| |
|
|
| |
|
|
| def run_voice_clone_tts( |
| text, |
| prompt_speech_path, |
| prompt_text, |
| processor, |
| model, |
| device, |
| ): |
| """Performs voice cloning TTS using Transformers.""" |
| if not MODEL_LOADED: |
| return None, "Error: Model not loaded." |
| if not text: |
| return None, "Error: Please provide text to synthesize." |
| if not prompt_speech_path: |
| return None, "Error: Please provide a prompt audio file (upload or record)." |
|
|
| logging.info("Starting voice cloning inference...") |
| logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'") |
|
|
| try: |
| |
| prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip() |
|
|
| |
| inputs = processor( |
| text=text.lower(), |
| prompt_speech_path=prompt_speech_path, |
| prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean, |
| return_tensors="pt" |
| ).to(device) |
|
|
| |
| global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) |
| if global_tokens_prompt is None: |
| logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.") |
|
|
| |
| with torch.no_grad(): |
| |
| |
| |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=3000, |
| do_sample=True, |
| temperature=0.8, |
| top_k=50, |
| top_p=0.95, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| pad_token_id=processor.tokenizer.pad_token_id |
| ) |
|
|
| |
| output_clone = processor.decode( |
| generated_ids=output_ids, |
| global_token_ids_prompt=global_tokens_prompt, |
| input_ids_len=inputs["input_ids"].shape[-1] |
| ) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
| sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"]) |
| output_path = tmpfile.name |
|
|
| logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}") |
| return output_path, None |
|
|
| except Exception as e: |
| logging.error(f"Error during voice cloning inference: {e}", exc_info=True) |
| return None, f"Error during generation: {e}" |
|
|
|
|
| def run_voice_creation_tts( |
| text, |
| gender, |
| pitch_level, |
| speed_level, |
| processor, |
| model, |
| device, |
| ): |
| """Performs voice creation TTS using Transformers.""" |
| if not MODEL_LOADED: |
| return None, "Error: Model not loaded." |
| if not text: |
| return None, "Error: Please provide text to synthesize." |
|
|
| |
| pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") |
| speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") |
|
|
| logging.info("Starting voice creation inference...") |
| logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})") |
|
|
| try: |
| |
| inputs = processor( |
| text=text.lower(), |
| |
| |
| gender=gender, |
| pitch=pitch_str, |
| speed=speed_str, |
| return_tensors="pt" |
| ).to(device) |
|
|
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=3000, |
| do_sample=True, |
| temperature=0.8, |
| top_k=50, |
| top_p=0.95, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| pad_token_id=processor.tokenizer.pad_token_id |
| ) |
|
|
| |
| output_create = processor.decode( |
| generated_ids=output_ids, |
| input_ids_len=inputs["input_ids"].shape[-1] |
| ) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
| sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"]) |
| output_path = tmpfile.name |
|
|
| logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}") |
| return output_path, None |
|
|
| except Exception as e: |
| logging.error(f"Error during voice creation inference: {e}", exc_info=True) |
| return None, f"Error during generation: {e}" |
|
|
|
|
| |
| def build_ui(): |
| with gr.Blocks() as demo: |
| gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') |
| gr.Markdown( |
| "Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). " |
| "Choose a tab for Voice Cloning or Voice Creation." |
| ) |
|
|
| if not MODEL_LOADED: |
| gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.") |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("Voice Clone"): |
| gr.Markdown( |
| "### Upload Reference Audio or Record" |
| ) |
| gr.Markdown( |
| "Provide a short audio clip (5-20 seconds) of the voice you want to clone. " |
| "Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize." |
| ) |
|
|
| with gr.Row(): |
| prompt_wav_upload = gr.Audio( |
| sources=["upload"], |
| type="filepath", |
| label="Upload Prompt Audio File (WAV/MP3)", |
| ) |
| prompt_wav_record = gr.Audio( |
| sources=["microphone"], |
| type="filepath", |
| label="Or Record Prompt Audio", |
| ) |
|
|
| with gr.Row(): |
| text_input_clone = gr.Textbox( |
| label="Text to Synthesize", |
| lines=4, |
| placeholder="Enter text here..." |
| ) |
| prompt_text_input = gr.Textbox( |
| label="Text of Prompt Speech (Optional)", |
| lines=2, |
| placeholder="Enter the transcript of the prompt audio (if available).", |
| info="Recommended for cloning in the same language." |
| ) |
|
|
| audio_output_clone = gr.Audio( |
| label="Generated Audio", |
| autoplay=False, |
| ) |
| status_clone = gr.Textbox(label="Status", interactive=False) |
|
|
| generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED) |
|
|
| def voice_clone_callback(text, prompt_text, audio_upload, audio_record): |
| |
| prompt_speech = audio_upload if audio_upload else audio_record |
| if not prompt_speech: |
| |
| return None, "Error: Please upload or record a reference audio." |
|
|
| |
| output_path, error_msg = run_voice_clone_tts( |
| text, |
| prompt_speech, |
| prompt_text, |
| processor, |
| model, |
| device |
| ) |
| if error_msg: |
| return None, error_msg |
| else: |
| |
| return output_path, "Audio generated successfully!" |
|
|
|
|
| generate_button_clone.click( |
| voice_clone_callback, |
| inputs=[ |
| text_input_clone, |
| prompt_text_input, |
| prompt_wav_upload, |
| prompt_wav_record, |
| ], |
| outputs=[audio_output_clone, status_clone], |
| ) |
|
|
| |
| |
| gr.Examples( |
| examples=[ |
| ["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None], |
| ["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], |
| ["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None] |
| ], |
| inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record], |
| outputs=[audio_output_clone, status_clone], |
| fn=voice_clone_callback, |
| cache_examples=False, |
| label="Clone Examples" |
| ) |
|
|
|
|
| |
| with gr.TabItem("Voice Creation"): |
| gr.Markdown( |
| "### Create Your Own Voice Based on the Following Parameters" |
| ) |
| gr.Markdown( |
| "Select gender, adjust pitch and speed to generate a new synthetic voice." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gender = gr.Radio( |
| choices=["male", "female"], value="female", label="Gender" |
| ) |
| pitch = gr.Slider( |
| minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)" |
| ) |
| speed = gr.Slider( |
| minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)" |
| ) |
| with gr.Column(scale=2): |
| text_input_creation = gr.Textbox( |
| label="Text to Synthesize", |
| lines=5, |
| placeholder="Enter text here...", |
| value="You can generate a customized voice by adjusting parameters such as pitch and speed.", |
| ) |
|
|
| audio_output_creation = gr.Audio( |
| label="Generated Audio", |
| autoplay=False, |
| ) |
| status_create = gr.Textbox(label="Status", interactive=False) |
|
|
| create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED) |
|
|
| def voice_creation_callback(text, gender, pitch_val, speed_val): |
| |
| output_path, error_msg = run_voice_creation_tts( |
| text, |
| gender, |
| int(pitch_val), |
| int(speed_val), |
| processor, |
| model, |
| device |
| ) |
| if error_msg: |
| return None, error_msg |
| else: |
| return output_path, "Audio generated successfully!" |
|
|
| create_button.click( |
| voice_creation_callback, |
| inputs=[text_input_creation, gender, pitch, speed], |
| outputs=[audio_output_creation, status_create], |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ["This is a female voice with average pitch and speed.", "female", 3, 3], |
| ["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4], |
| ["A deep and slow female voice.", "female", 1, 2], |
| ["A very high-pitched and fast male voice.", "male", 5, 5] |
| ], |
| inputs=[text_input_creation, gender, pitch, speed], |
| outputs=[audio_output_creation, status_create], |
| fn=voice_creation_callback, |
| cache_examples=False, |
| label="Creation Examples" |
| ) |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| demo = build_ui() |
| demo.launch() |