Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import datetime | |
| import gc | |
| from pathlib import Path | |
| from pydub import AudioSegment | |
| from nemo.collections.asr.models import ASRModel | |
| # Global model cache to avoid reloading | |
| _MODEL_CACHE = {} | |
| def get_model(model_name="nvidia/parakeet-tdt-0.6b-v3"): | |
| """Get or load the ASR model.""" | |
| if model_name not in _MODEL_CACHE: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ASRModel.from_pretrained(model_name=model_name) | |
| model.eval() | |
| _MODEL_CACHE[model_name] = (model, device) | |
| return _MODEL_CACHE[model_name] | |
| def format_srt_time(seconds: float) -> str: | |
| """Converts seconds to SRT time format HH:MM:SS,mmm""" | |
| sanitized_total_seconds = max(0.0, seconds) | |
| delta = datetime.timedelta(seconds=sanitized_total_seconds) | |
| total_int_seconds = int(delta.total_seconds()) | |
| hours = total_int_seconds // 3600 | |
| remainder_seconds_after_hours = total_int_seconds % 3600 | |
| minutes = remainder_seconds_after_hours // 60 | |
| seconds_part = remainder_seconds_after_hours % 60 | |
| milliseconds = delta.microseconds // 1000 | |
| return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}" | |
| def generate_srt_content(segment_timestamps: list) -> str: | |
| """Generates SRT formatted string from segment timestamps.""" | |
| srt_content = [] | |
| for i, ts in enumerate(segment_timestamps): | |
| start_time = format_srt_time(ts['start']) | |
| end_time = format_srt_time(ts['end']) | |
| text = ts['segment'] | |
| srt_content.append(str(i + 1)) | |
| srt_content.append(f"{start_time} --> {end_time}") | |
| srt_content.append(text) | |
| srt_content.append("") | |
| return "\n".join(srt_content) | |
| def transcribe_audio_to_srt(audio_path: str, output_srt_path: str, model_name="nvidia/parakeet-tdt-0.6b-v3") -> str: | |
| """ | |
| Transcribe an audio file and save it as an SRT file. | |
| Args: | |
| audio_path: Path to the input audio file. | |
| output_srt_path: Path where the SRT file will be saved. | |
| model_name: The NeMo ASR model name to use. | |
| Returns: | |
| Path to the generated SRT file. | |
| """ | |
| if not audio_path or not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| model, device = get_model(model_name) | |
| processed_audio_path = None | |
| long_audio_settings_applied = False | |
| try: | |
| # Load audio to check duration and preprocess if needed | |
| audio = AudioSegment.from_file(audio_path) | |
| duration_sec = audio.duration_seconds | |
| target_sr = 16000 | |
| resampled = False | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| resampled = True | |
| mono = False | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| mono = True | |
| if resampled or mono: | |
| # Create a temporary WAV file for NeMo | |
| temp_dir = os.path.dirname(output_srt_path) | |
| audio_stem = Path(audio_path).stem | |
| processed_audio_path = os.path.join(temp_dir, f"{audio_stem}_processed.wav") | |
| audio.export(processed_audio_path, format="wav") | |
| transcribe_path = processed_audio_path | |
| else: | |
| transcribe_path = audio_path | |
| # Transcription logic | |
| model.to(device) | |
| model.to(torch.float32) | |
| if duration_sec > 480: # 8 minutes | |
| try: | |
| model.change_attention_model("rel_pos_local_attn", [256, 256]) | |
| model.change_subsampling_conv_chunking_factor(1) | |
| long_audio_settings_applied = True | |
| except Exception as e: | |
| print(f"Warning: Failed to apply long audio settings: {e}") | |
| model.to(torch.bfloat16) | |
| output = model.transcribe([transcribe_path], timestamps=True) | |
| if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp'): | |
| raise RuntimeError("Transcription failed or produced unexpected output format.") | |
| segment_timestamps = output[0].timestamp['segment'] | |
| srt_content = generate_srt_content(segment_timestamps) | |
| with open(output_srt_path, "w", encoding="utf-8") as f: | |
| f.write(srt_content) | |
| return output_srt_path | |
| finally: | |
| # Cleanup | |
| if processed_audio_path and os.path.exists(processed_audio_path): | |
| try: | |
| os.remove(processed_audio_path) | |
| except Exception: | |
| pass | |
| # Revert model settings if needed | |
| if 'long_audio_settings_applied' in locals() and long_audio_settings_applied: | |
| try: | |
| model.change_attention_model("rel_pos") | |
| model.change_subsampling_conv_chunking_factor(-1) | |
| except Exception: | |
| pass | |
| # Move model back to CPU to save GPU memory if needed | |
| # model.cpu() | |
| gc.collect() | |
| if device == 'cuda': | |
| torch.cuda.empty_cache() | |