bezzam HF Staff commited on
Commit
3cf412c
·
verified ·
1 Parent(s): a9dd49d

Update run_eval.py

Browse files
Files changed (1) hide show
  1. run_eval.py +201 -163
run_eval.py CHANGED
@@ -1,207 +1,224 @@
 
 
 
 
 
 
1
  import argparse
 
2
 
3
- import io
4
- import os
5
- import torch
6
  import evaluate
7
- import soundfile
8
-
9
- from tqdm import tqdm
10
  from normalizer import data_utils
11
- import numpy as np
12
-
13
- from nemo.collections.asr.models import ASRModel
14
- import time
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- wer_metric = evaluate.load("wer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def main(args):
21
-
22
- data_cache_root = args.data_cache_root if args.data_cache_root is not None else os.getcwd()
23
- DATA_CACHE_DIR = os.path.join(data_cache_root, "audio_cache")
24
- DATASET_NAME = args.dataset
25
- SPLIT_NAME = args.split
26
-
27
- CACHE_DIR = os.path.join(DATA_CACHE_DIR, DATASET_NAME, SPLIT_NAME)
28
- if not os.path.exists(CACHE_DIR):
29
- os.makedirs(CACHE_DIR)
30
-
31
- if args.device >= 0:
32
- device = torch.device(f"cuda:{args.device}")
33
- compute_dtype=torch.bfloat16
34
  else:
35
- device = torch.device("cpu")
36
- compute_dtype=torch.float32
37
-
38
- if args.model_id.endswith(".nemo"):
39
- asr_model = ASRModel.restore_from(args.model_id, map_location=device)
40
- else:
41
- asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel
42
-
43
- asr_model.to(compute_dtype)
44
- asr_model.eval()
45
- print(f"Model size: {sum(p.numel() for p in asr_model.parameters()) / 1e9:.2f}B parameters")
46
-
47
- dataset = data_utils.load_data(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- if args.max_eval_samples is not None and args.max_eval_samples > 0:
50
- print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
51
- dataset = dataset.take(args.max_eval_samples)
52
 
53
- # Prepare data FIRST - this casts audio to proper format with "array" and "sampling_rate" keys
54
- dataset = data_utils.prepare_data(dataset)
 
 
55
 
56
- def download_audio_files(batch):
57
 
58
- # download audio files and write the paths, transcriptions and durations to a manifest file
59
- audio_paths = []
60
- original_audio_paths = []
61
- durations = []
62
- file_names = batch.get("file_name", [None] * len(batch["audio"]))
63
 
64
- # Use 'id' column if available, otherwise generate sequential IDs
65
- if "id" in batch:
66
- ids = batch["id"]
67
  else:
68
- # Generate IDs based on index
69
- start_idx = len([f for f in os.listdir(CACHE_DIR) if f.endswith('.wav')]) if os.path.exists(CACHE_DIR) else 0
70
- ids = [f"sample_{start_idx + i}" for i in range(len(batch["audio"]))]
71
-
72
- for id, file_name, audio_sample in zip(ids, file_names, batch["audio"]):
73
-
74
- # first step added here to make ID and wav filenames unique
75
- # several datasets like earnings22 have a hierarchical structure
76
- # for eg. earnings22/test/4432298/281.wav, earnings22/test/4450488/281.wav
77
- # lhotse uses the filename (281.wav) here as unique ID to create and name cuts
78
- # ref: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/collation.py#L186
79
- original_id = id # preserve before sanitization for use as audio_filepath
80
- id = id.replace('/', '_').removesuffix('.wav')
81
-
82
- audio_path = os.path.join(CACHE_DIR, f"{id}.wav")
83
- audio_array = np.float32(audio_sample["array"])
84
- sample_rate = audio_sample["sampling_rate"]
85
-
86
- if not os.path.exists(audio_path):
87
- os.makedirs(os.path.dirname(audio_path), exist_ok=True)
88
- soundfile.write(audio_path, audio_array, sample_rate)
89
-
90
- audio_paths.append(audio_path)
91
- # Prefer the original file_name from the dataset; fall back to the
92
- # sample id (before path-sanitization) so audio_filepath in the
93
- # JSONL is always a meaningful identifier rather than "sample_N".
94
- if file_name is not None:
95
- original_audio_paths.append(os.path.basename(str(file_name)))
96
- else:
97
- original_audio_paths.append(original_id)
98
- durations.append(len(audio_array) / sample_rate)
99
-
100
-
101
- batch["references"] = batch["norm_text"]
102
- batch["audio_filepaths"] = audio_paths
103
- batch["original_audio_filepaths"] = original_audio_paths
104
- batch["durations"] = durations
105
 
106
- return batch
 
107
 
108
- if asr_model.cfg.decoding.strategy != "beam":
109
- asr_model.cfg.decoding.strategy = "greedy_batch"
110
- asr_model.change_decoding_strategy(asr_model.cfg.decoding)
111
-
112
- # prepraing the offline dataset
113
- dataset = dataset.map(download_audio_files, batch_size=args.batch_size, batched=True, remove_columns=["audio"])
 
 
114
 
115
- # Write manifest from daraset batch using json and keys audio_filepath, duration, text
 
 
116
 
117
- all_data = {
118
- "audio_filepaths": [],
119
- "original_audio_filepaths": [],
120
- "durations": [],
121
  "references": [],
 
122
  }
123
-
124
- data_itr = iter(dataset)
125
- for data in tqdm(data_itr, desc="Downloading Samples"):
126
- for key in all_data:
127
- all_data[key].append(data[key])
128
-
129
- # Sort audio_filepaths and references based on durations values
130
- sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True)
131
- all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices]
132
- all_data["original_audio_filepaths"] = [all_data["original_audio_filepaths"][i] for i in sorted_indices]
133
- all_data["references"] = [all_data["references"][i] for i in sorted_indices]
134
- all_data["durations"] = [all_data["durations"][i] for i in sorted_indices]
135
-
136
-
137
- total_time = 0
138
- for _ in range(2): # warmup once and calculate rtf
139
- if _ == 0:
140
- audio_files = all_data["audio_filepaths"][:args.batch_size * 4] # warmup with 4 batches
141
- else:
142
- audio_files = all_data["audio_filepaths"]
143
- start_time = time.time()
144
- with torch.inference_mode(), torch.no_grad():
145
-
146
- if 'canary' in args.model_id and 'v2' not in args.model_id:
147
- pnc = 'nopnc'
148
- else:
149
- pnc = 'pnc'
150
-
151
- if 'canary' in args.model_id:
152
- transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc=pnc, num_workers=1)
153
- else:
154
- transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
155
- end_time = time.time()
156
- if _ == 1:
157
- total_time += end_time - start_time
158
- total_time = total_time
159
-
160
- # normalize transcriptions with English normalizer
161
- if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
162
- transcriptions = transcriptions[0]
163
- predictions = [data_utils.normalizer(pred.text) for pred in transcriptions]
164
-
165
- avg_time = total_time / len(all_data["audio_filepaths"])
166
 
