Masaaki Kawata commited on
Commit
aaffb76
·
1 Parent(s): 6209837

Remove Faster Irodori TTS runtime and related files

Browse files
Dockerfile CHANGED
@@ -34,7 +34,6 @@ RUN python -m pip install --upgrade pip setuptools wheel \
34
 
35
  COPY app.py .
36
  COPY faster_qwen3_tts ./faster_qwen3_tts
37
- COPY faster_irodori_tts ./faster_irodori_tts
38
  COPY qwen_tts ./qwen_tts
39
  COPY irodori_tts ./irodori_tts
40
 
 
34
 
35
  COPY app.py .
36
  COPY faster_qwen3_tts ./faster_qwen3_tts
 
37
  COPY qwen_tts ./qwen_tts
38
  COPY irodori_tts ./irodori_tts
39
 
app.py CHANGED
@@ -22,15 +22,14 @@ from huggingface_hub import hf_hub_download
22
  #from huggingface_hub import snapshot_download
23
  #from qwen_tts import Qwen3TTSModel
24
  from faster_qwen3_tts import FasterQwen3TTS
25
- #from irodori_tts.inference_runtime import InferenceRuntime, RuntimeKey, SamplingRequest
26
- from faster_irodori_tts import FasterIrodoriTTSRuntime, RuntimeKey, SamplingRequest
27
 
28
 
29
  load_dotenv(verbose=False)
30
 
31
  #TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
32
  TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
33
- IRODORI_TTS_RUNTIME: Optional[FasterIrodoriTTSRuntime] = None
34
  WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
35
  REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
36
  REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
