""" Translation interface using the MADLAD-400 3B model. Translates between 418 languages from the MADLAD-400 paper. """ import os import time import warnings from collections.abc import Generator from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> T5TokenizerFast: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> T5ForConditionalGeneration: device = _get_device() # T5/MADLAD was trained in bfloat16 and is numerically unstable in float16: its # activations overflow fp16's narrow range, risking inf/NaN (garbage) output. Use # bfloat16 on CUDA (same 2-byte footprint, native on the Blackwell backing GPU, # fp32-range exponent) and float32 on CPU. dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 # MADLAD-400 is a T5 model: its relative-position-bias attention is eager-only # (transformers rejects sdpa / flash_attention_2 / flash_attention_3), so the usual # ZeroGPU attention speedups don't apply. AoTI targets fixed-shape forward passes # (e.g. diffusion), not T5's autoregressive .generate(), so it isn't wired up either. return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code: dict[str, str] = {} for code, info in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{info['name']} ({locale})" name_to_code[display_name] = code # Sort by region, then alphabetically within each region sorted_names = sorted( name_to_code.keys(), key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), ) return name_to_code, sorted_names def _maybe_eager_load() -> None: """On ZeroGPU, place the model at module scope so the ``spaces`` hijack can pack weights to disk at startup and stream them into each worker's VRAM (fast cold starts). Off-ZeroGPU (local, tests, cpu-basic) this is a no-op, so importing the app never downloads the model. ``SPACES_ZERO_GPU`` is set only on ZeroGPU.""" if os.environ.get("SPACES_ZERO_GPU") == "1": _load_tokenizer() _load_model() def _estimate_duration( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> int: """Reserve GPU time scaled to the worst case: generation cost grows with the number of tokens generated and the beam width. Mirrors translate()'s signature (ZeroGPU calls the duration callable with the decorated function's args). Conservative and capped at 120s; calibrate from the perf_counter log in translate() (zerogpu.md 'Sizing duration').""" del text, target_language_name, temperature # only token/beam counts drive runtime return min(120, 30 + (max_new_tokens * num_beams) // 8) @spaces.GPU(duration=_estimate_duration) def translate( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> str: tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and temperature != 1.0: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} # Greedy by default (deterministic, higher-quality MT). Only sample when the user # explicitly sets a non-default temperature; beam search (num_beams > 1) ignores it. if num_beams == 1 and temperature != 1.0: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature start = time.perf_counter() outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.perf_counter() - start print(f"[translate] max_new_tokens={max_new_tokens} num_beams={num_beams} took {elapsed:.1f}s") # @spaces.GPU return values cross the worker -> main process boundary via pickle. # Return the decoded str, never the raw `outputs` CUDA tensor: unpickling a CUDA # tensor in the main process triggers torch.cuda._lazy_init(), which ZeroGPU blocks # and hangs the call. CPU tensors, numpy arrays, and plain str/objects are safe. return result def _translate_with_loading( text: str, target_language_name: str, ) -> Generator[tuple[object, object], None, None]: yield gr.update(value="Translating...", interactive=False), gr.update() result = translate(text, target_language_name) yield gr.update(value="Translate", interactive=True), result def _swap_languages( source_lang: str, target_lang: str, source_text: str, target_text: str ) -> tuple[str, str, str, str]: """Swap source/target languages and their text.""" return target_lang, source_lang, target_text, source_text def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translate") as demo: with gr.Row(): source_language = gr.Dropdown( choices=language_names, value="English (en)", show_label=False, filterable=True, ) swap_btn = gr.Button("⇄", scale=0, min_width=60) target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, ) with gr.Row(equal_height=True): input_text = gr.Textbox( lines=6, max_length=2000, show_label=False, autofocus=True, ) output_text = gr.Textbox( placeholder="Translation", lines=6, show_label=False, interactive=False, buttons=["copy"], ) translate_btn = gr.Button("Translate", variant="primary") # UI-only handlers: kept off the public API surface (private) so only /translate is exposed. swap_btn.click( fn=_swap_languages, inputs=[source_language, target_language, input_text, output_text], outputs=[source_language, target_language, input_text, output_text], api_visibility="private", ) translate_btn.click( fn=_translate_with_loading, inputs=[input_text, target_language], outputs=[translate_btn, output_text], api_visibility="private", ) input_text.submit( fn=translate, inputs=[input_text, target_language], outputs=output_text, api_name="translate", ) return demo demo = _build_demo() _maybe_eager_load() def main() -> None: demo.launch(theme=gr.themes.Ocean()) if __name__ == "__main__": main()