submission_svara_2 / inference.py
fahin-one's picture
Update inference.py
d619cb8 verified
import torch
import torchaudio
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from snac import SNAC
CODE_TO_NAME = {
"hi": "hindi",
"bn": "bengali",
"te": "telugu",
"kn": "kannada",
"bhj": "bhojpuri",
"ch": "chhattisgarhi",
"en": "english",
"mth": "maithili",
"mar": "marathi",
"guj": "gujarati",
"mgh": "magahi"
}
class SvaraEngine:
def __init__(self, base_model_id, adapter_map, device: str = "cpu"):
# Decide device
if torch.cuda.is_available() and device != "cpu":
self.device = "cuda"
else:
self.device = "cpu"
print(f"🚀 Initializing SvaraEngine on {self.device} ...")
# Tokenizer and base model
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
).to(self.device)
# SNAC codec
print("🔊 Loading SNAC codec ...")
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
self.snac_model = self.snac_model.eval().to(self.device)
# Load LoRA adapters
if adapter_map:
print("🔗 Loading adapters ...")
first_lang = list(adapter_map.keys())[0]
self.model = PeftModel.from_pretrained(
self.base_model,
adapter_map[first_lang],
adapter_name=first_lang,
)
for lang, path in adapter_map.items():
if lang == first_lang:
continue
self.model.load_adapter(path, adapter_name=lang)
print("✅ Loaded adapters:", list(adapter_map.keys()))
else:
self.model = self.base_model
print("⚠️ No adapters provided. Running base model only.")
# ----------------------------------------------------------------------
# Reference audio → SNAC codes → LLM token space
# ----------------------------------------------------------------------
def preprocess_audio_prompt(self, audio_path: str):
"""
Load reference WAV, resample to 24kHz mono, SNAC-encode,
interleave with offsets, and add AUDIO_CODE_BASE_OFFSET.
"""
try:
wav_np, sr = sf.read(audio_path)
print(f"[DEBUG] sf.read -> shape={getattr(wav_np, 'shape', None)}, sr={sr}")
if wav_np.size == 0:
print("[DEBUG] wav_np.size == 0 → returning []")
return []
# To (C, T)
wav = torch.tensor(wav_np, dtype=torch.float32)
if wav.ndim == 1:
wav = wav.unsqueeze(0) # (1, T)
else:
wav = wav.T # (C, T)
wav = wav.to(self.device)
# Resample to 24k
if sr != 24000:
resampler = torchaudio.transforms.Resample(sr, 24000).to(self.device)
wav = resampler(wav)
# Force mono
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# SNAC encode
with torch.no_grad():
codes = self.snac_model.encode(wav) # list/tuple of 3 tensors
print(f"[DEBUG] SNAC encode ok.")
# Interleave + offsets
interleaved = self.interleave_codes(codes)
print(f"[DEBUG] Interleaved tokens: {len(interleaved)}")
# Add global base offset to move into LLM acoustic code range
AUDIO_CODE_BASE_OFFSET = 128266
audio_prompt_tokens = [t + AUDIO_CODE_BASE_OFFSET for t in interleaved]
return audio_prompt_tokens
except Exception as e:
print(f"⚠️ preprocess_audio_prompt error: {e}")
return []
# ----------------------------------------------------------------------
# Interleave SNAC codes into 7-token pattern (with layer offsets)
# ----------------------------------------------------------------------
def interleave_codes(self, codes):
"""
Interleave SNAC output into 7-token pattern:
[L0, L1, L2, L2, L1, L2, L2] with distinct codebook offsets.
"""
l0 = codes[0].flatten().tolist()
l1 = codes[1].flatten().tolist()
l2 = codes[2].flatten().tolist()
layer_offsets = [
0, # pos 0 (L0)
4096, # pos 1 (L1)
8192, # pos 2 (L2)
12288, # pos 3 (L2)
16384, # pos 4 (L1)
20480, # pos 5 (L2)
24576, # pos 6 (L2)
]
interleaved = []
for i in range(len(l0)):
try:
frame_raw = [
l0[i],
l1[2 * i],
l2[4 * i],
l2[4 * i + 1],
l1[2 * i + 1],
l2[4 * i + 2],
l2[4 * i + 3],
]
frame_shifted = [
c + off for c, off in zip(frame_raw, layer_offsets)
]
interleaved.extend(frame_shifted)
except IndexError:
# Reached end where one of the lists runs out
break
return interleaved
# ----------------------------------------------------------------------
# Reverse interleave → SNAC decode (matching official notebook)
# ----------------------------------------------------------------------
def redistribute_codes(self, code_list):
"""
De-interleave SNAC tokens into 3 hierarchical levels,
using official llm_codebook_offsets.
"""
codes_lvl = [[] for _ in range(3)]
llm_codebook_offsets = [i * 4096 for i in range(7)]
for i in range(0, len(code_list), 7):
if i + 6 >= len(code_list):
break
# Level 0 (coarse)
codes_lvl[0].append(code_list[i] - llm_codebook_offsets[0])
# Level 1 (medium)
codes_lvl[1].append(code_list[i + 1] - llm_codebook_offsets[1])
codes_lvl[1].append(code_list[i + 4] - llm_codebook_offsets[4])
# Level 2 (fine)
codes_lvl[2].append(code_list[i + 2] - llm_codebook_offsets[2])
codes_lvl[2].append(code_list[i + 3] - llm_codebook_offsets[3])
codes_lvl[2].append(code_list[i + 5] - llm_codebook_offsets[5])
codes_lvl[2].append(code_list[i + 6] - llm_codebook_offsets[6])
hierarchical_codes = []
for lvl_codes in codes_lvl:
if not lvl_codes:
return None
tensor = torch.tensor(
lvl_codes, dtype=torch.long, device=self.device
).unsqueeze(0)
hierarchical_codes.append(tensor)
with torch.no_grad():
audio_hat = self.snac_model.decode(hierarchical_codes)
return audio_hat.cpu().squeeze().numpy()
# ----------------------------------------------------------------------
# Main synthesis entry point
# ----------------------------------------------------------------------
def synthesize(self, text: str, language_code: str, reference_audio_path: str):
# Switch adapter if available
if hasattr(self.model, "set_adapter") and hasattr(self.model, "peft_config"):
if language_code in self.model.peft_config:
self.model.set_adapter(language_code)
else:
print(
f"⚠️ Adapter '{language_code}' not found in peft_config; using default."
)
# Build prompt as per official Svara notebook
full_lang_name = CODE_TO_NAME.get(language_code, "Hindi")
voice_label = f"{full_lang_name} (Female)"
formatted_text = f"<|audio|> {voice_label}: {text}<|eot_id|>"
prompt = "<custom_token_3>" + formatted_text + "<custom_token_4><custom_token_5>"
prompt_ids = self.tokenizer(
prompt, return_tensors="pt"
).input_ids.to(self.device)
start_token = torch.tensor([[128259]], dtype=torch.int64).to(self.device)
end_tokens = torch.tensor(
[[128009, 128260, 128261, 128257]],
dtype=torch.int64,
).to(self.device)
# Voice cloning → reference audio tokens
audio_prompt_ids = self.preprocess_audio_prompt(reference_audio_path)
if audio_prompt_ids:
print(f"🎤 Cloning with {len(audio_prompt_ids)} tokens...")
audio_tensor = torch.tensor(
[audio_prompt_ids], device=self.device, dtype=torch.long
)
input_ids = torch.cat(
[start_token, audio_tensor, prompt_ids, end_tokens],
dim=1,
)
else:
print("⚠️ No valid reference audio. Using default voice.")
input_ids = torch.cat(
[start_token, prompt_ids, end_tokens],
dim=1,
)
waveform = self._generate_waveform(input_ids)
# Fallback: if cloning fails, retry default
if waveform is None and audio_prompt_ids:
print("⚠️ Cloning generation failed. Retrying without cloning...")
input_ids = torch.cat(
[start_token, prompt_ids, end_tokens],
dim=1,
)
waveform = self._generate_waveform(input_ids)
return waveform, 24000
# ----------------------------------------------------------------------
# LM generate → parse acoustic tokens → SNAC decode
# ----------------------------------------------------------------------
def _generate_waveform(self, input_ids):
try:
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=800,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
num_return_sequences=1,
eos_token_id=128258,
attention_mask=torch.ones_like(input_ids),
)
START_OF_SPEECH_TOKEN = 128257
END_OF_SPEECH_TOKEN = 128258
AUDIO_CODE_BASE_OFFSET = 128266
AUDIO_CODE_MAX = AUDIO_CODE_BASE_OFFSET + (7 * 4096) - 1
PAD_TOKEN = 128263
row = generated_ids[0]
# Find last START_OF_SPEECH token
token_indices = (row == START_OF_SPEECH_TOKEN).nonzero(as_tuple=True)[0]
if len(token_indices) == 0:
print("⚠️ No START_OF_SPEECH_TOKEN in generated output.")
return None
start_idx = token_indices[-1].item() + 1
audio_tokens = row[start_idx:]
# Remove END_OF_SPEECH and PAD
audio_tokens = audio_tokens[audio_tokens != END_OF_SPEECH_TOKEN]
audio_tokens = audio_tokens[audio_tokens != PAD_TOKEN]
# Keep only acoustic code tokens in range
valid_mask = (audio_tokens >= AUDIO_CODE_BASE_OFFSET) & (
audio_tokens <= AUDIO_CODE_MAX
)
audio_tokens = audio_tokens[valid_mask]
snac_tokens = [t - AUDIO_CODE_BASE_OFFSET for t in audio_tokens.tolist()]
# Trim to multiple of 7
new_length = (len(snac_tokens) // 7) * 7
snac_tokens = snac_tokens[:new_length]
if not snac_tokens:
print("⚠️ No valid acoustic tokens after filtering.")
return None
return self.redistribute_codes(snac_tokens)
except Exception as e:
print(f"Generation Error: {e}")
return None