Daryl Lim
Refactor app.py to lazy-load model and tokenizer
0dc9d0c
Raw
History Blame
2.77 kB
"""
Translation interface using the MADLAD-400 3B model.
Translates English text to nearly 400 languages.
"""
import warnings
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
def _get_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be very slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
device = _get_device()
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab}
return name_to_code, sorted(name_to_code.keys())
@spaces.GPU
def translate(text: str, target_language_name: str) -> str:
tokenizer = _load_tokenizer()
model = _load_model()
device = _get_device()
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
input_ids = tokenizer(target_code + text, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=512)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not isinstance(result, str):
raise TypeError(f"Expected str from decode, got {type(result)}")
return result
def main() -> None:
_, language_names = _build_language_mappings()
demo = gr.Interface(
fn=translate,
inputs=[
gr.Textbox(label="Text", placeholder="Enter text here"),
gr.Dropdown(choices=language_names, value="Hawaiian", label="Target language"),
],
outputs=gr.Textbox(label="Translation"),
title="MADLAD-400 Translation",
description="Translation from English to (almost) 400 languages based on "
"[research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research.",
)
demo.launch()
if __name__ == "__main__":
main()