167
  # Write manifest results (WER and RTFX)
168
  manifest_path = data_utils.write_manifest(
169
- all_data["references"],
170
- predictions,
171
- args.model_id,
172
  args.dataset_path,
173
  args.dataset,
174
  args.split,
175
- audio_length=all_data["durations"],
176
- transcription_time=[avg_time] * len(all_data["audio_filepaths"]),
177
- audio_filepaths=all_data["original_audio_filepaths"],
178
  )
179
-
180
  print("Results saved at path:", os.path.abspath(manifest_path))
181
 
182
- wer = wer_metric.compute(references=all_data['references'], predictions=predictions)
 
 
 
183
  wer = round(100 * wer, 2)
184
-
185
- # transcription_time = sum(all_results["transcription_time"])
186
- audio_length = sum(all_data["durations"])
187
- rtfx = audio_length / total_time
188
- rtfx = round(rtfx, 2)
189
-
190
- print("RTFX:", rtfx)
191
- print("WER:", wer, "%")
192
 
193
 
194
  if __name__ == "__main__":
195
  parser = argparse.ArgumentParser()
196
 
197
  parser.add_argument(
198
- "--model_id", type=str, required=True, help="Model identifier. Should be loadable with NVIDIA NeMo.",
 
 
 
199
  )
 
200
  parser.add_argument(
201
- '--dataset_path', type=str, default='hf-audio/open-asr-leaderboard', help='Dataset path. By default, it is `hf-audio/open-asr-leaderboard`'
 
 
 
202
  )
 
203
  parser.add_argument(
204
- '--data_cache_root', type=str, default=None, help='Root directory for audio cache. By default, it is the current working directory.'
 
 
 
205
  )
206
  parser.add_argument(
207
  "--dataset",
@@ -223,7 +240,10 @@ if __name__ == "__main__":
223
  help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
224
  )
225
  parser.add_argument(
226
- "--batch_size", type=int, default=32, help="Number of samples to go through each streamed batch.",
 
 
 
227
  )
