""" Translation interface using the MADLAD-400 3B model. Translates English text to 22 production-ready languages from the MADLAD-400 paper. """ import warnings from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> AutoModelForSeq2SeqLM: device = _get_device() dtype = torch.float16 if device.type == "cuda" else torch.float32 return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code = {} for code, name in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{name} ({locale})" name_to_code[display_name] = code return name_to_code, sorted(name_to_code.keys()) @spaces.GPU def translate( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> str: tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and temperature != 1.0: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} if num_beams == 1: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) if not isinstance(result, str): raise TypeError(f"Expected str from decode, got {type(result)}") return result def _update_char_count(text: str) -> str: return f"{len(text)} characters" def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translation") as demo: with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("**English**") input_text = gr.Textbox( placeholder="Enter English text here", lines=4, show_label=False, ) char_count = gr.Markdown("0 characters") with gr.Column(): target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, ) output_text = gr.Textbox( placeholder="Translation", lines=4, show_label=False, interactive=False, buttons=["copy"], ) with gr.Row(): clear_btn = gr.Button("Clear") translate_btn = gr.Button("Translate", variant="primary") translate_btn.click( fn=translate, inputs=[input_text, target_language], outputs=output_text, ) input_text.submit( fn=translate, inputs=[input_text, target_language], outputs=output_text, ) input_text.input( fn=_update_char_count, inputs=input_text, outputs=char_count, ) clear_btn.click( fn=lambda: ("", "", "0 characters"), outputs=[input_text, output_text, char_count], ) return demo demo = _build_demo() def main() -> None: demo.launch(theme=gr.themes.Soft()) if __name__ == "__main__": main()