EyeeSEE / checkpoint_loader.py
Nj-1111's picture
Upload 9 files
f9b628d verified
Raw
History Blame
9.74 kB
"""
Phase 3 - Universal Checkpoint Loader
Supports: .pth / .pt (PyTorch) and .safetensors formats
"""
import os
import json
from pathlib import Path
from typing import Optional, Tuple
import torch
from huggingface_hub import hf_hub_download, list_repo_files
def detect_checkpoint_format(checkpoint_path: str) -> str:
"""
Detect format of a single checkpoint file.
Returns: 'safetensors', 'pytorch', or 'unknown'
"""
p = Path(checkpoint_path)
if p.suffix == '.safetensors':
return 'safetensors'
if p.suffix in ('.pth', '.pt'):
return 'pytorch'
return 'unknown'
def load_state_dict_from_file(file_path: str, device: torch.device) -> dict:
"""
Load a state dict from either .safetensors or .pth/.pt file.
Returns the raw state dict.
"""
fmt = detect_checkpoint_format(file_path)
if fmt == 'safetensors':
try:
import safetensors.torch
return safetensors.torch.load_file(file_path, device=str(device))
except ImportError:
raise ImportError("Install safetensors: pip install safetensors")
if fmt == 'pytorch':
return torch.load(file_path, map_location=device, weights_only=False)
raise ValueError(f"Unrecognised file format: {file_path}")
def load_model_weights(model, checkpoint_path: str, device: torch.device) -> None:
"""
Load only model weights from a single .pth/.pt or .safetensors file.
Handles both raw state-dicts and wrapped checkpoints produced by train.py.
"""
raw = load_state_dict_from_file(checkpoint_path, device)
if isinstance(raw, dict) and 'model_state_dict' in raw:
state_dict = raw['model_state_dict']
else:
state_dict = raw
model.load_state_dict(state_dict)
model.to(device)
model.eval()
def load_full_checkpoint_dir(
model,
trainer,
checkpoint_dir: str,
device: torch.device
) -> int:
"""
Load a full training checkpoint from a directory.
Directory layout (new SafeTensors format from train_fixed.py):
checkpoint_dir/
model.safetensors
optimizer.safetensors
scheduler.safetensors
metadata.json
Also handles legacy single-file .pth layout:
checkpoint_epoch_N.pth (contains model + optimizer + scheduler)
Returns:
next_epoch (int): epoch to resume from
"""
checkpoint_dir = Path(checkpoint_dir)
# ── New multi-file SafeTensors layout ──
metadata_path = checkpoint_dir / 'metadata.json'
if metadata_path.exists():
with open(metadata_path) as f:
metadata = json.load(f)
for key, attr in [
('model', ('model', None)),
('optimizer', ('optimizer', trainer.optimizer)),
('scheduler', ('scheduler', trainer.scheduler)),
]:
for ext in ('safetensors', 'pt'):
fp = checkpoint_dir / f'{key}.{ext}'
if fp.exists():
state = load_state_dict_from_file(str(fp), device)
if key == 'model':
model.load_state_dict(state)
model.to(device)
else:
attr[1].load_state_dict(state)
break
trainer.best_val_loss = metadata.get('best_val_loss', float('inf'))
last_epoch = metadata['epoch']
next_epoch = last_epoch + 1
print(f"Checkpoint loaded (epoch {last_epoch})")
print(f" Val Loss : {metadata.get('metrics', {}).get('loss', 'N/A')}")
print(f" Disc Dice: {metadata.get('metrics', {}).get('disc_dice', 'N/A')}")
print(f" Cup Dice : {metadata.get('metrics', {}).get('cup_dice', 'N/A')}")
print(f"Resuming from epoch {next_epoch}")
return next_epoch
# ── Legacy single-file .pth layout ──
pth_files = list(checkpoint_dir.glob('checkpoint_epoch_*.pth'))
if pth_files:
pth_file = sorted(pth_files)[-1]
ckpt = torch.load(str(pth_file), map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.to(device)
if 'optimizer_state_dict' in ckpt:
trainer.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
if 'scheduler_state_dict' in ckpt:
trainer.scheduler.load_state_dict(ckpt['scheduler_state_dict'])
trainer.best_val_loss = ckpt.get(
'best_val_loss', ckpt.get('metrics', {}).get('loss', float('inf'))
)
last_epoch = ckpt['epoch']
next_epoch = last_epoch + 1
print(f"Legacy checkpoint loaded (epoch {last_epoch})")
print(f"Resuming from epoch {next_epoch}")
return next_epoch
raise FileNotFoundError(f"No valid checkpoint found in: {checkpoint_dir}")
# ─────────────────────────────────────────────────────────────────────────────
# HuggingFace helpers
# ─────────────────────────────────────────────────────────────────────────────
def list_hf_checkpoints(repo_id: str, token: Optional[str] = None) -> list:
"""
Return all checkpoint folder names in Nj-1111/EyeeSEE/checkpoints/
sorted by epoch number (ascending).
Works for both:
checkpoints/epoch_050/ (new format)
checkpoints/checkpoint_epoch_50.pth (legacy format)
"""
import re
token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN')
files = list_repo_files(repo_id=repo_id, token=token)
epoch_map = {}
for f in files:
# New directory format
m = re.search(r'checkpoints/epoch_(\d+)/', f)
if m:
ep = int(m.group(1))
folder = f'checkpoints/epoch_{ep:03d}'
epoch_map[ep] = folder
continue
# Legacy single-file format
m = re.search(r'checkpoints/checkpoint_epoch_(\d+)\.pth', f)
if m:
ep = int(m.group(1))
epoch_map[ep] = f
return [(ep, path) for ep, path in sorted(epoch_map.items())]
def download_checkpoint_for_inference(
repo_id: str,
epoch: Optional[int] = None,
token: Optional[str] = None,
local_dir: str = '/kaggle/working/ckpt_inference'
) -> Tuple[str, str]:
"""
Download a checkpoint from HuggingFace for inference.
Args:
repo_id: HF repo ID
epoch: Specific epoch to download. None β†’ downloads latest.
token: HF token
local_dir: Where to save files
Returns:
(local_path, fmt) where fmt is 'safetensors_dir' or 'pytorch_file'
"""
token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN')
checkpoints = list_hf_checkpoints(repo_id, token)
if not checkpoints:
raise FileNotFoundError(f"No checkpoints in {repo_id}")
if epoch is None:
ep, ckpt_ref = checkpoints[-1]
else:
matched = [(e, p) for e, p in checkpoints if e == epoch]
if not matched:
raise ValueError(f"Epoch {epoch} not found. Available: {[e for e, _ in checkpoints]}")
ep, ckpt_ref = matched[0]
local_dir = Path(local_dir) / f'epoch_{ep:03d}'
local_dir.mkdir(parents=True, exist_ok=True)
# SafeTensors directory format
if ckpt_ref.endswith('/') or not ckpt_ref.endswith('.pth'):
all_files = list_repo_files(repo_id=repo_id, token=token)
ckpt_files = [f for f in all_files if f.startswith(ckpt_ref)]
for hf_filename in ckpt_files:
local_path = hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
token=token,
local_dir=str(local_dir)
)
print(f"Downloaded checkpoint (epoch {ep}) β†’ {local_dir}")
return str(local_dir), 'safetensors_dir'
# Legacy single .pth file
local_path = hf_hub_download(
repo_id=repo_id,
filename=ckpt_ref,
token=token,
local_dir=str(local_dir)
)
print(f"Downloaded checkpoint (epoch {ep}) β†’ {local_path}")
return local_path, 'pytorch_file'
def load_model_for_inference(model, repo_id, epoch=None, device=None, token=None):
"""
Download checkpoint from HF and load weights into model.
Handles .pt and .pth, detects epoch automatically.
"""
import os, re
from huggingface_hub import hf_hub_download, list_repo_files
token = token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN')
device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
files = list_repo_files(repo_id=repo_id, token=token)
epochs = {}
for f in files:
m = re.search(r'checkpoints/checkpoint_epoch_(\d+)\.pt(?:h?)$', f)
if m:
epochs[int(m.group(1))] = f
if not epochs:
raise FileNotFoundError(f"No checkpoints found in {repo_id}")
target_epoch = epoch if epoch is not None else max(epochs)
if target_epoch not in epochs:
raise ValueError(f"Epoch {target_epoch} not found. Available: {sorted(epochs)}")
ckpt_file = hf_hub_download(
repo_id=repo_id,
filename=epochs[target_epoch],
token=token
)
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt['model_state_dict']
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print(f"Model loaded from epoch {target_epoch} on {device}")
return model