from __future__ import annotations import os import sys # Ensure repo root (hub cache snapshot) is on path so configuration_htr and htr_convtext resolve _repo_dir = os.path.dirname(os.path.abspath(__file__)) if _repo_dir not in sys.path: sys.path.insert(0, _repo_dir) import re from collections import OrderedDict from dataclasses import dataclass from typing import Optional import torch from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput from .configuration_htr import HTRConfig from .htr_convtext import HTR_ConvText @dataclass class HTRModelOutput(ModelOutput): logits: torch.FloatTensor class HTRConvTextModel(PreTrainedModel): config_class = HTRConfig base_model_prefix = "model" main_input_name = "pixel_values" def __init__(self, config: HTRConfig) -> None: super().__init__(config) self.model = HTR_ConvText( nb_cls=config.vocab_size, img_size=[config.image_max_width, config.image_height], patch_size=config.patch_size, embed_dim=config.embed_dim, depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, conv_kernel_size=config.conv_kernel_size, dropout=config.dropout, drop_path=config.drop_path, down_after=config.down_after, up_after=config.up_after, ds_kernel=config.ds_kernel, max_seq_len=config.max_seq_len, upsample_mode=config.upsample_mode, ) self.post_init() def forward( self, pixel_values: torch.FloatTensor, use_masking: Optional[bool] = None, return_dict: bool = True, **_: dict, ) -> HTRModelOutput | tuple[torch.FloatTensor]: logits = self.model( pixel_values, use_masking=( self.config.use_masking_default if use_masking is None else bool(use_masking) ), ) if not return_dict: return (logits,) return HTRModelOutput(logits=logits) @staticmethod def _extract_state_dict(checkpoint: dict) -> OrderedDict: source_key = "state_dict_ema" if "state_dict_ema" in checkpoint else "model" if source_key not in checkpoint: raise KeyError( "Checkpoint must contain either 'state_dict_ema' or 'model' keys." ) pattern = re.compile(r"^module\.") state = OrderedDict() for key, value in checkpoint[source_key].items(): state[re.sub(pattern, "", key)] = value return state @classmethod def from_original_checkpoint( cls, checkpoint_path: str, config: HTRConfig, map_location: str = "cpu", strict: bool = True, ) -> "HTRConvTextModel": checkpoint = torch.load( checkpoint_path, map_location=map_location, weights_only=False ) state_dict = cls._extract_state_dict(checkpoint) model = cls(config) missing, unexpected = model.model.load_state_dict(state_dict, strict=strict) if strict and (missing or unexpected): raise RuntimeError( f"Failed strict checkpoint load. Missing keys: {missing}; " f"Unexpected keys: {unexpected}" ) return model