import os import time import librosa import torch import argparse import soundfile as sf import cn2an import requests import re import numpy as np import onnxruntime as ort import axengine as axe from model import SinusoidalPositionEncoder from utils.ax_model_bin import AX_SenseVoiceSmall from utils.ax_vad_bin import AX_Fsmn_vad from utils.vad_utils import merge_vad from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer from libmelotts.python.split_utils import split_sentence from libmelotts.python.text import cleaned_text_to_sequence from libmelotts.python.text.cleaner import clean_text from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP # 配置参数 TTS_MODEL_DIR = "libmelotts/models" TTS_MODEL_FILES = { "g": "g-zh_mix_en.bin", "encoder": "encoder-zh.onnx", "decoder": "decoder-zh.axmodel" } QWEN_API_URL = "" def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst return result def get_text_for_tts_infer(text, language_str, symbol_to_id=None): """音素处理:确保所有数组长度一致""" try: norm_text, phone, tone, word2ph = clean_text(text, language_str) phone_mapping = { 'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '', 'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '', 'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '', } processed_phone = [] processed_tone = [] removed_symbols = set() for p, t in zip(phone, tone): if p in phone_mapping: removed_symbols.add(p) elif p in symbol_to_id: processed_phone.append(p) processed_tone.append(t) else: removed_symbols.add(p) if removed_symbols: print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素") if not processed_phone: print("[警告] 没有有效音素,使用默认中文音素") processed_phone = ['ni', 'hao'] processed_tone = ['1', '3'] word2ph = [1, 1] if len(processed_phone) != len(phone): word2ph = [1] * len(processed_phone) phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id) phone = intersperse(phone, 0) tone = intersperse(tone, 0) language = intersperse(language, 0) phone = np.array(phone, dtype=np.int32) tone = np.array(tone, dtype=np.int32) language = np.array(language, dtype=np.int32) word2ph = np.array(word2ph, dtype=np.int32) * 2 word2ph[0] += 1 return phone, tone, language, norm_text, word2ph except Exception as e: print(f"[错误] 文本处理失败: {e}") import traceback traceback.print_exc() raise e def audio_numpy_concat(segment_data_list, sr, speed=1.): if not segment_data_list: return np.array([], dtype=np.float32) total_len = sum(len(segment) for segment in segment_data_list) pause_samples = int((sr * 0.05) / speed) total_len += pause_samples * (len(segment_data_list) - 1) audio_segments = np.zeros(total_len, dtype=np.float32) current_pos = 0 for i, segment_data in enumerate(segment_data_list): segment_len = len(segment_data) segment_flat = segment_data.reshape(-1) audio_segments[current_pos:current_pos + segment_len] = segment_flat current_pos += segment_len if i < len(segment_data_list) - 1: current_pos += pause_samples return audio_segments def merge_sub_audio(sub_audio_list, pad_size, audio_len): if pad_size > 0: for i in range(len(sub_audio_list) - 1): sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] sub_audio_list[i][-pad_size:] /= 2 if i > 0: sub_audio_list[i] = sub_audio_list[i][pad_size:] sub_audio = np.concatenate(sub_audio_list, axis=-1) return sub_audio[:audio_len] def calc_word2pronoun(word2ph, pronoun_lens): indice = [0] for ph in word2ph[:-1]: indice.append(indice[-1] + ph) word2pronoun = [] for i, ph in zip(indice, word2ph): word2pronoun.append(np.sum(pronoun_lens[i : i + ph])) return word2pronoun def generate_slices(word2pronoun, dec_len): pn_start, pn_end = 0, 0 zp_start, zp_end = 0, 0 zp_len = 0 pn_slices = [] zp_slices = [] while pn_end < len(word2pronoun): if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len: zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end]) zp_start = zp_end - zp_len pn_start = pn_end - 2 else: zp_len = 0 zp_start = zp_end pn_start = pn_end while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len: zp_len += word2pronoun[pn_end] pn_end += 1 zp_end = zp_start + zp_len pn_slices.append(slice(pn_start, pn_end)) zp_slices.append(slice(zp_start, zp_end)) return pn_slices, zp_slices def lang_detect_with_regex(text): text_without_digits = re.sub(r'\d+', '', text) if not text_without_digits: return 'unknown' if re.search(r'[\u4e00-\u9fff]', text_without_digits): return 'chinese' else: if re.search(r'[a-zA-Z]', text_without_digits): return 'english' else: return 'unknown' class QwenTranslationAPI: def __init__(self, api_url=QWEN_API_URL): self.api_url = api_url self.session_id = f"speech_translate_{int(time.time())}" def reset_context(self): try: reset_url = f"{self.api_url}/api/reset" response = requests.post(reset_url, json={}, timeout=5) if response.status_code == 200: print("[API] 上下文重置成功") return True else: print(f"[API] 重置失败,状态码: {response.status_code}") except Exception as e: print(f"[API] 重置上下文失败: {e}") return False def translate(self, text_content, max_retries=3, timeout=120): if not text_content or text_content.strip() == "": return "输入文本为空" if lang_detect_with_regex(text_content)=='chinese': prompt_f = "回答(限制在100个字以内)" else: prompt_f = "回答(限制在100个字以内)" prompt = f"{prompt_f}:{text_content}" print(f"[API] 发送请求: {prompt}") for attempt in range(max_retries): try: generate_url = f"{self.api_url}/api/generate" payload = { "prompt": prompt, "temperature": 0.1, "repetition_penalty": 1.0, "top-p": 0.9, "top-k": 40, "max_new_tokens": 512 } print(f"[API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})") response = requests.post(generate_url, json=payload, timeout=30) response.raise_for_status() print("[API] 生成请求成功") result_url = f"{self.api_url}/api/generate_provider" start_time = time.time() full_translation = "" error_detected = False while time.time() - start_time < timeout: try: result_response = requests.get(result_url, timeout=10) result_data = result_response.json() current_chunk = result_data.get("response", "") if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower(): print(f"[API] 检测到错误: {current_chunk}") error_detected = True self.reset_context() break full_translation += current_chunk if result_data.get("done", False): print(f"[API] 完成: {full_translation}") return full_translation time.sleep(0.05) except requests.exceptions.RequestException as e: print(f"[API] 轮询请求失败: {e}") if time.time() - start_time > timeout: break continue if error_detected and attempt < max_retries - 1: print(f"[API] 等待1秒后重试...") time.sleep(1) continue print(f"[API] 轮询超时,尝试第 {attempt + 1} 次重试") except requests.exceptions.RequestException as e: print(f"[API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: wait_time = 2 ** attempt print(f"[API] 等待 {wait_time} 秒后重试...") time.sleep(wait_time) else: return f"失败: {str(e)}" except Exception as e: print(f"[API] 过程出错: {e}") return f"失败: {str(e)}" return "超时,请检查API服务状态" class SpeechTranslationPipeline: def __init__(self, tts_model_dir, tts_model_files, asr_model_dir="ax_model", seq_len=132, tts_dec_len=128, sample_rate=44100, tts_speed=0.8, qwen_api_url=QWEN_API_URL): self.tts_model_dir = tts_model_dir self.tts_model_files = tts_model_files self.asr_model_dir = asr_model_dir self.seq_len = seq_len self.tts_dec_len = tts_dec_len self.sample_rate = sample_rate self.tts_speed = tts_speed self.qwen_api_url = qwen_api_url self._init_asr_models() self._init_tts_models() self.translator = QwenTranslationAPI(api_url=qwen_api_url) self._validate_files() def _init_asr_models(self): """初始化语音识别相关模型""" print("Initializing SenseVoice models...") self.model_vad = AX_Fsmn_vad(self.asr_model_dir) self.embed = SinusoidalPositionEncoder() self.position_encoding = self.embed.get_position_encoding( torch.randn(1, self.seq_len, 560)).numpy() self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len) tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) print("SenseVoice models initialized successfully.") def _init_tts_models(self): """初始化TTS相关模型""" print("Initializing MeloTTS models...") init_start = time.time() enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]) dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]) self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) self.sess_dec = axe.InferenceSession(dec_model) g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"]) self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1) self.tts_language = "ZH_MIX_EN" self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])} # 提前加载所有懒加载的模块(这是主要耗时部分) print(" Warming up TTS modules (loading language models, tokenizers, etc.)...") warmup_start = time.time() # 中英 try: warmup_start_mix = time.time() warmup_text_mix = "这是一个test测试。" _, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id) print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms") except Exception as e: print(f" Warning: Mixed warm-up failed: {e}") total_init_time = (time.time() - init_start) * 1000 print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)") def _validate_files(self): for key, filename in self.tts_model_files.items(): filepath = os.path.join(self.tts_model_dir, filename) if not os.path.exists(filepath): raise FileNotFoundError(f"TTS模型文件不存在: {filepath}") try: response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5) print("[API检查] 千问API服务连接正常") except: print("[API警告] 无法连接到千问API服务") def speech_recognition(self, speech, fs): """第一步:语音识别(ASR)""" speech_lengths = len(speech) print("Running VAD...") vad_start_time = time.time() res_vad = self.model_vad(speech)[0] vad_segments = merge_vad(res_vad, 15 * 1000) vad_time_cost = time.time() - vad_start_time print(f"VAD processing time: {vad_time_cost:.2f} seconds") print(f"VAD segments detected: {len(vad_segments)}") print("Running ASR...") asr_start_time = time.time() all_results = "" for i, segment in enumerate(vad_segments): segment_start, segment_end = segment start_sample = int(segment_start / 1000 * fs) end_sample = min(int(segment_end / 1000 * fs), speech_lengths) segment_speech = speech[start_sample:end_sample] segment_filename = f"temp_segment_{i}.wav" sf.write(segment_filename, segment_speech, fs) try: segment_res = self.model_bin( segment_filename, "auto", True, self.position_encoding, tokenizer=self.tokenizer, ) all_results += segment_res if os.path.exists(segment_filename): os.remove(segment_filename) except Exception as e: if os.path.exists(segment_filename): os.remove(segment_filename) print(f"Error processing segment {i}: {e}") continue asr_time_cost = time.time() - asr_start_time print(f"ASR processing time: {asr_time_cost:.2f} seconds") print(f"ASR Result: {all_results}") return all_results.strip() def run_translation(self, text_content): """第二步:调用Qwen大模型API处理""" print("Starting translation via API...") translation_start_time = time.time() translate_content = self.translator.translate(text_content) translation_time_cost = time.time() - translation_start_time print(f"Translation processing time: {translation_time_cost:.2f} seconds") print(f"Translation Result: {translate_content}") return translate_content def run_tts(self, translate_content, output_dir, output_wav=None): """第三步:使用TTS模型合成语音""" output_path = os.path.join(output_dir, output_wav) try: if lang_detect_with_regex(translate_content) == "chinese": translate_content = cn2an.transform(translate_content, "an2cn") print(f"TTS synthesis for text: {translate_content}") sens = split_sentence(translate_content, language_str=self.tts_language) print(f"Text split into {len(sens)} sentences") segments_dir = os.path.join(output_dir, "segments") os.makedirs(segments_dir, exist_ok=True) audio_list = [] for n, se in enumerate(sens): if self.tts_language in ['EN', 'ZH_MIX_EN']: se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se) print(f"Processing sentence[{n}]: {se}") phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer( se, self.tts_language, symbol_to_id=self.symbol_to_id) encoder_start = time.time() z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={ 'phone': phones, 'g': self.tts_g, 'tone': tones, 'language': lang_ids, 'noise_scale': np.array([0], dtype=np.float32), 'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32), 'noise_scale_w': np.array([0], dtype=np.float32), 'sdp_ratio': np.array([0], dtype=np.float32)}) encoder_time = time.time() - encoder_start print(f"Encoder run time: {encoder_time*1000:.2f}ms") word2pronoun = calc_word2pronoun(word2ph, pronoun_lens) pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len) audio_len = audio_len[0] sub_audio_list = [] for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)): zp_slice = z_p[..., zs] sub_dec_len = zp_slice.shape[-1] sub_audio_len = 512 * sub_dec_len if zp_slice.shape[-1] < self.tts_dec_len: zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) decoder_start = time.time() audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten() audio_start = 0 if len(sub_audio_list) > 0: if pn_slices[i - 1].stop > ps.start: audio_start = 512 * word2pronoun[ps.start] audio_end = sub_audio_len if i < len(pn_slices) - 1: if ps.stop > pn_slices[i + 1].start: audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1] audio = audio[audio_start:audio_end] sub_audio_list.append(audio) merge_start = time.time() sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) merge_time = time.time() - merge_start print(f"Sentence[{n}] merge time: {merge_time*1000:.2f}ms") output_wav_name = output_wav.split(".wav")[0] segment_filename = os.path.join(segments_dir, f"{output_wav_name}_sentence_{n:03d}.wav") sf.write(segment_filename, sub_audio, self.sample_rate) print(f"Saved segment audio: {segment_filename}") audio_list.append(sub_audio) concat_start = time.time() audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed) concat_time = time.time() - concat_start print(f"Audio concatenation time: {concat_time*1000:.2f}ms") sf.write(output_path, audio, self.sample_rate) print(f"TTS audio saved to {output_path}") return output_path except Exception as e: print(f"TTS synthesis failed: {e}") import traceback traceback.print_exc() raise e def full_pipeline(self, speech, fs, output_dir=None, output_tts=None): """完整Pipeline:语音识别 -> qwen -> TTS合成""" print("\n----------------------VAD+ASR----------------------------\n") start_time = time.time() text_content = self.speech_recognition(speech, fs) asr_time = time.time() - start_time print(f"语音识别耗时: {asr_time:.2f} 秒") if not text_content or text_content.strip() == "": raise ValueError("ASR未能识别出有效文本") print("\n---------------------Qwen---------------------------\n") start_time = time.time() translate_content = self.run_translation(text_content) translate_time = time.time() - start_time print(f"qwen耗时: {translate_time:.2f} 秒") print("-------------------------TTS-------------------------------\n") start_time = time.time() output_path = self.run_tts(translate_content, output_dir, output_tts) tts_time = time.time() - start_time print(f"TTS合成耗时: {tts_time:.2f} 秒") return { "original_text": text_content, "translated_text": translate_content, "audio_path": output_path } def main(): parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline") parser.add_argument("--audio_dir", type=str, default="./input_question", help="Input audio directory path") parser.add_argument("--output_dir", type=str, default="./output_answer", help="Output directory") parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API server URL") args = parser.parse_args() print("-------------------START------------------------\n") os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.audio_dir): print(f"错误: 音频目录不存在: {args.audio_dir}") return audio_files = [] for file in os.listdir(args.audio_dir): if file.lower().endswith(('.wav', '.mp3')): audio_files.append(os.path.join(args.audio_dir, file)) if not audio_files: print(f"错误: 在目录 {args.audio_dir} 中没有找到音频文件") return audio_files.sort() print(f"找到 {len(audio_files)} 个音频文件: {[os.path.basename(f) for f in audio_files]}") pipeline = SpeechTranslationPipeline( tts_model_dir=TTS_MODEL_DIR, tts_model_files=TTS_MODEL_FILES, asr_model_dir="ax_model", seq_len=132, tts_dec_len=128, sample_rate=44100, tts_speed=0.8, qwen_api_url=args.api_url ) all_results = [] total_start_time = time.time() for i, audio_file in enumerate(audio_files): print(f"\n{'='*60}") print(f"处理第 {i+1}/{len(audio_files)} 个音频文件: {os.path.basename(audio_file)}") print(f"{'='*60}") file_start_time = time.time() try: speech, fs = librosa.load(audio_file, sr=None) if fs != 16000: print(f"重采样音频从 {fs}Hz 到 16000Hz") speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000) fs = 16000 audio_duration = librosa.get_duration(y=speech, sr=fs) base_name = os.path.splitext(os.path.basename(audio_file))[0] output_tts = f"{base_name}_answer.wav" result = pipeline.full_pipeline(speech, fs, args.output_dir, output_tts) file_time_cost = time.time() - file_start_time out_wav = os.path.join(args.output_dir,output_tts) speech, fs = librosa.load(out_wav, sr=None) output_duration = librosa.get_duration(y=speech, sr=fs) rtf = file_time_cost / output_duration result.update({ "audio_file": audio_file, "processing_time": file_time_cost, "output_duration": output_duration, "rtf": rtf }) all_results.append(result) print(f"\n文件处理完成: {os.path.basename(audio_file)}") print(f"原始文本: {result['original_text']}") print(f"回答文本: {result['translated_text']}") print(f"生成音频: {result['audio_path']}") print(f"处理时间: {file_time_cost:.2f} 秒") print(f"音频时长: {output_duration:.2f} 秒") print(f"RTF: {rtf:.2f}") except Exception as e: print(f"处理文件 {audio_file} 时出错: {e}") import traceback traceback.print_exc() continue total_time_cost = time.time() - total_start_time print(f"\n{'='*80}") print("所有文件处理完成!") print(f"{'='*80}") print(f"总共处理了 {len(all_results)} 个文件") print(f"总处理时间: {total_time_cost:.2f} 秒") summary_file = os.path.join(args.output_dir, "processing_summary.txt") with open(summary_file, 'w', encoding='utf-8') as f: f.write("批量处理结果汇总\n") f.write("=" * 50 + "\n\n") for i, result in enumerate(all_results): f.write(f"文件 {i+1}: {os.path.basename(result['audio_file'])}\n") f.write(f" 原始文本: {result['original_text']}\n") f.write(f" 回答结果: {result['translated_text']}\n") f.write(f" 合成音频: {os.path.basename(result['audio_path'])}\n") f.write(f" 处理时间: {result['processing_time']:.2f} 秒\n") f.write(f" 音频时长: {result['output_duration']:.2f} 秒\n") f.write(f" RTF: {result['rtf']:.2f}\n") f.write("-" * 50 + "\n") f.write(f"\n总计: {len(all_results)} 个文件\n") f.write(f"总处理时间: {total_time_cost:.2f} 秒\n") print(f"详细结果已保存到: {summary_file}") if __name__ == "__main__": main()