""" Translation interface using the MADLAD-400 3B model. Translates English text to nearly 400 languages. """ import warnings from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be very slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> AutoModelForSeq2SeqLM: device = _get_device() return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab} return name_to_code, sorted(name_to_code.keys()) @spaces.GPU def translate(text: str, target_language_name: str) -> str: tokenizer = _load_tokenizer() model = _load_model() device = _get_device() name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") input_ids = tokenizer(target_code + text, return_tensors="pt").input_ids.to(device) outputs = model.generate(input_ids=input_ids, max_new_tokens=512) result = tokenizer.decode(outputs[0], skip_special_tokens=True) if not isinstance(result, str): raise TypeError(f"Expected str from decode, got {type(result)}") return result def main() -> None: _, language_names = _build_language_mappings() demo = gr.Interface( fn=translate, inputs=[ gr.Textbox(label="Text", placeholder="Enter text here"), gr.Dropdown(choices=language_names, value="Hawaiian", label="Target language"), ], outputs=gr.Textbox(label="Translation"), title="MADLAD-400 Translation", description="Translation from English to (almost) 400 languages based on " "[research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research.", ) demo.launch() if __name__ == "__main__": main()