import gradio as gr import torch import numpy as np import librosa from transformers import pipeline, VitsModel, AutoTokenizer import scipy # if needed for processing # ----------------------------------------------- # 1. ASR Pipeline (English) # ----------------------------------------------- asr = pipeline( "automatic-speech-recognition", model="facebook/wav2vec2-base-960h" ) # ----------------------------------------------- # 2. Translation Models (3 languages) # ----------------------------------------------- translation_models = { "Spanish": "Helsinki-NLP/opus-mt-en-es", "Chinese": "Helsinki-NLP/opus-mt-en-zh", "Japanese": "Helsinki-NLP/opus-mt-en-ja" } translation_tasks = { "Spanish": "translation_en_to_es", "Chinese": "translation_en_to_zh", "Japanese": "translation_en_to_ja" } # ----------------------------------------------- # 3. TTS Model Configurations # We'll load them manually (not with pipeline("text-to-speech")) # ----------------------------------------------- # - Spanish (MMS TTS, uses VITS architecture) # - Chinese (MMS TTS, uses VITS architecture) # - Japanese (SpeechT5 or a VITS-based model—here we pick a SpeechT5 example) tts_config = { "Spanish": { "model_id": "facebook/mms-tts-spa", "architecture": "vits" # We'll use VitsModel }, "Chinese": { "model_id": "facebook/mms-tts-che", "architecture": "vits" }, "Japanese": { "model_id": "esnya/japanese_speecht5_tts", "architecture": "speecht5" # We'll treat this differently } } # ----------------------------------------------- # 4. Caches # ----------------------------------------------- translator_cache = {} tts_model_cache = {} # store (model, tokenizer, architecture) # ----------------------------------------------- # 5. Translator Helper # ----------------------------------------------- def get_translator(lang): if lang in translator_cache: return translator_cache[lang] model_name = translation_models[lang] task_name = translation_tasks[lang] translator = pipeline(task_name, model=model_name) translator_cache[lang] = translator return translator # ----------------------------------------------- # 6. TTS Helper # ----------------------------------------------- def get_tts_model(lang): """ Loads (model, tokenizer, architecture) from Hugging Face once, then caches. """ if lang in tts_model_cache: return tts_model_cache[lang] config = tts_config.get(lang) if config is None: raise ValueError(f"No TTS config found for language: {lang}") model_id = config["model_id"] arch = config["architecture"] try: if arch == "vits": # Load a VitsModel + tokenizer model = VitsModel.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) elif arch == "speecht5": # For a SpeechT5 model, we might do something else # e.g., pipeline("text-to-speech", model=...) if it works # or custom loading if it's also a VITS-based approach # We'll attempt a similar pattern: model = VitsModel.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) else: raise ValueError(f"Unknown TTS architecture: {arch}") except Exception as e: raise RuntimeError(f"Failed to load TTS model {model_id}: {e}") tts_model_cache[lang] = (model, tokenizer, arch) return tts_model_cache[lang] def run_tts_inference(lang, text): """ Generates waveform using the loaded TTS model and tokenizer. Returns (sample_rate, np_array). """ model, tokenizer, arch = get_tts_model(lang) inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): output = model(**inputs) # VitsModel output is typically `.waveform` if hasattr(output, "waveform"): waveform_tensor = output.waveform else: # Some models might return a different attribute raise RuntimeError("The TTS model output doesn't have 'waveform' attribute.") # Convert to numpy array waveform = waveform_tensor.squeeze().cpu().numpy() # Typically, MMS TTS uses 16 kHz sample_rate = 16000 return (sample_rate, waveform) # ----------------------------------------------- # 7. Prediction Function # ----------------------------------------------- def predict(audio, text, target_language): """ 1. If text is provided, use it directly as English text. Else, if audio is provided, run ASR. 2. Translate English -> target_language. 3. Run TTS with the correct approach for that language. """ # Step 1: English text if text.strip(): english_text = text.strip() elif audio is not None: sample_rate, audio_data = audio # Convert to float32 if audio_data.dtype not in [np.float32, np.float64]: audio_data = audio_data.astype(np.float32) # Mono if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: audio_data = np.mean(audio_data, axis=1) # Resample to 16k if sample_rate != 16000: audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) asr_input = {"array": audio_data, "sampling_rate": 16000} asr_result = asr(asr_input) english_text = asr_result["text"] else: return "No input provided.", "", None # Step 2: Translation translator = get_translator(target_language) try: translation_result = translator(english_text) translated_text = translation_result[0]["translation_text"] except Exception as e: return english_text, f"Translation error: {e}", None # Step 3: TTS try: sample_rate, waveform = run_tts_inference(target_language, translated_text) except Exception as e: return english_text, translated_text, f"TTS error: {e}" return english_text, translated_text, (sample_rate, waveform) # ----------------------------------------------- # 8. Gradio Interface # ----------------------------------------------- iface = gr.Interface( fn=predict, inputs=[ gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"), gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"), gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language") ], outputs=[ gr.Textbox(label="English Transcription"), gr.Textbox(label="Translation (Target Language)"), gr.Audio(label="Synthesized Speech in Target Language") ], title="Multimodal Language Learning Aid (VITS-based TTS)", description=( "This app:\n" "1. Transcribes English speech (via ASR) or accepts English text.\n" "2. Translates to Spanish, Chinese, or Japanese.\n" "3. Synthesizes speech with VITS-based or SpeechT5-based models.\n\n" "Note: Some models are experimental and may produce errors or poor quality.\n" "Either upload/record English audio or enter text, then select a target language." ), allow_flagging="never" ) if __name__ == "__main__": iface.launch()