Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| import soundfile as sf | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from snac import SNAC | |
| CODE_TO_NAME = { | |
| "hi": "hindi", | |
| "bn": "bengali", | |
| "te": "telugu", | |
| "kn": "kannada", | |
| "bhj": "bhojpuri", | |
| "ch": "chhattisgarhi", | |
| "en": "english", | |
| "mth": "maithili", | |
| "mar": "marathi", | |
| "guj": "gujarati", | |
| "mgh": "magahi" | |
| } | |
| class SvaraEngine: | |
| def __init__(self, base_model_id, adapter_map, device: str = "cpu"): | |
| # Decide device | |
| if torch.cuda.is_available() and device != "cpu": | |
| self.device = "cuda" | |
| else: | |
| self.device = "cpu" | |
| print(f"🚀 Initializing SvaraEngine on {self.device} ...") | |
| # Tokenizer and base model | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch_dtype, | |
| ).to(self.device) | |
| # SNAC codec | |
| print("🔊 Loading SNAC codec ...") | |
| self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") | |
| self.snac_model = self.snac_model.eval().to(self.device) | |
| # Load LoRA adapters | |
| if adapter_map: | |
| print("🔗 Loading adapters ...") | |
| first_lang = list(adapter_map.keys())[0] | |
| self.model = PeftModel.from_pretrained( | |
| self.base_model, | |
| adapter_map[first_lang], | |
| adapter_name=first_lang, | |
| ) | |
| for lang, path in adapter_map.items(): | |
| if lang == first_lang: | |
| continue | |
| self.model.load_adapter(path, adapter_name=lang) | |
| print("✅ Loaded adapters:", list(adapter_map.keys())) | |
| else: | |
| self.model = self.base_model | |
| print("⚠️ No adapters provided. Running base model only.") | |
| # ---------------------------------------------------------------------- | |
| # Reference audio → SNAC codes → LLM token space | |
| # ---------------------------------------------------------------------- | |
| def preprocess_audio_prompt(self, audio_path: str): | |
| """ | |
| Load reference WAV, resample to 24kHz mono, SNAC-encode, | |
| interleave with offsets, and add AUDIO_CODE_BASE_OFFSET. | |
| """ | |
| try: | |
| wav_np, sr = sf.read(audio_path) | |
| print(f"[DEBUG] sf.read -> shape={getattr(wav_np, 'shape', None)}, sr={sr}") | |
| if wav_np.size == 0: | |
| print("[DEBUG] wav_np.size == 0 → returning []") | |
| return [] | |
| # To (C, T) | |
| wav = torch.tensor(wav_np, dtype=torch.float32) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) # (1, T) | |
| else: | |
| wav = wav.T # (C, T) | |
| wav = wav.to(self.device) | |
| # Resample to 24k | |
| if sr != 24000: | |
| resampler = torchaudio.transforms.Resample(sr, 24000).to(self.device) | |
| wav = resampler(wav) | |
| # Force mono | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| # SNAC encode | |
| with torch.no_grad(): | |
| codes = self.snac_model.encode(wav) # list/tuple of 3 tensors | |
| print(f"[DEBUG] SNAC encode ok.") | |
| # Interleave + offsets | |
| interleaved = self.interleave_codes(codes) | |
| print(f"[DEBUG] Interleaved tokens: {len(interleaved)}") | |
| # Add global base offset to move into LLM acoustic code range | |
| AUDIO_CODE_BASE_OFFSET = 128266 | |
| audio_prompt_tokens = [t + AUDIO_CODE_BASE_OFFSET for t in interleaved] | |
| return audio_prompt_tokens | |
| except Exception as e: | |
| print(f"⚠️ preprocess_audio_prompt error: {e}") | |
| return [] | |
| # ---------------------------------------------------------------------- | |
| # Interleave SNAC codes into 7-token pattern (with layer offsets) | |
| # ---------------------------------------------------------------------- | |
| def interleave_codes(self, codes): | |
| """ | |
| Interleave SNAC output into 7-token pattern: | |
| [L0, L1, L2, L2, L1, L2, L2] with distinct codebook offsets. | |
| """ | |
| l0 = codes[0].flatten().tolist() | |
| l1 = codes[1].flatten().tolist() | |
| l2 = codes[2].flatten().tolist() | |
| layer_offsets = [ | |
| 0, # pos 0 (L0) | |
| 4096, # pos 1 (L1) | |
| 8192, # pos 2 (L2) | |
| 12288, # pos 3 (L2) | |
| 16384, # pos 4 (L1) | |
| 20480, # pos 5 (L2) | |
| 24576, # pos 6 (L2) | |
| ] | |
| interleaved = [] | |
| for i in range(len(l0)): | |
| try: | |
| frame_raw = [ | |
| l0[i], | |
| l1[2 * i], | |
| l2[4 * i], | |
| l2[4 * i + 1], | |
| l1[2 * i + 1], | |
| l2[4 * i + 2], | |
| l2[4 * i + 3], | |
| ] | |
| frame_shifted = [ | |
| c + off for c, off in zip(frame_raw, layer_offsets) | |
| ] | |
| interleaved.extend(frame_shifted) | |
| except IndexError: | |
| # Reached end where one of the lists runs out | |
| break | |
| return interleaved | |
| # ---------------------------------------------------------------------- | |
| # Reverse interleave → SNAC decode (matching official notebook) | |
| # ---------------------------------------------------------------------- | |
| def redistribute_codes(self, code_list): | |
| """ | |
| De-interleave SNAC tokens into 3 hierarchical levels, | |
| using official llm_codebook_offsets. | |
| """ | |
| codes_lvl = [[] for _ in range(3)] | |
| llm_codebook_offsets = [i * 4096 for i in range(7)] | |
| for i in range(0, len(code_list), 7): | |
| if i + 6 >= len(code_list): | |
| break | |
| # Level 0 (coarse) | |
| codes_lvl[0].append(code_list[i] - llm_codebook_offsets[0]) | |
| # Level 1 (medium) | |
| codes_lvl[1].append(code_list[i + 1] - llm_codebook_offsets[1]) | |
| codes_lvl[1].append(code_list[i + 4] - llm_codebook_offsets[4]) | |
| # Level 2 (fine) | |
| codes_lvl[2].append(code_list[i + 2] - llm_codebook_offsets[2]) | |
| codes_lvl[2].append(code_list[i + 3] - llm_codebook_offsets[3]) | |
| codes_lvl[2].append(code_list[i + 5] - llm_codebook_offsets[5]) | |
| codes_lvl[2].append(code_list[i + 6] - llm_codebook_offsets[6]) | |
| hierarchical_codes = [] | |
| for lvl_codes in codes_lvl: | |
| if not lvl_codes: | |
| return None | |
| tensor = torch.tensor( | |
| lvl_codes, dtype=torch.long, device=self.device | |
| ).unsqueeze(0) | |
| hierarchical_codes.append(tensor) | |
| with torch.no_grad(): | |
| audio_hat = self.snac_model.decode(hierarchical_codes) | |
| return audio_hat.cpu().squeeze().numpy() | |
| # ---------------------------------------------------------------------- | |
| # Main synthesis entry point | |
| # ---------------------------------------------------------------------- | |
| def synthesize(self, text: str, language_code: str, reference_audio_path: str): | |
| # Switch adapter if available | |
| if hasattr(self.model, "set_adapter") and hasattr(self.model, "peft_config"): | |
| if language_code in self.model.peft_config: | |
| self.model.set_adapter(language_code) | |
| else: | |
| print( | |
| f"⚠️ Adapter '{language_code}' not found in peft_config; using default." | |
| ) | |
| # Build prompt as per official Svara notebook | |
| full_lang_name = CODE_TO_NAME.get(language_code, "Hindi") | |
| voice_label = f"{full_lang_name} (Female)" | |
| formatted_text = f"<|audio|> {voice_label}: {text}<|eot_id|>" | |
| prompt = "<custom_token_3>" + formatted_text + "<custom_token_4><custom_token_5>" | |
| prompt_ids = self.tokenizer( | |
| prompt, return_tensors="pt" | |
| ).input_ids.to(self.device) | |
| start_token = torch.tensor([[128259]], dtype=torch.int64).to(self.device) | |
| end_tokens = torch.tensor( | |
| [[128009, 128260, 128261, 128257]], | |
| dtype=torch.int64, | |
| ).to(self.device) | |
| # Voice cloning → reference audio tokens | |
| audio_prompt_ids = self.preprocess_audio_prompt(reference_audio_path) | |
| if audio_prompt_ids: | |
| print(f"🎤 Cloning with {len(audio_prompt_ids)} tokens...") | |
| audio_tensor = torch.tensor( | |
| [audio_prompt_ids], device=self.device, dtype=torch.long | |
| ) | |
| input_ids = torch.cat( | |
| [start_token, audio_tensor, prompt_ids, end_tokens], | |
| dim=1, | |
| ) | |
| else: | |
| print("⚠️ No valid reference audio. Using default voice.") | |
| input_ids = torch.cat( | |
| [start_token, prompt_ids, end_tokens], | |
| dim=1, | |
| ) | |
| waveform = self._generate_waveform(input_ids) | |
| # Fallback: if cloning fails, retry default | |
| if waveform is None and audio_prompt_ids: | |
| print("⚠️ Cloning generation failed. Retrying without cloning...") | |
| input_ids = torch.cat( | |
| [start_token, prompt_ids, end_tokens], | |
| dim=1, | |
| ) | |
| waveform = self._generate_waveform(input_ids) | |
| return waveform, 24000 | |
| # ---------------------------------------------------------------------- | |
| # LM generate → parse acoustic tokens → SNAC decode | |
| # ---------------------------------------------------------------------- | |
| def _generate_waveform(self, input_ids): | |
| try: | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=800, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| repetition_penalty=1.2, | |
| num_return_sequences=1, | |
| eos_token_id=128258, | |
| attention_mask=torch.ones_like(input_ids), | |
| ) | |
| START_OF_SPEECH_TOKEN = 128257 | |
| END_OF_SPEECH_TOKEN = 128258 | |
| AUDIO_CODE_BASE_OFFSET = 128266 | |
| AUDIO_CODE_MAX = AUDIO_CODE_BASE_OFFSET + (7 * 4096) - 1 | |
| PAD_TOKEN = 128263 | |
| row = generated_ids[0] | |
| # Find last START_OF_SPEECH token | |
| token_indices = (row == START_OF_SPEECH_TOKEN).nonzero(as_tuple=True)[0] | |
| if len(token_indices) == 0: | |
| print("⚠️ No START_OF_SPEECH_TOKEN in generated output.") | |
| return None | |
| start_idx = token_indices[-1].item() + 1 | |
| audio_tokens = row[start_idx:] | |
| # Remove END_OF_SPEECH and PAD | |
| audio_tokens = audio_tokens[audio_tokens != END_OF_SPEECH_TOKEN] | |
| audio_tokens = audio_tokens[audio_tokens != PAD_TOKEN] | |
| # Keep only acoustic code tokens in range | |
| valid_mask = (audio_tokens >= AUDIO_CODE_BASE_OFFSET) & ( | |
| audio_tokens <= AUDIO_CODE_MAX | |
| ) | |
| audio_tokens = audio_tokens[valid_mask] | |
| snac_tokens = [t - AUDIO_CODE_BASE_OFFSET for t in audio_tokens.tolist()] | |
| # Trim to multiple of 7 | |
| new_length = (len(snac_tokens) // 7) * 7 | |
| snac_tokens = snac_tokens[:new_length] | |
| if not snac_tokens: | |
| print("⚠️ No valid acoustic tokens after filtering.") | |
| return None | |
| return self.redistribute_codes(snac_tokens) | |
| except Exception as e: | |
| print(f"Generation Error: {e}") | |
| return None | |