| """ |
| Phase 3 - Universal Checkpoint Loader |
| Supports: .pth / .pt (PyTorch) and .safetensors formats |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from typing import Optional, Tuple |
| import torch |
| from huggingface_hub import hf_hub_download, list_repo_files |
|
|
|
|
| def detect_checkpoint_format(checkpoint_path: str) -> str: |
| """ |
| Detect format of a single checkpoint file. |
| Returns: 'safetensors', 'pytorch', or 'unknown' |
| """ |
| p = Path(checkpoint_path) |
| if p.suffix == '.safetensors': |
| return 'safetensors' |
| if p.suffix in ('.pth', '.pt'): |
| return 'pytorch' |
| return 'unknown' |
|
|
|
|
| def load_state_dict_from_file(file_path: str, device: torch.device) -> dict: |
| """ |
| Load a state dict from either .safetensors or .pth/.pt file. |
| Returns the raw state dict. |
| """ |
| fmt = detect_checkpoint_format(file_path) |
|
|
| if fmt == 'safetensors': |
| try: |
| import safetensors.torch |
| return safetensors.torch.load_file(file_path, device=str(device)) |
| except ImportError: |
| raise ImportError("Install safetensors: pip install safetensors") |
|
|
| if fmt == 'pytorch': |
| return torch.load(file_path, map_location=device, weights_only=False) |
|
|
| raise ValueError(f"Unrecognised file format: {file_path}") |
|
|
|
|
| def load_model_weights(model, checkpoint_path: str, device: torch.device) -> None: |
| """ |
| Load only model weights from a single .pth/.pt or .safetensors file. |
| Handles both raw state-dicts and wrapped checkpoints produced by train.py. |
| """ |
| raw = load_state_dict_from_file(checkpoint_path, device) |
|
|
| if isinstance(raw, dict) and 'model_state_dict' in raw: |
| state_dict = raw['model_state_dict'] |
| else: |
| state_dict = raw |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
|
|
|
|
| def load_full_checkpoint_dir( |
| model, |
| trainer, |
| checkpoint_dir: str, |
| device: torch.device |
| ) -> int: |
| """ |
| Load a full training checkpoint from a directory. |
| |
| Directory layout (new SafeTensors format from train_fixed.py): |
| checkpoint_dir/ |
| model.safetensors |
| optimizer.safetensors |
| scheduler.safetensors |
| metadata.json |
| |
| Also handles legacy single-file .pth layout: |
| checkpoint_epoch_N.pth (contains model + optimizer + scheduler) |
| |
| Returns: |
| next_epoch (int): epoch to resume from |
| """ |
| checkpoint_dir = Path(checkpoint_dir) |
|
|
| |
| metadata_path = checkpoint_dir / 'metadata.json' |
| if metadata_path.exists(): |
| with open(metadata_path) as f: |
| metadata = json.load(f) |
|
|
| for key, attr in [ |
| ('model', ('model', None)), |
| ('optimizer', ('optimizer', trainer.optimizer)), |
| ('scheduler', ('scheduler', trainer.scheduler)), |
| ]: |
| for ext in ('safetensors', 'pt'): |
| fp = checkpoint_dir / f'{key}.{ext}' |
| if fp.exists(): |
| state = load_state_dict_from_file(str(fp), device) |
| if key == 'model': |
| model.load_state_dict(state) |
| model.to(device) |
| else: |
| attr[1].load_state_dict(state) |
| break |
|
|
| trainer.best_val_loss = metadata.get('best_val_loss', float('inf')) |
| last_epoch = metadata['epoch'] |
| next_epoch = last_epoch + 1 |
|
|
| print(f"Checkpoint loaded (epoch {last_epoch})") |
| print(f" Val Loss : {metadata.get('metrics', {}).get('loss', 'N/A')}") |
| print(f" Disc Dice: {metadata.get('metrics', {}).get('disc_dice', 'N/A')}") |
| print(f" Cup Dice : {metadata.get('metrics', {}).get('cup_dice', 'N/A')}") |
| print(f"Resuming from epoch {next_epoch}") |
| return next_epoch |
|
|
| |
| pth_files = list(checkpoint_dir.glob('checkpoint_epoch_*.pth')) |
| if pth_files: |
| pth_file = sorted(pth_files)[-1] |
| ckpt = torch.load(str(pth_file), map_location=device, weights_only=False) |
|
|
| model.load_state_dict(ckpt['model_state_dict']) |
| model.to(device) |
|
|
| if 'optimizer_state_dict' in ckpt: |
| trainer.optimizer.load_state_dict(ckpt['optimizer_state_dict']) |
| if 'scheduler_state_dict' in ckpt: |
| trainer.scheduler.load_state_dict(ckpt['scheduler_state_dict']) |
|
|
| trainer.best_val_loss = ckpt.get( |
| 'best_val_loss', ckpt.get('metrics', {}).get('loss', float('inf')) |
| ) |
|
|
| last_epoch = ckpt['epoch'] |
| next_epoch = last_epoch + 1 |
| print(f"Legacy checkpoint loaded (epoch {last_epoch})") |
| print(f"Resuming from epoch {next_epoch}") |
| return next_epoch |
|
|
| raise FileNotFoundError(f"No valid checkpoint found in: {checkpoint_dir}") |
|
|
|
|
| |
| |
| |
|
|
| def list_hf_checkpoints(repo_id: str, token: Optional[str] = None) -> list: |
| """ |
| Return all checkpoint folder names in Nj-1111/EyeeSEE/checkpoints/ |
| sorted by epoch number (ascending). |
| |
| Works for both: |
| checkpoints/epoch_050/ (new format) |
| checkpoints/checkpoint_epoch_50.pth (legacy format) |
| """ |
| import re |
| token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN') |
|
|
| files = list_repo_files(repo_id=repo_id, token=token) |
|
|
| epoch_map = {} |
|
|
| for f in files: |
| |
| m = re.search(r'checkpoints/epoch_(\d+)/', f) |
| if m: |
| ep = int(m.group(1)) |
| folder = f'checkpoints/epoch_{ep:03d}' |
| epoch_map[ep] = folder |
| continue |
|
|
| |
| m = re.search(r'checkpoints/checkpoint_epoch_(\d+)\.pth', f) |
| if m: |
| ep = int(m.group(1)) |
| epoch_map[ep] = f |
|
|
| return [(ep, path) for ep, path in sorted(epoch_map.items())] |
|
|
|
|
| def download_checkpoint_for_inference( |
| repo_id: str, |
| epoch: Optional[int] = None, |
| token: Optional[str] = None, |
| local_dir: str = '/kaggle/working/ckpt_inference' |
| ) -> Tuple[str, str]: |
| """ |
| Download a checkpoint from HuggingFace for inference. |
| |
| Args: |
| repo_id: HF repo ID |
| epoch: Specific epoch to download. None β downloads latest. |
| token: HF token |
| local_dir: Where to save files |
| |
| Returns: |
| (local_path, fmt) where fmt is 'safetensors_dir' or 'pytorch_file' |
| """ |
| token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN') |
| checkpoints = list_hf_checkpoints(repo_id, token) |
|
|
| if not checkpoints: |
| raise FileNotFoundError(f"No checkpoints in {repo_id}") |
|
|
| if epoch is None: |
| ep, ckpt_ref = checkpoints[-1] |
| else: |
| matched = [(e, p) for e, p in checkpoints if e == epoch] |
| if not matched: |
| raise ValueError(f"Epoch {epoch} not found. Available: {[e for e, _ in checkpoints]}") |
| ep, ckpt_ref = matched[0] |
|
|
| local_dir = Path(local_dir) / f'epoch_{ep:03d}' |
| local_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if ckpt_ref.endswith('/') or not ckpt_ref.endswith('.pth'): |
| all_files = list_repo_files(repo_id=repo_id, token=token) |
| ckpt_files = [f for f in all_files if f.startswith(ckpt_ref)] |
|
|
| for hf_filename in ckpt_files: |
| local_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=hf_filename, |
| token=token, |
| local_dir=str(local_dir) |
| ) |
|
|
| print(f"Downloaded checkpoint (epoch {ep}) β {local_dir}") |
| return str(local_dir), 'safetensors_dir' |
|
|
| |
| local_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=ckpt_ref, |
| token=token, |
| local_dir=str(local_dir) |
| ) |
| print(f"Downloaded checkpoint (epoch {ep}) β {local_path}") |
| return local_path, 'pytorch_file' |
|
|
|
|
| def load_model_for_inference(model, repo_id, epoch=None, device=None, token=None): |
| """ |
| Download checkpoint from HF and load weights into model. |
| Handles .pt and .pth, detects epoch automatically. |
| """ |
| import os, re |
| from huggingface_hub import hf_hub_download, list_repo_files |
|
|
| token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN') |
| device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| files = list_repo_files(repo_id=repo_id, token=token) |
| epochs = {} |
| for f in files: |
| m = re.search(r'checkpoints/checkpoint_epoch_(\d+)\.pt(?:h?)$', f) |
| if m: |
| epochs[int(m.group(1))] = f |
|
|
| if not epochs: |
| raise FileNotFoundError(f"No checkpoints found in {repo_id}") |
|
|
| target_epoch = epoch if epoch is not None else max(epochs) |
| if target_epoch not in epochs: |
| raise ValueError(f"Epoch {target_epoch} not found. Available: {sorted(epochs)}") |
|
|
| ckpt_file = hf_hub_download( |
| repo_id=repo_id, |
| filename=epochs[target_epoch], |
| token=token |
| ) |
|
|
| ckpt = torch.load(ckpt_file, map_location=device, weights_only=False) |
| state_dict = ckpt['model_state_dict'] |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| print(f"Model loaded from epoch {target_epoch} on {device}") |
| return model |
|
|