""" Vocence PromptTTS engine: MOSS-TTS (HF snapshot, ``trust_remote_code``). Aligned with chute validator ``POST .../speak`` JSON ``{"text": "", "instruction": "gender: male | pitch: mid | ..."}`` and the same trait split as ``_prompt_to_speak_payload()`` in ``generation.py``: first ``|`` separates transcription from traits when using a combined payload string. Loads weights from ``path_hf_repo`` if given; otherwise from the directory that contains this ``miner.py`` file (the MOSS-TTS bundle root). Optional ``vocence_config.yaml`` adjusts limits, generation kwargs, and ``runtime.strict_instruction_traits``. """ from __future__ import annotations import importlib.util import inspect import os import tempfile import urllib.request from pathlib import Path from typing import Any import numpy as np import soundfile as sf import torch import torchaudio from transformers import AutoModel, AutoProcessor def _patch_torchaudio_load_bypass_torchcodec() -> None: """MOSS ``encode_audios_from_path`` may use ``torchaudio.load``; use soundfile.""" def load( filepath, frame_offset=0, num_frames=-1, normalize=True, channels_first=True, **kwargs, ): path_str = os.fspath(filepath) if path_str.startswith(("http://", "https://")): suffix = Path(path_str.split("?", 1)[0]).suffix or ".wav" tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) tmp_path = tmp.name tmp.close() try: urllib.request.urlretrieve(path_str, tmp_path) data, sr = sf.read(tmp_path, always_2d=True, dtype="float32") finally: Path(tmp_path).unlink(missing_ok=True) else: data, sr = sf.read(path_str, always_2d=True, dtype="float32") wav = torch.from_numpy(data.copy()).transpose(0, 1).contiguous() if frame_offset > 0: wav = wav[:, frame_offset:] if num_frames is not None and num_frames >= 0: wav = wav[:, :num_frames] if not channels_first: wav = wav.transpose(0, 1) return wav, int(sr) torchaudio.load = load # type: ignore[method-assign] _patch_torchaudio_load_bypass_torchcodec() torch.backends.cuda.enable_cudnn_sdp(False) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) # Enumerated trait values (validator); other keys pass through unchanged. _TRAIT_ENUMS: dict[str, frozenset[str]] = { "gender": frozenset({"male", "female", "neutral"}), "pitch": frozenset({"low", "mid", "high"}), "speed": frozenset({"slow", "normal", "fast"}), "age_group": frozenset({"child", "young_adult", "adult", "senior"}), } def split_speak_payload(combined: str) -> tuple[str, str | None]: """ Split like ``_prompt_to_speak_payload``: first ``|`` → (text, instruction_or_none). ``"hello world | gender: male | pitch: mid"`` → ``("hello world", "gender: male | pitch: mid")``. """ s = combined.strip() if "|" not in s: return s, None left, right = s.split("|", 1) inst = right.strip() return left.strip(), inst if inst else None def validate_instruction_traits(instruction: str, *, strict: bool = False) -> list[str]: """ Parse ``key: value`` segments (pipe-delimited); warn or raise on bad enum values. """ warnings_out: list[str] = [] for raw in instruction.split("|"): part = raw.strip() if not part or ":" not in part: continue key, _, val = part.partition(":") key_norm = key.strip().lower().replace(" ", "_") val_norm = val.strip().lower().replace(" ", "_") allowed = _TRAIT_ENUMS.get(key_norm) if allowed is not None and val_norm not in allowed: msg = f"trait {key_norm!r} value {val_norm!r} not in allowed set {sorted(allowed)}" if strict: raise ValueError(msg) warnings_out.append(msg) return warnings_out def _resolve_attn_implementation(device: str, dtype: torch.dtype) -> str: if ( device == "cuda" and importlib.util.find_spec("flash_attn") is not None and dtype in (torch.float16, torch.bfloat16) ): major, _ = torch.cuda.get_device_capability() if major >= 8: return "flash_attention_2" if device == "cuda": return "sdpa" return "eager" def _load_yaml_config(repo: Path) -> dict[str, Any]: path = repo / "vocence_config.yaml" if not path.is_file(): return {} try: import yaml with path.open(encoding="utf-8") as f: data = yaml.safe_load(f) return data if isinstance(data, dict) else {} except Exception: return {} def default_hf_repo_root() -> Path: """Directory containing ``miner.py`` (expected HF snapshot layout: config, weights, …).""" return Path(__file__).resolve().parent class Miner: """MOSS-TTS: ``generate_wav(instruction, text)`` → mono float32 PCM + sample rate.""" def __init__(self, path_hf_repo: Path | str | os.PathLike[str] | None = None) -> None: self._repo_path = ( Path(path_hf_repo).resolve() if path_hf_repo is not None else default_hf_repo_root() ) self._cfg = _load_yaml_config(self._repo_path) self._runtime = self._cfg.get("runtime") or {} self._device = "cuda" if torch.cuda.is_available() else "cpu" dtype_name = self._runtime.get("dtype", "") if dtype_name == "float32": self._dtype = torch.float32 else: self._dtype = torch.bfloat16 if self._device == "cuda" else torch.float32 self._strict_traits = bool( self._runtime.get("strict_instruction_traits") or (self._cfg.get("generation") or {}).get("strict_traits") ) attn = _resolve_attn_implementation(self._device, self._dtype) repo = str(self._repo_path) self._processor = AutoProcessor.from_pretrained( repo, trust_remote_code=True, ) self._processor.audio_tokenizer = self._processor.audio_tokenizer.to(self._device) self._model = AutoModel.from_pretrained( repo, trust_remote_code=True, attn_implementation=attn, torch_dtype=self._dtype, ).to(self._device) self._model.eval() self._gen = self._cfg.get("generation") or {} self._limits = self._cfg.get("limits") or {} def _truncate(self, s: str, key: str, default_max: int) -> str: cap = int(self._limits.get(key, default_max)) s = s.strip() if len(s) > cap: return s[:cap] return s def _merge_instruction_text(self, instruction: str, text: str) -> str: tpl = self._gen.get("instruction_text_template") if tpl and isinstance(tpl, str): return tpl.format(instruction=instruction, text=text) return f"{instruction}\n{text}" def _build_user_message(self, instruction: str | None, text: str) -> Any: """Match MOSS runner: ``build_user_message(text=..., instruction=...)`` when supported.""" proc = self._processor params = inspect.signature(proc.build_user_message).parameters names = set(params) text_s = text.strip() inst_s = instruction.strip() if instruction else "" inst_arg: str | None = inst_s if inst_s else None kw: dict[str, Any] = {} tok = self._gen.get("tokens") if tok is not None and "tokens" in names: kw["tokens"] = int(tok) if "instruction" in names and "text" in names: kw["text"] = text_s kw["instruction"] = inst_arg elif "prompt" in names and "text" in names: kw["prompt"] = inst_s or "" kw["text"] = text_s elif "description" in names and "text" in names: kw["description"] = inst_s or "" kw["text"] = text_s else: if inst_s: kw["text"] = self._merge_instruction_text(inst_s, text_s) else: kw["text"] = text_s ref_cfg = self._cfg.get("reference") or {} ref_name = ref_cfg.get("ref_audio") if ref_name and "reference" in names: ref_path = self._repo_path / str(ref_name) if ref_path.is_file(): kw["reference"] = [str(ref_path)] return proc.build_user_message(**{k: v for k, v in kw.items() if k in names}) def _generate_kwargs(self) -> dict[str, Any]: out: dict[str, Any] = { "max_new_tokens": int(self._gen.get("max_new_tokens", 4096)), } if "audio_temperature" in self._gen: out["audio_temperature"] = float(self._gen["audio_temperature"]) if "audio_top_p" in self._gen: out["audio_top_p"] = float(self._gen["audio_top_p"]) if "audio_top_k" in self._gen: out["audio_top_k"] = int(self._gen["audio_top_k"]) return out def warmup(self) -> None: _ = self.generate_wav( instruction="gender: male | pitch: mid | speed: normal | age_group: adult", text="Warmup.", ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: if not (text or "").strip(): raise ValueError("text must be non-empty") instruction = self._truncate(instruction, "max_instruction_chars", 600) text = self._truncate(text, "max_text_chars", 2000) if instruction: for w in validate_instruction_traits(instruction, strict=self._strict_traits): print(f"[warn] instruction: {w}") user = self._build_user_message(instruction if instruction.strip() else None, text) conversations = [[user]] batch = self._processor(conversations, mode="generation") input_ids = batch["input_ids"].to(self._device) attention_mask = batch["attention_mask"].to(self._device) gkw = self._generate_kwargs() with torch.no_grad(): outputs = self._model.generate( input_ids=input_ids, attention_mask=attention_mask, **gkw, ) for message in self._processor.decode(outputs): audio = message.audio_codes_list[0] wav = audio.detach().cpu().float().reshape(-1).numpy().astype(np.float32, copy=False) sr = int(self._processor.model_config.sampling_rate) return wav, sr raise RuntimeError("No audio decoded from MOSS-TTS outputs.")