Spaces:
Running on Zero
Running on Zero
File size: 2,769 Bytes
1f39aa6 0dc9d0c 1f39aa6 3fdc7b6 0dc9d0c 1f39aa6 43b39f0 1f39aa6 0dc9d0c c571a45 0dc9d0c c571a45 0dc9d0c 43b39f0 1f39aa6 0dc9d0c 43b39f0 c571a45 0dc9d0c 3fdc7b6 1f39aa6 0dc9d0c 1f39aa6 0dc9d0c 1f39aa6 3fdc7b6 1f39aa6 0dc9d0c 1f39aa6 43b39f0 0dc9d0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | """
Translation interface using the MADLAD-400 3B model.
Translates English text to nearly 400 languages.
"""
import warnings
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
def _get_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be very slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
device = _get_device()
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab}
return name_to_code, sorted(name_to_code.keys())
@spaces.GPU
def translate(text: str, target_language_name: str) -> str:
tokenizer = _load_tokenizer()
model = _load_model()
device = _get_device()
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
input_ids = tokenizer(target_code + text, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=512)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not isinstance(result, str):
raise TypeError(f"Expected str from decode, got {type(result)}")
return result
def main() -> None:
_, language_names = _build_language_mappings()
demo = gr.Interface(
fn=translate,
inputs=[
gr.Textbox(label="Text", placeholder="Enter text here"),
gr.Dropdown(choices=language_names, value="Hawaiian", label="Target language"),
],
outputs=gr.Textbox(label="Translation"),
title="MADLAD-400 Translation",
description="Translation from English to (almost) 400 languages based on "
"[research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research.",
)
demo.launch()
if __name__ == "__main__":
main()
|