Daryl Lim
feat: update browser tab title to MADLAD-400 Translate
1414eac
Raw
History Blame
4.77 kB
"""
Translation interface using the MADLAD-400 3B model.
Translates English text to 22 production-ready languages from the MADLAD-400 paper.
"""
import warnings
from collections.abc import Generator
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
device = _get_device()
dtype = torch.float16 if device.type == "cuda" else torch.float32
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)
@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code = {}
for code, name in langid_to_language.items():
if code in vocab:
locale = code[2:-1] # <2fr> → fr
display_name = f"{name} ({locale})"
name_to_code[display_name] = code
return name_to_code, sorted(name_to_code.keys())
@spaces.GPU
def translate(
text: str,
target_language_name: str,
max_new_tokens: int = 512,
num_beams: int = 1,
temperature: float = 1.0,
) -> str:
tokenizer = _load_tokenizer()
model = _load_model()
device = model.device
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
if num_beams > 1 and temperature != 1.0:
gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")
input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)
generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
if num_beams == 1:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
outputs = model.generate(**generate_kwargs)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not isinstance(result, str):
raise TypeError(f"Expected str from decode, got {type(result)}")
return result
def _translate_with_loading(
text: str,
target_language_name: str,
) -> Generator[tuple, None, None]:
yield gr.update(value="Translating...", interactive=False), gr.update()
result = translate(text, target_language_name)
yield gr.update(value="Translate", interactive=True), result
def _build_demo() -> gr.Blocks:
_, language_names = _build_language_mappings()
with gr.Blocks(title="MADLAD-400 Translate") as demo:
with gr.Row(equal_height=True):
with gr.Column():
gr.Dropdown(
choices=["English"],
value="English",
show_label=False,
interactive=False,
)
input_text = gr.Textbox(
lines=6,
max_length=2000,
show_label=False,
buttons=["clear"],
)
with gr.Column():
target_language = gr.Dropdown(
choices=language_names,
value="French (fr)",
show_label=False,
filterable=True,
)
output_text = gr.Textbox(
placeholder="Translation",
lines=6,
show_label=False,
interactive=False,
buttons=["copy"],
)
translate_btn = gr.Button("Translate", variant="primary")
translate_btn.click(
fn=_translate_with_loading,
inputs=[input_text, target_language],
outputs=[translate_btn, output_text],
)
input_text.submit(
fn=translate,
inputs=[input_text, target_language],
outputs=output_text,
)
return demo
demo = _build_demo()
def main() -> None:
demo.launch(theme=gr.themes.Ocean())
if __name__ == "__main__":
main()