Spaces:
Running on Zero
Running on Zero
| """ | |
| Translation interface using the MADLAD-400 3B model. | |
| Translates English text to nearly 400 languages. | |
| """ | |
| import warnings | |
| from functools import lru_cache | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from langmap.langid_mapping import langid_to_language | |
| MODEL_NAME = "google/madlad400-3b-mt" | |
| def _get_device() -> torch.device: | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| warnings.warn("No GPU available. Running on CPU — translation will be very slow.", stacklevel=2) | |
| return torch.device("cpu") | |
| def _load_tokenizer() -> AutoTokenizer: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| if tokenizer is None: | |
| raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") | |
| return tokenizer | |
| def _load_model() -> AutoModelForSeq2SeqLM: | |
| device = _get_device() | |
| return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) | |
| def _build_language_mappings() -> tuple[dict[str, str], list[str]]: | |
| tokenizer = _load_tokenizer() | |
| vocab = tokenizer.get_vocab() | |
| name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab} | |
| return name_to_code, sorted(name_to_code.keys()) | |
| def translate(text: str, target_language_name: str) -> str: | |
| tokenizer = _load_tokenizer() | |
| model = _load_model() | |
| device = _get_device() | |
| name_to_code, _ = _build_language_mappings() | |
| target_code = name_to_code.get(target_language_name) | |
| if target_code is None: | |
| raise ValueError(f"Unsupported language: {target_language_name}") | |
| input_ids = tokenizer(target_code + text, return_tensors="pt").input_ids.to(device) | |
| outputs = model.generate(input_ids=input_ids, max_new_tokens=512) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if not isinstance(result, str): | |
| raise TypeError(f"Expected str from decode, got {type(result)}") | |
| return result | |
| def main() -> None: | |
| _, language_names = _build_language_mappings() | |
| demo = gr.Interface( | |
| fn=translate, | |
| inputs=[ | |
| gr.Textbox(label="Text", placeholder="Enter text here"), | |
| gr.Dropdown(choices=language_names, value="Hawaiian", label="Target language"), | |
| ], | |
| outputs=gr.Textbox(label="Translation"), | |
| title="MADLAD-400 Translation", | |
| description="Translation from English to (almost) 400 languages based on " | |
| "[research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research.", | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |