Spaces:
Running on Zero
Running on Zero
File size: 4,768 Bytes
1f39aa6 0dc9d0c 67a163c 1f39aa6 3fdc7b6 0dc9d0c 6edcaed 0dc9d0c 1f39aa6 43b39f0 1f39aa6 0dc9d0c c571a45 0dc9d0c c571a45 0dc9d0c 43b39f0 1f39aa6 0dc9d0c de4ab0d 0dc9d0c 43b39f0 c571a45 0dc9d0c 3fdc7b6 1f39aa6 0dc9d0c de4ab0d 151fa6d 1f39aa6 e0dc6e1 0dc9d0c 2ffda3e 0dc9d0c 1f39aa6 3fdc7b6 1f39aa6 8d0070f 0dc9d0c e0dc6e1 0dc9d0c 1f39aa6 43b39f0 8d0070f 7d583b2 8d0070f 0dc9d0c 6edcaed 613b55a 0dc9d0c 1414eac 2ffda3e 57a1d31 2ffda3e 12ff0ac badac0a 2ffda3e ea7372a 2ffda3e 12ff0ac 2ffda3e c077bcb 8ff0bfd c077bcb 6edcaed c077bcb 6edcaed c077bcb 2ffda3e 3872c0b 613b55a 0758e3a 0dc9d0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Translation interface using the MADLAD-400 3B model.
Translates English text to 22 production-ready languages from the MADLAD-400 paper.
"""
import warnings
from collections.abc import Generator
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
device = _get_device()
dtype = torch.float16 if device.type == "cuda" else torch.float32
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)
@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code = {}
for code, name in langid_to_language.items():
if code in vocab:
locale = code[2:-1] # <2fr> → fr
display_name = f"{name} ({locale})"
name_to_code[display_name] = code
return name_to_code, sorted(name_to_code.keys())
@spaces.GPU
def translate(
text: str,
target_language_name: str,
max_new_tokens: int = 512,
num_beams: int = 1,
temperature: float = 1.0,
) -> str:
tokenizer = _load_tokenizer()
model = _load_model()
device = model.device
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
if num_beams > 1 and temperature != 1.0:
gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")
input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)
generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
if num_beams == 1:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
outputs = model.generate(**generate_kwargs)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not isinstance(result, str):
raise TypeError(f"Expected str from decode, got {type(result)}")
return result
def _translate_with_loading(
text: str,
target_language_name: str,
) -> Generator[tuple, None, None]:
yield gr.update(value="Translating...", interactive=False), gr.update()
result = translate(text, target_language_name)
yield gr.update(value="Translate", interactive=True), result
def _build_demo() -> gr.Blocks:
_, language_names = _build_language_mappings()
with gr.Blocks(title="MADLAD-400 Translate") as demo:
with gr.Row(equal_height=True):
with gr.Column():
gr.Dropdown(
choices=["English"],
value="English",
show_label=False,
interactive=False,
)
input_text = gr.Textbox(
lines=6,
max_length=2000,
show_label=False,
buttons=["clear"],
)
with gr.Column():
target_language = gr.Dropdown(
choices=language_names,
value="French (fr)",
show_label=False,
filterable=True,
)
output_text = gr.Textbox(
placeholder="Translation",
lines=6,
show_label=False,
interactive=False,
buttons=["copy"],
)
translate_btn = gr.Button("Translate", variant="primary")
translate_btn.click(
fn=_translate_with_loading,
inputs=[input_text, target_language],
outputs=[translate_btn, output_text],
)
input_text.submit(
fn=translate,
inputs=[input_text, target_language],
outputs=output_text,
)
return demo
demo = _build_demo()
def main() -> None:
demo.launch(theme=gr.themes.Ocean())
if __name__ == "__main__":
main()
|