Spaces:
Running on Zero
Running on Zero
File size: 16,237 Bytes
1f39aa6 0dc9d0c 7633dda 1f39aa6 3fdc7b6 e45a74c 2ba774e 0dc9d0c 6edcaed 0dc9d0c 1f39aa6 43b39f0 1f39aa6 607cea3 c571a45 0dc9d0c c571a45 0dc9d0c 43b39f0 cf70957 1f39aa6 0dc9d0c de4ab0d 0dc9d0c 43b39f0 c571a45 0dc9d0c 607cea3 0dc9d0c 3fdc7b6 1f39aa6 0dc9d0c 607cea3 0dc9d0c 9ba4f71 ca97bb4 151fa6d 1f39aa6 e0dc6e1 0dc9d0c d12366f 2ffda3e d12366f 2ffda3e d12366f 1f39aa6 3fdc7b6 2ba774e e4538dd e45a74c e4538dd e45a74c e4538dd e45a74c 2ba774e e45a74c 2ba774e e45a74c 2ba774e 8d0070f e45a74c 8d0070f cf70957 e45a74c cf70957 e45a74c 0dc9d0c e0dc6e1 0dc9d0c 1f39aa6 43b39f0 e45a74c 8d0070f 7d583b2 8d0070f 1df06c4 e45a74c 8d0070f 2ba774e 8d0070f 2ba774e ca97bb4 2ba774e 0dc9d0c 6edcaed e45a74c 1df06c4 6edcaed e45a74c cf70957 6edcaed 4453ebb e45a74c 4453ebb 613b55a 0dc9d0c 1414eac 4453ebb cf70957 4453ebb cf70957 4453ebb 2ffda3e 4453ebb f8a5dc6 cf70957 4453ebb cf70957 4453ebb c077bcb 8ff0bfd cf70957 e45a74c cf70957 f8a5dc6 4453ebb f8a5dc6 4453ebb 879879e c077bcb 6edcaed cf70957 6edcaed f8a5dc6 cf70957 c077bcb 879879e e45a74c 2ffda3e 879879e 2ffda3e f8a5dc6 cf70957 2ffda3e 3872c0b 613b55a 2ba774e 613b55a 0758e3a 0dc9d0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 | """
Translation interface using the MADLAD-400 3B model.
Translates between 418 languages from the MADLAD-400 paper.
"""
import math
import os
import time
import warnings
from collections.abc import Generator
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
# Target tokens whose script is right-to-left, used to flip the output box's direction.
# The langmap stores a geographic "region", not script direction, and region is NOT a usable
# proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>,
# Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia".
# So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani
# Kurdish <2ckb> (Arabic script) is RTL.
RTL_CODES = frozenset(
{
"<2ar>", # Arabic
"<2mey>", # Hassaniyya (Arabic script)
"<2he>", # Hebrew
"<2ks>", # Kashmiri (Perso-Arabic)
"<2ckb>", # Kurdish, Sorani (Arabic script)
"<2lrc>", # Northern Luri (Perso-Arabic)
"<2luz>", # Southern Luri (Perso-Arabic)
"<2bgp>", # Eastern Baluchi (Perso-Arabic)
"<2ps>", # Pashto
"<2fa>", # Persian
"<2skr>", # Saraiki (Perso-Arabic)
"<2sd>", # Sindhi (Perso-Arabic)
"<2syr>", # Syriac
"<2ur>", # Urdu
"<2ug>", # Uyghur (Arabic script)
"<2yi>", # Yiddish (Hebrew script)
"<2dv>", # Dhivehi (Thaana)
}
)
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> T5TokenizerFast:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> T5ForConditionalGeneration:
device = _get_device()
# T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its
# activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use
# bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU,
# fp32-range exponent) and float32 on CPU.
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
# MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only
# (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual
# ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes
# (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either.
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)
@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code: dict[str, str] = {}
for code, info in langid_to_language.items():
if code in vocab:
locale = code[2:-1] # <2fr> → fr
display_name = f"{info['name']} ({locale})"
name_to_code[display_name] = code
# Sort by region, then alphabetically within each region
sorted_names = sorted(
name_to_code.keys(),
key=lambda n: (langid_to_language[name_to_code[n]]["region"], n),
)
return name_to_code, sorted_names
def _maybe_eager_load() -> None:
"""On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack
weights to disk at startup and stream them into each worker's VRAM (fast cold
starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the
app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU."""
if os.environ.get("SPACES_ZERO_GPU") == "1":
_load_tokenizer()
_load_model()
# Cap the token×beam product so even the largest request fits the GPU time _estimate_duration
# reserves. That estimate is 30 + product//8 (clamped at 120s), so a product <= 720 keeps it at
# <= 120s, meaning generation can't outlive its reservation and be killed mid-decode. The token
# count is trimmed to honour this (not the beam width, which drives translation quality).
_MAX_TOKEN_BEAM_PRODUCT = 720
def _normalize_params(
max_new_tokens: float | None, num_beams: float | None, temperature: float | None
) -> tuple[int, int, float]:
"""Coerce the advanced generation params to safe values. A cleared ``gr.Number`` arrives as
``None`` (Gradio skips its bounds check for ``None``) and the public ``/translate`` path
passes values uncast, so this is the single funnel every caller — button, submit, API,
direct, and the ZeroGPU duration callable — goes through. ``None``/``NaN`` fall back to the
defaults; values are clamped to the ranges the Advanced ``gr.Number`` controls advertise; and
the token×beam product is capped (``_MAX_TOKEN_BEAM_PRODUCT``) so generation stays within the
GPU time it reserved."""
def _num(value: float | None, default: float) -> float:
return default if value is None or math.isnan(value) else value
beams = int(max(1, min(8, _num(num_beams, 1))))
tokens = int(max(1, min(1024, _num(max_new_tokens, 512))))
tokens = min(tokens, _MAX_TOKEN_BEAM_PRODUCT // beams)
temperature = float(max(0.1, min(2.0, _num(temperature, 1.0))))
return tokens, beams, temperature
def _estimate_duration(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> int:
"""Reserve GPU time scaled to the worst case: generation cost grows with the number of
tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the
duration callable with the decorated function's args, and runs it *before* translate(), so
it must tolerate the same cleared-field ``None`` values — normalize first). Conservative and
capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing
duration')."""
del text, target_language_name # only token/beam counts drive runtime
max_new_tokens, num_beams, _ = _normalize_params(max_new_tokens, num_beams, temperature)
return min(120, 30 + (max_new_tokens * num_beams) // 8)
@spaces.GPU(duration=_estimate_duration)
def translate(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> str:
# No-op on empty/whitespace input: skip the model entirely rather than feeding a bare
# "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives
# here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths
# (which call translate() directly) are covered too. (text or "") stays None-safe for an
# API caller that POSTs a null text field. Returns a str, so the contract holds.
if not (text or "").strip():
return ""
# Normalize the generation params here — translate() is the single source of truth. The
# public submit path passes gr.Number values uncast, and a cleared field arrives as
# None/NaN, so coerce and clamp before use (the duration callable normalizes identically).
max_new_tokens, num_beams, temperature = _normalize_params(max_new_tokens, num_beams, temperature)
# Compare with a tolerance so float spinner drift (e.g. 0.1*9 = 0.999…) doesn't trip sampling.
sampling = abs(temperature - 1.0) > 1e-6
tokenizer = _load_tokenizer()
model = _load_model()
device = model.device
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
if num_beams > 1 and sampling:
gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")
input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)
generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
# Greedy by default (deterministic, higher-quality MT). Only sample when the user
# explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it.
if num_beams == 1 and sampling:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
start = time.perf_counter()
outputs = model.generate(**generate_kwargs)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
elapsed = time.perf_counter() - start
print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s")
# @spaces.GPU return values cross the worker -> main process boundary via pickle.
# Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA
# tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks
# and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe.
return result
def _translate_with_loading(
text: str,
target_language_name: str,
max_new_tokens: int | None = 512,
num_beams: int | None = 1,
temperature: float | None = 1.0,
) -> Generator[tuple[object, object], None, None]:
yield gr.update(value="Translating...", interactive=False), gr.update()
# translate() normalizes the params (None/NaN/clamp), so forward them as-is.
result = translate(text, target_language_name, max_new_tokens, num_beams, temperature)
# Flip the output box to RTL for right-to-left target scripts so the translation reads
# correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path
# carries this — the public /translate endpoint stays a bare str to keep its API stable.
name_to_code, _ = _build_language_mappings()
is_rtl = name_to_code.get(target_language_name) in RTL_CODES
output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left")
yield gr.update(value="Translate", interactive=True), output_update
def _swap_languages(
source_lang: str, target_lang: str, source_text: str, target_text: str
) -> tuple[str, str, object, object]:
"""Swap source/target languages and their text, flipping each textbox's direction to follow
the text that lands in it. rtl is sticky across reruns, so a stale RTL flip left by a prior
translation must be reset. After the swap the input box holds the old target text and the
output box holds the old source text."""
name_to_code, _ = _build_language_mappings()
input_rtl = name_to_code.get(target_lang) in RTL_CODES # old target text moves into the input box
output_rtl = name_to_code.get(source_lang) in RTL_CODES # old source text moves into the output box
return (
target_lang,
source_lang,
gr.update(value=target_text, rtl=input_rtl, text_align="right" if input_rtl else "left"),
gr.update(value=source_text, rtl=output_rtl, text_align="right" if output_rtl else "left"),
)
def _build_demo() -> gr.Blocks:
_, language_names = _build_language_mappings()
with gr.Blocks(title="MADLAD-400 Translate") as demo:
with gr.Row():
source_language = gr.Dropdown(
choices=language_names,
value="English (en)",
show_label=False,
filterable=True,
# The model auto-detects the source from the text; this control only feeds the
# swap button. Say so, so the symmetric layout doesn't imply it sets the source.
info="Auto-detected from your text; used only for the swap button.",
)
swap_btn = gr.Button("⇄", scale=0, min_width=60)
target_language = gr.Dropdown(
choices=language_names,
value="French (fr)",
show_label=False,
filterable=True,
# MADLAD-400 is a research model; the paper notes quality varies sharply by language.
info="Quality varies by language; strongest for high-resource languages.",
)
with gr.Row(equal_height=True):
input_text = gr.Textbox(
lines=6,
max_length=2000,
show_label=False,
autofocus=True,
info="Press Ctrl+Enter to translate.",
)
output_text = gr.Textbox(
placeholder="Translation",
lines=6,
show_label=False,
interactive=False,
buttons=["copy"],
info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0",
)
translate_btn = gr.Button("Translate", variant="primary")
# translate() already accepts these generation params; expose them here behind a
# collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design).
# Defaults mirror translate()'s signature so the default surface stays greedy.
with gr.Accordion("Advanced", open=False):
max_new_tokens = gr.Number(
value=512,
label="Max new tokens",
minimum=1,
maximum=1024,
precision=0,
info="Caps the length of the translation.",
)
num_beams = gr.Number(
value=1,
label="Beams",
minimum=1,
maximum=8,
precision=0,
info="Beam-search width; higher is slower and ignores temperature.",
)
temperature = gr.Number(
value=1.0,
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
info="No effect at 1.0 (greedy) or when Beams > 1; below 1.0 is more focused, above 1.0 more random.",
)
# UI-only handlers: kept off the public API surface (private) so only /translate is exposed.
swap_btn.click(
fn=_swap_languages,
inputs=[source_language, target_language, input_text, output_text],
outputs=[source_language, target_language, input_text, output_text],
api_visibility="private",
)
# Both translate handlers carry the advanced params: the button (via the loading-state
# wrapper, which also applies the RTL output update) and the public /translate submit.
translate_btn.click(
fn=_translate_with_loading,
inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
outputs=[translate_btn, output_text],
api_visibility="private",
show_progress="minimal",
)
# /translate exposes the advanced params too. They all have defaults, so existing
# two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter
# honor the Advanced accordion, matching the Translate button. The endpoint returns a
# bare str, so an RTL target submitted via Ctrl+Enter is NOT direction-flipped — that
# happens only on the Translate-button path (an accepted, documented UI divergence).
input_text.submit(
fn=translate,
inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
outputs=output_text,
api_name="translate",
show_progress="minimal",
)
return demo
demo = _build_demo()
_maybe_eager_load()
def main() -> None:
demo.launch(theme=gr.themes.Ocean())
if __name__ == "__main__":
main()
|