228
  parser.add_argument(
229
  "--max_eval_samples",
@@ -236,6 +256,24 @@ if __name__ == "__main__":
236
  action="store_true",
237
  help="Stream the dataset lazily over the network instead of downloading it in full before the evaluation. Off by default for reproducible benchmark timings.",
238
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  args = parser.parse_args()
240
 
241
  main(args)
 
1
+ """Script to evaluate a pretrained SpeechBrain model from the 🤗 Hub.
2
+
3
+ Authors
4
+ * Adel Moumen 2023 <adel.moumen@univ-avignon.fr>
5
+ * Sanchit Gandhi 2024 <sanchit@huggingface.co>
6
+ """
7
  import argparse
8
+ import time
9
 
 
 
 
10
  import evaluate
 
 
 
11
  from normalizer import data_utils
12
+ from tqdm import tqdm
13
+ import torch
14
+ import speechbrain.inference.ASR as ASR
15
+ from speechbrain.utils.data_utils import batch_pad_right
16
+ import os
17
 
18
+ def get_model(
19
+ speechbrain_repository: str,
20
+ speechbrain_pretrained_class_name: str,
21
+ beam_size: int,
22
+ ctc_weight_decode: float,
23
+ **kwargs,
24
+ ):
25
+ """Fetch a pretrained SpeechBrain model from the SpeechBrain 🤗 Hub.
26
+
27
+ Arguments
28
+ ---------
29
+ speechbrain_repository : str
30
+ The name of the SpeechBrain repository to fetch the pretrained model from. E.g. `asr-crdnn-rnnlm-librispeech`.
31
+ speechbrain_pretrained_class_name : str
32
+ The name of the SpeechBrain pretrained class to fetch. E.g. `EncoderASR`.
33
+ See: https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/pretrained/interfaces.py
34
+ beam_size : int
35
+ Size of the beam for decoding.
36
+ ctc_weight_decode : float
37
+ Weight of the CTC prob for decoding with joint CTC/Attn.
38
+ **kwargs
39
+ Additional keyword arguments to pass to override the default run options of the pretrained model.
40
+
41
+ Returns
42
+ -------
43
+ SpeechBrain pretrained model
44
+ The Pretrained model.
45
+
46
+ Example
47
+ -------
48
+ >>> from open_asr_leaderboard.speechbrain.run_eval import get_model
49
+ >>> model = get_model("asr-crdnn-rnnlm-librispeech", "EncoderASR", device="cuda:0")
50
+ """
51
+
52
+ run_opt_defaults = {
53
+ "device": "cuda",
54
+ "data_parallel_count": -1,
55
+ "data_parallel_backend": False,
56
+ "distributed_launch": False,
57
+ "distributed_backend": "nccl",
58
+ "jit_module_keys": None,
59
+ "precision": "fp16",
60
+ }
61
 
62
+ run_opts = {**run_opt_defaults}
63
+
64
+ overrides_dict = {}
65
+ if beam_size:
66
+ overrides_dict["test_beam_size"] = beam_size
67
+ if ctc_weight_decode is not None:
68
+ overrides_dict["ctc_weight_decode"] = ctc_weight_decode
69
+
70
+ # Build overrides as a YAML string so hyperpyyaml applies them during
71
+ # parsing (before class imports), preventing ImportError for missing classes.
72
+ override_lines = []
73
+ if ctc_weight_decode is not None and ctc_weight_decode == 0.0:
74
+ override_lines.append("scorer: null")
75
+ for k, v in overrides_dict.items():
76
+ override_lines.append(f"{k}: {v}")
77
+ overrides_str = "\n".join(override_lines) if override_lines else None
78
+
79
+ kwargs = {
80
+ "source": f"{speechbrain_repository}",
81
+ "savedir": f"pretrained_models/{speechbrain_repository}",
82
+ "run_opts": run_opts,
83
+ }
84
+ if overrides_str:
85
+ kwargs["overrides"] = overrides_str
86
+
87
+ try:
88
+ model_class = getattr(ASR, speechbrain_pretrained_class_name)
89
+ except AttributeError:
90
+ raise AttributeError(
91
+ f"SpeechBrain Pretrained class: {speechbrain_pretrained_class_name} not found in pretrained.py"
92
+ )
93
+
94
+ return model_class.from_hparams(**kwargs)
95
 
96
 
97
  def main(args):
98
+ """Run the evaluation script."""
99
+ if args.device == -1:
100
+ device = "cpu"
 
 
 
 
 
 
 
 
 
 
101
  else:
102
+ device = f"cuda:{args.device}"
103
+
104
+ model = get_model(
105
+ args.source,
106
+ args.speechbrain_pretrained_class_name,
107
+ args.beam_size,
108
+ args.ctc_weight_decode,
109
+ device=device
110
+ )
111
+ print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters")
112
+
113
+ def benchmark(batch):
114
+ # Load audio inputs
115
+ audios = [torch.from_numpy(sample["array"]) for sample in batch["audio"]]
116
+ minibatch_size = len(audios)
117
+ sampling_rate = batch["audio"][0]["sampling_rate"]
118
+ batch["audio_length_s"] = [len(sample["array"]) / sampling_rate for sample in batch["audio"]]
119
+ batch["audio_filepath"] = data_utils.extract_audio_filepaths_from_batch(batch, minibatch_size)
120
+
121
+ audios, audio_lens = batch_pad_right(audios)
122
+ audios = audios.to(device)
123
+ audio_lens = audio_lens.to(device)
124
+
125
+ start_time = time.time()
126
+ with torch.autocast(device_type="cuda"):
127
+ predictions, _ = model.transcribe_batch(audios, audio_lens)
128
+ runtime = time.time() - start_time
129
 
130
+ batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
 
 
131
 
132
+ # normalize transcriptions with English normalizer
133
+ batch["predictions"] = [data_utils.normalizer(pred) for pred in predictions]
134
+ batch["references"] = batch["norm_text"]
135
+ return batch
136
 
 
137
 
138
+ if args.warmup_steps is not None:
139
+ dataset = data_utils.load_data(args)
140
+ dataset = data_utils.prepare_data(dataset)
 
 
141
 
142
+ num_warmup_samples = args.warmup_steps * args.batch_size
143
+ if args.streaming:
144
+ warmup_dataset = dataset.take(num_warmup_samples)
145
  else:
146
+ warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
147
+ warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ for _ in tqdm(warmup_dataset, desc="Warming up..."):
150
+ continue
151
 
152
+ dataset = data_utils.load_data(args)
153
+ if args.max_eval_samples is not None and args.max_eval_samples > 0:
154
+ print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
155
+ if args.streaming:
156
+ dataset = dataset.take(args.max_eval_samples)
157
+ else:
158
+ dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
159
+ dataset = data_utils.prepare_data(dataset)
160
 
161
+ dataset = dataset.map(
162
+ benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
163
+ )
164
 
165
+ all_results = {
166
+ "audio_length_s": [],
167
+ "transcription_time_s": [],
168
+ "predictions": [],
169
  "references": [],
170
+ "audio_filepath": [],
171
  }
172
+ result_iter = iter(dataset)
173
+ for result in tqdm(result_iter, desc="Samples..."):
174
+ for key in all_results:
175
+ all_results[key].append(result[key])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Write manifest results (WER and RTFX)
178
  manifest_path = data_utils.write_manifest(
179
+ all_results["references"],
180
+ all_results["predictions"],
181
+ args.source,
182
  args.dataset_path,
183
  args.dataset,
184
  args.split,
185
+ audio_length=all_results["audio_length_s"],
186
+ transcription_time=all_results["transcription_time_s"],
187
+ audio_filepaths=all_results["audio_filepath"],
188
  )
 
189
  print("Results saved at path:", os.path.abspath(manifest_path))
190
 
191
+ wer_metric = evaluate.load("wer")
192
+ wer = wer_metric.compute(
193
+ references=all_results["references"], predictions=all_results["predictions"]
194
+ )
195
  wer = round(100 * wer, 2)
196
+ rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
197
+ print("WER:", wer, "%", "RTFx:", rtfx)
 
 
 
 
 
 
198
 
199
 
200
  if __name__ == "__main__":
201
  parser = argparse.ArgumentParser()
202
 
203
  parser.add_argument(
204
+ "--source",
205
+ type=str,
206
+ required=True,
207
+ help="SpeechBrain model repository. E.g. `asr-crdnn-rnnlm-librispeech`",
208
  )
209
+
210
  parser.add_argument(
211
+ "--speechbrain_pretrained_class_name",
212
+ type=str,
213
+ required=True,
214
+ help="SpeechBrain pretrained class name. E.g. `EncoderASR`",
215
  )
216
+
217
  parser.add_argument(
218
+ "--dataset_path",
219
+ type=str,
220
+ default="hf-audio/open-asr-leaderboard",
221
+ help="Dataset path. By default, it is `hf-audio/open-asr-leaderboard`",
222
  )
223
  parser.add_argument(
224
  "--dataset",
 
240
  help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
241
  )
242
  parser.add_argument(
243
+ "--batch_size",
244
+ type=int,
245
+ default=16,
246
+ help="Number of samples to go through each streamed batch.",
247
  )
248
  parser.add_argument(
249
  "--max_eval_samples",
 
256
  action="store_true",
257
  help="Stream the dataset lazily over the network instead of downloading it in full before the evaluation. Off by default for reproducible benchmark timings.",
258
  )
259
+ parser.add_argument(
260
+ "--warmup_steps",
261
+ type=int,
262
+ default=2,
263
+ help="Number of warm-up steps to run before launching the timed runs.",
264
+ )
265
+ parser.add_argument(
266
+ "--beam_size",
267
+ type=int,
268
+ default=None,
269
+ help="Beam size for decoding"
270
+ )
271
+ parser.add_argument(
272
+ "--ctc_weight_decode",
273
+ type=float,
274
+ default=None,
275
+ help="Weight of CTC for joint CTC/Att. decoding. Only pass for models that support it (e.g. EncoderDecoderASR)."
276
+ )
277
  args = parser.parse_args()
278
 
279
  main(args)