| import enum |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| import transformers |
| import wavlm_phoneme_fr_it |
|
|
| SAMPLING_RATE = 16_000 |
|
|
| class Languages(enum.Enum): |
| FR = 0 |
| IT = 1 |
|
|
|
|
| class Scoring(enum.Enum): |
| NUMBER_CORRECT = 0 |
| PHONEME_DELETION = 1 |
|
|
|
|
| def get_model(): |
| checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it" |
| processor = transformers.AutoProcessor.from_pretrained( |
| checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" |
| ) |
|
|
| model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained( |
| checkpoint |
| ) |
| return model, processor |
|
|
|
|
| def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE): |
| """Convert audio to the correct format and sample rate""" |
| if audio_data is None: |
| return None |
|
|
| sample_rate, audio = audio_data |
|
|
| |
| if len(audio.shape) > 1: |
| audio = audio.mean(axis=1) |
|
|
| |
| if sample_rate != target_sample_rate: |
| audio_tensor = torch.from_numpy(audio).float().unsqueeze(0) |
| resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor) |
| audio = resampled.squeeze(0).numpy() |
|
|
| |
| audio = audio.astype(np.float32) |
| if np.max(np.abs(audio)) > 0: |
| audio = audio / np.max(np.abs(audio)) |
|
|
| return audio |
|
|
|
|
| def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE, language=Languages.FR): |
| """Prepare inputs for the model""" |
| inputs = processor( |
| audio, |
| sampling_rate=sampling_rate, |
| return_tensors="pt", |
| padding=True |
| ) |
|
|
| |
| language_code = 0. if language is Languages.FR else 1. |
| inputs["language"] = torch.tensor([[language_code]], dtype=torch.float32) |
|
|
| return inputs |
|
|
|
|
| def run_inference(model, inputs): |
| """Run model inference and return predictions""" |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| predicted_ids = torch.argmax(logits, dim=-1) |
|
|
| return outputs, predicted_ids |
|
|
|
|
| def decode_transcription(processor, predicted_ids): |
| """Decode predicted IDs to text""" |
| return processor.batch_decode(predicted_ids)[0] |
|
|
|
|
| def compare_with_target(transcription, target_word): |
| """Compare transcription with target word and return formatted result""" |
| result = f"**Transcription:** {transcription}\n\n" |
|
|
| if target_word and target_word.strip(): |
| target_clean = target_word.strip().lower() |
| transcription_clean = transcription.lower().replace("[pad]", "").strip() |
|
|
| if target_clean in transcription_clean: |
| result += f"✅ **Match found!** The target word '{target_word}' appears in the transcription." |
| else: |
| result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription." |
|
|
| return result |
|
|