import torch import torchaudio import soundfile as sf from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from snac import SNAC CODE_TO_NAME = { "hi": "hindi", "bn": "bengali", "te": "telugu", "kn": "kannada", "bhj": "bhojpuri", "ch": "chhattisgarhi", "en": "english", "mth": "maithili", "mar": "marathi", "guj": "gujarati", "mgh": "magahi" } class SvaraEngine: def __init__(self, base_model_id, adapter_map, device: str = "cpu"): # Decide device if torch.cuda.is_available() and device != "cpu": self.device = "cuda" else: self.device = "cpu" print(f"🚀 Initializing SvaraEngine on {self.device} ...") # Tokenizer and base model self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch_dtype, ).to(self.device) # SNAC codec print("🔊 Loading SNAC codec ...") self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") self.snac_model = self.snac_model.eval().to(self.device) # Load LoRA adapters if adapter_map: print("🔗 Loading adapters ...") first_lang = list(adapter_map.keys())[0] self.model = PeftModel.from_pretrained( self.base_model, adapter_map[first_lang], adapter_name=first_lang, ) for lang, path in adapter_map.items(): if lang == first_lang: continue self.model.load_adapter(path, adapter_name=lang) print("✅ Loaded adapters:", list(adapter_map.keys())) else: self.model = self.base_model print("⚠️ No adapters provided. Running base model only.") # ---------------------------------------------------------------------- # Reference audio → SNAC codes → LLM token space # ---------------------------------------------------------------------- def preprocess_audio_prompt(self, audio_path: str): """ Load reference WAV, resample to 24kHz mono, SNAC-encode, interleave with offsets, and add AUDIO_CODE_BASE_OFFSET. """ try: wav_np, sr = sf.read(audio_path) print(f"[DEBUG] sf.read -> shape={getattr(wav_np, 'shape', None)}, sr={sr}") if wav_np.size == 0: print("[DEBUG] wav_np.size == 0 → returning []") return [] # To (C, T) wav = torch.tensor(wav_np, dtype=torch.float32) if wav.ndim == 1: wav = wav.unsqueeze(0) # (1, T) else: wav = wav.T # (C, T) wav = wav.to(self.device) # Resample to 24k if sr != 24000: resampler = torchaudio.transforms.Resample(sr, 24000).to(self.device) wav = resampler(wav) # Force mono if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) # SNAC encode with torch.no_grad(): codes = self.snac_model.encode(wav) # list/tuple of 3 tensors print(f"[DEBUG] SNAC encode ok.") # Interleave + offsets interleaved = self.interleave_codes(codes) print(f"[DEBUG] Interleaved tokens: {len(interleaved)}") # Add global base offset to move into LLM acoustic code range AUDIO_CODE_BASE_OFFSET = 128266 audio_prompt_tokens = [t + AUDIO_CODE_BASE_OFFSET for t in interleaved] return audio_prompt_tokens except Exception as e: print(f"⚠️ preprocess_audio_prompt error: {e}") return [] # ---------------------------------------------------------------------- # Interleave SNAC codes into 7-token pattern (with layer offsets) # ---------------------------------------------------------------------- def interleave_codes(self, codes): """ Interleave SNAC output into 7-token pattern: [L0, L1, L2, L2, L1, L2, L2] with distinct codebook offsets. """ l0 = codes[0].flatten().tolist() l1 = codes[1].flatten().tolist() l2 = codes[2].flatten().tolist() layer_offsets = [ 0, # pos 0 (L0) 4096, # pos 1 (L1) 8192, # pos 2 (L2) 12288, # pos 3 (L2) 16384, # pos 4 (L1) 20480, # pos 5 (L2) 24576, # pos 6 (L2) ] interleaved = [] for i in range(len(l0)): try: frame_raw = [ l0[i], l1[2 * i], l2[4 * i], l2[4 * i + 1], l1[2 * i + 1], l2[4 * i + 2], l2[4 * i + 3], ] frame_shifted = [ c + off for c, off in zip(frame_raw, layer_offsets) ] interleaved.extend(frame_shifted) except IndexError: # Reached end where one of the lists runs out break return interleaved # ---------------------------------------------------------------------- # Reverse interleave → SNAC decode (matching official notebook) # ---------------------------------------------------------------------- def redistribute_codes(self, code_list): """ De-interleave SNAC tokens into 3 hierarchical levels, using official llm_codebook_offsets. """ codes_lvl = [[] for _ in range(3)] llm_codebook_offsets = [i * 4096 for i in range(7)] for i in range(0, len(code_list), 7): if i + 6 >= len(code_list): break # Level 0 (coarse) codes_lvl[0].append(code_list[i] - llm_codebook_offsets[0]) # Level 1 (medium) codes_lvl[1].append(code_list[i + 1] - llm_codebook_offsets[1]) codes_lvl[1].append(code_list[i + 4] - llm_codebook_offsets[4]) # Level 2 (fine) codes_lvl[2].append(code_list[i + 2] - llm_codebook_offsets[2]) codes_lvl[2].append(code_list[i + 3] - llm_codebook_offsets[3]) codes_lvl[2].append(code_list[i + 5] - llm_codebook_offsets[5]) codes_lvl[2].append(code_list[i + 6] - llm_codebook_offsets[6]) hierarchical_codes = [] for lvl_codes in codes_lvl: if not lvl_codes: return None tensor = torch.tensor( lvl_codes, dtype=torch.long, device=self.device ).unsqueeze(0) hierarchical_codes.append(tensor) with torch.no_grad(): audio_hat = self.snac_model.decode(hierarchical_codes) return audio_hat.cpu().squeeze().numpy() # ---------------------------------------------------------------------- # Main synthesis entry point # ---------------------------------------------------------------------- def synthesize(self, text: str, language_code: str, reference_audio_path: str): # Switch adapter if available if hasattr(self.model, "set_adapter") and hasattr(self.model, "peft_config"): if language_code in self.model.peft_config: self.model.set_adapter(language_code) else: print( f"⚠️ Adapter '{language_code}' not found in peft_config; using default." ) # Build prompt as per official Svara notebook full_lang_name = CODE_TO_NAME.get(language_code, "Hindi") voice_label = f"{full_lang_name} (Female)" formatted_text = f"<|audio|> {voice_label}: {text}<|eot_id|>" prompt = "" + formatted_text + "" prompt_ids = self.tokenizer( prompt, return_tensors="pt" ).input_ids.to(self.device) start_token = torch.tensor([[128259]], dtype=torch.int64).to(self.device) end_tokens = torch.tensor( [[128009, 128260, 128261, 128257]], dtype=torch.int64, ).to(self.device) # Voice cloning → reference audio tokens audio_prompt_ids = self.preprocess_audio_prompt(reference_audio_path) if audio_prompt_ids: print(f"🎤 Cloning with {len(audio_prompt_ids)} tokens...") audio_tensor = torch.tensor( [audio_prompt_ids], device=self.device, dtype=torch.long ) input_ids = torch.cat( [start_token, audio_tensor, prompt_ids, end_tokens], dim=1, ) else: print("⚠️ No valid reference audio. Using default voice.") input_ids = torch.cat( [start_token, prompt_ids, end_tokens], dim=1, ) waveform = self._generate_waveform(input_ids) # Fallback: if cloning fails, retry default if waveform is None and audio_prompt_ids: print("⚠️ Cloning generation failed. Retrying without cloning...") input_ids = torch.cat( [start_token, prompt_ids, end_tokens], dim=1, ) waveform = self._generate_waveform(input_ids) return waveform, 24000 # ---------------------------------------------------------------------- # LM generate → parse acoustic tokens → SNAC decode # ---------------------------------------------------------------------- def _generate_waveform(self, input_ids): try: with torch.no_grad(): generated_ids = self.model.generate( input_ids=input_ids, max_new_tokens=800, do_sample=True, temperature=0.7, top_p=0.95, repetition_penalty=1.2, num_return_sequences=1, eos_token_id=128258, attention_mask=torch.ones_like(input_ids), ) START_OF_SPEECH_TOKEN = 128257 END_OF_SPEECH_TOKEN = 128258 AUDIO_CODE_BASE_OFFSET = 128266 AUDIO_CODE_MAX = AUDIO_CODE_BASE_OFFSET + (7 * 4096) - 1 PAD_TOKEN = 128263 row = generated_ids[0] # Find last START_OF_SPEECH token token_indices = (row == START_OF_SPEECH_TOKEN).nonzero(as_tuple=True)[0] if len(token_indices) == 0: print("⚠️ No START_OF_SPEECH_TOKEN in generated output.") return None start_idx = token_indices[-1].item() + 1 audio_tokens = row[start_idx:] # Remove END_OF_SPEECH and PAD audio_tokens = audio_tokens[audio_tokens != END_OF_SPEECH_TOKEN] audio_tokens = audio_tokens[audio_tokens != PAD_TOKEN] # Keep only acoustic code tokens in range valid_mask = (audio_tokens >= AUDIO_CODE_BASE_OFFSET) & ( audio_tokens <= AUDIO_CODE_MAX ) audio_tokens = audio_tokens[valid_mask] snac_tokens = [t - AUDIO_CODE_BASE_OFFSET for t in audio_tokens.tolist()] # Trim to multiple of 7 new_length = (len(snac_tokens) // 7) * 7 snac_tokens = snac_tokens[:new_length] if not snac_tokens: print("⚠️ No valid acoustic tokens after filtering.") return None return self.redistribute_codes(snac_tokens) except Exception as e: print(f"Generation Error: {e}") return None