from __future__ import annotations import argparse import json from pathlib import Path import cv2 import numpy as np import onnxruntime as ort import torch from configuration_htr import HTRConfig from modeling_htr import HTRConvTextModel from processing_htr import HTRProcessor def preprocess_cv2(img_path: str, max_w: int, max_h: int, stride: int) -> np.ndarray: img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) if img is None: raise ValueError(f"Could not load image: {img_path}") h, w = img.shape scale = max_h / h new_w = max(1, int(w * scale)) img = cv2.resize(img, (new_w, max_h), interpolation=cv2.INTER_AREA) if new_w > max_w: img = img[:, :max_w] new_w = max_w if new_w % stride != 0: aligned_w = ((new_w // stride) + 1) * stride pad_width = aligned_w - new_w img = np.pad( img, ((0, 0), (0, pad_width)), mode="constant", constant_values=0.0 ) arr = img.astype(np.float32) / 255.0 arr = np.expand_dims(arr, axis=-1) return arr.transpose(2, 0, 1).astype(np.float32) def cer(pred: str, gt: str) -> float: import editdistance if len(gt) == 0: return 0.0 if len(pred) == 0 else 1.0 return editdistance.eval(pred, gt) / len(gt) def wer(pred: str, gt: str) -> float: import editdistance pred_words = pred.split() gt_words = gt.split() if len(gt_words) == 0: return 0.0 if len(pred_words) == 0 else 1.0 return editdistance.eval(pred_words, gt_words) / len(gt_words) def build_image_list(images_dir: Path) -> list[Path]: paths: list[Path] = [] for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"): paths.extend(images_dir.rglob(f"*{ext}")) return sorted(paths) def main() -> None: parser = argparse.ArgumentParser( description="Validate PIL migration and ONNX parity." ) parser.add_argument( "--images-dir", required=True, help="Directory with evaluation images." ) parser.add_argument( "--checkpoint-path", required=True, help="Original .pth checkpoint path." ) parser.add_argument( "--onnx-path", default=None, help="Optional ONNX model path for parity." ) parser.add_argument("--alphabet-path", required=True, help="Path to alphabet.json.") parser.add_argument("--max-width", type=int, default=3072) parser.add_argument("--image-height", type=int, default=64) parser.add_argument("--stride", type=int, default=32) parser.add_argument("--device", default="cpu") parser.add_argument("--max-images", type=int, default=200) parser.add_argument("--pixel-mean-diff-thr", type=float, default=0.02) parser.add_argument("--logit-cos-thr", type=float, default=0.995) parser.add_argument("--decode-match-thr", type=float, default=0.99) parser.add_argument("--cer-delta-thr", type=float, default=0.005) parser.add_argument("--wer-delta-thr", type=float, default=0.01) args = parser.parse_args() with open(args.alphabet_path, "r", encoding="utf-8") as f: alph = json.load(f) characters = alph["characters"] cfg = HTRConfig( vocab_size=len(characters) + 1, image_height=args.image_height, image_max_width=args.max_width, width_stride=args.stride, ) model = HTRConvTextModel.from_original_checkpoint( checkpoint_path=args.checkpoint_path, config=cfg, map_location=args.device, strict=True, ).to(args.device) model.eval() processor = HTRProcessor( characters=characters, image_height=args.image_height, image_max_width=args.max_width, width_stride=args.stride, resample="bilinear", ) image_paths = build_image_list(Path(args.images_dir))[: args.max_images] if not image_paths: raise ValueError("No images found.") pix_diffs: list[float] = [] cos_sims: list[float] = [] decode_matches = 0 total = 0 cer_cv_total = 0.0 cer_pil_total = 0.0 wer_cv_total = 0.0 wer_pil_total = 0.0 n_gt = 0 ort_sess = None if args.onnx_path: ort_sess = ort.InferenceSession( args.onnx_path, providers=["CPUExecutionProvider"] ) onnx_cos_sims: list[float] = [] for path in image_paths: cv_arr = preprocess_cv2( str(path), args.max_width, args.image_height, args.stride ) pil_arr = processor(str(path), return_tensors="np")["pixel_values"][0] pix_diffs.append(float(np.mean(np.abs(cv_arr - pil_arr)))) cv_tensor = torch.from_numpy(cv_arr[None, ...]).to(args.device) pil_tensor = torch.from_numpy(pil_arr[None, ...]).to(args.device) with torch.no_grad(): logits_cv = model(pixel_values=cv_tensor).logits.detach().cpu().numpy() logits_pil = model(pixel_values=pil_tensor).logits.detach().cpu().numpy() a = logits_cv.reshape(-1) b = logits_pil.reshape(-1) denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-8 cos_sims.append(float(np.dot(a, b) / denom)) pred_cv = processor.batch_decode(logits_cv, logit_layout="ntc")[0] pred_pil = processor.batch_decode(logits_pil, logit_layout="ntc")[0] decode_matches += int(pred_cv == pred_pil) total += 1 print("Pred CV:", pred_cv) print("Pred PIL:", pred_pil) gt_path = path.with_suffix(".txt") if gt_path.exists(): gt = gt_path.read_text(encoding="utf-8").strip().replace("\n", " ") cer_cv_total += cer(pred_cv, gt) cer_pil_total += cer(pred_pil, gt) wer_cv_total += wer(pred_cv, gt) wer_pil_total += wer(pred_pil, gt) n_gt += 1 if ort_sess is not None: logits_onnx = ort_sess.run(None, {"image": pil_arr[None, ...]})[0] c = logits_onnx.reshape(-1) denom_po = (np.linalg.norm(b) * np.linalg.norm(c)) + 1e-8 onnx_cos_sims.append(float(np.dot(b, c) / denom_po)) decode_match_rate = decode_matches / max(total, 1) mean_pix_diff = float(np.mean(pix_diffs)) mean_cos = float(np.mean(cos_sims)) report: dict[str, float | int | bool] = { "n_images": total, "mean_pixel_abs_diff": mean_pix_diff, "mean_logit_cosine_similarity_cv2_vs_pil": mean_cos, "decode_exact_match_rate_cv2_vs_pil": decode_match_rate, "pixel_diff_ok": mean_pix_diff <= args.pixel_mean_diff_thr, "logit_cos_ok": mean_cos >= args.logit_cos_thr, "decode_match_ok": decode_match_rate >= args.decode_match_thr, } if n_gt > 0: cer_cv = cer_cv_total / n_gt cer_pil = cer_pil_total / n_gt wer_cv = wer_cv_total / n_gt wer_pil = wer_pil_total / n_gt report.update( { "n_with_ground_truth": n_gt, "cer_cv2": cer_cv, "cer_pil": cer_pil, "cer_delta_abs": abs(cer_pil - cer_cv), "wer_cv2": wer_cv, "wer_pil": wer_pil, "wer_delta_abs": abs(wer_pil - wer_cv), "cer_delta_ok": abs(cer_pil - cer_cv) <= args.cer_delta_thr, "wer_delta_ok": abs(wer_pil - wer_cv) <= args.wer_delta_thr, } ) if ort_sess is not None: report["mean_logit_cosine_similarity_pytorch_vs_onnx"] = float( np.mean(onnx_cos_sims) ) print(json.dumps(report, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()