Daryl Lim
fix: cap token×beam product so generation fits its GPU reservation
e4538dd
Raw
History Blame
16.2 kB
"""
Translation interface using the MADLAD-400 3B model.
Translates between 418 languages from the MADLAD-400 paper.
"""
import math
import os
import time
import warnings
from collections.abc import Generator
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
# Target tokens whose script is right-to-left, used to flip the output box's direction.
# The langmap stores a geographic "region", not script direction, and region is NOT a usable
# proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>,
# Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia".
# So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani
# Kurdish <2ckb> (Arabic script) is RTL.
RTL_CODES = frozenset(
{
"<2ar>", # Arabic
"<2mey>", # Hassaniyya (Arabic script)
"<2he>", # Hebrew
"<2ks>", # Kashmiri (Perso-Arabic)
"<2ckb>", # Kurdish, Sorani (Arabic script)
"<2lrc>", # Northern Luri (Perso-Arabic)
"<2luz>", # Southern Luri (Perso-Arabic)
"<2bgp>", # Eastern Baluchi (Perso-Arabic)
"<2ps>", # Pashto
"<2fa>", # Persian
"<2skr>", # Saraiki (Perso-Arabic)
"<2sd>", # Sindhi (Perso-Arabic)
"<2syr>", # Syriac
"<2ur>", # Urdu
"<2ug>", # Uyghur (Arabic script)
"<2yi>", # Yiddish (Hebrew script)
"<2dv>", # Dhivehi (Thaana)
}
)
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> T5TokenizerFast:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> T5ForConditionalGeneration:
device = _get_device()
# T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its
# activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use
# bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU,
# fp32-range exponent) and float32 on CPU.
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
# MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only
# (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual
# ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes
# (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either.
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)
@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code: dict[str, str] = {}
for code, info in langid_to_language.items():
if code in vocab:
locale = code[2:-1] # <2fr> → fr
display_name = f"{info['name']} ({locale})"
name_to_code[display_name] = code
# Sort by region, then alphabetically within each region
sorted_names = sorted(
name_to_code.keys(),
key=lambda n: (langid_to_language[name_to_code[n]]["region"], n),
)
return name_to_code, sorted_names
def _maybe_eager_load() -> None:
"""On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack
weights to disk at startup and stream them into each worker's VRAM (fast cold
starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the
app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU."""
if os.environ.get("SPACES_ZERO_GPU") == "1":
_load_tokenizer()
_load_model()
# Cap the token×beam product so even the largest request fits the GPU time _estimate_duration
# reserves. That estimate is 30 + product//8 (clamped at 120s), so a product <= 720 keeps it at
# <= 120s, meaning generation can't outlive its reservation and be killed mid-decode. The token
# count is trimmed to honour this (not the beam width, which drives translation quality).
_MAX_TOKEN_BEAM_PRODUCT = 720
def _normalize_params(
max_new_tokens: float | None, num_beams: float | None, temperature: float | None
) -> tuple[int, int, float]:
"""Coerce the advanced generation params to safe values. A cleared ``gr.Number`` arrives as
``None`` (Gradio skips its bounds check for ``None``) and the public ``/translate`` path
passes values uncast, so this is the single funnel every caller — button, submit, API,
direct, and the ZeroGPU duration callable — goes through. ``None``/``NaN`` fall back to the
defaults; values are clamped to the ranges the Advanced ``gr.Number`` controls advertise; and
the token×beam product is capped (``_MAX_TOKEN_BEAM_PRODUCT``) so generation stays within the
GPU time it reserved."""
def _num(value: float | None, default: float) -> float:
return default if value is None or math.isnan(value) else value
beams = int(max(1, min(8, _num(num_beams, 1))))
tokens = int(max(1, min(1024, _num(max_new_tokens, 512))))
tokens = min(tokens, _MAX_TOKEN_BEAM_PRODUCT // beams)
temperature = float(max(0.1, min(2.0, _num(temperature, 1.0))))
return tokens, beams, temperature
def _estimate_duration(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> int:
"""Reserve GPU time scaled to the worst case: generation cost grows with the number of
tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the
duration callable with the decorated function's args, and runs it *before* translate(), so
it must tolerate the same cleared-field ``None`` values — normalize first). Conservative and
capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing
duration')."""
del text, target_language_name # only token/beam counts drive runtime
max_new_tokens, num_beams, _ = _normalize_params(max_new_tokens, num_beams, temperature)
return min(120, 30 + (max_new_tokens * num_beams) // 8)
@spaces.GPU(duration=_estimate_duration)
def translate(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> str:
# No-op on empty/whitespace input: skip the model entirely rather than feeding a bare
# "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives
# here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths
# (which call translate() directly) are covered too. (text or "") stays None-safe for an
# API caller that POSTs a null text field. Returns a str, so the contract holds.
if not (text or "").strip():
return ""
# Normalize the generation params here — translate() is the single source of truth. The
# public submit path passes gr.Number values uncast, and a cleared field arrives as
# None/NaN, so coerce and clamp before use (the duration callable normalizes identically).
max_new_tokens, num_beams, temperature = _normalize_params(max_new_tokens, num_beams, temperature)
# Compare with a tolerance so float spinner drift (e.g. 0.1*9 = 0.999…) doesn't trip sampling.
sampling = abs(temperature - 1.0) > 1e-6
tokenizer = _load_tokenizer()
model = _load_model()
device = model.device
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
if num_beams > 1 and sampling:
gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")
input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)
generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
# Greedy by default (deterministic, higher-quality MT). Only sample when the user
# explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it.
if num_beams == 1 and sampling:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
start = time.perf_counter()
outputs = model.generate(**generate_kwargs)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
elapsed = time.perf_counter() - start
print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s")
# @spaces.GPU return values cross the worker -> main process boundary via pickle.
# Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA
# tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks
# and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe.
return result
def _translate_with_loading(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> Generator[tuple[object, object], None, None]:
yield gr.update(value="Translating...", interactive=False), gr.update()
# translate() normalizes the params (None/NaN/clamp), so forward them as-is.
result = translate(text, target_language_name, max_new_tokens, num_beams, temperature)
# Flip the output box to RTL for right-to-left target scripts so the translation reads
# correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path
# carries this — the public /translate endpoint stays a bare str to keep its API stable.
name_to_code, _ = _build_language_mappings()
is_rtl = name_to_code.get(target_language_name) in RTL_CODES
output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left")
yield gr.update(value="Translate", interactive=True), output_update
def _swap_languages(
source_lang: str, target_lang: str, source_text: str, target_text: str
) -> tuple[str, str, object, object]:
"""Swap source/target languages and their text, flipping each textbox's direction to follow
the text that lands in it. rtl is sticky across reruns, so a stale RTL flip left by a prior
translation must be reset. After the swap the input box holds the old target text and the
output box holds the old source text."""
name_to_code, _ = _build_language_mappings()
input_rtl = name_to_code.get(target_lang) in RTL_CODES # old target text moves into the input box
output_rtl = name_to_code.get(source_lang) in RTL_CODES # old source text moves into the output box
return (
target_lang,
source_lang,
gr.update(value=target_text, rtl=input_rtl, text_align="right" if input_rtl else "left"),
gr.update(value=source_text, rtl=output_rtl, text_align="right" if output_rtl else "left"),
)
def _build_demo() -> gr.Blocks:
_, language_names = _build_language_mappings()
with gr.Blocks(title="MADLAD-400 Translate") as demo:
with gr.Row():
source_language = gr.Dropdown(
choices=language_names,
value="English (en)",
show_label=False,
filterable=True,
# The model auto-detects the source from the text; this control only feeds the
# swap button. Say so, so the symmetric layout doesn't imply it sets the source.
info="Auto-detected from your text; used only for the swap button.",
)
swap_btn = gr.Button("⇄", scale=0, min_width=60)
target_language = gr.Dropdown(
choices=language_names,
value="French (fr)",
show_label=False,
filterable=True,
# MADLAD-400 is a research model; the paper notes quality varies sharply by language.
info="Quality varies by language; strongest for high-resource languages.",
)
with gr.Row(equal_height=True):
input_text = gr.Textbox(
lines=6,
max_length=2000,
show_label=False,
autofocus=True,
info="Press Ctrl+Enter to translate.",
)
output_text = gr.Textbox(
placeholder="Translation",
lines=6,
show_label=False,
interactive=False,
buttons=["copy"],
info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0",
)
translate_btn = gr.Button("Translate", variant="primary")
# translate() already accepts these generation params; expose them here behind a
# collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design).
# Defaults mirror translate()'s signature so the default surface stays greedy.
with gr.Accordion("Advanced", open=False):
max_new_tokens = gr.Number(
value=512,
label="Max new tokens",
minimum=1,
maximum=1024,
precision=0,
info="Caps the length of the translation.",
)
num_beams = gr.Number(
value=1,
label="Beams",
minimum=1,
maximum=8,
precision=0,
info="Beam-search width; higher is slower and ignores temperature.",
)
temperature = gr.Number(
value=1.0,
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
info="No effect at 1.0 (greedy) or when Beams > 1; below 1.0 is more focused, above 1.0 more random.",
)
# UI-only handlers: kept off the public API surface (private) so only /translate is exposed.
swap_btn.click(
fn=_swap_languages,
inputs=[source_language, target_language, input_text, output_text],
outputs=[source_language, target_language, input_text, output_text],
api_visibility="private",
)
# Both translate handlers carry the advanced params: the button (via the loading-state
# wrapper, which also applies the RTL output update) and the public /translate submit.
translate_btn.click(
fn=_translate_with_loading,
inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
outputs=[translate_btn, output_text],
api_visibility="private",
show_progress="minimal",
)
# /translate exposes the advanced params too. They all have defaults, so existing
# two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter
# honor the Advanced accordion, matching the Translate button. The endpoint returns a
# bare str, so an RTL target submitted via Ctrl+Enter is NOT direction-flipped — that
# happens only on the Translate-button path (an accepted, documented UI divergence).
input_text.submit(
fn=translate,
inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
outputs=output_text,
api_name="translate",
show_progress="minimal",
)
return demo
demo = _build_demo()
_maybe_eager_load()
def main() -> None:
demo.launch(theme=gr.themes.Ocean())
if __name__ == "__main__":
main()