aaannnlll commited on
Commit
fbf7153
·
verified ·
1 Parent(s): ae4ef90

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +424 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from irodori_tts.inference_runtime import (
11
+ InferenceRuntime,
12
+ RuntimeKey,
13
+ SamplingRequest,
14
+ )
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Configuration
18
+ # ---------------------------------------------------------------------------
19
+
20
+ MODEL_REPO = os.environ.get("MODEL_REPO", "Aratako/Irodori-TTS-500M-v3")
21
+ CODEC_REPO = "Aratako/Semantic-DACVAE-Japanese-32dim"
22
+ MAX_GRADIO_CANDIDATES = int(os.environ.get("MAX_GRADIO_CANDIDATES", "32"))
23
+ GRADIO_AUDIO_COLS_PER_ROW = 8
24
+
25
+ # Global state
26
+ _runtime: InferenceRuntime | None = None
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Helpers
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ def _parse_optional_float(raw: str | None, label: str) -> float | None:
35
+ if raw is None:
36
+ return None
37
+ text = str(raw).strip()
38
+ if text == "" or text.lower() == "none":
39
+ return None
40
+ try:
41
+ return float(text)
42
+ except ValueError as exc:
43
+ raise ValueError(f"{label} must be a float or blank.") from exc
44
+
45
+
46
+ def _parse_optional_int(raw: str | None, label: str) -> int | None:
47
+ if raw is None:
48
+ return None
49
+ text = str(raw).strip()
50
+ if text == "" or text.lower() == "none":
51
+ return None
52
+ try:
53
+ return int(text)
54
+ except ValueError as exc:
55
+ raise ValueError(f"{label} must be an int or blank.") from exc
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Model Loading
60
+ # ---------------------------------------------------------------------------
61
+
62
+
63
+ def load_models():
64
+ global _runtime
65
+
66
+ if _runtime is not None:
67
+ return
68
+
69
+ print(f"[Info] Downloading checkpoint from {MODEL_REPO}...")
70
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors")
71
+
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ precision = "bf16" if device == "cuda" else "fp32"
74
+
75
+ key = RuntimeKey(
76
+ checkpoint=checkpoint_path,
77
+ model_device=device,
78
+ codec_repo=CODEC_REPO,
79
+ model_precision=precision,
80
+ codec_device=device,
81
+ codec_precision=precision,
82
+ )
83
+
84
+ print("[Info] Building runtime...")
85
+ _runtime = InferenceRuntime.from_key(key)
86
+ print("[Info] All models loaded successfully.")
87
+
88
+
89
+ # Load models at startup
90
+ load_models()
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # GPU-decorated Inference
95
+ # ---------------------------------------------------------------------------
96
+
97
+
98
+ @spaces.GPU(duration=120)
99
+ def run_inference_gpu(
100
+ text: str,
101
+ uploaded_audio: str | None,
102
+ num_steps: int,
103
+ num_candidates: int,
104
+ seed_raw: str,
105
+ seconds_raw: str,
106
+ duration_scale: float,
107
+ cfg_guidance_mode: str,
108
+ cfg_scale_text: float,
109
+ cfg_scale_speaker: float,
110
+ cfg_scale_raw: str,
111
+ cfg_min_t: float,
112
+ cfg_max_t: float,
113
+ context_kv_cache: bool,
114
+ truncation_factor_raw: str,
115
+ rescale_k_raw: str,
116
+ rescale_sigma_raw: str,
117
+ speaker_kv_scale_raw: str,
118
+ speaker_kv_min_t_raw: str,
119
+ speaker_kv_max_layers_raw: str,
120
+ ) -> tuple[list[tuple[int, np.ndarray]], str]:
121
+ load_models()
122
+
123
+ log_buffer = io.StringIO()
124
+
125
+ def stdout_log(msg: str) -> None:
126
+ print(msg, flush=True)
127
+ log_buffer.write(msg + "\n")
128
+
129
+ if not str(text).strip():
130
+ raise gr.Error("Please enter text to synthesize.")
131
+
132
+ cfg_scale = _parse_optional_float(cfg_scale_raw, "cfg_scale")
133
+ truncation_factor = _parse_optional_float(truncation_factor_raw, "truncation_factor")
134
+ rescale_k = _parse_optional_float(rescale_k_raw, "rescale_k")
135
+ rescale_sigma = _parse_optional_float(rescale_sigma_raw, "rescale_sigma")
136
+ speaker_kv_scale = _parse_optional_float(speaker_kv_scale_raw, "speaker_kv_scale")
137
+ speaker_kv_min_t = _parse_optional_float(speaker_kv_min_t_raw, "speaker_kv_min_t")
138
+ speaker_kv_max_layers = _parse_optional_int(speaker_kv_max_layers_raw, "speaker_kv_max_layers")
139
+ seed = _parse_optional_int(seed_raw, "seed")
140
+ manual_seconds = _parse_optional_float(seconds_raw, "seconds")
141
+ requested_candidates = int(num_candidates)
142
+ if requested_candidates <= 0:
143
+ raise gr.Error("num_candidates must be >= 1.")
144
+ if requested_candidates > MAX_GRADIO_CANDIDATES:
145
+ raise gr.Error(f"num_candidates must be <= {MAX_GRADIO_CANDIDATES}.")
146
+
147
+ ref_wav: str | None = None
148
+ no_ref = True
149
+ if uploaded_audio is not None and str(uploaded_audio).strip() != "":
150
+ ref_wav = str(uploaded_audio)
151
+ no_ref = False
152
+
153
+ stdout_log(
154
+ (
155
+ "[Info] request: mode={} seconds={} duration_scale={} "
156
+ "steps={} seed={} no_ref={} candidates={}"
157
+ ).format(
158
+ cfg_guidance_mode,
159
+ "auto" if manual_seconds is None else manual_seconds,
160
+ float(duration_scale),
161
+ int(num_steps),
162
+ "random" if seed is None else seed,
163
+ no_ref,
164
+ requested_candidates,
165
+ )
166
+ )
167
+
168
+ result = _runtime.synthesize(
169
+ SamplingRequest(
170
+ text=str(text),
171
+ ref_wav=ref_wav,
172
+ ref_latent=None,
173
+ no_ref=bool(no_ref),
174
+ ref_normalize_db=-16.0,
175
+ ref_ensure_max=True,
176
+ num_candidates=requested_candidates,
177
+ decode_mode="sequential",
178
+ seconds=manual_seconds,
179
+ duration_scale=float(duration_scale),
180
+ max_ref_seconds=30.0,
181
+ max_text_len=None,
182
+ num_steps=int(num_steps),
183
+ seed=None if seed is None else int(seed),
184
+ cfg_guidance_mode=str(cfg_guidance_mode),
185
+ cfg_scale_text=float(cfg_scale_text),
186
+ cfg_scale_speaker=float(cfg_scale_speaker),
187
+ cfg_scale=cfg_scale,
188
+ cfg_min_t=float(cfg_min_t),
189
+ cfg_max_t=float(cfg_max_t),
190
+ truncation_factor=truncation_factor,
191
+ rescale_k=rescale_k,
192
+ rescale_sigma=rescale_sigma,
193
+ context_kv_cache=bool(context_kv_cache),
194
+ speaker_kv_scale=speaker_kv_scale,
195
+ speaker_kv_min_t=speaker_kv_min_t,
196
+ speaker_kv_max_layers=speaker_kv_max_layers,
197
+ trim_tail=True,
198
+ ),
199
+ log_fn=stdout_log,
200
+ )
201
+
202
+ sample_rate = result.sample_rate
203
+ audio_results: list[tuple[int, np.ndarray]] = []
204
+ for audio in result.audios:
205
+ waveform = audio.squeeze(0).float().numpy()
206
+ audio_results.append((sample_rate, waveform))
207
+ stdout_log(f"[Info] seed_used: {result.used_seed}")
208
+ stdout_log(f"[Info] candidates: {len(result.audios)}")
209
+ return audio_results, log_buffer.getvalue()
210
+
211
+
212
+ # ---------------------------------------------------------------------------
213
+ # Gradio UI
214
+ # ---------------------------------------------------------------------------
215
+
216
+
217
+ def build_demo():
218
+ MODEL_LINK = f"https://huggingface.co/{MODEL_REPO}"
219
+ GITHUB_REPO = "https://github.com/Aratako/Irodori-TTS"
220
+
221
+ title = "# Irodori-TTS-500M-v3 Demo"
222
+ description = f"""\
223
+ [Model]({MODEL_LINK}) | [GitHub]({GITHUB_REPO})
224
+
225
+ Flow-matching based Japanese TTS model (500M parameters). \
226
+ Generates speech from text using rectified flow over DACVAE latents.
227
+
228
+ - **Reference audio**: Optional. Upload to condition the speaker voice. \
229
+ Leave blank for unconditional generation.
230
+ - **Duration**: By default, v3 predicts the output duration automatically. \
231
+ Use Duration Scale for small adjustments or Seconds for exact manual control.
232
+ """
233
+
234
+ with gr.Blocks() as demo:
235
+ gr.Markdown(title)
236
+ gr.Markdown(description)
237
+
238
+ text = gr.Textbox(label="Text", lines=4)
239
+ uploaded_audio = gr.Audio(
240
+ label="Reference Audio Upload (optional, blank = no-reference mode)",
241
+ type="filepath",
242
+ )
243
+
244
+ with gr.Accordion("Sampling", open=True):
245
+ with gr.Row():
246
+ num_steps = gr.Slider(
247
+ label="Num Steps",
248
+ minimum=1,
249
+ maximum=120,
250
+ value=40,
251
+ step=1,
252
+ )
253
+ num_candidates = gr.Slider(
254
+ label="Num Candidates",
255
+ minimum=1,
256
+ maximum=MAX_GRADIO_CANDIDATES,
257
+ value=1,
258
+ step=1,
259
+ )
260
+ seed_raw = gr.Textbox(
261
+ label="Seed (blank=random)",
262
+ value="",
263
+ )
264
+ seconds_raw = gr.Textbox(
265
+ label="Seconds (blank=auto)",
266
+ value="",
267
+ )
268
+ duration_scale = gr.Slider(
269
+ label="Duration Scale",
270
+ minimum=0.5,
271
+ maximum=1.5,
272
+ value=1.0,
273
+ step=0.01,
274
+ )
275
+
276
+ with gr.Row():
277
+ cfg_guidance_mode = gr.Dropdown(
278
+ label="CFG Guidance Mode",
279
+ choices=["independent", "joint", "alternating"],
280
+ value="independent",
281
+ )
282
+ cfg_scale_text = gr.Slider(
283
+ label="CFG Scale Text",
284
+ minimum=0.0,
285
+ maximum=10.0,
286
+ value=3.0,
287
+ step=0.1,
288
+ )
289
+ cfg_scale_speaker = gr.Slider(
290
+ label="CFG Scale Speaker",
291
+ minimum=0.0,
292
+ maximum=10.0,
293
+ value=5.0,
294
+ step=0.1,
295
+ )
296
+
297
+ with gr.Accordion("Advanced (Optional)", open=False):
298
+ cfg_scale_raw = gr.Textbox(label="CFG Scale Override (optional)", value="")
299
+ with gr.Row():
300
+ cfg_min_t = gr.Number(label="CFG Min t", value=0.5)
301
+ cfg_max_t = gr.Number(label="CFG Max t", value=1.0)
302
+ context_kv_cache = gr.Checkbox(label="Context KV Cache", value=True)
303
+ with gr.Row():
304
+ truncation_factor_raw = gr.Textbox(label="Truncation Factor (optional)", value="")
305
+ rescale_k_raw = gr.Textbox(label="Rescale k (optional)", value="")
306
+ rescale_sigma_raw = gr.Textbox(label="Rescale sigma (optional)", value="")
307
+ with gr.Row():
308
+ speaker_kv_scale_raw = gr.Textbox(label="Speaker KV Scale (optional)", value="")
309
+ speaker_kv_min_t_raw = gr.Textbox(label="Speaker KV Min t (optional)", value="0.9")
310
+ speaker_kv_max_layers_raw = gr.Textbox(
311
+ label="Speaker KV Max Layers (optional)", value=""
312
+ )
313
+
314
+ generate_btn = gr.Button("Generate", variant="primary")
315
+
316
+ out_audios: list[gr.Audio] = []
317
+ num_rows = (
318
+ MAX_GRADIO_CANDIDATES + GRADIO_AUDIO_COLS_PER_ROW - 1
319
+ ) // GRADIO_AUDIO_COLS_PER_ROW
320
+ with gr.Column():
321
+ for row_idx in range(num_rows):
322
+ with gr.Row():
323
+ for col_idx in range(GRADIO_AUDIO_COLS_PER_ROW):
324
+ i = row_idx * GRADIO_AUDIO_COLS_PER_ROW + col_idx
325
+ if i >= MAX_GRADIO_CANDIDATES:
326
+ break
327
+ out_audios.append(
328
+ gr.Audio(
329
+ label=f"Generated Audio {i + 1}",
330
+ type="numpy",
331
+ visible=(i == 0),
332
+ )
333
+ )
334
+ out_log = gr.Textbox(label="Run Log", lines=6)
335
+
336
+ def gradio_inference(
337
+ text,
338
+ uploaded_audio,
339
+ num_steps,
340
+ num_candidates,
341
+ seed_raw,
342
+ seconds_raw,
343
+ duration_scale,
344
+ cfg_guidance_mode,
345
+ cfg_scale_text,
346
+ cfg_scale_speaker,
347
+ cfg_scale_raw,
348
+ cfg_min_t,
349
+ cfg_max_t,
350
+ context_kv_cache,
351
+ truncation_factor_raw,
352
+ rescale_k_raw,
353
+ rescale_sigma_raw,
354
+ speaker_kv_scale_raw,
355
+ speaker_kv_min_t_raw,
356
+ speaker_kv_max_layers_raw,
357
+ ):
358
+ try:
359
+ audio_results, log_text = run_inference_gpu(
360
+ text=text,
361
+ uploaded_audio=uploaded_audio,
362
+ num_steps=num_steps,
363
+ num_candidates=num_candidates,
364
+ seed_raw=seed_raw,
365
+ seconds_raw=seconds_raw,
366
+ duration_scale=duration_scale,
367
+ cfg_guidance_mode=cfg_guidance_mode,
368
+ cfg_scale_text=cfg_scale_text,
369
+ cfg_scale_speaker=cfg_scale_speaker,
370
+ cfg_scale_raw=cfg_scale_raw,
371
+ cfg_min_t=cfg_min_t,
372
+ cfg_max_t=cfg_max_t,
373
+ context_kv_cache=context_kv_cache,
374
+ truncation_factor_raw=truncation_factor_raw,
375
+ rescale_k_raw=rescale_k_raw,
376
+ rescale_sigma_raw=rescale_sigma_raw,
377
+ speaker_kv_scale_raw=speaker_kv_scale_raw,
378
+ speaker_kv_min_t_raw=speaker_kv_min_t_raw,
379
+ speaker_kv_max_layers_raw=speaker_kv_max_layers_raw,
380
+ )
381
+ audio_updates: list[object] = []
382
+ for i in range(MAX_GRADIO_CANDIDATES):
383
+ if i < len(audio_results):
384
+ audio_updates.append(gr.update(value=audio_results[i], visible=True))
385
+ else:
386
+ audio_updates.append(gr.update(value=None, visible=False))
387
+ return (*audio_updates, log_text)
388
+ except Exception as e:
389
+ raise gr.Error(str(e)) from e
390
+
391
+ generate_btn.click(
392
+ fn=gradio_inference,
393
+ inputs=[
394
+ text,
395
+ uploaded_audio,
396
+ num_steps,
397
+ num_candidates,
398
+ seed_raw,
399
+ seconds_raw,
400
+ duration_scale,
401
+ cfg_guidance_mode,
402
+ cfg_scale_text,
403
+ cfg_scale_speaker,
404
+ cfg_scale_raw,
405
+ cfg_min_t,
406
+ cfg_max_t,
407
+ context_kv_cache,
408
+ truncation_factor_raw,
409
+ rescale_k_raw,
410
+ rescale_sigma_raw,
411
+ speaker_kv_scale_raw,
412
+ speaker_kv_min_t_raw,
413
+ speaker_kv_max_layers_raw,
414
+ ],
415
+ outputs=[*out_audios, out_log],
416
+ )
417
+
418
+ return demo
419
+
420
+
421
+ if __name__ == "__main__":
422
+ demo = build_demo()
423
+ demo.queue(default_concurrency_limit=1)
424
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.5.1
2
+ torchaudio>=2.5.1
3
+ transformers<5
4
+ sentencepiece>=0.1.99,<0.2
5
+ safetensors>=0.7.0
6
+ soundfile>=0.12.0
7
+ huggingface-hub>=0.34.0,<1.0
8
+ gradio>=5.0.0
9
+ numpy
10
+ peft>=0.18.0
11
+ dacvae @ git+https://github.com/facebookresearch/dacvae
12
+ torchcodec>=0.10.0
13
+ silentcipher @ git+https://github.com/SesameAILabs/silentcipher.git@d46d7d0893a583d8968ab3a6626e2289faec9152