File size: 4,768 Bytes
1f39aa6
0dc9d0c
67a163c
1f39aa6
3fdc7b6
0dc9d0c
6edcaed
0dc9d0c
 
1f39aa6
43b39f0
1f39aa6
0dc9d0c
c571a45
0dc9d0c
c571a45
0dc9d0c
43b39f0
1f39aa6
0dc9d0c
 
 
de4ab0d
0dc9d0c
43b39f0
c571a45
0dc9d0c
 
 
 
 
 
3fdc7b6
1f39aa6
0dc9d0c
 
 
de4ab0d
151fa6d
1f39aa6
 
e0dc6e1
0dc9d0c
 
 
2ffda3e
 
 
 
 
 
0dc9d0c
1f39aa6
3fdc7b6
1f39aa6
8d0070f
 
 
 
 
 
 
0dc9d0c
 
e0dc6e1
0dc9d0c
 
 
 
1f39aa6
43b39f0
8d0070f
 
 
7d583b2
8d0070f
 
 
 
 
 
 
0dc9d0c
 
 
 
 
 
6edcaed
 
 
 
 
 
 
 
 
613b55a
0dc9d0c
 
1414eac
2ffda3e
 
57a1d31
 
 
 
 
 
2ffda3e
12ff0ac
badac0a
2ffda3e
ea7372a
2ffda3e
 
 
 
 
 
 
 
 
 
12ff0ac
2ffda3e
 
 
 
c077bcb
8ff0bfd
 
c077bcb
6edcaed
c077bcb
6edcaed
c077bcb
2ffda3e
 
 
 
 
3872c0b
613b55a
 
 
 
 
 
 
0758e3a
0dc9d0c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Translation interface using the MADLAD-400 3B model.
Translates English text to 22 production-ready languages from the MADLAD-400 paper.
"""

import warnings
from collections.abc import Generator
from functools import lru_cache

import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from langmap.langid_mapping import langid_to_language

MODEL_NAME = "google/madlad400-3b-mt"


def _get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
    return torch.device("cpu")


@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    if tokenizer is None:
        raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
    return tokenizer


@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
    device = _get_device()
    dtype = torch.float16 if device.type == "cuda" else torch.float32
    return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)


@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
    tokenizer = _load_tokenizer()
    vocab = tokenizer.get_vocab()
    name_to_code = {}
    for code, name in langid_to_language.items():
        if code in vocab:
            locale = code[2:-1]  # <2fr> → fr
            display_name = f"{name} ({locale})"
            name_to_code[display_name] = code
    return name_to_code, sorted(name_to_code.keys())


@spaces.GPU
def translate(
    text: str,
    target_language_name: str,
    max_new_tokens: int = 512,
    num_beams: int = 1,
    temperature: float = 1.0,
) -> str:
    tokenizer = _load_tokenizer()
    model = _load_model()
    device = model.device

    name_to_code, _ = _build_language_mappings()
    target_code = name_to_code.get(target_language_name)
    if target_code is None:
        raise ValueError(f"Unsupported language: {target_language_name}")

    if num_beams > 1 and temperature != 1.0:
        gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")

    input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)

    generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
    if num_beams == 1:
        generate_kwargs["do_sample"] = True
        generate_kwargs["temperature"] = temperature

    outputs = model.generate(**generate_kwargs)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if not isinstance(result, str):
        raise TypeError(f"Expected str from decode, got {type(result)}")
    return result


def _translate_with_loading(
    text: str,
    target_language_name: str,
) -> Generator[tuple, None, None]:
    yield gr.update(value="Translating...", interactive=False), gr.update()
    result = translate(text, target_language_name)
    yield gr.update(value="Translate", interactive=True), result


def _build_demo() -> gr.Blocks:
    _, language_names = _build_language_mappings()

    with gr.Blocks(title="MADLAD-400 Translate") as demo:
        with gr.Row(equal_height=True):
            with gr.Column():
                gr.Dropdown(
                    choices=["English"],
                    value="English",
                    show_label=False,
                    interactive=False,
                )
                input_text = gr.Textbox(
                    lines=6,
                    max_length=2000,
                    show_label=False,
                    buttons=["clear"],
                )
            with gr.Column():
                target_language = gr.Dropdown(
                    choices=language_names,
                    value="French (fr)",
                    show_label=False,
                    filterable=True,
                )
                output_text = gr.Textbox(
                    placeholder="Translation",
                    lines=6,
                    show_label=False,
                    interactive=False,
                    buttons=["copy"],
                )

        translate_btn = gr.Button("Translate", variant="primary")

        translate_btn.click(
            fn=_translate_with_loading,
            inputs=[input_text, target_language],
            outputs=[translate_btn, output_text],
        )
        input_text.submit(
            fn=translate,
            inputs=[input_text, target_language],
            outputs=output_text,
        )

    return demo


demo = _build_demo()


def main() -> None:
    demo.launch(theme=gr.themes.Ocean())


if __name__ == "__main__":
    main()