""" Translation interface using the MADLAD-400 3B model. Translates between 418 languages from the MADLAD-400 paper. """ import os import time import warnings from collections.abc import Generator from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> T5TokenizerFast: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> T5ForConditionalGeneration: device = _get_device() dtype = torch.float16 if device.type == "cuda" else torch.float32 return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code: dict[str, str] = {} for code, info in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{info['name']} ({locale})" name_to_code[display_name] = code # Sort by region, then alphabetically within each region sorted_names = sorted( name_to_code.keys(), key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), ) return name_to_code, sorted_names def _maybe_eager_load() -> None: """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack weights to disk at startup and stream them into each worker's VRAM (fast cold starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU.""" if os.environ.get("SPACES_ZERO_GPU") == "1": _load_tokenizer() _load_model() def _estimate_duration( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> int: """Reserve GPU time scaled to the worst case: generation cost grows with the number of tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the duration callable with the decorated function's args). Conservative and capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing duration').""" del text, target_language_name, temperature # only token/beam counts drive runtime return min(120, 30 + (max_new_tokens * num_beams) // 8) @spaces.GPU(duration=_estimate_duration) def translate( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> str: tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and temperature != 1.0: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} if num_beams == 1: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature start = time.perf_counter() outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.perf_counter() - start print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s") return result def _translate_with_loading( text: str, target_language_name: str, ) -> Generator[tuple, None, None]: yield gr.update(value="Translating...", interactive=False), gr.update() result = translate(text, target_language_name) yield gr.update(value="Translate", interactive=True), result def _swap_languages( source_lang: str, target_lang: str, source_text: str, target_text: str ) -> tuple[str, str, str, str]: """Swap source/target languages and their text.""" return target_lang, source_lang, target_text, source_text def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translate") as demo: with gr.Row(): source_language = gr.Dropdown( choices=language_names, value="English (en)", show_label=False, filterable=True, ) swap_btn = gr.Button("⇄", scale=0, min_width=60) target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, ) with gr.Row(equal_height=True): input_text = gr.Textbox( lines=6, max_length=2000, show_label=False, autofocus=True, ) output_text = gr.Textbox( placeholder="Translation", lines=6, show_label=False, interactive=False, buttons=["copy"], ) translate_btn = gr.Button("Translate", variant="primary") # UI-only handlers: kept off the public API surface (private) so only /translate is exposed. swap_btn.click( fn=_swap_languages, inputs=[source_language, target_language, input_text, output_text], outputs=[source_language, target_language, input_text, output_text], api_visibility="private", ) translate_btn.click( fn=_translate_with_loading, inputs=[input_text, target_language], outputs=[translate_btn, output_text], api_visibility="private", ) input_text.submit( fn=translate, inputs=[input_text, target_language], outputs=output_text, api_name="translate", ) return demo demo = _build_demo() _maybe_eager_load() def main() -> None: demo.launch(theme=gr.themes.Ocean()) if __name__ == "__main__": main()