""" Translation interface using the MADLAD-400 3B model. Translates between 418 languages from the MADLAD-400 paper. """ import os import time import warnings from collections.abc import Generator from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" # Target tokens whose script is right-to-left, used to flip the output box's direction. # The langmap stores a geographic "region", not script direction, and region is NOT a usable # proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>, # Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia". # So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani # Kurdish <2ckb> (Arabic script) is RTL. RTL_CODES = frozenset( { "<2ar>", # Arabic "<2mey>", # Hassaniyya (Arabic script) "<2he>", # Hebrew "<2ks>", # Kashmiri (Perso-Arabic) "<2ckb>", # Kurdish, Sorani (Arabic script) "<2lrc>", # Northern Luri (Perso-Arabic) "<2luz>", # Southern Luri (Perso-Arabic) "<2bgp>", # Eastern Baluchi (Perso-Arabic) "<2ps>", # Pashto "<2fa>", # Persian "<2skr>", # Saraiki (Perso-Arabic) "<2sd>", # Sindhi (Perso-Arabic) "<2syr>", # Syriac "<2ur>", # Urdu "<2ug>", # Uyghur (Arabic script) "<2yi>", # Yiddish (Hebrew script) "<2dv>", # Dhivehi (Thaana) } ) def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> T5TokenizerFast: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> T5ForConditionalGeneration: device = _get_device() # T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its # activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use # bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU, # fp32-range exponent) and float32 on CPU. dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 # MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only # (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual # ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes # (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either. return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code: dict[str, str] = {} for code, info in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{info['name']} ({locale})" name_to_code[display_name] = code # Sort by region, then alphabetically within each region sorted_names = sorted( name_to_code.keys(), key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), ) return name_to_code, sorted_names def _maybe_eager_load() -> None: """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack weights to disk at startup and stream them into each worker's VRAM (fast cold starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU.""" if os.environ.get("SPACES_ZERO_GPU") == "1": _load_tokenizer() _load_model() def _estimate_duration( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> int: """Reserve GPU time scaled to the worst case: generation cost grows with the number of tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the duration callable with the decorated function's args). Conservative and capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing duration').""" del text, target_language_name, temperature # only token/beam counts drive runtime return min(120, 30 + (max_new_tokens * num_beams) // 8) @spaces.GPU(duration=_estimate_duration) def translate( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> str: # No-op on empty/whitespace input: skip the model entirely rather than feeding a bare # "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives # here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths # (which call translate() directly) are covered too. Returns a str, so the contract holds. if not text.strip(): return "" tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and temperature != 1.0: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} # Greedy by default (deterministic, higher-quality MT). Only sample when the user # explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it. if num_beams == 1 and temperature != 1.0: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature start = time.perf_counter() outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.perf_counter() - start print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s") # @spaces.GPU return values cross the worker -> main process boundary via pickle. # Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA # tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks # and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe. return result def _translate_with_loading( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> Generator[tuple[object, object], None, None]: yield gr.update(value="Translating...", interactive=False), gr.update() result = translate(text, target_language_name, int(max_new_tokens), int(num_beams), float(temperature)) # Flip the output box to RTL for right-to-left target scripts so the translation reads # correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path # carries this — the public /translate endpoint stays a bare str to keep its API stable. name_to_code, _ = _build_language_mappings() is_rtl = name_to_code.get(target_language_name) in RTL_CODES output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left") yield gr.update(value="Translate", interactive=True), output_update def _swap_languages( source_lang: str, target_lang: str, source_text: str, target_text: str ) -> tuple[str, str, str, str]: """Swap source/target languages and their text.""" return target_lang, source_lang, target_text, source_text def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translate") as demo: with gr.Row(): source_language = gr.Dropdown( choices=language_names, value="English (en)", show_label=False, filterable=True, # The model auto-detects the source from the text; this control only feeds the # swap button. Say so, so the symmetric layout doesn't imply it sets the source. info="Auto-detected from your text; used only for the swap button.", ) swap_btn = gr.Button("⇄", scale=0, min_width=60) target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, # MADLAD-400 is a research model; the paper notes quality varies sharply by language. info="Quality varies by language; strongest for high-resource languages.", ) with gr.Row(equal_height=True): input_text = gr.Textbox( lines=6, max_length=2000, show_label=False, autofocus=True, info="Press Ctrl+Enter to translate.", ) output_text = gr.Textbox( placeholder="Translation", lines=6, show_label=False, interactive=False, buttons=["copy"], info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0", ) translate_btn = gr.Button("Translate", variant="primary") # translate() already accepts these generation params; expose them here behind a # collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design). # Defaults mirror translate()'s signature so the default surface stays greedy. with gr.Accordion("Advanced", open=False): max_new_tokens = gr.Number( value=512, label="Max new tokens", minimum=1, maximum=1024, precision=0, info="Caps the length of the translation.", ) num_beams = gr.Number( value=1, label="Beams", minimum=1, maximum=8, precision=0, info="Beam-search width; higher is slower and ignores temperature.", ) temperature = gr.Number( value=1.0, label="Temperature", minimum=0.1, maximum=2.0, step=0.1, info="No effect at 1.0 (greedy) or when Beams > 1; lower samples less randomly.", ) # UI-only handlers: kept off the public API surface (private) so only /translate is exposed. swap_btn.click( fn=_swap_languages, inputs=[source_language, target_language, input_text, output_text], outputs=[source_language, target_language, input_text, output_text], api_visibility="private", ) # Both translate handlers carry the advanced params: the button (via the loading-state # wrapper, which also applies the RTL output update) and the public /translate submit. translate_btn.click( fn=_translate_with_loading, inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], outputs=[translate_btn, output_text], api_visibility="private", show_progress="minimal", ) # /translate exposes the advanced params too. They all have defaults, so existing # two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter # honor the Advanced accordion, matching the Translate button. The endpoint returns a # bare str (RTL direction is a UI-only concern handled on the button path). input_text.submit( fn=translate, inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], outputs=output_text, api_name="translate", show_progress="minimal", ) return demo demo = _build_demo() _maybe_eager_load() def main() -> None: demo.launch(theme=gr.themes.Ocean()) if __name__ == "__main__": main()