@@ -235,7 +234,7 @@ def generate_voice_clone(model: str | None, input_text: str, language: str | Non
235
  if IRODORI_TTS_RUNTIME is None:
236
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
237
  precision = 'bf16' if device == 'cuda' else 'fp32'
238
- IRODORI_TTS_RUNTIME = FasterIrodoriTTSRuntime.from_key(RuntimeKey(
239
  checkpoint=hf_hub_download(repo_id='Aratako/Irodori-TTS-500M-v2', filename='model.safetensors'),
240
  model_device=device,
241
  codec_repo='Aratako/Semantic-DACVAE-Japanese-32dim',
@@ -245,6 +244,9 @@ def generate_voice_clone(model: str | None, input_text: str, language: str | Non
245
  enable_watermark=False,
246
  ))
247
 
 
 
 
248
  result = IRODORI_TTS_RUNTIME.synthesize(SamplingRequest(
249
  text=input_text,
250
  ref_wav=reference_audio,
 
22
  #from huggingface_hub import snapshot_download
23
  #from qwen_tts import Qwen3TTSModel
24
  from faster_qwen3_tts import FasterQwen3TTS
25
+ from irodori_tts.inference_runtime import InferenceRuntime, RuntimeKey, SamplingRequest
 
26
 
27
 
28
  load_dotenv(verbose=False)
29
 
30
  #TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
31
  TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
32
+ IRODORI_TTS_RUNTIME: Optional[InferenceRuntime] = None
33
  WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
34
  REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
35
  REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
 
234
  if IRODORI_TTS_RUNTIME is None:
235
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
236
  precision = 'bf16' if device == 'cuda' else 'fp32'
237
+ IRODORI_TTS_RUNTIME = InferenceRuntime.from_key(RuntimeKey(
238
  checkpoint=hf_hub_download(repo_id='Aratako/Irodori-TTS-500M-v2', filename='model.safetensors'),
239
  model_device=device,
240
  codec_repo='Aratako/Semantic-DACVAE-Japanese-32dim',
 
244
  enable_watermark=False,
245
  ))
246
 
247
+ if sample_rate != 48000:
248
+ reference_audio = (_resample(reference_audio[0], sample_rate, 48000), 48000)
249
+
250
  result = IRODORI_TTS_RUNTIME.synthesize(SamplingRequest(
251
  text=input_text,
252
  ref_wav=reference_audio,
faster_irodori_tts/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- """CUDA Graph accelerated runtime helpers for Irodori-TTS."""
2
-
3
- from .runtime import (
4
- FasterInferenceRuntime,
5
- FasterIrodoriTTSRuntime,
6
- RuntimeKey,
7
- SamplingRequest,
8
- SamplingResult,
9
- )
10
-
11
- __all__ = [
12
- "FasterInferenceRuntime",
13
- "FasterIrodoriTTSRuntime",
14
- "RuntimeKey",
15
- "SamplingRequest",
16
- "SamplingResult",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faster_irodori_tts/rf_graph.py DELETED
@@ -1,511 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections import OrderedDict
4
- from dataclasses import dataclass
5
-
6
- import torch
7
-
8
- from irodori_tts.rf import sample_euler_rf_cfg
9
-
10
-
11
- @dataclass(frozen=True)
12
- class RFGraphSignature:
13
- batch_size: int
14
- sequence_length: int
15
- latent_dim: int
16
- text_len: int
17
- speaker_len: int
18
- num_steps: int
19
- cfg_scale_text: float
20
- cfg_scale_speaker: float
21
- cfg_min_t: float
22
- cfg_max_t: float
23
- dtype: str
24
- device: str
25
-
26
-
27
- @dataclass
28
- class RFGraphSampleResult:
29
- latent: torch.Tensor
30
- graph_used: bool
31
- fallback_reason: str | None = None
32
-
33
-
34
- def _device_key(device: torch.device) -> str:
35
- index = 0 if device.index is None else int(device.index)
36
- return f"{device.type}:{index}"
37
-
38
-
39
- def _pad_reference_to_bucket(
40
- ref_latent: torch.Tensor,
41
- ref_mask: torch.Tensor,
42
- *,
43
- speaker_patch_size: int,
44
- bucket_multiple: int,
45
- ) -> tuple[torch.Tensor, torch.Tensor]:
46
- if bucket_multiple <= 1:
47
- return ref_latent, ref_mask
48
-
49
- patch = max(1, int(speaker_patch_size))
50
- current = int(ref_latent.shape[1])
51
- after_patch = max(1, (current + patch - 1) // patch)
52
- bucketed_after_patch = (
53
- (after_patch + int(bucket_multiple) - 1) // int(bucket_multiple)
54
- ) * int(bucket_multiple)
55
- target = bucketed_after_patch * patch
56
- if target <= current:
57
- return ref_latent, ref_mask
58
-
59
- pad_len = target - current
60
- latent_pad = torch.zeros(
61
- (ref_latent.shape[0], pad_len, ref_latent.shape[2]),
62
- device=ref_latent.device,
63
- dtype=ref_latent.dtype,
64
- )
65
- mask_pad = torch.zeros(
66
- (ref_mask.shape[0], pad_len),
67
- device=ref_mask.device,
68
- dtype=ref_mask.dtype,
69
- )
70
- return torch.cat([ref_latent, latent_pad], dim=1), torch.cat([ref_mask, mask_pad], dim=1)
71
-
72
-
73
- def _copy_context_kv(
74
- dst: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
75
- src: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
76
- ) -> None:
77
- if len(dst) != len(src):
78
- raise ValueError(f"Context KV layer count mismatch: graph={len(dst)} input={len(src)}")
79
- for dst_layer, src_layer in zip(dst, src):
80
- for dst_tensor, src_tensor in zip(dst_layer, src_layer):
81
- if tuple(dst_tensor.shape) != tuple(src_tensor.shape):
82
- raise ValueError(
83
- "Context KV shape mismatch: "
84
- f"graph={tuple(dst_tensor.shape)} input={tuple(src_tensor.shape)}"
85
- )
86
- dst_tensor.copy_(src_tensor)
87
-
88
-
89
- class IrodoriRFGraph:
90
- """Captured RF Euler sampler for one fixed Irodori-TTS shape/configuration."""
91
-
92
- def __init__(
93
- self,
94
- model,
95
- signature: RFGraphSignature,
96
- *,
97
- num_warmup: int = 2,
98
- ) -> None:
99
- self.model = model
100
- self.signature = signature
101
- self.device = model.device
102
- self.device_index = 0 if self.device.index is None else int(self.device.index)
103
- self.dtype = model.dtype
104
- self.num_warmup = int(num_warmup)
105
- self.cfg_batch_mult = 3
106
-
107
- bsz = signature.batch_size
108
- seq_len = signature.sequence_length
109
- latent_dim = signature.latent_dim
110
- cfg_bsz = bsz * self.cfg_batch_mult
111
- text_dim = model.cfg.text_dim
112
- speaker_dim = model.cfg.speaker_dim
113
-
114
- self.x_buf = torch.zeros((bsz, seq_len, latent_dim), device=self.device, dtype=self.dtype)
115
- self.x_cfg_buf = torch.zeros(
116
- (cfg_bsz, seq_len, latent_dim), device=self.device, dtype=self.dtype
117
- )
118
- self.v_buf = torch.zeros_like(self.x_buf)
119
- self.latent_mask = torch.ones((bsz, seq_len), device=self.device, dtype=torch.bool)
120
- self.latent_mask_cfg = torch.ones((cfg_bsz, seq_len), device=self.device, dtype=torch.bool)
121
-
122
- self.text_state_cond = torch.zeros(
123
- (bsz, signature.text_len, text_dim), device=self.device, dtype=self.dtype
124
- )
125
- self.text_mask_cond = torch.zeros(
126
- (bsz, signature.text_len), device=self.device, dtype=torch.bool
127
- )
128
- self.speaker_state_cond = torch.zeros(
129
- (bsz, signature.speaker_len, speaker_dim), device=self.device, dtype=self.dtype
130
- )
131
- self.speaker_mask_cond = torch.zeros(
132
- (bsz, signature.speaker_len), device=self.device, dtype=torch.bool
133
- )
134
-
135
- self.text_state_cfg = torch.zeros(
136
- (cfg_bsz, signature.text_len, text_dim), device=self.device, dtype=self.dtype
137
- )
138
- self.text_mask_cfg = torch.zeros(
139
- (cfg_bsz, signature.text_len), device=self.device, dtype=torch.bool
140
- )
141
- self.speaker_state_cfg = torch.zeros(
142
- (cfg_bsz, signature.speaker_len, speaker_dim), device=self.device, dtype=self.dtype
143
- )
144
- self.speaker_mask_cfg = torch.zeros(
145
- (cfg_bsz, signature.speaker_len), device=self.device, dtype=torch.bool
146
- )
147
-
148
- self.context_kv_cond = self._make_context_kv_buffers(bsz)
149
- self.context_kv_cfg = self._make_context_kv_buffers(cfg_bsz)
150
-
151
- init_scale = 0.999
152
- t_schedule = torch.linspace(
153
- 1.0,
154
- 0.0,
155
- signature.num_steps + 1,
156
- device=self.device,
157
- dtype=torch.float32,
158
- ) * init_scale
159
- self.t_cond = [torch.full((bsz,), t_schedule[i], device=self.device, dtype=self.dtype)
160
- for i in range(signature.num_steps)]
161
- self.t_cfg = [self.t_cond[i].repeat(self.cfg_batch_mult)
162
- for i in range(signature.num_steps)]
163
- self.deltas = [
164
- float((t_schedule[i + 1] - t_schedule[i]).detach().cpu())
165
- for i in range(signature.num_steps)
166
- ]
167
- self.use_cfg = [
168
- bool(signature.cfg_min_t <= float(t_schedule[i].detach().cpu()) <= signature.cfg_max_t)
169
- for i in range(signature.num_steps)
170
- ]
171
-
172
- self.graph: torch.cuda.CUDAGraph | None = None
173
- self.captured = False
174
-
175
- def _make_context_kv_buffers(
176
- self,
177
- batch_size: int,
178
- ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
179
- buffers: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
180
- for block in self.model.blocks:
181
- attn = block.attention
182
- k_text = torch.zeros(
183
- (batch_size, self.signature.text_len, attn.heads, attn.head_dim),
184
- device=self.device,
185
- dtype=self.dtype,
186
- )
187
- v_text = torch.zeros_like(k_text)
188
- k_speaker = torch.zeros(
189
- (batch_size, self.signature.speaker_len, attn.heads, attn.head_dim),
190
- device=self.device,
191
- dtype=self.dtype,
192
- )
193
- v_speaker = torch.zeros_like(k_speaker)
194
- buffers.append((k_text, v_text, k_speaker, v_speaker))
195
- return buffers
196
-
197
- def _copy_cfg_x(self) -> None:
198
- bsz = self.signature.batch_size
199
- self.x_cfg_buf[:bsz].copy_(self.x_buf)
200
- self.x_cfg_buf[bsz : 2 * bsz].copy_(self.x_buf)
201
- self.x_cfg_buf[2 * bsz : 3 * bsz].copy_(self.x_buf)
202
-
203
- def _full_loop(self) -> None:
204
- bsz = self.signature.batch_size
205
- scale_text = float(self.signature.cfg_scale_text)
206
- scale_speaker = float(self.signature.cfg_scale_speaker)
207
- cond_weight = 1.0 + scale_text + scale_speaker
208
-
209
- for i in range(self.signature.num_steps):
210
- if self.use_cfg[i]:
211
- self._copy_cfg_x()
212
- v_out = self.model.forward_with_encoded_conditions(
213
- x_t=self.x_cfg_buf,
214
- t=self.t_cfg[i],
215
- text_state=self.text_state_cfg,
216
- text_mask=self.text_mask_cfg,
217
- speaker_state=self.speaker_state_cfg,
218
- speaker_mask=self.speaker_mask_cfg,
219
- latent_mask=self.latent_mask_cfg,
220
- context_kv_cache=self.context_kv_cfg,
221
- )
222
- v_cond = v_out[:bsz]
223
- v_uncond_text = v_out[bsz : 2 * bsz]
224
- v_uncond_speaker = v_out[2 * bsz : 3 * bsz]
225
- self.v_buf.copy_(v_cond)
226
- self.v_buf.mul_(cond_weight)
227
- self.v_buf.add_(v_uncond_text, alpha=-scale_text)
228
- self.v_buf.add_(v_uncond_speaker, alpha=-scale_speaker)
229
- else:
230
- v_out = self.model.forward_with_encoded_conditions(
231
- x_t=self.x_buf,
232
- t=self.t_cond[i],
233
- text_state=self.text_state_cond,
234
- text_mask=self.text_mask_cond,
235
- speaker_state=self.speaker_state_cond,
236
- speaker_mask=self.speaker_mask_cond,
237
- latent_mask=self.latent_mask,
238
- context_kv_cache=self.context_kv_cond,
239
- )
240
- self.v_buf.copy_(v_out)
241
-
242
- self.x_buf.add_(self.v_buf, alpha=self.deltas[i])
243
-
244
- @torch.inference_mode()
245
- def capture(self) -> None:
246
- if self.captured:
247
- return
248
-
249
- # Populate module-side RoPE caches and allocator pools before capture.
250
- for _ in range(max(1, self.num_warmup)):
251
- self._full_loop()
252
- torch.cuda.synchronize(self.device)
253
-
254
- with torch.cuda.device(self.device_index):
255
- stream = torch.cuda.Stream()
256
- stream.wait_stream(torch.cuda.current_stream())
257
- with torch.cuda.stream(stream):
258
- for _ in range(max(1, self.num_warmup)):
259
- self._full_loop()
260
- torch.cuda.synchronize(self.device)
261
-
262
- self.graph = torch.cuda.CUDAGraph()
263
- with torch.cuda.graph(self.graph):
264
- self._full_loop()
265
-
266
- torch.cuda.current_stream().wait_stream(stream)
267
- torch.cuda.synchronize(self.device)
268
- self.captured = True
269
-
270
- @torch.inference_mode()
271
- def run(
272
- self,
273
- *,
274
- x_t: torch.Tensor,
275
- text_state_cond: torch.Tensor,
276
- text_mask_cond: torch.Tensor,
277
- speaker_state_cond: torch.Tensor,
278
- speaker_mask_cond: torch.Tensor,
279
- text_state_cfg: torch.Tensor,
280
- text_mask_cfg: torch.Tensor,
281
- speaker_state_cfg: torch.Tensor,
282
- speaker_mask_cfg: torch.Tensor,
283
- context_kv_cond: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
284
- context_kv_cfg: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
285
- ) -> torch.Tensor:
286
- if not self.captured or self.graph is None:
287
- self.capture()
288
-
289
- self.x_buf.copy_(x_t)
290
- self.text_state_cond.copy_(text_state_cond)
291
- self.text_mask_cond.copy_(text_mask_cond)
292
- self.speaker_state_cond.copy_(speaker_state_cond)
293
- self.speaker_mask_cond.copy_(speaker_mask_cond)
294
- self.text_state_cfg.copy_(text_state_cfg)
295
- self.text_mask_cfg.copy_(text_mask_cfg)
296
- self.speaker_state_cfg.copy_(speaker_state_cfg)
297
- self.speaker_mask_cfg.copy_(speaker_mask_cfg)
298
- _copy_context_kv(self.context_kv_cond, context_kv_cond)
299
- _copy_context_kv(self.context_kv_cfg, context_kv_cfg)
300
-
301
- self.graph.replay()
302
- return self.x_buf.clone()
303
-
304
-
305
- class FasterIrodoriRFSampler:
306
- """Graph cache and safe fallback wrapper for Irodori RF sampling."""
307
-
308
- def __init__(
309
- self,
310
- *,
311
- max_graphs: int = 2,
312
- speaker_bucket_multiple: int = 64,
313
- num_warmup: int = 2,
314
- ) -> None:
315
- self.max_graphs = max(1, int(max_graphs))
316
- self.speaker_bucket_multiple = max(1, int(speaker_bucket_multiple))
317
- self.num_warmup = max(1, int(num_warmup))
318
- self._graphs: OrderedDict[RFGraphSignature, IrodoriRFGraph] = OrderedDict()
319
-
320
- def _unsupported_reason(
321
- self,
322
- *,
323
- model,
324
- cfg_guidance_mode: str,
325
- cfg_scale_text: float,
326
- cfg_scale_speaker: float,
327
- rescale_k: float | None,
328
- rescale_sigma: float | None,
329
- use_context_kv_cache: bool,
330
- speaker_kv_scale: float | None,
331
- ) -> str | None:
332
- if model.device.type != "cuda" or not torch.cuda.is_available():
333
- return "CUDA Graph requires a CUDA device"
334
- if str(cfg_guidance_mode).strip().lower() != "independent":
335
- return "only cfg_guidance_mode='independent' is currently graphed"
336
- if cfg_scale_text <= 0 or cfg_scale_speaker <= 0:
337
- return "graph path currently expects both text and speaker CFG scales to be > 0"
338
- if rescale_k is not None or rescale_sigma is not None:
339
- return "rescale_k/rescale_sigma path is not graph-enabled"
340
- if not use_context_kv_cache:
341
- return "context_kv_cache=False is not graph-enabled"
342
- if speaker_kv_scale is not None:
343
- return "speaker_kv_scale path is not graph-enabled"
344
- return None
345
-
346
- def _get_graph(self, model, signature: RFGraphSignature) -> IrodoriRFGraph:
347
- graph = self._graphs.get(signature)
348
- if graph is not None:
349
- self._graphs.move_to_end(signature)
350
- return graph
351
-
352
- graph = IrodoriRFGraph(model, signature, num_warmup=self.num_warmup)
353
- graph.capture()
354
- self._graphs[signature] = graph
355
- self._graphs.move_to_end(signature)
356
- while len(self._graphs) > self.max_graphs:
357
- self._graphs.popitem(last=False)
358
- return graph
359
-
360
- @torch.inference_mode()
361
- def sample(
362
- self,
363
- *,
364
- model,
365
- text_input_ids: torch.Tensor,
366
- text_mask: torch.Tensor,
367
- ref_latent: torch.Tensor,
368
- ref_mask: torch.Tensor,
369
- sequence_length: int,
370
- num_steps: int = 40,
371
- cfg_scale_text: float = 3.0,
372
- cfg_scale_speaker: float = 5.0,
373
- cfg_guidance_mode: str = "independent",
374
- cfg_min_t: float = 0.5,
375
- cfg_max_t: float = 1.0,
376
- seed: int = 0,
377
- truncation_factor: float | None = None,
378
- rescale_k: float | None = None,
379
- rescale_sigma: float | None = None,
380
- use_context_kv_cache: bool = True,
381
- speaker_kv_scale: float | None = None,
382
- speaker_kv_max_layers: int | None = None,
383
- speaker_kv_min_t: float | None = None,
384
- ) -> RFGraphSampleResult:
385
- def fallback(reason: str) -> RFGraphSampleResult:
386
- return RFGraphSampleResult(
387
- latent=sample_euler_rf_cfg(
388
- model=model,
389
- text_input_ids=text_input_ids,
390
- text_mask=text_mask,
391
- ref_latent=ref_latent,
392
- ref_mask=ref_mask,
393
- sequence_length=sequence_length,
394
- num_steps=num_steps,
395
- cfg_scale_text=cfg_scale_text,
396
- cfg_scale_speaker=cfg_scale_speaker,
397
- cfg_guidance_mode=cfg_guidance_mode,
398
- cfg_min_t=cfg_min_t,
399
- cfg_max_t=cfg_max_t,
400
- seed=seed,
401
- truncation_factor=truncation_factor,
402
- rescale_k=rescale_k,
403
- rescale_sigma=rescale_sigma,
404
- use_context_kv_cache=use_context_kv_cache,
405
- speaker_kv_scale=speaker_kv_scale,
406
- speaker_kv_max_layers=speaker_kv_max_layers,
407
- speaker_kv_min_t=speaker_kv_min_t,
408
- ),
409
- graph_used=False,
410
- fallback_reason=reason,
411
- )
412
-
413
- reason = self._unsupported_reason(
414
- model=model,
415
- cfg_guidance_mode=cfg_guidance_mode,
416
- cfg_scale_text=float(cfg_scale_text),
417
- cfg_scale_speaker=float(cfg_scale_speaker),
418
- rescale_k=rescale_k,
419
- rescale_sigma=rescale_sigma,
420
- use_context_kv_cache=bool(use_context_kv_cache),
421
- speaker_kv_scale=speaker_kv_scale,
422
- )
423
- if reason is not None:
424
- return fallback(reason)
425
-
426
- device = model.device
427
- dtype = model.dtype
428
- batch_size = int(text_input_ids.shape[0])
429
- latent_dim = model.cfg.patched_latent_dim
430
-
431
- ref_latent, ref_mask = _pad_reference_to_bucket(
432
- ref_latent,
433
- ref_mask,
434
- speaker_patch_size=model.cfg.speaker_patch_size,
435
- bucket_multiple=self.speaker_bucket_multiple,
436
- )
437
-
438
- rng = torch.Generator(device=device).manual_seed(int(seed))
439
- x_t = torch.randn(
440
- (batch_size, int(sequence_length), latent_dim),
441
- device=device,
442
- dtype=dtype,
443
- generator=rng,
444
- )
445
- if truncation_factor is not None:
446
- x_t = x_t * float(truncation_factor)
447
-
448
- text_state_cond, text_mask_cond, speaker_state_cond, speaker_mask_cond = (
449
- model.encode_conditions(
450
- text_input_ids=text_input_ids,
451
- text_mask=text_mask,
452
- ref_latent=ref_latent,
453
- ref_mask=ref_mask,
454
- )
455
- )
456
- text_state_uncond = torch.zeros_like(text_state_cond)
457
- text_mask_uncond = torch.zeros_like(text_mask_cond)
458
- speaker_state_uncond = torch.zeros_like(speaker_state_cond)
459
- speaker_mask_uncond = torch.zeros_like(speaker_mask_cond)
460
-
461
- text_state_cfg = torch.cat([text_state_cond, text_state_uncond, text_state_cond], dim=0)
462
- text_mask_cfg = torch.cat([text_mask_cond, text_mask_uncond, text_mask_cond], dim=0)
463
- speaker_state_cfg = torch.cat(
464
- [speaker_state_cond, speaker_state_cond, speaker_state_uncond], dim=0
465
- )
466
- speaker_mask_cfg = torch.cat(
467
- [speaker_mask_cond, speaker_mask_cond, speaker_mask_uncond], dim=0
468
- )
469
-
470
- context_kv_cond = model.build_context_kv_cache(
471
- text_state=text_state_cond,
472
- speaker_state=speaker_state_cond,
473
- )
474
- context_kv_cfg = model.build_context_kv_cache(
475
- text_state=text_state_cfg,
476
- speaker_state=speaker_state_cfg,
477
- )
478
-
479
- signature = RFGraphSignature(
480
- batch_size=batch_size,
481
- sequence_length=int(sequence_length),
482
- latent_dim=int(latent_dim),
483
- text_len=int(text_state_cond.shape[1]),
484
- speaker_len=int(speaker_state_cond.shape[1]),
485
- num_steps=int(num_steps),
486
- cfg_scale_text=float(cfg_scale_text),
487
- cfg_scale_speaker=float(cfg_scale_speaker),
488
- cfg_min_t=float(cfg_min_t),
489
- cfg_max_t=float(cfg_max_t),
490
- dtype=str(dtype),
491
- device=_device_key(device),
492
- )
493
- try:
494
- graph = self._get_graph(model, signature)
495
- latent = graph.run(
496
- x_t=x_t,
497
- text_state_cond=text_state_cond,
498
- text_mask_cond=text_mask_cond,
499
- speaker_state_cond=speaker_state_cond,
500
- speaker_mask_cond=speaker_mask_cond,
501
- text_state_cfg=text_state_cfg,
502
- text_mask_cfg=text_mask_cfg,
503
- speaker_state_cfg=speaker_state_cfg,
504
- speaker_mask_cfg=speaker_mask_cfg,
505
- context_kv_cond=context_kv_cond,
506
- context_kv_cfg=context_kv_cfg,
507
- )
508
- except Exception as exc:
509
- self._graphs.pop(signature, None)
510
- return fallback(f"CUDA Graph capture/replay failed: {exc}")
511
- return RFGraphSampleResult(latent=latent, graph_used=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faster_irodori_tts/runtime.py DELETED
@@ -1,290 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
- import secrets
5
- from collections.abc import Callable
6
-
7
- import torch
8
-
9
- from irodori_tts.codec import unpatchify_latent
10
- from irodori_tts.inference_runtime import (
11
- InferenceRuntime,
12
- RuntimeKey,
13
- SamplingRequest,
14
- SamplingResult,
15
- _measure_end,
16
- _measure_start,
17
- find_flattening_point,
18
- resolve_cfg_scales,
19
- )
20
- from irodori_tts.text_normalization import normalize_text
21
-
22
- from .rf_graph import FasterIrodoriRFSampler
23
-
24
-
25
- class FasterIrodoriTTSRuntime(InferenceRuntime):
26
- """Irodori runtime that uses CUDA Graphs for supported RF sampling requests."""
27
-
28
- def __init__(self, **kwargs) -> None:
29
- super().__init__(**kwargs)
30
- self.rf_sampler = FasterIrodoriRFSampler()
31
-
32
- def synthesize(
33
- self,
34
- req: SamplingRequest,
35
- *,
36
- log_fn: Callable[[str], None] | None = None,
37
- ) -> SamplingResult:
38
- def _log(msg: str) -> None:
39
- if log_fn is not None:
40
- log_fn(msg)
41
-
42
- messages: list[str] = []
43
- _log(
44
- (
45
- "[faster_runtime] start synthesize "
46
- "model_device={} model_precision={} codec_device={} codec_precision={} "
47
- "watermark={} mode={} seconds={} steps={} seed={} candidates={} decode_mode={}"
48
- ).format(
49
- self.key.model_device,
50
- self.key.model_precision,
51
- self.key.codec_device,
52
- self.key.codec_precision,
53
- self.codec.enable_watermark,
54
- req.cfg_guidance_mode,
55
- req.seconds,
56
- req.num_steps,
57
- "random" if req.seed is None else int(req.seed),
58
- req.num_candidates,
59
- req.decode_mode,
60
- )
61
- )
62
-
63
- if req.seconds <= 0:
64
- raise ValueError(f"seconds must be > 0, got {req.seconds}")
65
- num_candidates = int(req.num_candidates)
66
- if num_candidates <= 0:
67
- raise ValueError(f"num_candidates must be > 0, got {num_candidates}")
68
- decode_mode = str(req.decode_mode).strip().lower()
69
- if decode_mode not in {"sequential", "batch"}:
70
- raise ValueError(
71
- f"Unsupported decode_mode={req.decode_mode!r}. Expected one of: sequential, batch."
72
- )
73
-
74
- raw_text = str(req.text)
75
- normalized_text = normalize_text(raw_text).strip()
76
- if normalized_text == "":
77
- raise ValueError("text became empty after normalization.")
78
-
79
- text_max_len = (
80
- self.default_text_max_len if req.max_text_len is None else int(req.max_text_len)
81
- )
82
- if text_max_len <= 0:
83
- raise ValueError(f"max_text_len must be > 0, got {text_max_len}")
84
-
85
- truncation_factor = None if req.truncation_factor is None else float(req.truncation_factor)
86
- rescale_k = None if req.rescale_k is None else float(req.rescale_k)
87
- rescale_sigma = None if req.rescale_sigma is None else float(req.rescale_sigma)
88
- if truncation_factor is not None and truncation_factor <= 0:
89
- raise ValueError(f"truncation_factor must be > 0, got {truncation_factor}")
90
- if (rescale_k is None) != (rescale_sigma is None):
91
- raise ValueError("rescale_k and rescale_sigma must be set together.")
92
- if rescale_k is not None and rescale_k <= 0:
93
- raise ValueError(f"rescale_k must be > 0, got {rescale_k}")
94
- if rescale_sigma is not None and rescale_sigma <= 0:
95
- raise ValueError(f"rescale_sigma must be > 0, got {rescale_sigma}")
96
-
97
- speaker_kv_scale = None if req.speaker_kv_scale is None else float(req.speaker_kv_scale)
98
- speaker_kv_min_t = None
99
- speaker_kv_max_layers = (
100
- None if req.speaker_kv_max_layers is None else int(req.speaker_kv_max_layers)
101
- )
102
- if speaker_kv_scale is not None:
103
- if speaker_kv_scale <= 0:
104
- raise ValueError(f"speaker_kv_scale must be > 0, got {speaker_kv_scale}")
105
- speaker_kv_min_t = 0.9 if req.speaker_kv_min_t is None else float(req.speaker_kv_min_t)
106
- if not (0.0 <= speaker_kv_min_t <= 1.0):
107
- raise ValueError(f"speaker_kv_min_t must be in [0, 1], got {speaker_kv_min_t}")
108
- if speaker_kv_max_layers is not None and speaker_kv_max_layers < 0:
109
- raise ValueError(
110
- f"speaker_kv_max_layers must be >= 0 when specified, got {speaker_kv_max_layers}"
111
- )
112
-
113
- cfg_mode = str(req.cfg_guidance_mode).strip().lower()
114
- if cfg_mode not in {"independent", "joint", "alternating"}:
115
- raise ValueError(
116
- f"Unsupported cfg_guidance_mode={req.cfg_guidance_mode!r}. "
117
- "Expected one of: independent, joint, alternating."
118
- )
119
-
120
- cfg_scale_text, cfg_scale_speaker, scale_messages = resolve_cfg_scales(
121
- cfg_guidance_mode=cfg_mode,
122
- cfg_scale_text=req.cfg_scale_text,
123
- cfg_scale_speaker=req.cfg_scale_speaker,
124
- cfg_scale=req.cfg_scale,
125
- )
126
- messages.extend(scale_messages)
127
- for msg in scale_messages:
128
- _log(msg)
129
-
130
- stage_timings: list[tuple[str, float]] = []
131
- if req.seed is None:
132
- used_seed = int(secrets.randbits(63))
133
- msg = f"info: seed not specified; using random seed {used_seed}."
134
- messages.append(msg)
135
- _log(msg)
136
- else:
137
- used_seed = int(req.seed)
138
- _log(f"[faster_runtime] using seed: {used_seed}")
139
- post_load_t0 = _measure_start(self.model_device, self.codec_device)
140
-
141
- with self._infer_lock, torch.inference_mode():
142
- t0 = _measure_start(self.model_device)
143
- text_ids, text_mask = self.tokenizer.batch_encode(
144
- [normalized_text] * num_candidates,
145
- max_length=text_max_len,
146
- )
147
- stage_sec = _measure_end(self.model_device, t0)
148
- stage_timings.append(("tokenize_text", stage_sec))
149
- _log(f"[faster_runtime] tokenize_text: {stage_sec * 1000.0:.1f} ms")
150
- text_ids = text_ids.to(self.model_device)
151
- text_mask = text_mask.to(self.model_device)
152
-
153
- target_samples = int(float(req.seconds) * self.codec.sample_rate)
154
- latent_steps = math.ceil(target_samples / int(self.codec.model.hop_length))
155
- patched_steps = math.ceil(latent_steps / self.model_cfg.latent_patch_size)
156
-
157
- if isinstance(self.train_cfg, dict):
158
- fixed_steps = self.train_cfg.get("fixed_target_latent_steps")
159
- if isinstance(fixed_steps, int) and fixed_steps > 0 and latent_steps > fixed_steps:
160
- msg = (
161
- f"warning: requested latent length ({latent_steps}) exceeds fixed_target_latent_steps ({fixed_steps}) "
162
- "used in training. Long-tail stability may degrade."
163
- )
164
- messages.append(msg)
165
- _log(msg)
166
-
167
- t0 = _measure_start(self.model_device, self.codec_device)
168
- msg_count_before_ref = len(messages)
169
- ref_latent, ref_mask = self._load_reference_latent(
170
- req=req,
171
- batch_size=num_candidates,
172
- messages=messages,
173
- )
174
- stage_sec = _measure_end(self.model_device, t0, self.codec_device)
175
- stage_timings.append(("prepare_reference", stage_sec))
176
- for msg in messages[msg_count_before_ref:]:
177
- _log(msg)
178
- _log(f"[faster_runtime] prepare_reference: {stage_sec * 1000.0:.1f} ms")
179
-
180
- t0 = _measure_start(self.model_device)
181
- sample_result = self.rf_sampler.sample(
182
- model=self.model,
183
- text_input_ids=text_ids,
184
- text_mask=text_mask,
185
- ref_latent=ref_latent,
186
- ref_mask=ref_mask,
187
- sequence_length=patched_steps,
188
- num_steps=int(req.num_steps),
189
- cfg_scale_text=cfg_scale_text,
190
- cfg_scale_speaker=cfg_scale_speaker,
191
- cfg_guidance_mode=cfg_mode,
192
- cfg_min_t=float(req.cfg_min_t),
193
- cfg_max_t=float(req.cfg_max_t),
194
- seed=used_seed,
195
- truncation_factor=truncation_factor,
196
- rescale_k=rescale_k,
197
- rescale_sigma=rescale_sigma,
198
- use_context_kv_cache=bool(req.context_kv_cache),
199
- speaker_kv_scale=speaker_kv_scale,
200
- speaker_kv_max_layers=speaker_kv_max_layers,
201
- speaker_kv_min_t=speaker_kv_min_t,
202
- )
203
- z_patched = sample_result.latent
204
- stage_sec = _measure_end(self.model_device, t0)
205
- stage_timings.append(("sample_rf", stage_sec))
206
- if sample_result.graph_used:
207
- _log(f"[faster_runtime] sample_rf (cuda_graph): {stage_sec * 1000.0:.1f} ms")
208
- else:
209
- msg = f"info: RF CUDA Graph fallback: {sample_result.fallback_reason}"
210
- messages.append(msg)
211
- _log(msg)
212
- _log(f"[faster_runtime] sample_rf (fallback): {stage_sec * 1000.0:.1f} ms")
213
-
214
- t0 = _measure_start(self.model_device)
215
- z = unpatchify_latent(
216
- z_patched,
217
- patch_size=self.model_cfg.latent_patch_size,
218
- latent_dim=self.model_cfg.latent_dim,
219
- )
220
- stage_sec = _measure_end(self.model_device, t0)
221
- stage_timings.append(("unpatchify_latent", stage_sec))
222
- _log(f"[faster_runtime] unpatchify_latent: {stage_sec * 1000.0:.1f} ms")
223
- z = z[:, :latent_steps]
224
-
225
- t0 = _measure_start(self.model_device, self.codec_device)
226
- trimmed_audios: list[torch.Tensor] = []
227
- if decode_mode == "batch":
228
- audio_batch = self.codec.decode_latent(z).cpu()
229
- for i in range(num_candidates):
230
- audio_i = audio_batch[i]
231
- max_samples = target_samples
232
- if bool(req.trim_tail):
233
- flattening_point = find_flattening_point(
234
- z[i],
235
- window_size=max(1, int(req.tail_window_size)),
236
- std_threshold=float(req.tail_std_threshold),
237
- mean_threshold=float(req.tail_mean_threshold),
238
- )
239
- flattening_samples = int(
240
- flattening_point * int(self.codec.model.hop_length)
241
- )
242
- if flattening_samples > 0:
243
- max_samples = min(max_samples, flattening_samples)
244
- trimmed_audios.append(audio_i[:, :max_samples])
245
- else:
246
- for i in range(num_candidates):
247
- audio_i = self.codec.decode_latent(z[i : i + 1]).cpu()[0]
248
- max_samples = target_samples
249
- if bool(req.trim_tail):
250
- flattening_point = find_flattening_point(
251
- z[i],
252
- window_size=max(1, int(req.tail_window_size)),
253
- std_threshold=float(req.tail_std_threshold),
254
- mean_threshold=float(req.tail_mean_threshold),
255
- )
256
- flattening_samples = int(
257
- flattening_point * int(self.codec.model.hop_length)
258
- )
259
- if flattening_samples > 0:
260
- max_samples = min(max_samples, flattening_samples)
261
- trimmed_audios.append(audio_i[:, :max_samples])
262
- stage_sec = _measure_end(self.model_device, t0, self.codec_device)
263
- stage_timings.append(("decode_latent", stage_sec))
264
- _log(f"[faster_runtime] decode_latent ({decode_mode}): {stage_sec * 1000.0:.1f} ms")
265
-
266
- total_to_decode = _measure_end(self.model_device, post_load_t0, self.codec_device)
267
- _log(f"[faster_runtime] total_to_decode: {total_to_decode:.3f} s")
268
-
269
- _log("[faster_runtime] done synthesize")
270
- return SamplingResult(
271
- audio=trimmed_audios[0],
272
- audios=trimmed_audios,
273
- sample_rate=int(self.codec.sample_rate),
274
- stage_timings=stage_timings,
275
- total_to_decode=total_to_decode,
276
- used_seed=used_seed,
277
- messages=messages,
278
- )
279
-
280
-
281
- # Backward-friendly alias for callers that prefer an InferenceRuntime-like name.
282
- FasterInferenceRuntime = FasterIrodoriTTSRuntime
283
-
284
- __all__ = [
285
- "FasterIrodoriTTSRuntime",
286
- "FasterInferenceRuntime",
287
- "RuntimeKey",
288
- "SamplingRequest",
289
- "SamplingResult",
290
- ]