#!/usr/bin/env python # -*- coding: utf-8 -*- """ MOSS Transcribe Diarize Gradio Demo (Remote API) ========================== Provides a web interface for audio/video upload and transcription via official API. """ import base64 import argparse import json import os import re import subprocess from pathlib import Path from typing import Any, Tuple import gradio as gr import requests DEFAULT_API_URL = os.getenv("MOSS_API_URL", "https://studio.mosi.cn/v1/audio/transcriptions") DEFAULT_AUTH_TOKEN = os.getenv("MOSS_API_KEY", os.getenv("MOSS_API_AUTH_TOKEN", "")) DEFAULT_MODEL = os.getenv("MOSS_MODEL", "moss-transcribe-diarize") MAX_AUDIO_DURATION = 60 * 30 # seconds MAX_FILE_SIZE = 1024 * 1024 * 100 # 100MB AUDIO_SUFFIXES: Tuple[str, ...] = ( ".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg", ".wma", ".mp4", ".mov", ".mkv", ".avi", ".wmv", ".webm", ) APP_ARGS: argparse.Namespace = argparse.Namespace() # --- Time Formatting Helper --- def _sec_to_hhmmss_cs(sec: float) -> str: """Convert seconds to compact HH:MM:SS.ss format.""" if sec < 0: sec = 0.0 total = float(sec) hh = int(total // 3600) mm = int((total % 3600) // 60) ss = total % 60.0 if hh > 0: return f"{hh:02d}:{mm:02d}:{ss:05.2f}" if mm > 0: return f"{mm:02d}:{ss:05.2f}" return f"{ss:05.2f}" # --- I18N Configuration (For UI Elements) --- i18n = gr.I18n( en={ "header": "## 🎤 MOSS Transcribe Diarize: Accurate Transcription with Speaker Diarization", "tips": ( f"> **💡 Note**: This demo currently supports ASR with speaker recognition for audio clips up to **{MAX_AUDIO_DURATION}s**. \n" "> ✅ Long-audio transcription with full timestamps is now available via API. Try it here: [API Docs](https://studio.mosi.cn/docs/moss-transcribe-diarize) \n" "> **🔗 Links**: [paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)" ), "audio_tab": "🎵 Audio", "audio_label": "📥 Upload / Record Audio", "video_tab": "🎬 Video", "video_tip": "💡 **Note**: Uploading a video will extract the audio for transcription.", "video_label": "📥 Upload Video", "run_btn": "🚀 Start Transcription", "output_label": "📝 Transcription Result", }, **{"zh-CN": { "header": "## 🎤 MOSS Transcribe Diarize: 精准转写与说话人识别", "tips": ( f"> **💡 说明**:本演示版本仅支持短音频(最长 **{MAX_AUDIO_DURATION}s**)的文本转写及说话人识别。 \n" "> ✅ 已提供支持完整时间戳的长音频转写 API,可在此试用:[API 文档](https://studio.mosi.cn/docs/moss-transcribe-diarize) \n" "> **🔗 链接**:[paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)" ), "audio_tab": "🎵 音频", "audio_label": "📥 上传/录制音频", "video_tab": "🎬 视频", "video_tip": "💡 **提示**:上传视频将提取其中的音频进行转录。", "video_label": "📥 上传视频", "run_btn": "🚀 开始转写", "output_label": "📝 转写结果", }} ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="MOSS Transcribe Diarize Gradio Demo (Remote API)") parser.add_argument("--api_url", default=DEFAULT_API_URL, help="Remote inference service URL") parser.add_argument( "--auth_token", default=DEFAULT_AUTH_TOKEN, help="MOSI API key or full Authorization header (env: MOSS_API_KEY / MOSS_API_AUTH_TOKEN)", ) parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or snapshot") parser.add_argument("--timeout", type=int, default=120, help="HTTP request timeout (seconds)") parser.add_argument("--max_new_tokens", type=int, default=16384, help="Max new tokens") parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") parser.add_argument("--top_k", type=int, default=-1, help="Sampling top_k") parser.add_argument("--top_p", type=float, default=1.0, help="Sampling top_p") parser.add_argument("--target_sample_rate", type=int, default=16000, help="Resample to this rate (0 to disable)") parser.add_argument("--keep_channels", action="store_true", help="Keep multiple channels (default: downmix to mono)") parser.add_argument("--share", action="store_true", help="Whether to generate a public link") parser.add_argument("--server_name", default="0.0.0.0", help="Gradio server name") parser.add_argument("--server_port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")), help="Gradio server port") return parser.parse_args() def _get_duration(file_path: str) -> float: """Get the duration of an audio/video file in seconds.""" cmd = [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", file_path ] try: proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) return float(proc.stdout.strip()) except Exception: return 0.0 def _ffmpeg_to_wav_bytes(file_path: str, target_sample_rate: int, keep_channels: bool, duration_limit: float = 0.0) -> bytes: cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-nostdin"] if duration_limit > 0: cmd += ["-t", str(duration_limit)] cmd += ["-i", file_path] if not keep_channels: cmd += ["-ac", "1"] if target_sample_rate and int(target_sample_rate) > 0: cmd += ["-ar", str(int(target_sample_rate))] cmd += ["-f", "wav", "pipe:1"] proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if proc.returncode != 0: err = proc.stderr.decode(errors="ignore").strip() raise RuntimeError(f"ffmpeg transcoding failed: {err or 'unknown error'}") return proc.stdout def _file_to_wav_bytes(file_path: str, duration_limit: float = 0.0) -> bytes: return _ffmpeg_to_wav_bytes(file_path, APP_ARGS.target_sample_rate, APP_ARGS.keep_channels, duration_limit) def _file_to_data_uri(file_path: str, duration_limit: float = 0.0) -> str: wav_bytes = _file_to_wav_bytes(file_path, duration_limit) b64 = base64.b64encode(wav_bytes).decode("utf-8") return f"data:audio/wav;base64,{b64}" def _call_remote_asr( audio_data_uri: str, model: str, max_new_tokens: int, temperature: float, top_k: int, top_p: float, ) -> Any: token = (APP_ARGS.auth_token or "").strip() if not token: raise RuntimeError("Missing API key. Set --auth_token or env MOSS_API_KEY.") auth_header = token if token.lower().startswith("bearer ") else f"Bearer {token}" payload = { "model": model, "audio_data": audio_data_uri, "sampling_params": { "max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_k": int(top_k), "top_p": float(top_p), }, "meta_info": True, } headers = { "Content-Type": "application/json", "Authorization": auth_header, } try: resp = requests.post(APP_ARGS.api_url, headers=headers, json=payload, timeout=int(APP_ARGS.timeout)) except Exception as e: raise RuntimeError(f"Request failed: {e}") from e if resp.status_code != 200: text = (resp.text or "").strip() if len(text) > 2000: text = text[:2000] + " ... (truncated)" raise RuntimeError(f"HTTP {resp.status_code}: {text}") try: return resp.json() except Exception: return {"raw_text": resp.text} def _safe_float(value: Any) -> float | None: try: return float(value) except Exception: return None def _format_segments(segments: Any) -> str: if not isinstance(segments, list): return "" lines: list[str] = [] for seg in segments: if not isinstance(seg, dict): continue text = seg.get("text") if not isinstance(text, str): continue text = re.sub(r"<[^>]+>", "", text).strip() if not text: continue start_fmt = seg.get("start_s") end_fmt = seg.get("end_s") speaker = seg.get("speaker") speaker_tag = f"[{speaker.strip()}]" if isinstance(speaker, str) and speaker.strip() else "" if start_fmt is not None and end_fmt is not None: line = f"[{start_fmt}-{end_fmt}] {speaker_tag} {text}".strip() elif speaker_tag: line = f"{speaker_tag} {text}" else: line = text lines.append(line) return "\n".join(lines).strip() def _post_process_transcription(text: str) -> str: """Post process model output. - Remove tags (like , , etc.) - If text contains timestamped segments like: [0.00][S01]... [4.20][8.38][S02]... format them into: [start-end] [Sxx] 内容 where start/end are converted from seconds to HH:MM:SS.ss - If timestamp parsing fails, strip numeric timestamps and fallback to speaker-only formatting. """ # print(text) def _is_time_token(tok: str) -> bool: # seconds like 0.00 / 12 / 12.3 / 12.34 return re.fullmatch(r"\d+(?:\.\d+)?", tok) is not None def _is_speaker_token(tok: str) -> bool: # S01 / S1 / S001 ... return re.fullmatch(r"S\d{1,3}", tok) is not None def _strip_numeric_timestamps(s: str) -> str: return re.sub(r"\[(?:\d+(?:\.\d+)?)\]", "", s) # 1) Remove and tags text = re.sub(r"<[^>]+>", "", text) # 2) Try timestamped parsing first try: bracket_pat = re.compile(r"\[([^\]]+)\]") segments: list[tuple[float, float, str, str]] = [] times_buffer: list[float] = [] cur_speaker: str | None = None cur_start: float | None = None cur_text: list[str] = [] idx = 0 for m in bracket_pat.finditer(text): between = text[idx:m.start()] if cur_speaker is not None and between: cur_text.append(between) idx = m.end() tok = (m.group(1) or "").strip() if not tok: continue # time token: accumulate,稍后在说话人边界统一分配 if _is_time_token(tok): times_buffer.append(float(tok)) continue # 说话人 token if _is_speaker_token(tok): speaker = f"[{tok}]" if cur_speaker is None: # 第一个 segment:用最近的时间作为起始 cur_start = times_buffer[-1] if times_buffer else None cur_speaker = speaker cur_text = [] times_buffer = [] else: # 连续重复同一说话人、且中间没有时间戳:忽略这个重复标签 if not times_buffer and speaker == cur_speaker: continue # 结束上一个 segment,并为新 segment 设定 start if not times_buffer: raise ValueError("no timestamp between speakers") prev_end = times_buffer[0] if cur_start is None: raise ValueError("segment without start time") txt = "".join(cur_text).strip() segments.append((cur_start, prev_end, cur_speaker, txt)) # 新说话人:用最后一个时间作为 start(成对时即第二个值) next_start = times_buffer[-1] cur_start = next_start cur_speaker = speaker cur_text = [] times_buffer = [] continue # 其他方括号内容(如 [event]),保留在文本中 if cur_speaker is not None: cur_text.append(f"[{tok}]") # 处理最后一段文本 tail = text[idx:] if cur_speaker is not None and tail: cur_text.append(tail) # 收尾:最后一个 segment 的结束时间取最后一个时间戳 if cur_speaker is not None: if not times_buffer or cur_start is None: raise ValueError("last segment missing timestamp") last_end = times_buffer[-1] txt = "".join(cur_text).strip() segments.append((cur_start, last_end, cur_speaker, txt)) if not segments: raise ValueError("no valid segments parsed") formatted_lines: list[str] = [] for s_sec, e_sec, spk, txt in segments: s_fmt = _sec_to_hhmmss_cs(s_sec) e_fmt = _sec_to_hhmmss_cs(e_sec) formatted_lines.append(f"[{s_fmt}-{e_fmt}] {spk} {txt}") return "\n".join(formatted_lines).strip() except Exception: # fallback: strip numeric timestamps then apply speaker-only formatting text_wo_ts = _strip_numeric_timestamps(text) # 3) Speaker-only formatting (merge consecutive identical speakers) speaker_pat = re.compile(r"(\[S\d{1,3}\])") parts = speaker_pat.split(text_wo_ts) processed_turns: list[str] = [] current_speaker: str | None = None current_text: list[str] = [] for part in parts: if not part: continue if speaker_pat.fullmatch(part): speaker = part if speaker == current_speaker: continue if current_speaker is not None: txt = "".join(current_text).strip() if txt: processed_turns.append(f"{current_speaker} {txt}") current_speaker = speaker current_text = [] else: current_text.append(part) if current_speaker is not None: txt = "".join(current_text).strip() if txt: processed_turns.append(f"{current_speaker} {txt}") elif current_text: return "".join(current_text).strip() return "\n".join(processed_turns).strip() def _format_api_response(resp_obj: Any) -> str: raw_text = "" if isinstance(resp_obj, str): raw_text = resp_obj elif isinstance(resp_obj, dict): asr_result = resp_obj.get("asr_transcription_result") if isinstance(asr_result, dict): segments_text = _format_segments(asr_result.get("segments")) if segments_text: return segments_text full_text = asr_result.get("full_text") if isinstance(full_text, str) and full_text.strip(): return _post_process_transcription(full_text) for k in ("text", "result", "transcription", "output", "generated_text"): v = resp_obj.get(k) if isinstance(v, str) and v.strip(): raw_text = v break else: raw_text = json.dumps(resp_obj, ensure_ascii=False, indent=2) else: raw_text = str(resp_obj) return _post_process_transcription(raw_text) def _normalize_path(file_obj) -> str: if isinstance(file_obj, dict): name = file_obj.get("name") if isinstance(name, str): return name if isinstance(file_obj, str): return file_obj name = getattr(file_obj, "name", None) if isinstance(name, str): return name raise gr.Error("Unrecognized file object.") def preprocess_file(audio_obj, video_obj) -> str: provided = [obj for obj in (audio_obj, video_obj) if obj] if len(provided) == 0: raise gr.Error("Please upload an audio or video file.") if len(provided) > 1: raise gr.Error("Please select either audio or video, not both.") file_path = _normalize_path(provided[0]) suffix = Path(file_path).suffix.lower() if suffix in AUDIO_SUFFIXES: return file_path raise gr.Error("Unsupported file format.") def run_transcription( audio_obj, video_obj, progress=gr.Progress(track_tqdm=False), ) -> str: progress(0.15, "Processing file...") audio_path = preprocess_file(audio_obj, video_obj) duration = _get_duration(audio_path) actual_limit = 0.0 if duration > MAX_AUDIO_DURATION + 0.1: actual_limit = float(MAX_AUDIO_DURATION) gr.Warning(f"File is too long ({duration:.1f}s). It has been truncated to the first {MAX_AUDIO_DURATION}s.") progress(0.3, "Decoding and transcoding to WAV (ffmpeg)...") try: audio_data_uri = _file_to_data_uri(audio_path, duration_limit=actual_limit) except Exception as e: raise gr.Error(str(e)) progress(0.6, "Requesting remote inference service...") try: resp_obj = _call_remote_asr( audio_data_uri, APP_ARGS.model, int(APP_ARGS.max_new_tokens), float(APP_ARGS.temperature), int(APP_ARGS.top_k), float(APP_ARGS.top_p), ) result = _format_api_response(resp_obj) if not result.strip(): result = json.dumps(resp_obj, ensure_ascii=False, indent=2) progress(1.0, "Done") return result except Exception as e: raise gr.Error(str(e)) def build_demo() -> gr.Blocks: with gr.Blocks(title="MOSS Transcribe Diarize", theme=gr.themes.Soft()) as demo: gr.Markdown(i18n("header")) gr.Markdown(i18n("tips")) with gr.Row(): with gr.Column(scale=1): with gr.Tabs() as tabs: with gr.Tab(i18n("audio_tab"), id="audio") as audio_tab: audio_input = gr.Audio( label=i18n("audio_label"), sources=["upload", "microphone"], type="filepath", interactive=True, ) with gr.Tab(i18n("video_tab"), id="video") as video_tab: gr.Markdown(i18n("video_tip")) video_input = gr.Video( label=i18n("video_label"), interactive=True, ) run_button = gr.Button(i18n("run_btn"), variant="primary") with gr.Column(scale=1): output_box = gr.Textbox(label=i18n("output_label"), lines=18) audio_tab.select(fn=lambda: None, outputs=video_input) video_tab.select(fn=lambda: None, outputs=audio_input) run_button.click( run_transcription, inputs=[ audio_input, video_input, ], outputs=output_box, ) return demo def main() -> None: global APP_ARGS APP_ARGS = parse_args() demo = build_demo() demo.queue().launch( share=APP_ARGS.share, server_name=APP_ARGS.server_name, server_port=APP_ARGS.server_port, max_file_size=MAX_FILE_SIZE, i18n=i18n, ) if __name__ == "__main__": main()