bezzam's picture
bezzam HF Staff
Update run_eval.py
a708c8f verified
"""Script to evaluate a pretrained SpeechBrain model from the 🤗 Hub.
Authors
* Adel Moumen 2023 <adel.moumen@univ-avignon.fr>
* Sanchit Gandhi 2024 <sanchit@huggingface.co>
"""
import argparse
import time
import evaluate
from normalizer import data_utils
from tqdm import tqdm
import torch
import speechbrain.inference.ASR as ASR
from speechbrain.utils.data_utils import batch_pad_right
import os
def get_model(
speechbrain_repository: str,
speechbrain_pretrained_class_name: str,
beam_size: int,
ctc_weight_decode: float,
**kwargs,
):
"""Fetch a pretrained SpeechBrain model from the SpeechBrain 🤗 Hub.
Arguments
---------
speechbrain_repository : str
The name of the SpeechBrain repository to fetch the pretrained model from. E.g. `asr-crdnn-rnnlm-librispeech`.
speechbrain_pretrained_class_name : str
The name of the SpeechBrain pretrained class to fetch. E.g. `EncoderASR`.
See: https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/pretrained/interfaces.py
beam_size : int
Size of the beam for decoding.
ctc_weight_decode : float
Weight of the CTC prob for decoding with joint CTC/Attn.
**kwargs
Additional keyword arguments to pass to override the default run options of the pretrained model.
Returns
-------
SpeechBrain pretrained model
The Pretrained model.
Example
-------
>>> from open_asr_leaderboard.speechbrain.run_eval import get_model
>>> model = get_model("asr-crdnn-rnnlm-librispeech", "EncoderASR", device="cuda:0")
"""
run_opt_defaults = {
"device": "cuda",
"data_parallel_count": -1,
"data_parallel_backend": False,
"distributed_launch": False,
"distributed_backend": "nccl",
"jit_module_keys": None,
"precision": "fp16",
}
run_opts = {**run_opt_defaults}
overrides_dict = {}
if beam_size:
overrides_dict["test_beam_size"] = beam_size
if ctc_weight_decode is not None:
overrides_dict["ctc_weight_decode"] = ctc_weight_decode
# Build overrides as a YAML string so hyperpyyaml applies them during
# parsing (before class imports), preventing ImportError for missing classes.
override_lines = []
if ctc_weight_decode is not None and ctc_weight_decode == 0.0:
override_lines.append("scorer: null")
for k, v in overrides_dict.items():
override_lines.append(f"{k}: {v}")
overrides_str = "\n".join(override_lines) if override_lines else None
kwargs = {
"source": f"{speechbrain_repository}",
"savedir": f"pretrained_models/{speechbrain_repository}",
"run_opts": run_opts,
}
if overrides_str:
kwargs["overrides"] = overrides_str
try:
model_class = getattr(ASR, speechbrain_pretrained_class_name)
except AttributeError:
raise AttributeError(
f"SpeechBrain Pretrained class: {speechbrain_pretrained_class_name} not found in pretrained.py"
)
return model_class.from_hparams(**kwargs)
def main(args):
"""Run the evaluation script."""
if args.device == -1:
device = "cpu"
else:
device = f"cuda:{args.device}"
model = get_model(
args.source,
args.speechbrain_pretrained_class_name,
args.beam_size,
args.ctc_weight_decode,
device=device
)
print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters")
def benchmark(batch):
# Load audio inputs
audios = [torch.from_numpy(sample["array"]) for sample in batch["audio"]]
minibatch_size = len(audios)
sampling_rate = batch["audio"][0]["sampling_rate"]
batch["audio_length_s"] = [len(sample["array"]) / sampling_rate for sample in batch["audio"]]
batch["audio_filepath"] = data_utils.extract_audio_filepaths_from_batch(batch, minibatch_size)
audios, audio_lens = batch_pad_right(audios)
audios = audios.to(device)
audio_lens = audio_lens.to(device)
start_time = time.time()
with torch.autocast(device_type="cuda"):
predictions, _ = model.transcribe_batch(audios, audio_lens)
runtime = time.time() - start_time
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
batch["predictions"] = predictions # raw; normalization applied at scoring time
batch["references"] = batch["original_text"] # raw; normalization applied at scoring time
return batch
if args.warmup_steps is not None:
dataset = data_utils.load_data(args)
dataset = data_utils.prepare_data(dataset)
num_warmup_samples = args.warmup_steps * args.batch_size
if args.streaming:
warmup_dataset = dataset.take(num_warmup_samples)
else:
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
for _ in tqdm(warmup_dataset, desc="Warming up..."):
continue
dataset = data_utils.load_data(args)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
if args.streaming:
dataset = dataset.take(args.max_eval_samples)
else:
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
dataset = data_utils.prepare_data(dataset)
dataset = dataset.map(
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
)
all_results = {
"audio_length_s": [],
"transcription_time_s": [],
"predictions": [],
"references": [],
"audio_filepath": [],
}
result_iter = iter(dataset)
for result in tqdm(result_iter, desc="Samples..."):
for key in all_results:
all_results[key].append(result[key])
# Write manifest results (WER and RTFX)
manifest_path = data_utils.write_manifest(
all_results["references"],
all_results["predictions"],
args.source,
args.dataset_path,
args.dataset,
args.split,
audio_length=all_results["audio_length_s"],
transcription_time=all_results["transcription_time_s"],
audio_filepaths=all_results["audio_filepath"],
)
print("Results saved at path:", os.path.abspath(manifest_path))
wer_metric = evaluate.load("wer")
norm_refs = [data_utils.normalizer(r) for r in all_results["references"]]
norm_preds = [data_utils.normalizer(p) for p in all_results["predictions"]]
wer = wer_metric.compute(
references=norm_refs, predictions=norm_preds
)
wer = round(100 * wer, 2)
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
print("WER:", wer, "%", "RTFx:", rtfx)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--source",
type=str,
required=True,
help="SpeechBrain model repository. E.g. `asr-crdnn-rnnlm-librispeech`",
)
parser.add_argument(
"--speechbrain_pretrained_class_name",
type=str,
required=True,
help="SpeechBrain pretrained class name. E.g. `EncoderASR`",
)
parser.add_argument(
"--dataset_path",
type=str,
default="hf-audio/open-asr-leaderboard",
help="Dataset path. By default, it is `hf-audio/open-asr-leaderboard`",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
"can be found at `https://huggingface.co/datasets/hf-audio/open-asr-leaderboard`",
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--streaming",
action="store_true",
help="Stream the dataset lazily over the network instead of downloading it in full before the evaluation. Off by default for reproducible benchmark timings.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=2,
help="Number of warm-up steps to run before launching the timed runs.",
)
parser.add_argument(
"--beam_size",
type=int,
default=None,
help="Beam size for decoding"
)
parser.add_argument(
"--ctc_weight_decode",
type=float,
default=None,
help="Weight of CTC for joint CTC/Att. decoding. Only pass for models that support it (e.g. EncoderDecoderASR)."
)
args = parser.parse_args()
main(args)