Spaces:
Running on Zero
Running on Zero
| """ | |
| Translation interface using the MADLAD-400 3B model. | |
| Translates between 418 languages from the MADLAD-400 paper. | |
| """ | |
| import math | |
| import os | |
| import time | |
| import warnings | |
| from collections.abc import Generator | |
| from functools import lru_cache | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast | |
| from langmap.langid_mapping import langid_to_language | |
| MODEL_NAME = "google/madlad400-3b-mt" | |
| # Target tokens whose script is right-to-left, used to flip the output box's direction. | |
| # The langmap stores a geographic "region", not script direction, and region is NOT a usable | |
| # proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>, | |
| # Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia". | |
| # So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani | |
| # Kurdish <2ckb> (Arabic script) is RTL. | |
| RTL_CODES = frozenset( | |
| { | |
| "<2ar>", # Arabic | |
| "<2mey>", # Hassaniyya (Arabic script) | |
| "<2he>", # Hebrew | |
| "<2ks>", # Kashmiri (Perso-Arabic) | |
| "<2ckb>", # Kurdish, Sorani (Arabic script) | |
| "<2lrc>", # Northern Luri (Perso-Arabic) | |
| "<2luz>", # Southern Luri (Perso-Arabic) | |
| "<2bgp>", # Eastern Baluchi (Perso-Arabic) | |
| "<2ps>", # Pashto | |
| "<2fa>", # Persian | |
| "<2skr>", # Saraiki (Perso-Arabic) | |
| "<2sd>", # Sindhi (Perso-Arabic) | |
| "<2syr>", # Syriac | |
| "<2ur>", # Urdu | |
| "<2ug>", # Uyghur (Arabic script) | |
| "<2yi>", # Yiddish (Hebrew script) | |
| "<2dv>", # Dhivehi (Thaana) | |
| } | |
| ) | |
| def _get_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) | |
| return torch.device("cpu") | |
| def _load_tokenizer() -> T5TokenizerFast: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| if tokenizer is None: | |
| raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") | |
| return tokenizer | |
| def _load_model() -> T5ForConditionalGeneration: | |
| device = _get_device() | |
| # T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its | |
| # activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use | |
| # bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU, | |
| # fp32-range exponent) and float32 on CPU. | |
| dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 | |
| # MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only | |
| # (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual | |
| # ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes | |
| # (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either. | |
| return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) | |
| def _build_language_mappings() -> tuple[dict[str, str], list[str]]: | |
| tokenizer = _load_tokenizer() | |
| vocab = tokenizer.get_vocab() | |
| name_to_code: dict[str, str] = {} | |
| for code, info in langid_to_language.items(): | |
| if code in vocab: | |
| locale = code[2:-1] # <2fr> → fr | |
| display_name = f"{info['name']} ({locale})" | |
| name_to_code[display_name] = code | |
| # Sort by region, then alphabetically within each region | |
| sorted_names = sorted( | |
| name_to_code.keys(), | |
| key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), | |
| ) | |
| return name_to_code, sorted_names | |
| def _maybe_eager_load() -> None: | |
| """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack | |
| weights to disk at startup and stream them into each worker's VRAM (fast cold | |
| starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the | |
| app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU.""" | |
| if os.environ.get("SPACES_ZERO_GPU") == "1": | |
| _load_tokenizer() | |
| _load_model() | |
| # Cap the token×beam product so even the largest request fits the GPU time _estimate_duration | |
| # reserves. That estimate is 30 + product//8 (clamped at 120s), so a product <= 720 keeps it at | |
| # <= 120s, meaning generation can't outlive its reservation and be killed mid-decode. The token | |
| # count is trimmed to honour this (not the beam width, which drives translation quality). | |
| _MAX_TOKEN_BEAM_PRODUCT = 720 | |
| def _normalize_params( | |
| max_new_tokens: float | None, num_beams: float | None, temperature: float | None | |
| ) -> tuple[int, int, float]: | |
| """Coerce the advanced generation params to safe values. A cleared ``gr.Number`` arrives as | |
| ``None`` (Gradio skips its bounds check for ``None``) and the public ``/translate`` path | |
| passes values uncast, so this is the single funnel every caller — button, submit, API, | |
| direct, and the ZeroGPU duration callable — goes through. ``None``/``NaN`` fall back to the | |
| defaults; values are clamped to the ranges the Advanced ``gr.Number`` controls advertise; and | |
| the token×beam product is capped (``_MAX_TOKEN_BEAM_PRODUCT``) so generation stays within the | |
| GPU time it reserved.""" | |
| def _num(value: float | None, default: float) -> float: | |
| return default if value is None or math.isnan(value) else value | |
| beams = int(max(1, min(8, _num(num_beams, 1)))) | |
| tokens = int(max(1, min(1024, _num(max_new_tokens, 512)))) | |
| tokens = min(tokens, _MAX_TOKEN_BEAM_PRODUCT // beams) | |
| temperature = float(max(0.1, min(2.0, _num(temperature, 1.0)))) | |
| return tokens, beams, temperature | |
| def _estimate_duration( | |
| text: str, | |
| target_language_name: str, | |
| max_new_tokens: int | None = 512, | |
| num_beams: int | None = 1, | |
| temperature: float | None = 1.0, | |
| ) -> int: | |
| """Reserve GPU time scaled to the worst case: generation cost grows with the number of | |
| tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the | |
| duration callable with the decorated function's args, and runs it *before* translate(), so | |
| it must tolerate the same cleared-field ``None`` values — normalize first). Conservative and | |
| capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing | |
| duration').""" | |
| del text, target_language_name # only token/beam counts drive runtime | |
| max_new_tokens, num_beams, _ = _normalize_params(max_new_tokens, num_beams, temperature) | |
| return min(120, 30 + (max_new_tokens * num_beams) // 8) | |
| def translate( | |
| text: str, | |
| target_language_name: str, | |
| max_new_tokens: int | None = 512, | |
| num_beams: int | None = 1, | |
| temperature: float | None = 1.0, | |
| ) -> str: | |
| # No-op on empty/whitespace input: skip the model entirely rather than feeding a bare | |
| # "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives | |
| # here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths | |
| # (which call translate() directly) are covered too. (text or "") stays None-safe for an | |
| # API caller that POSTs a null text field. Returns a str, so the contract holds. | |
| if not (text or "").strip(): | |
| return "" | |
| # Normalize the generation params here — translate() is the single source of truth. The | |
| # public submit path passes gr.Number values uncast, and a cleared field arrives as | |
| # None/NaN, so coerce and clamp before use (the duration callable normalizes identically). | |
| max_new_tokens, num_beams, temperature = _normalize_params(max_new_tokens, num_beams, temperature) | |
| # Compare with a tolerance so float spinner drift (e.g. 0.1*9 = 0.999…) doesn't trip sampling. | |
| sampling = abs(temperature - 1.0) > 1e-6 | |
| tokenizer = _load_tokenizer() | |
| model = _load_model() | |
| device = model.device | |
| name_to_code, _ = _build_language_mappings() | |
| target_code = name_to_code.get(target_language_name) | |
| if target_code is None: | |
| raise ValueError(f"Unsupported language: {target_language_name}") | |
| if num_beams > 1 and sampling: | |
| gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") | |
| input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) | |
| generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} | |
| # Greedy by default (deterministic, higher-quality MT). Only sample when the user | |
| # explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it. | |
| if num_beams == 1 and sampling: | |
| generate_kwargs["do_sample"] = True | |
| generate_kwargs["temperature"] = temperature | |
| start = time.perf_counter() | |
| outputs = model.generate(**generate_kwargs) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| elapsed = time.perf_counter() - start | |
| print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s") | |
| # @spaces.GPU return values cross the worker -> main process boundary via pickle. | |
| # Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA | |
| # tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks | |
| # and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe. | |
| return result | |
| def _translate_with_loading( | |
| text: str, | |
| target_language_name: str, | |
| max_new_tokens: int | None = 512, | |
| num_beams: int | None = 1, | |
| temperature: float | None = 1.0, | |
| ) -> Generator[tuple[object, object], None, None]: | |
| yield gr.update(value="Translating...", interactive=False), gr.update() | |
| # translate() normalizes the params (None/NaN/clamp), so forward them as-is. | |
| result = translate(text, target_language_name, max_new_tokens, num_beams, temperature) | |
| # Flip the output box to RTL for right-to-left target scripts so the translation reads | |
| # correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path | |
| # carries this — the public /translate endpoint stays a bare str to keep its API stable. | |
| name_to_code, _ = _build_language_mappings() | |
| is_rtl = name_to_code.get(target_language_name) in RTL_CODES | |
| output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left") | |
| yield gr.update(value="Translate", interactive=True), output_update | |
| def _swap_languages( | |
| source_lang: str, target_lang: str, source_text: str, target_text: str | |
| ) -> tuple[str, str, object, object]: | |
| """Swap source/target languages and their text, flipping each textbox's direction to follow | |
| the text that lands in it. rtl is sticky across reruns, so a stale RTL flip left by a prior | |
| translation must be reset. After the swap the input box holds the old target text and the | |
| output box holds the old source text.""" | |
| name_to_code, _ = _build_language_mappings() | |
| input_rtl = name_to_code.get(target_lang) in RTL_CODES # old target text moves into the input box | |
| output_rtl = name_to_code.get(source_lang) in RTL_CODES # old source text moves into the output box | |
| return ( | |
| target_lang, | |
| source_lang, | |
| gr.update(value=target_text, rtl=input_rtl, text_align="right" if input_rtl else "left"), | |
| gr.update(value=source_text, rtl=output_rtl, text_align="right" if output_rtl else "left"), | |
| ) | |
| def _build_demo() -> gr.Blocks: | |
| _, language_names = _build_language_mappings() | |
| with gr.Blocks(title="MADLAD-400 Translate") as demo: | |
| with gr.Row(): | |
| source_language = gr.Dropdown( | |
| choices=language_names, | |
| value="English (en)", | |
| show_label=False, | |
| filterable=True, | |
| # The model auto-detects the source from the text; this control only feeds the | |
| # swap button. Say so, so the symmetric layout doesn't imply it sets the source. | |
| info="Auto-detected from your text; used only for the swap button.", | |
| ) | |
| swap_btn = gr.Button("⇄", scale=0, min_width=60) | |
| target_language = gr.Dropdown( | |
| choices=language_names, | |
| value="French (fr)", | |
| show_label=False, | |
| filterable=True, | |
| # MADLAD-400 is a research model; the paper notes quality varies sharply by language. | |
| info="Quality varies by language; strongest for high-resource languages.", | |
| ) | |
| with gr.Row(equal_height=True): | |
| input_text = gr.Textbox( | |
| lines=6, | |
| max_length=2000, | |
| show_label=False, | |
| autofocus=True, | |
| info="Press Ctrl+Enter to translate.", | |
| ) | |
| output_text = gr.Textbox( | |
| placeholder="Translation", | |
| lines=6, | |
| show_label=False, | |
| interactive=False, | |
| buttons=["copy"], | |
| info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0", | |
| ) | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| # translate() already accepts these generation params; expose them here behind a | |
| # collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design). | |
| # Defaults mirror translate()'s signature so the default surface stays greedy. | |
| with gr.Accordion("Advanced", open=False): | |
| max_new_tokens = gr.Number( | |
| value=512, | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=1024, | |
| precision=0, | |
| info="Caps the length of the translation.", | |
| ) | |
| num_beams = gr.Number( | |
| value=1, | |
| label="Beams", | |
| minimum=1, | |
| maximum=8, | |
| precision=0, | |
| info="Beam-search width; higher is slower and ignores temperature.", | |
| ) | |
| temperature = gr.Number( | |
| value=1.0, | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.1, | |
| info="No effect at 1.0 (greedy) or when Beams > 1; below 1.0 is more focused, above 1.0 more random.", | |
| ) | |
| # UI-only handlers: kept off the public API surface (private) so only /translate is exposed. | |
| swap_btn.click( | |
| fn=_swap_languages, | |
| inputs=[source_language, target_language, input_text, output_text], | |
| outputs=[source_language, target_language, input_text, output_text], | |
| api_visibility="private", | |
| ) | |
| # Both translate handlers carry the advanced params: the button (via the loading-state | |
| # wrapper, which also applies the RTL output update) and the public /translate submit. | |
| translate_btn.click( | |
| fn=_translate_with_loading, | |
| inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], | |
| outputs=[translate_btn, output_text], | |
| api_visibility="private", | |
| show_progress="minimal", | |
| ) | |
| # /translate exposes the advanced params too. They all have defaults, so existing | |
| # two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter | |
| # honor the Advanced accordion, matching the Translate button. The endpoint returns a | |
| # bare str, so an RTL target submitted via Ctrl+Enter is NOT direction-flipped — that | |
| # happens only on the Translate-button path (an accepted, documented UI divergence). | |
| input_text.submit( | |
| fn=translate, | |
| inputs=[input_text, target_language, max_new_tokens, num_beams, temperature], | |
| outputs=output_text, | |
| api_name="translate", | |
| show_progress="minimal", | |
| ) | |
| return demo | |
| demo = _build_demo() | |
| _maybe_eager_load() | |
| def main() -> None: | |
| demo.launch(theme=gr.themes.Ocean()) | |
| if __name__ == "__main__": | |
| main() | |