from __future__ import annotations import argparse import json import os from configuration_htr import HTRConfig from modeling_htr import HTRConvTextModel from processing_htr import HTRProcessor def main() -> None: parser = argparse.ArgumentParser( description="Convert original HTR-ConvText checkpoint to Hugging Face artifacts." ) parser.add_argument("--checkpoint-path", required=True, help="Path to original .pth checkpoint.") parser.add_argument("--alphabet-path", required=True, help="Path to alphabet.json with characters.") parser.add_argument("--output-dir", required=True, help="Output directory for HF artifacts.") parser.add_argument("--image-height", type=int, default=64) parser.add_argument("--image-max-width", type=int, default=3072) parser.add_argument("--width-stride", type=int, default=32) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) with open(args.alphabet_path, "r", encoding="utf-8") as f: alphabet_data = json.load(f) characters = alphabet_data["characters"] config = HTRConfig( vocab_size=len(characters) + 1, image_height=args.image_height, image_max_width=args.image_max_width, width_stride=args.width_stride, ) model = HTRConvTextModel.from_original_checkpoint( checkpoint_path=args.checkpoint_path, config=config, map_location="cpu", strict=True, ) model.save_pretrained(args.output_dir, safe_serialization=True) processor = HTRProcessor( characters=characters, image_height=args.image_height, image_max_width=args.image_max_width, width_stride=args.width_stride, resample="bilinear", ) processor.save_pretrained(args.output_dir) config_path = os.path.join(args.output_dir, "config.json") with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) cfg["auto_map"] = { "AutoConfig": "configuration_htr.HTRConfig", "AutoModel": "modeling_htr.HTRConvTextModel", "AutoProcessor": "processing_htr.HTRProcessor", } with open(config_path, "w", encoding="utf-8") as f: json.dump(cfg, f, ensure_ascii=False, indent=2) print(f"HF artifacts created in: {args.output_dir}") if __name__ == "__main__": main()