GitHub Actions
fix: easter-egg ticker now properly works in Gradio app
d6ca3a2
Raw
History Blame
5.37 kB
"""Whisper local STT backend (openai-whisper or faster-whisper)."""
from __future__ import annotations
import asyncio
import contextlib
import tempfile
import time
from typing import Any
class WhisperBackend:
name = "whisper"
def __init__(
self,
model_size: str = "base",
device: str = "auto",
) -> None:
self._model_size = model_size
self._device = device
self._model: Any = None
self._backend_lib: str | None = None # "openai_whisper" or "faster_whisper"
def _resolve_device(self) -> str:
if self._device != "auto":
return self._device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def health(self) -> dict:
# Prefer faster_whisper, fall back to openai whisper
try:
import faster_whisper # noqa: F401
return {
"backend": self.name,
"status": "ok",
"lib": "faster_whisper",
"model": self._model_size,
}
except ImportError:
pass
try:
import whisper # noqa: F401
return {
"backend": self.name,
"status": "ok",
"lib": "openai_whisper",
"model": self._model_size,
}
except ImportError:
pass
return {
"backend": self.name,
"status": "unavailable",
"reason": "Neither openai-whisper nor faster-whisper is installed",
}
def _load_model_sync(self) -> None:
device = self._resolve_device()
try:
from faster_whisper import WhisperModel # type: ignore[import]
self._model = WhisperModel(self._model_size, device=device)
self._backend_lib = "faster_whisper"
return
except ImportError:
pass
import whisper # type: ignore[import]
self._model = whisper.load_model(self._model_size, device=device)
self._backend_lib = "openai_whisper"
async def _ensure_loaded(self) -> None:
if self._model is None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._load_model_sync)
async def transcribe(
self,
audio_bytes: bytes,
language: str | None = None,
translate_to_en: bool = False,
) -> Any:
from hearthnet.services.speech.backends.base import SttResult
await self._ensure_loaded()
t0 = time.monotonic()
loop = asyncio.get_running_loop()
segments, detected_lang = await loop.run_in_executor(
None, self._transcribe_sync, audio_bytes, language, translate_to_en
)
ms = int((time.monotonic() - t0) * 1000)
full_text = " ".join(s.text for s in segments)
return SttResult(
segments=segments,
full_text=full_text,
detected_language=detected_lang or "unknown",
backend=self.name,
ms=ms,
)
def _transcribe_sync(
self,
audio_bytes: bytes,
language: str | None,
translate_to_en: bool,
) -> tuple[list[Any], str | None]:
from hearthnet.services.speech.backends.base import SttSegment
# Write to temp file because whisper expects file path
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
segments_out: list[SttSegment] = []
detected: str | None = None
try:
if self._backend_lib == "faster_whisper":
task = "translate" if translate_to_en else "transcribe"
segs, info = self._model.transcribe(
tmp_path,
language=language,
task=task,
)
detected = info.language
segments_out.extend(
SttSegment(
start_seconds=seg.start,
end_seconds=seg.end,
text=seg.text.strip(),
language=detected,
confidence=None,
)
for seg in segs
)
else:
# openai-whisper
task = "translate" if translate_to_en else "transcribe"
kwargs: dict = {"task": task}
if language:
kwargs["language"] = language
result = self._model.transcribe(tmp_path, **kwargs)
detected = result.get("language")
segments_out.extend(
SttSegment(
start_seconds=float(seg["start"]),
end_seconds=float(seg["end"]),
text=str(seg["text"]).strip(),
language=detected,
confidence=None,
)
for seg in result.get("segments", [])
)
finally:
import os
with contextlib.suppress(OSError):
os.unlink(tmp_path)
return segments_out, detected