""" Translation interface using the MADLAD-400 3B model. Translates between 418 languages from the MADLAD-400 paper. """ import math import os import time import warnings from collections.abc import Generator from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" # Target tokens whose script is right-to-left, used to flip the output box's direction. # The langmap stores a geographic "region", not script direction, and region is NOT a usable # proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>, # Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia". # So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani # Kurdish <2ckb> (Arabic script) is RTL. RTL_CODES = frozenset( { "<2ar>", # Arabic "<2mey>", # Hassaniyya (Arabic script) "<2he>", # Hebrew "<2ks>", # Kashmiri (Perso-Arabic) "<2ckb>", # Kurdish, Sorani (Arabic script) "<2lrc>", # Northern Luri (Perso-Arabic) "<2luz>", # Southern Luri (Perso-Arabic) "<2bgp>", # Eastern Baluchi (Perso-Arabic) "<2ps>", # Pashto "<2fa>", # Persian "<2skr>", # Saraiki (Perso-Arabic) "<2sd>", # Sindhi (Perso-Arabic) "<2syr>", # Syriac "<2ur>", # Urdu "<2ug>", # Uyghur (Arabic script) "<2yi>", # Yiddish (Hebrew script) "<2dv>", # Dhivehi (Thaana) } ) def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> T5TokenizerFast: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> T5ForConditionalGeneration: device = _get_device() # T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its # activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use # bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU, # fp32-range exponent) and float32 on CPU. dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 # MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only # (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual # ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes # (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either. return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code: dict[str, str] = {} for code, info in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{info['name']} ({locale})" name_to_code[display_name] = code # Sort by region, then alphabetically within each region sorted_names = sorted( name_to_code.keys(), key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), ) return name_to_code, sorted_names def _maybe_eager_load() -> None: """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack weights to disk at startup and stream them into each worker's VRAM (fast cold starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU.""" if os.environ.get("SPACES_ZERO_GPU") == "1": _load_tokenizer() _load_model() # Cap the token×beam product so even the largest request fits the GPU time _estimate_duration # reserves. That estimate is 30 + product//8 (clamped at 120s), so a product <= 720 keeps it at # <= 120s, meaning generation can't outlive its reservation and be killed mid-decode. The token # count is trimmed to honour this (not the beam width, which drives translation quality). _MAX_TOKEN_BEAM_PRODUCT = 720 def _normalize_params( max_new_tokens: float | None, num_beams: float | None, temperature: float | None ) -> tuple[int, int, float]: """Coerce the advanced generation params to safe values. A cleared ``gr.Number`` arrives as ``None`` (Gradio skips its bounds check for ``None``) and the public ``/translate`` path passes values uncast, so this is the single funnel every caller — button, submit, API, direct, and the ZeroGPU duration callable — goes through. ``None``/``NaN`` fall back to the defaults; values are clamped to the ranges the Advanced ``gr.Number`` controls advertise; and the token×beam product is capped (``_MAX_TOKEN_BEAM_PRODUCT``) so generation stays within the GPU time it reserved.""" def _num(value: float | None, default: float) -> float: return default if value is None or math.isnan(value) else value beams = int(max(1, min(8, _num(num_beams, 1)))) tokens = int(max(1, min(1024, _num(max_new_tokens, 512)))) tokens = min(tokens, _MAX_TOKEN_BEAM_PRODUCT // beams) temperature = float(max(0.1, min(2.0, _num(temperature, 1.0)))) return tokens, beams, temperature def _estimate_duration( text: str, target_language_name: str, max_new_tokens: int | None = 512, num_beams: int | None = 1, temperature: float | None = 1.0, ) -> int: """Reserve GPU time scaled to the worst case: generation cost grows with the number of tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the duration callable with the decorated function's args, and runs it *before* translate(), so it must tolerate the same cleared-field ``None`` values — normalize first). Conservative and capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing duration').""" del text, target_language_name # only token/beam counts drive runtime max_new_tokens, num_beams, _ = _normalize_params(max_new_tokens, num_beams, temperature) return min(120, 30 + (max_new_tokens * num_beams) // 8) @spaces.GPU(duration=_estimate_duration) def translate( text: str, target_language_name: str, max_new_tokens: int | None = 512, num_beams: int | None = 1, temperature: float | None = 1.0, ) -> str: # No-op on empty/whitespace input: skip the model entirely rather than feeding a bare # "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives # here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths # (which call translate() directly) are covered too. (text or "") stays None-safe for an # API caller that POSTs a null text field. Returns a str, so the contract holds. if not (text or "").strip(): return "" # Normalize the generation params here — translate() is the single source of truth. The # public submit path passes gr.Number values uncast, and a cleared field arrives as # None/NaN, so coerce and clamp before use (the duration callable normalizes identically). max_new_tokens, num_beams, temperature = _normalize_params(max_new_tokens, num_beams, temperature) # Compare with a tolerance so float spinner drift (e.g. 0.1*9 = 0.999…) doesn't trip sampling. sampling = abs(temperature - 1.0) > 1e-6 tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and sampling: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} # Greedy by default (deterministic, higher-quality MT). Only sample when the user # explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it. if num_beams == 1 and sampling: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature start = time.perf_counter() outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.perf_counter() - start print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s") # @spaces.GPU return values cross the worker -> main process boundary via pickle. # Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA # tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks # and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe. return result def _translate_with_loading( text: str, target_language_name: str, max_new_tokens: int | None = 512, num_beams: int | None = 1, temperature: float | None = 1.0, ) -> Generator[tuple[object, object], None, None]: yield gr.update(value="Translating...", interactive=False), gr.update() # translate() normalizes the params (None/NaN/clamp), so forward them as-is. result = translate(text, target_language_name, max_new_tokens, num_beams, temperature) # Flip the output box to RTL for right-to-left target scripts so the translation reads # correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path # carries this — the public /translate endpoint stays a bare str to keep its API stable. name_to_code, _ = _build_language_mappings() is_rtl = name_to_code.get(target_language_name) in RTL_CODES output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left") yield gr.update(value="Translate", interactive=True), output_update def _swap_languages( source_lang: str, target_lang: str, source_text: str, target_text: str ) -> tuple[str, str, object, object]: """Swap source/target languages and their text, flipping each textbox's direction to follow the text that lands in it. rtl is sticky across reruns, so a stale RTL flip left by a prior translation must be reset. After the swap the input box holds the old target text and the output box holds the old source text.""" name_to_code, _ = _build_language_mappings() input_rtl = name_to_code.get(target_lang) in RTL_CODES # old target text moves into the input box output_rtl = name_to_code.get(source_lang) in RTL_CODES # old source text moves into the output box return ( target_lang, source_lang, gr.update(value=target_text, rtl=input_rtl, text_align="right" if input_rtl else "left"), gr.update(value=source_text, rtl=output_rtl, text_align="right" if output_rtl else "left"), ) def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translate") as demo: with gr.Row(): source_language = gr.Dropdown( choices=language_names, value="English (en)", show_label=False, filterable=True, # The model auto-detects the source from the text; this control only feeds the # swap button. Say so, so the symmetric layout doesn't imply it sets the source. info="Auto-detected from your text; used only for the swap button.", ) swap_btn = gr.Button("⇄", scale=0, min_width=60) target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, # MADLAD-400 is a research model; the paper notes quality varies sharply by language. info="Quality varies by language; strongest for high-resource languages.", ) with gr.Row(equal_height=True): input_text = gr.Textbox( lines=6, max_length=2000, show_label=False, autofocus=True, info="Press Ctrl+Enter to translate.", ) output_text = gr.Textbox( placeholder="Translation", lines=6, show_label=False, interactive=False, buttons=["copy"], info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0", ) translate_btn = gr.Button("Translate", variant="primary") # translate() already accepts these generation params; expose them here behind a # collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design). # Defaults mirror translate()'s signature so the default surface stays greedy. with gr.Accordion("Advanced", open=False): max_new_tokens = gr.Number( value=512, label="Max new tokens", minimum=1, maximum=1024, precision=0, info="Caps the length of the translation.", ) num_beams = gr.Number( value=1, label="Beams", minimum=1, maximum=8, precision=0, info="Beam-search width; higher is slower and ignores temperature.", ) temperature = gr.Number( value=1.0, label="Temperature", minimum=0.1, maximum=2.0, step=0.1, info="No effect at 1.0 (greedy) or when Beams > 1; below 1.0 is more focused, above 1.0 more random.", ) # UI-only handlers: kept off the public API surface (private) so only /translate is exposed. swap_btn.click( fn=_swap_languages, inputs=[source_language, target_language, input_text, output_text], outputs=[source_language, target_language, input_text, output_text], api_visibility="private", ) # Both translate handlers carry the advanced params: the button (via the loading-state # wrapper, which also applies the RTL output update) and the public /translate submit. translate_btn.click( fn=_translate_with_loading, inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], outputs=[translate_btn, output_text], api_visibility="private", show_progress="minimal", ) # /translate exposes the advanced params too. They all have defaults, so existing # two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter # honor the Advanced accordion, matching the Translate button. The endpoint returns a # bare str, so an RTL target submitted via Ctrl+Enter is NOT direction-flipped — that # happens only on the Translate-button path (an accepted, documented UI divergence). input_text.submit( fn=translate, inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], outputs=output_text, api_name="translate", show_progress="minimal", ) return demo demo = _build_demo() _maybe_eager_load() def main() -> None: demo.launch(theme=gr.themes.Ocean()) if __name__ == "__main__": main()