SRT-Processing-Tool / tools /audio_transcriber.py
BiliSakura's picture
Update Parakeet TDT model to version 0.6b-v3 and enhance audio transcription tool
ab061ab
import os
import torch
import datetime
import gc
from pathlib import Path
from pydub import AudioSegment
from nemo.collections.asr.models import ASRModel
# Global model cache to avoid reloading
_MODEL_CACHE = {}
def get_model(model_name="nvidia/parakeet-tdt-0.6b-v3"):
"""Get or load the ASR model."""
if model_name not in _MODEL_CACHE:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ASRModel.from_pretrained(model_name=model_name)
model.eval()
_MODEL_CACHE[model_name] = (model, device)
return _MODEL_CACHE[model_name]
def format_srt_time(seconds: float) -> str:
"""Converts seconds to SRT time format HH:MM:SS,mmm"""
sanitized_total_seconds = max(0.0, seconds)
delta = datetime.timedelta(seconds=sanitized_total_seconds)
total_int_seconds = int(delta.total_seconds())
hours = total_int_seconds // 3600
remainder_seconds_after_hours = total_int_seconds % 3600
minutes = remainder_seconds_after_hours // 60
seconds_part = remainder_seconds_after_hours % 60
milliseconds = delta.microseconds // 1000
return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
def generate_srt_content(segment_timestamps: list) -> str:
"""Generates SRT formatted string from segment timestamps."""
srt_content = []
for i, ts in enumerate(segment_timestamps):
start_time = format_srt_time(ts['start'])
end_time = format_srt_time(ts['end'])
text = ts['segment']
srt_content.append(str(i + 1))
srt_content.append(f"{start_time} --> {end_time}")
srt_content.append(text)
srt_content.append("")
return "\n".join(srt_content)
def transcribe_audio_to_srt(audio_path: str, output_srt_path: str, model_name="nvidia/parakeet-tdt-0.6b-v3") -> str:
"""
Transcribe an audio file and save it as an SRT file.
Args:
audio_path: Path to the input audio file.
output_srt_path: Path where the SRT file will be saved.
model_name: The NeMo ASR model name to use.
Returns:
Path to the generated SRT file.
"""
if not audio_path or not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
model, device = get_model(model_name)
processed_audio_path = None
long_audio_settings_applied = False
try:
# Load audio to check duration and preprocess if needed
audio = AudioSegment.from_file(audio_path)
duration_sec = audio.duration_seconds
target_sr = 16000
resampled = False
if audio.frame_rate != target_sr:
audio = audio.set_frame_rate(target_sr)
resampled = True
mono = False
if audio.channels > 1:
audio = audio.set_channels(1)
mono = True
if resampled or mono:
# Create a temporary WAV file for NeMo
temp_dir = os.path.dirname(output_srt_path)
audio_stem = Path(audio_path).stem
processed_audio_path = os.path.join(temp_dir, f"{audio_stem}_processed.wav")
audio.export(processed_audio_path, format="wav")
transcribe_path = processed_audio_path
else:
transcribe_path = audio_path
# Transcription logic
model.to(device)
model.to(torch.float32)
if duration_sec > 480: # 8 minutes
try:
model.change_attention_model("rel_pos_local_attn", [256, 256])
model.change_subsampling_conv_chunking_factor(1)
long_audio_settings_applied = True
except Exception as e:
print(f"Warning: Failed to apply long audio settings: {e}")
model.to(torch.bfloat16)
output = model.transcribe([transcribe_path], timestamps=True)
if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp'):
raise RuntimeError("Transcription failed or produced unexpected output format.")
segment_timestamps = output[0].timestamp['segment']
srt_content = generate_srt_content(segment_timestamps)
with open(output_srt_path, "w", encoding="utf-8") as f:
f.write(srt_content)
return output_srt_path
finally:
# Cleanup
if processed_audio_path and os.path.exists(processed_audio_path):
try:
os.remove(processed_audio_path)
except Exception:
pass
# Revert model settings if needed
if 'long_audio_settings_applied' in locals() and long_audio_settings_applied:
try:
model.change_attention_model("rel_pos")
model.change_subsampling_conv_chunking_factor(-1)
except Exception:
pass
# Move model back to CPU to save GPU memory if needed
# model.cpu()
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()