from __future__ import annotations import argparse import json import os import torch from configuration_htr import HTRConfig from modeling_htr import HTRConvTextModel class InferenceWrapper(torch.nn.Module): def __init__(self, model: HTRConvTextModel): super().__init__() self.model = model def forward(self, image: torch.Tensor) -> torch.Tensor: return self.model(pixel_values=image, return_dict=True).logits def main() -> None: parser = argparse.ArgumentParser(description="Export HF model to ONNX.") parser.add_argument( "--hf-model-dir", required=True, help="Directory with HF artifacts." ) parser.add_argument("--output-dir", default="onnx", help="ONNX output directory.") parser.add_argument( "--onnx-name", default="model.onnx", help="Output ONNX model filename." ) parser.add_argument( "--dummy-width", type=int, default=3072, help="Dummy input width for export." ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) model = HTRConvTextModel.from_pretrained( args.hf_model_dir, trust_remote_code=True, low_cpu_mem_usage=False, torch_dtype=torch.float32, ) model.eval() cfg = HTRConfig.from_pretrained(args.hf_model_dir, trust_remote_code=True) dummy = torch.randn(1, 1, cfg.image_height, args.dummy_width) wrapped = InferenceWrapper(model) onnx_path = os.path.join(args.output_dir, args.onnx_name) torch.onnx.export( wrapped, dummy, onnx_path, input_names=["image"], output_names=["logits"], dynamic_axes={ "image": {0: "batch", 3: "width"}, "logits": {0: "batch", 1: "timesteps"}, }, opset_version=18, do_constant_folding=True, export_params=True, ) alphabet_path = os.path.join(args.hf_model_dir, "alphabet.json") if os.path.isfile(alphabet_path): with open(alphabet_path, "r", encoding="utf-8") as f: alph = json.load(f) with open( os.path.join(args.output_dir, "alphabet.json"), "w", encoding="utf-8" ) as f: json.dump(alph, f, ensure_ascii=False, indent=2) print(f"ONNX exported to: {onnx_path}") if __name__ == "__main__": main()