| """Script to evaluate a pretrained SpeechBrain model from the 🤗 Hub. |
| |
| Authors |
| * Adel Moumen 2023 <adel.moumen@univ-avignon.fr> |
| * Sanchit Gandhi 2024 <sanchit@huggingface.co> |
| """ |
| import argparse |
| import time |
|
|
| import evaluate |
| from normalizer import data_utils |
| from tqdm import tqdm |
| import torch |
| import speechbrain.inference.ASR as ASR |
| from speechbrain.utils.data_utils import batch_pad_right |
| import os |
|
|
| def get_model( |
| speechbrain_repository: str, |
| speechbrain_pretrained_class_name: str, |
| beam_size: int, |
| ctc_weight_decode: float, |
| **kwargs, |
| ): |
| """Fetch a pretrained SpeechBrain model from the SpeechBrain 🤗 Hub. |
| |
| Arguments |
| --------- |
| speechbrain_repository : str |
| The name of the SpeechBrain repository to fetch the pretrained model from. E.g. `asr-crdnn-rnnlm-librispeech`. |
| speechbrain_pretrained_class_name : str |
| The name of the SpeechBrain pretrained class to fetch. E.g. `EncoderASR`. |
| See: https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/pretrained/interfaces.py |
| beam_size : int |
| Size of the beam for decoding. |
| ctc_weight_decode : float |
| Weight of the CTC prob for decoding with joint CTC/Attn. |
| **kwargs |
| Additional keyword arguments to pass to override the default run options of the pretrained model. |
| |
| Returns |
| ------- |
| SpeechBrain pretrained model |
| The Pretrained model. |
| |
| Example |
| ------- |
| >>> from open_asr_leaderboard.speechbrain.run_eval import get_model |
| >>> model = get_model("asr-crdnn-rnnlm-librispeech", "EncoderASR", device="cuda:0") |
| """ |
|
|
| run_opt_defaults = { |
| "device": "cuda", |
| "data_parallel_count": -1, |
| "data_parallel_backend": False, |
| "distributed_launch": False, |
| "distributed_backend": "nccl", |
| "jit_module_keys": None, |
| "precision": "fp16", |
| } |
|
|
| run_opts = {**run_opt_defaults} |
|
|
| overrides_dict = {} |
| if beam_size: |
| overrides_dict["test_beam_size"] = beam_size |
| if ctc_weight_decode is not None: |
| overrides_dict["ctc_weight_decode"] = ctc_weight_decode |
|
|
| |
| |
| override_lines = [] |
| if ctc_weight_decode is not None and ctc_weight_decode == 0.0: |
| override_lines.append("scorer: null") |
| for k, v in overrides_dict.items(): |
| override_lines.append(f"{k}: {v}") |
| overrides_str = "\n".join(override_lines) if override_lines else None |
|
|
| kwargs = { |
| "source": f"{speechbrain_repository}", |
| "savedir": f"pretrained_models/{speechbrain_repository}", |
| "run_opts": run_opts, |
| } |
| if overrides_str: |
| kwargs["overrides"] = overrides_str |
|
|
| try: |
| model_class = getattr(ASR, speechbrain_pretrained_class_name) |
| except AttributeError: |
| raise AttributeError( |
| f"SpeechBrain Pretrained class: {speechbrain_pretrained_class_name} not found in pretrained.py" |
| ) |
| |
| return model_class.from_hparams(**kwargs) |
|
|
|
|
| def main(args): |
| """Run the evaluation script.""" |
| if args.device == -1: |
| device = "cpu" |
| else: |
| device = f"cuda:{args.device}" |
|
|
| model = get_model( |
| args.source, |
| args.speechbrain_pretrained_class_name, |
| args.beam_size, |
| args.ctc_weight_decode, |
| device=device |
| ) |
| print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters") |
|
|
| def benchmark(batch): |
| |
| audios = [torch.from_numpy(sample["array"]) for sample in batch["audio"]] |
| minibatch_size = len(audios) |
| sampling_rate = batch["audio"][0]["sampling_rate"] |
| batch["audio_length_s"] = [len(sample["array"]) / sampling_rate for sample in batch["audio"]] |
| batch["audio_filepath"] = data_utils.extract_audio_filepaths_from_batch(batch, minibatch_size) |
|
|
| audios, audio_lens = batch_pad_right(audios) |
| audios = audios.to(device) |
| audio_lens = audio_lens.to(device) |
| |
| start_time = time.time() |
| with torch.autocast(device_type="cuda"): |
| predictions, _ = model.transcribe_batch(audios, audio_lens) |
| runtime = time.time() - start_time |
|
|
| batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size] |
|
|
| batch["predictions"] = predictions |
| batch["references"] = batch["original_text"] |
| return batch |
|
|
|
|
| if args.warmup_steps is not None: |
| dataset = data_utils.load_data(args) |
| dataset = data_utils.prepare_data(dataset) |
|
|
| num_warmup_samples = args.warmup_steps * args.batch_size |
| if args.streaming: |
| warmup_dataset = dataset.take(num_warmup_samples) |
| else: |
| warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset)))) |
| warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True)) |
|
|
| for _ in tqdm(warmup_dataset, desc="Warming up..."): |
| continue |
|
|
| dataset = data_utils.load_data(args) |
| if args.max_eval_samples is not None and args.max_eval_samples > 0: |
| print(f"Subsampling dataset to first {args.max_eval_samples} samples!") |
| if args.streaming: |
| dataset = dataset.take(args.max_eval_samples) |
| else: |
| dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) |
| dataset = data_utils.prepare_data(dataset) |
|
|
| dataset = dataset.map( |
| benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"], |
| ) |
|
|
| all_results = { |
| "audio_length_s": [], |
| "transcription_time_s": [], |
| "predictions": [], |
| "references": [], |
| "audio_filepath": [], |
| } |
| result_iter = iter(dataset) |
| for result in tqdm(result_iter, desc="Samples..."): |
| for key in all_results: |
| all_results[key].append(result[key]) |
|
|
| |
| manifest_path = data_utils.write_manifest( |
| all_results["references"], |
| all_results["predictions"], |
| args.source, |
| args.dataset_path, |
| args.dataset, |
| args.split, |
| audio_length=all_results["audio_length_s"], |
| transcription_time=all_results["transcription_time_s"], |
| audio_filepaths=all_results["audio_filepath"], |
| ) |
| print("Results saved at path:", os.path.abspath(manifest_path)) |
|
|
| wer_metric = evaluate.load("wer") |
| norm_refs = [data_utils.normalizer(r) for r in all_results["references"]] |
| norm_preds = [data_utils.normalizer(p) for p in all_results["predictions"]] |
| wer = wer_metric.compute( |
| references=norm_refs, predictions=norm_preds |
| ) |
| wer = round(100 * wer, 2) |
| rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2) |
| print("WER:", wer, "%", "RTFx:", rtfx) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--source", |
| type=str, |
| required=True, |
| help="SpeechBrain model repository. E.g. `asr-crdnn-rnnlm-librispeech`", |
| ) |
|
|
| parser.add_argument( |
| "--speechbrain_pretrained_class_name", |
| type=str, |
| required=True, |
| help="SpeechBrain pretrained class name. E.g. `EncoderASR`", |
| ) |
|
|
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="hf-audio/open-asr-leaderboard", |
| help="Dataset path. By default, it is `hf-audio/open-asr-leaderboard`", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| required=True, |
| help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " |
| "can be found at `https://huggingface.co/datasets/hf-audio/open-asr-leaderboard`", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="test", |
| help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", |
| ) |
| parser.add_argument( |
| "--device", |
| type=int, |
| default=-1, |
| help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=16, |
| help="Number of samples to go through each streamed batch.", |
| ) |
| parser.add_argument( |
| "--max_eval_samples", |
| type=int, |
| default=None, |
| help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
| ) |
| parser.add_argument( |
| "--streaming", |
| action="store_true", |
| help="Stream the dataset lazily over the network instead of downloading it in full before the evaluation. Off by default for reproducible benchmark timings.", |
| ) |
| parser.add_argument( |
| "--warmup_steps", |
| type=int, |
| default=2, |
| help="Number of warm-up steps to run before launching the timed runs.", |
| ) |
| parser.add_argument( |
| "--beam_size", |
| type=int, |
| default=None, |
| help="Beam size for decoding" |
| ) |
| parser.add_argument( |
| "--ctc_weight_decode", |
| type=float, |
| default=None, |
| help="Weight of CTC for joint CTC/Att. decoding. Only pass for models that support it (e.g. EncoderDecoderASR)." |
| ) |
| args = parser.parse_args() |
|
|
| main(args) |
|
|