Spaces:
Sleeping
Sleeping
File size: 5,099 Bytes
ffdc43e ab061ab ffdc43e ab061ab ffdc43e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os
import torch
import datetime
import gc
from pathlib import Path
from pydub import AudioSegment
from nemo.collections.asr.models import ASRModel
# Global model cache to avoid reloading
_MODEL_CACHE = {}
def get_model(model_name="nvidia/parakeet-tdt-0.6b-v3"):
"""Get or load the ASR model."""
if model_name not in _MODEL_CACHE:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ASRModel.from_pretrained(model_name=model_name)
model.eval()
_MODEL_CACHE[model_name] = (model, device)
return _MODEL_CACHE[model_name]
def format_srt_time(seconds: float) -> str:
"""Converts seconds to SRT time format HH:MM:SS,mmm"""
sanitized_total_seconds = max(0.0, seconds)
delta = datetime.timedelta(seconds=sanitized_total_seconds)
total_int_seconds = int(delta.total_seconds())
hours = total_int_seconds // 3600
remainder_seconds_after_hours = total_int_seconds % 3600
minutes = remainder_seconds_after_hours // 60
seconds_part = remainder_seconds_after_hours % 60
milliseconds = delta.microseconds // 1000
return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
def generate_srt_content(segment_timestamps: list) -> str:
"""Generates SRT formatted string from segment timestamps."""
srt_content = []
for i, ts in enumerate(segment_timestamps):
start_time = format_srt_time(ts['start'])
end_time = format_srt_time(ts['end'])
text = ts['segment']
srt_content.append(str(i + 1))
srt_content.append(f"{start_time} --> {end_time}")
srt_content.append(text)
srt_content.append("")
return "\n".join(srt_content)
def transcribe_audio_to_srt(audio_path: str, output_srt_path: str, model_name="nvidia/parakeet-tdt-0.6b-v3") -> str:
"""
Transcribe an audio file and save it as an SRT file.
Args:
audio_path: Path to the input audio file.
output_srt_path: Path where the SRT file will be saved.
model_name: The NeMo ASR model name to use.
Returns:
Path to the generated SRT file.
"""
if not audio_path or not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
model, device = get_model(model_name)
processed_audio_path = None
long_audio_settings_applied = False
try:
# Load audio to check duration and preprocess if needed
audio = AudioSegment.from_file(audio_path)
duration_sec = audio.duration_seconds
target_sr = 16000
resampled = False
if audio.frame_rate != target_sr:
audio = audio.set_frame_rate(target_sr)
resampled = True
mono = False
if audio.channels > 1:
audio = audio.set_channels(1)
mono = True
if resampled or mono:
# Create a temporary WAV file for NeMo
temp_dir = os.path.dirname(output_srt_path)
audio_stem = Path(audio_path).stem
processed_audio_path = os.path.join(temp_dir, f"{audio_stem}_processed.wav")
audio.export(processed_audio_path, format="wav")
transcribe_path = processed_audio_path
else:
transcribe_path = audio_path
# Transcription logic
model.to(device)
model.to(torch.float32)
if duration_sec > 480: # 8 minutes
try:
model.change_attention_model("rel_pos_local_attn", [256, 256])
model.change_subsampling_conv_chunking_factor(1)
long_audio_settings_applied = True
except Exception as e:
print(f"Warning: Failed to apply long audio settings: {e}")
model.to(torch.bfloat16)
output = model.transcribe([transcribe_path], timestamps=True)
if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp'):
raise RuntimeError("Transcription failed or produced unexpected output format.")
segment_timestamps = output[0].timestamp['segment']
srt_content = generate_srt_content(segment_timestamps)
with open(output_srt_path, "w", encoding="utf-8") as f:
f.write(srt_content)
return output_srt_path
finally:
# Cleanup
if processed_audio_path and os.path.exists(processed_audio_path):
try:
os.remove(processed_audio_path)
except Exception:
pass
# Revert model settings if needed
if 'long_audio_settings_applied' in locals() and long_audio_settings_applied:
try:
model.change_attention_model("rel_pos")
model.change_subsampling_conv_chunking_factor(-1)
except Exception:
pass
# Move model back to CPU to save GPU memory if needed
# model.cpu()
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
|