cohere-transcribe-diarize / vllm_diarized_patch.py
mhenrichsen's picture
Upload vllm_diarized_patch.py with huggingface_hub
52ffa47 verified
Raw
History Blame Contribute Delete
6.48 kB
"""Patch vLLM 0.19.0 to serve syv.ai's diarize+ts model.
Applies five edits across three files:
- protocol.py : add "diarized_json" to AudioResponseFormat
- protocol.py : force skip_special_tokens=False in to_sampling_params
(so <|spltoken*|> and <|t:*|> reach the response text)
- speech_to_text.py : accept "diarized_json" in the response_format validator
- speech_to_text.py : inject a parser that turns raw <|spltokenN|><|t:s|>text<|t:e|>
into an OpenAI-compatible {segments, speakers, …} JSON
- api_router.py : pass JSONResponse return values through unchanged
Idempotent — re-running is safe.
Usage:
python vllm_diarized_patch.py
"""
import sys
VLLM_ROOT = None
try:
import vllm
VLLM_ROOT = vllm.__path__[0]
except Exception:
sys.exit("vLLM not importable — install vllm==0.19.0 first")
PROTO = f"{VLLM_ROOT}/entrypoints/openai/speech_to_text/protocol.py"
SVC = f"{VLLM_ROOT}/entrypoints/openai/speech_to_text/speech_to_text.py"
ROUTER = f"{VLLM_ROOT}/entrypoints/openai/speech_to_text/api_router.py"
def patch(path, old, new, label):
s = open(path).read()
if new in s:
print(f" · {label} (already applied)")
return
if old not in s:
sys.exit(f"FAIL {label}: anchor not found in {path}")
open(path, "w").write(s.replace(old, new, 1))
print(f" ✓ {label}")
# 1. AudioResponseFormat enum
patch(
PROTO,
'AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]',
'AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt", "diarized_json"]',
"AudioResponseFormat += diarized_json",
)
# 2. Force skip_special_tokens=False so the diarize/timestamp tokens reach response text
patch(
PROTO,
' extra_args=self.vllm_xargs,\n skip_clone=True, # Created fresh per request, safe to skip clone',
' extra_args=self.vllm_xargs,\n skip_special_tokens=False, # SYVAI: preserve <|spltoken*|> + <|t:*|>\n skip_clone=True, # Created fresh per request, safe to skip clone',
"to_sampling_params(skip_special_tokens=False)",
)
# 3. Validator allows diarized_json
patch(
SVC,
'if request.response_format not in ["text", "json", "verbose_json"]:',
'if request.response_format not in ["text", "json", "verbose_json", "diarized_json"]:',
"validator allows diarized_json",
)
# 4. Inject diarized_json branch right before the existing final response construction
old_branch = ''' text = "".join(text_parts)
if self.task_type == "transcribe":
final_response: ResponseType
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
if request.response_format != "verbose_json":'''
new_branch = ''' text = "".join(text_parts)
if self.task_type == "transcribe":
final_response: ResponseType
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
# SYVAI: diarized_json — parse <|spltokenN|><|t:s|>text<|t:e|>
if request.response_format == "diarized_json":
import re as _re
SEG_RE = _re.compile(
r"<\\|spltoken(\\d+)\\|>\\s*<\\|t:(\\d+\\.\\d+)\\|>(.*?)<\\|t:(\\d+\\.\\d+)\\|>",
_re.DOTALL,
)
TOK_STRIP = _re.compile(r"<\\|[^|]+\\|>")
segs = []
last_spk = 0
for m in SEG_RE.finditer(text):
spk = int(m.group(1))
st = float(m.group(2))
ed = float(m.group(4))
if ed <= st: ed = st + 0.05
clean = TOK_STRIP.sub("", m.group(3)).strip()
segs.append({
"speaker": f"SPEAKER_{spk:02d}",
"start": st,
"end": ed,
"text": clean,
})
last_spk = max(last_spk, spk)
plain_text = TOK_STRIP.sub("", text).strip()
payload = {
"task": "transcribe",
"language": request.language,
"duration": duration_s,
"text": plain_text,
"segments": segs,
"speakers": [f"SPEAKER_{i:02d}" for i in range(last_spk + 1)] if segs else [],
"usage": usage,
}
from fastapi.responses import JSONResponse
return JSONResponse(content=payload)
if request.response_format != "verbose_json":'''
patch(SVC, old_branch, new_branch, "diarized_json response builder")
# 5. api_router: passthrough JSONResponse so the diarized branch's return value isn't
# misinterpreted as a streaming generator.
patch(
ROUTER,
' if isinstance(generator, ErrorResponse):\n return JSONResponse(\n content=generator.model_dump(), status_code=generator.error.code\n )\n\n elif isinstance(generator, TranscriptionResponseVariant):\n return JSONResponse(content=generator.model_dump())\n\n return StreamingResponse(content=generator, media_type="text/event-stream")\n',
' if isinstance(generator, ErrorResponse):\n return JSONResponse(\n content=generator.model_dump(), status_code=generator.error.code\n )\n\n elif isinstance(generator, JSONResponse): # SYVAI: diarized_json passthrough\n return generator\n\n elif isinstance(generator, TranscriptionResponseVariant):\n return JSONResponse(content=generator.model_dump())\n\n return StreamingResponse(content=generator, media_type="text/event-stream")\n',
"api_router JSONResponse passthrough",
)
print("\nAll patches applied. Restart your vllm serve process for them to take effect.")