File size: 5,099 Bytes
ffdc43e
 
 
 
 
 
 
 
 
 
 
ab061ab
ffdc43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab061ab
ffdc43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import torch
import datetime
import gc
from pathlib import Path
from pydub import AudioSegment
from nemo.collections.asr.models import ASRModel

# Global model cache to avoid reloading
_MODEL_CACHE = {}

def get_model(model_name="nvidia/parakeet-tdt-0.6b-v3"):
    """Get or load the ASR model."""
    if model_name not in _MODEL_CACHE:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = ASRModel.from_pretrained(model_name=model_name)
        model.eval()
        _MODEL_CACHE[model_name] = (model, device)
    return _MODEL_CACHE[model_name]

def format_srt_time(seconds: float) -> str:
    """Converts seconds to SRT time format HH:MM:SS,mmm"""
    sanitized_total_seconds = max(0.0, seconds)
    delta = datetime.timedelta(seconds=sanitized_total_seconds)
    total_int_seconds = int(delta.total_seconds())

    hours = total_int_seconds // 3600
    remainder_seconds_after_hours = total_int_seconds % 3600
    minutes = remainder_seconds_after_hours // 60
    seconds_part = remainder_seconds_after_hours % 60
    milliseconds = delta.microseconds // 1000

    return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"

def generate_srt_content(segment_timestamps: list) -> str:
    """Generates SRT formatted string from segment timestamps."""
    srt_content = []
    for i, ts in enumerate(segment_timestamps):
        start_time = format_srt_time(ts['start'])
        end_time = format_srt_time(ts['end'])
        text = ts['segment']
        srt_content.append(str(i + 1))
        srt_content.append(f"{start_time} --> {end_time}")
        srt_content.append(text)
        srt_content.append("")
    return "\n".join(srt_content)

def transcribe_audio_to_srt(audio_path: str, output_srt_path: str, model_name="nvidia/parakeet-tdt-0.6b-v3") -> str:
    """
    Transcribe an audio file and save it as an SRT file.
    
    Args:
        audio_path: Path to the input audio file.
        output_srt_path: Path where the SRT file will be saved.
        model_name: The NeMo ASR model name to use.
        
    Returns:
        Path to the generated SRT file.
    """
    if not audio_path or not os.path.exists(audio_path):
        raise FileNotFoundError(f"Audio file not found: {audio_path}")

    model, device = get_model(model_name)
    
    processed_audio_path = None
    long_audio_settings_applied = False
    try:
        # Load audio to check duration and preprocess if needed
        audio = AudioSegment.from_file(audio_path)
        duration_sec = audio.duration_seconds
        
        target_sr = 16000
        resampled = False
        if audio.frame_rate != target_sr:
            audio = audio.set_frame_rate(target_sr)
            resampled = True
            
        mono = False
        if audio.channels > 1:
            audio = audio.set_channels(1)
            mono = True
            
        if resampled or mono:
            # Create a temporary WAV file for NeMo
            temp_dir = os.path.dirname(output_srt_path)
            audio_stem = Path(audio_path).stem
            processed_audio_path = os.path.join(temp_dir, f"{audio_stem}_processed.wav")
            audio.export(processed_audio_path, format="wav")
            transcribe_path = processed_audio_path
        else:
            transcribe_path = audio_path

        # Transcription logic
        model.to(device)
        model.to(torch.float32)
        
        if duration_sec > 480:  # 8 minutes
            try:
                model.change_attention_model("rel_pos_local_attn", [256, 256])
                model.change_subsampling_conv_chunking_factor(1)
                long_audio_settings_applied = True
            except Exception as e:
                print(f"Warning: Failed to apply long audio settings: {e}")

        model.to(torch.bfloat16)
        output = model.transcribe([transcribe_path], timestamps=True)

        if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp'):
            raise RuntimeError("Transcription failed or produced unexpected output format.")

        segment_timestamps = output[0].timestamp['segment']
        srt_content = generate_srt_content(segment_timestamps)
        
        with open(output_srt_path, "w", encoding="utf-8") as f:
            f.write(srt_content)
            
        return output_srt_path

    finally:
        # Cleanup
        if processed_audio_path and os.path.exists(processed_audio_path):
            try:
                os.remove(processed_audio_path)
            except Exception:
                pass
        
        # Revert model settings if needed
        if 'long_audio_settings_applied' in locals() and long_audio_settings_applied:
            try:
                model.change_attention_model("rel_pos")
                model.change_subsampling_conv_chunking_factor(-1)
            except Exception:
                pass
                
        # Move model back to CPU to save GPU memory if needed
        # model.cpu()
        gc.collect()
        if device == 'cuda':
            torch.cuda.empty_cache()