File size: 2,769 Bytes
1f39aa6
0dc9d0c
 
1f39aa6
3fdc7b6
0dc9d0c
 
 
1f39aa6
43b39f0
1f39aa6
0dc9d0c
c571a45
0dc9d0c
c571a45
0dc9d0c
43b39f0
1f39aa6
0dc9d0c
 
 
 
 
 
 
43b39f0
c571a45
0dc9d0c
 
 
 
 
 
3fdc7b6
1f39aa6
0dc9d0c
 
 
 
1f39aa6
 
0dc9d0c
 
 
 
 
1f39aa6
3fdc7b6
1f39aa6
0dc9d0c
 
 
 
 
 
 
 
1f39aa6
43b39f0
0dc9d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Translation interface using the MADLAD-400 3B model.
Translates English text to nearly 400 languages.
"""

import warnings
from functools import lru_cache

import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from langmap.langid_mapping import langid_to_language

MODEL_NAME = "google/madlad400-3b-mt"


def _get_device() -> torch.device:
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    warnings.warn("No GPU available. Running on CPU — translation will be very slow.", stacklevel=2)
    return torch.device("cpu")


@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    if tokenizer is None:
        raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
    return tokenizer


@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
    device = _get_device()
    return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)


def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
    tokenizer = _load_tokenizer()
    vocab = tokenizer.get_vocab()
    name_to_code = {v: k for k, v in langid_to_language.items() if k in vocab}
    return name_to_code, sorted(name_to_code.keys())


@spaces.GPU
def translate(text: str, target_language_name: str) -> str:
    tokenizer = _load_tokenizer()
    model = _load_model()
    device = _get_device()

    name_to_code, _ = _build_language_mappings()
    target_code = name_to_code.get(target_language_name)
    if target_code is None:
        raise ValueError(f"Unsupported language: {target_language_name}")

    input_ids = tokenizer(target_code + text, return_tensors="pt").input_ids.to(device)
    outputs = model.generate(input_ids=input_ids, max_new_tokens=512)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if not isinstance(result, str):
        raise TypeError(f"Expected str from decode, got {type(result)}")
    return result


def main() -> None:
    _, language_names = _build_language_mappings()

    demo = gr.Interface(
        fn=translate,
        inputs=[
            gr.Textbox(label="Text", placeholder="Enter text here"),
            gr.Dropdown(choices=language_names, value="Hawaiian", label="Target language"),
        ],
        outputs=gr.Textbox(label="Translation"),
        title="MADLAD-400 Translation",
        description="Translation from English to (almost) 400 languages based on "
        "[research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research.",
    )

    demo.launch()


if __name__ == "__main__":
    main()