Spaces:
Running on Zero
Running on Zero
File size: 5,367 Bytes
4cd8837 4aaae80 4cd8837 d6ca3a2 4cd8837 4aaae80 4cd8837 4aaae80 4cd8837 4aaae80 4cd8837 8ee8138 4cd8837 4aaae80 4cd8837 8ee8138 4cd8837 6c38d43 4cd8837 6c38d43 4cd8837 6c38d43 4cd8837 6c38d43 4cd8837 4aaae80 d6ca3a2 4cd8837 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Whisper local STT backend (openai-whisper or faster-whisper)."""
from __future__ import annotations
import asyncio
import contextlib
import tempfile
import time
from typing import Any
class WhisperBackend:
name = "whisper"
def __init__(
self,
model_size: str = "base",
device: str = "auto",
) -> None:
self._model_size = model_size
self._device = device
self._model: Any = None
self._backend_lib: str | None = None # "openai_whisper" or "faster_whisper"
def _resolve_device(self) -> str:
if self._device != "auto":
return self._device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def health(self) -> dict:
# Prefer faster_whisper, fall back to openai whisper
try:
import faster_whisper # noqa: F401
return {
"backend": self.name,
"status": "ok",
"lib": "faster_whisper",
"model": self._model_size,
}
except ImportError:
pass
try:
import whisper # noqa: F401
return {
"backend": self.name,
"status": "ok",
"lib": "openai_whisper",
"model": self._model_size,
}
except ImportError:
pass
return {
"backend": self.name,
"status": "unavailable",
"reason": "Neither openai-whisper nor faster-whisper is installed",
}
def _load_model_sync(self) -> None:
device = self._resolve_device()
try:
from faster_whisper import WhisperModel # type: ignore[import]
self._model = WhisperModel(self._model_size, device=device)
self._backend_lib = "faster_whisper"
return
except ImportError:
pass
import whisper # type: ignore[import]
self._model = whisper.load_model(self._model_size, device=device)
self._backend_lib = "openai_whisper"
async def _ensure_loaded(self) -> None:
if self._model is None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._load_model_sync)
async def transcribe(
self,
audio_bytes: bytes,
language: str | None = None,
translate_to_en: bool = False,
) -> Any:
from hearthnet.services.speech.backends.base import SttResult
await self._ensure_loaded()
t0 = time.monotonic()
loop = asyncio.get_running_loop()
segments, detected_lang = await loop.run_in_executor(
None, self._transcribe_sync, audio_bytes, language, translate_to_en
)
ms = int((time.monotonic() - t0) * 1000)
full_text = " ".join(s.text for s in segments)
return SttResult(
segments=segments,
full_text=full_text,
detected_language=detected_lang or "unknown",
backend=self.name,
ms=ms,
)
def _transcribe_sync(
self,
audio_bytes: bytes,
language: str | None,
translate_to_en: bool,
) -> tuple[list[Any], str | None]:
from hearthnet.services.speech.backends.base import SttSegment
# Write to temp file because whisper expects file path
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
segments_out: list[SttSegment] = []
detected: str | None = None
try:
if self._backend_lib == "faster_whisper":
task = "translate" if translate_to_en else "transcribe"
segs, info = self._model.transcribe(
tmp_path,
language=language,
task=task,
)
detected = info.language
segments_out.extend(
SttSegment(
start_seconds=seg.start,
end_seconds=seg.end,
text=seg.text.strip(),
language=detected,
confidence=None,
)
for seg in segs
)
else:
# openai-whisper
task = "translate" if translate_to_en else "transcribe"
kwargs: dict = {"task": task}
if language:
kwargs["language"] = language
result = self._model.transcribe(tmp_path, **kwargs)
detected = result.get("language")
segments_out.extend(
SttSegment(
start_seconds=float(seg["start"]),
end_seconds=float(seg["end"]),
text=str(seg["text"]).strip(),
language=detected,
confidence=None,
)
for seg in result.get("segments", [])
)
finally:
import os
with contextlib.suppress(OSError):
os.unlink(tmp_path)
return segments_out, detected
|