File size: 16,237 Bytes
1f39aa6
0dc9d0c
7633dda
1f39aa6
3fdc7b6
e45a74c
2ba774e
 
0dc9d0c
6edcaed
0dc9d0c
 
1f39aa6
43b39f0
1f39aa6
607cea3
c571a45
0dc9d0c
c571a45
0dc9d0c
43b39f0
cf70957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f39aa6
0dc9d0c
 
 
de4ab0d
0dc9d0c
43b39f0
c571a45
0dc9d0c
607cea3
0dc9d0c
 
 
 
3fdc7b6
1f39aa6
0dc9d0c
607cea3
0dc9d0c
9ba4f71
 
 
 
 
ca97bb4
 
 
 
151fa6d
1f39aa6
 
e0dc6e1
0dc9d0c
 
 
d12366f
 
2ffda3e
 
d12366f
2ffda3e
d12366f
 
 
 
 
 
1f39aa6
3fdc7b6
2ba774e
 
 
 
 
 
 
 
 
 
e4538dd
 
 
 
 
 
 
e45a74c
 
 
 
 
 
 
e4538dd
 
 
e45a74c
 
 
 
e4538dd
 
 
 
 
e45a74c
 
2ba774e
 
 
e45a74c
 
 
2ba774e
e45a74c
 
 
 
 
 
 
 
2ba774e
 
 
 
8d0070f
 
 
e45a74c
 
 
8d0070f
cf70957
 
 
e45a74c
 
 
cf70957
e45a74c
 
 
 
 
 
 
0dc9d0c
 
e0dc6e1
0dc9d0c
 
 
 
1f39aa6
43b39f0
e45a74c
8d0070f
 
7d583b2
8d0070f
 
1df06c4
 
e45a74c
8d0070f
 
 
2ba774e
8d0070f
2ba774e
 
 
ca97bb4
 
 
 
2ba774e
0dc9d0c
 
6edcaed
 
 
e45a74c
 
 
1df06c4
6edcaed
e45a74c
 
cf70957
 
 
 
 
 
 
6edcaed
 
4453ebb
 
e45a74c
 
 
 
 
 
 
 
 
 
 
 
 
 
4453ebb
 
613b55a
0dc9d0c
 
1414eac
4453ebb
 
 
 
 
 
cf70957
 
 
4453ebb
 
 
 
 
 
 
cf70957
 
4453ebb
 
2ffda3e
4453ebb
 
 
 
f8a5dc6
cf70957
4453ebb
 
 
 
 
 
 
cf70957
4453ebb
c077bcb
8ff0bfd
 
cf70957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45a74c
cf70957
 
f8a5dc6
4453ebb
 
 
 
f8a5dc6
4453ebb
879879e
 
c077bcb
6edcaed
cf70957
6edcaed
f8a5dc6
cf70957
c077bcb
879879e
 
 
e45a74c
 
2ffda3e
 
879879e
2ffda3e
f8a5dc6
cf70957
2ffda3e
3872c0b
613b55a
 
 
 
2ba774e
613b55a
 
 
0758e3a
0dc9d0c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
"""
Translation interface using the MADLAD-400 3B model.
Translates between 418 languages from the MADLAD-400 paper.
"""

import math
import os
import time
import warnings
from collections.abc import Generator
from functools import lru_cache

import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast

from langmap.langid_mapping import langid_to_language

MODEL_NAME = "google/madlad400-3b-mt"

