""" Translation interface using the MADLAD-400 3B model. Translates English text to evaluated languages from the MADLAD-400 paper. """ import warnings from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> AutoModelForSeq2SeqLM: device = _get_device() dtype = torch.float16 if device.type == "cuda" else torch.float32 return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab} return name_to_code, sorted(name_to_code.keys()) @spaces.GPU def translate(text: str, target_language_name: str) -> str: tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) outputs = model.generate(input_ids=input_ids, max_new_tokens=512) result = tokenizer.decode(outputs[0], skip_special_tokens=True) if not isinstance(result, str): raise TypeError(f"Expected str from decode, got {type(result)}") return result def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translation") as demo: gr.Markdown( "# MADLAD-400 Translation\n" "Translate English into 183 languages using Google's MADLAD-400 3B model. " "[Paper](https://arxiv.org/pdf/2309.04662)" ) target_language = gr.Dropdown( choices=language_names, value="French", label="Target language", ) with gr.Row(): input_text = gr.Textbox( label="English", placeholder="Enter English text here", lines=4, ) output_text = gr.Textbox( label="Translation", lines=4, buttons=["copy"], interactive=False, ) translate_btn = gr.Button("Translate", variant="primary") translate_btn.click( fn=translate, inputs=[input_text, target_language], outputs=output_text, ) gr.Examples( examples=[ ["Hello, how are you today?", "French"], ["The weather is beautiful.", "Japanese"], ["Thank you very much.", "Swahili"], ["Where is the train station?", "Hindi"], ], inputs=[input_text, target_language], ) return demo demo = _build_demo() def main() -> None: demo.launch(theme=gr.themes.Soft()) if __name__ == "__main__": main()