import os import torch import datetime import gc from pathlib import Path from pydub import AudioSegment from nemo.collections.asr.models import ASRModel # Global model cache to avoid reloading _MODEL_CACHE = {} def get_model(model_name="nvidia/parakeet-tdt-0.6b-v3"): """Get or load the ASR model.""" if model_name not in _MODEL_CACHE: device = "cuda" if torch.cuda.is_available() else "cpu" model = ASRModel.from_pretrained(model_name=model_name) model.eval() _MODEL_CACHE[model_name] = (model, device) return _MODEL_CACHE[model_name] def format_srt_time(seconds: float) -> str: """Converts seconds to SRT time format HH:MM:SS,mmm""" sanitized_total_seconds = max(0.0, seconds) delta = datetime.timedelta(seconds=sanitized_total_seconds) total_int_seconds = int(delta.total_seconds()) hours = total_int_seconds // 3600 remainder_seconds_after_hours = total_int_seconds % 3600 minutes = remainder_seconds_after_hours // 60 seconds_part = remainder_seconds_after_hours % 60 milliseconds = delta.microseconds // 1000 return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}" def generate_srt_content(segment_timestamps: list) -> str: """Generates SRT formatted string from segment timestamps.""" srt_content = [] for i, ts in enumerate(segment_timestamps): start_time = format_srt_time(ts['start']) end_time = format_srt_time(ts['end']) text = ts['segment'] srt_content.append(str(i + 1)) srt_content.append(f"{start_time} --> {end_time}") srt_content.append(text) srt_content.append("") return "\n".join(srt_content) def transcribe_audio_to_srt(audio_path: str, output_srt_path: str, model_name="nvidia/parakeet-tdt-0.6b-v3") -> str: """ Transcribe an audio file and save it as an SRT file. Args: audio_path: Path to the input audio file. output_srt_path: Path where the SRT file will be saved. model_name: The NeMo ASR model name to use. Returns: Path to the generated SRT file. """ if not audio_path or not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") model, device = get_model(model_name) processed_audio_path = None long_audio_settings_applied = False try: # Load audio to check duration and preprocess if needed audio = AudioSegment.from_file(audio_path) duration_sec = audio.duration_seconds target_sr = 16000 resampled = False if audio.frame_rate != target_sr: audio = audio.set_frame_rate(target_sr) resampled = True mono = False if audio.channels > 1: audio = audio.set_channels(1) mono = True if resampled or mono: # Create a temporary WAV file for NeMo temp_dir = os.path.dirname(output_srt_path) audio_stem = Path(audio_path).stem processed_audio_path = os.path.join(temp_dir, f"{audio_stem}_processed.wav") audio.export(processed_audio_path, format="wav") transcribe_path = processed_audio_path else: transcribe_path = audio_path # Transcription logic model.to(device) model.to(torch.float32) if duration_sec > 480: # 8 minutes try: model.change_attention_model("rel_pos_local_attn", [256, 256]) model.change_subsampling_conv_chunking_factor(1) long_audio_settings_applied = True except Exception as e: print(f"Warning: Failed to apply long audio settings: {e}") model.to(torch.bfloat16) output = model.transcribe([transcribe_path], timestamps=True) if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp'): raise RuntimeError("Transcription failed or produced unexpected output format.") segment_timestamps = output[0].timestamp['segment'] srt_content = generate_srt_content(segment_timestamps) with open(output_srt_path, "w", encoding="utf-8") as f: f.write(srt_content) return output_srt_path finally: # Cleanup if processed_audio_path and os.path.exists(processed_audio_path): try: os.remove(processed_audio_path) except Exception: pass # Revert model settings if needed if 'long_audio_settings_applied' in locals() and long_audio_settings_applied: try: model.change_attention_model("rel_pos") model.change_subsampling_conv_chunking_factor(-1) except Exception: pass # Move model back to CPU to save GPU memory if needed # model.cpu() gc.collect() if device == 'cuda': torch.cuda.empty_cache()