# Target tokens whose script is right-to-left, used to flip the output box's direction.
# The langmap stores a geographic "region", not script direction, and region is NOT a usable
# proxy: the "Middle East & North Africa" region also holds LTR/Latin languages (Turkish <2tr>,
# Kurmanji <2ku>, Zaza <2zza>, Takestani <2tks>) while RTL Dhivehi <2dv> sits in "South Asia".
# So enumerate the RTL tokens explicitly. Kurmanji <2ku> (Latin) is excluded; only Sorani
# Kurdish <2ckb> (Arabic script) is RTL.
RTL_CODES = frozenset(
    {
        "<2ar>",  # Arabic
        "<2mey>",  # Hassaniyya (Arabic script)
        "<2he>",  # Hebrew
        "<2ks>",  # Kashmiri (Perso-Arabic)
        "<2ckb>",  # Kurdish, Sorani (Arabic script)
        "<2lrc>",  # Northern Luri (Perso-Arabic)
        "<2luz>",  # Southern Luri (Perso-Arabic)
        "<2bgp>",  # Eastern Baluchi (Perso-Arabic)
        "<2ps>",  # Pashto
        "<2fa>",  # Persian
        "<2skr>",  # Saraiki (Perso-Arabic)
        "<2sd>",  # Sindhi (Perso-Arabic)
        "<2syr>",  # Syriac
        "<2ur>",  # Urdu
        "<2ug>",  # Uyghur (Arabic script)
        "<2yi>",  # Yiddish (Hebrew script)
        "<2dv>",  # Dhivehi (Thaana)
    }
)


def _get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
    return torch.device("cpu")


@lru_cache(maxsize=1)
def _load_tokenizer() -> T5TokenizerFast:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    if tokenizer is None:
        raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
    return tokenizer


@lru_cache(maxsize=1)
def _load_model() -> T5ForConditionalGeneration:
    device = _get_device()
    # T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its
    # activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use
    # bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU,
    # fp32-range exponent) and float32 on CPU.
    dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
    # MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only
    # (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual
    # ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes
    # (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either.
    return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)


@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
    tokenizer = _load_tokenizer()
    vocab = tokenizer.get_vocab()
    name_to_code: dict[str, str] = {}
    for code, info in langid_to_language.items():
        if code in vocab:
            locale = code[2:-1]  # <2fr> → fr
            display_name = f"{info['name']} ({locale})"
            name_to_code[display_name] = code
    # Sort by region, then alphabetically within each region
    sorted_names = sorted(
        name_to_code.keys(),
        key=lambda n: (langid_to_language[name_to_code[n]]["region"], n),
    )
    return name_to_code, sorted_names


def _maybe_eager_load() -> None:
    """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack
    weights to disk at startup and stream them into each worker's VRAM (fast cold
    starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the
    app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU."""
    if os.environ.get("SPACES_ZERO_GPU") == "1":
        _load_tokenizer()
        _load_model()


# Cap the token×beam product so even the largest request fits the GPU time _estimate_duration
# reserves. That estimate is 30 + product//8 (clamped at 120s), so a product <= 720 keeps it at
# <= 120s, meaning generation can't outlive its reservation and be killed mid-decode. The token
# count is trimmed to honour this (not the beam width, which drives translation quality).
_MAX_TOKEN_BEAM_PRODUCT = 720


def _normalize_params(
    max_new_tokens: float | None, num_beams: float | None, temperature: float | None
) -> tuple[int, int, float]:
    """Coerce the advanced generation params to safe values. A cleared ``gr.Number`` arrives as
    ``None`` (Gradio skips its bounds check for ``None``) and the public ``/translate`` path
    passes values uncast, so this is the single funnel every caller — button, submit, API,
    direct, and the ZeroGPU duration callable — goes through. ``None``/``NaN`` fall back to the
    defaults; values are clamped to the ranges the Advanced ``gr.Number`` controls advertise; and
    the token×beam product is capped (``_MAX_TOKEN_BEAM_PRODUCT``) so generation stays within the
    GPU time it reserved."""

    def _num(value: float | None, default: float) -> float:
        return default if value is None or math.isnan(value) else value

    beams = int(max(1, min(8, _num(num_beams, 1))))
    tokens = int(max(1, min(1024, _num(max_new_tokens, 512))))
    tokens = min(tokens, _MAX_TOKEN_BEAM_PRODUCT // beams)
    temperature = float(max(0.1, min(2.0, _num(temperature, 1.0))))
    return tokens, beams, temperature


def _estimate_duration(
    text: str,
    target_language_name: str,
    max_new_tokens: int | None = 512,
    num_beams: int | None = 1,
    temperature: float | None = 1.0,
) -> int:
    """Reserve GPU time scaled to the worst case: generation cost grows with the number of
    tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the
    duration callable with the decorated function's args, and runs it *before* translate(), so
    it must tolerate the same cleared-field ``None`` values — normalize first). Conservative and
    capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing
    duration')."""
    del text, target_language_name  # only token/beam counts drive runtime
    max_new_tokens, num_beams, _ = _normalize_params(max_new_tokens, num_beams, temperature)
    return min(120, 30 + (max_new_tokens * num_beams) // 8)


@spaces.GPU(duration=_estimate_duration)
def translate(
    text: str,
    target_language_name: str,
    max_new_tokens: int | None = 512,
    num_beams: int | None = 1,
    temperature: float | None = 1.0,
) -> str:
    # No-op on empty/whitespace input: skip the model entirely rather than feeding a bare
    # "<2xx> " prompt (which would burn generation time and emit a stray token). Guard lives
    # here, not in _translate_with_loading, so the public /translate and Ctrl+Enter paths
    # (which call translate() directly) are covered too. (text or "") stays None-safe for an
    # API caller that POSTs a null text field. Returns a str, so the contract holds.
    if not (text or "").strip():
        return ""
    # Normalize the generation params here — translate() is the single source of truth. The
    # public submit path passes gr.Number values uncast, and a cleared field arrives as
    # None/NaN, so coerce and clamp before use (the duration callable normalizes identically).
    max_new_tokens, num_beams, temperature = _normalize_params(max_new_tokens, num_beams, temperature)
    # Compare with a tolerance so float spinner drift (e.g. 0.1*9 = 0.999…) doesn't trip sampling.
    sampling = abs(temperature - 1.0) > 1e-6

    tokenizer = _load_tokenizer()
    model = _load_model()
    device = model.device

    name_to_code, _ = _build_language_mappings()
    target_code = name_to_code.get(target_language_name)
    if target_code is None:
        raise ValueError(f"Unsupported language: {target_language_name}")

    if num_beams > 1 and sampling:
        gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")

    input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)

    generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
    # Greedy by default (deterministic, higher-quality MT). Only sample when the user
    # explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it.
    if num_beams == 1 and sampling:
        generate_kwargs["do_sample"] = True
        generate_kwargs["temperature"] = temperature

    start = time.perf_counter()
    outputs = model.generate(**generate_kwargs)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    elapsed = time.perf_counter() - start
    print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s")
    # @spaces.GPU return values cross the worker -> main process boundary via pickle.
    # Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA
    # tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks
    # and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe.
    return result


def _translate_with_loading(
    text: str,
    target_language_name: str,
    max_new_tokens: int | None = 512,
    num_beams: int | None = 1,
    temperature: float | None = 1.0,
) -> Generator[tuple[object, object], None, None]:
    yield gr.update(value="Translating...", interactive=False), gr.update()
    # translate() normalizes the params (None/NaN/clamp), so forward them as-is.
    result = translate(text, target_language_name, max_new_tokens, num_beams, temperature)
    # Flip the output box to RTL for right-to-left target scripts so the translation reads
    # correctly; reset to LTR otherwise (rtl is sticky across reruns). Only the button path
    # carries this — the public /translate endpoint stays a bare str to keep its API stable.
    name_to_code, _ = _build_language_mappings()
    is_rtl = name_to_code.get(target_language_name) in RTL_CODES
    output_update = gr.update(value=result, rtl=is_rtl, text_align="right" if is_rtl else "left")
    yield gr.update(value="Translate", interactive=True), output_update


def _swap_languages(
    source_lang: str, target_lang: str, source_text: str, target_text: str
) -> tuple[str, str, object, object]:
    """Swap source/target languages and their text, flipping each textbox's direction to follow
    the text that lands in it. rtl is sticky across reruns, so a stale RTL flip left by a prior
    translation must be reset. After the swap the input box holds the old target text and the
    output box holds the old source text."""
    name_to_code, _ = _build_language_mappings()
    input_rtl = name_to_code.get(target_lang) in RTL_CODES  # old target text moves into the input box
    output_rtl = name_to_code.get(source_lang) in RTL_CODES  # old source text moves into the output box
    return (
        target_lang,
        source_lang,
        gr.update(value=target_text, rtl=input_rtl, text_align="right" if input_rtl else "left"),
        gr.update(value=source_text, rtl=output_rtl, text_align="right" if output_rtl else "left"),
    )


def _build_demo() -> gr.Blocks:
    _, language_names = _build_language_mappings()

    with gr.Blocks(title="MADLAD-400 Translate") as demo:
        with gr.Row():
            source_language = gr.Dropdown(
                choices=language_names,
                value="English (en)",
                show_label=False,
                filterable=True,
                # The model auto-detects the source from the text; this control only feeds the
                # swap button. Say so, so the symmetric layout doesn't imply it sets the source.
                info="Auto-detected from your text; used only for the swap button.",
            )
            swap_btn = gr.Button("⇄", scale=0, min_width=60)
            target_language = gr.Dropdown(
                choices=language_names,
                value="French (fr)",
                show_label=False,
                filterable=True,
                # MADLAD-400 is a research model; the paper notes quality varies sharply by language.
                info="Quality varies by language; strongest for high-resource languages.",
            )

        with gr.Row(equal_height=True):
            input_text = gr.Textbox(
                lines=6,
                max_length=2000,
                show_label=False,
                autofocus=True,
                info="Press Ctrl+Enter to translate.",
            )
            output_text = gr.Textbox(
                placeholder="Translation",
                lines=6,
                show_label=False,
                interactive=False,
                buttons=["copy"],
                info="MADLAD-400 (google/madlad400-3b-mt) · arXiv:2309.04662 · Apache-2.0",
            )

        translate_btn = gr.Button("Translate", variant="primary")

        # translate() already accepts these generation params; expose them here behind a
        # collapsed accordion (gr.Number, not gr.Slider — the UI is slider-free by design).
        # Defaults mirror translate()'s signature so the default surface stays greedy.
        with gr.Accordion("Advanced", open=False):
            max_new_tokens = gr.Number(
                value=512,
                label="Max new tokens",
                minimum=1,
                maximum=1024,
                precision=0,
                info="Caps the length of the translation.",
            )
            num_beams = gr.Number(
                value=1,
                label="Beams",
                minimum=1,
                maximum=8,
                precision=0,
                info="Beam-search width; higher is slower and ignores temperature.",
            )
            temperature = gr.Number(
                value=1.0,
                label="Temperature",
                minimum=0.1,
                maximum=2.0,
                step=0.1,
                info="No effect at 1.0 (greedy) or when Beams > 1; below 1.0 is more focused, above 1.0 more random.",
            )

        # UI-only handlers: kept off the public API surface (private) so only /translate is exposed.
        swap_btn.click(
            fn=_swap_languages,
            inputs=[source_language, target_language, input_text, output_text],
            outputs=[source_language, target_language, input_text, output_text],
            api_visibility="private",
        )
        # Both translate handlers carry the advanced params: the button (via the loading-state
        # wrapper, which also applies the RTL output update) and the public /translate submit.
        translate_btn.click(
            fn=_translate_with_loading,
            inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
            outputs=[translate_btn, output_text],
            api_visibility="private",
            show_progress="minimal",
        )
        # /translate exposes the advanced params too. They all have defaults, so existing
        # two-arg callers (text, target) keep working; wiring them here also makes Ctrl+Enter
        # honor the Advanced accordion, matching the Translate button. The endpoint returns a
        # bare str, so an RTL target submitted via Ctrl+Enter is NOT direction-flipped — that
        # happens only on the Translate-button path (an accepted, documented UI divergence).
        input_text.submit(
            fn=translate,
            inputs=[input_text, target_language, max_new_tokens, num_beams, temperature],
            outputs=output_text,
            api_name="translate",
            show_progress="minimal",
        )

    return demo


demo = _build_demo()
_maybe_eager_load()


def main() -> None:
    demo.launch(theme=gr.themes.Ocean())


if __name__ == "__main__":
    main()