from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as TF from hair_mask_dataset.segface_hair_model import SegFaceHairModel IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run hair segmentation inference.") parser.add_argument("--input", required=True, help="Path to the input image.") parser.add_argument("--output-mask", required=True, help="Where to save the predicted binary mask.") parser.add_argument("--output-overlay", default="", help="Optional overlay output path.") parser.add_argument("--checkpoint", default="best.pt", help="Local checkpoint path.") parser.add_argument("--config", default="config.json", help="Local config path.") parser.add_argument("--repo-id", default="", help="Optional Hugging Face repo id to download best.pt/config.json from.") parser.add_argument("--revision", default="main", help="Hub revision to download from when using --repo-id.") parser.add_argument("--threshold", type=float, default=None, help="Override sigmoid threshold.") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Inference device.") return parser.parse_args() def resolve_artifacts(args: argparse.Namespace) -> tuple[Path, Path]: if args.repo_id: checkpoint_path = Path( hf_hub_download(repo_id=args.repo_id, filename="best.pt", revision=args.revision) ) config_path = Path( hf_hub_download(repo_id=args.repo_id, filename="config.json", revision=args.revision) ) return checkpoint_path, config_path return Path(args.checkpoint), Path(args.config) def load_model(checkpoint_path: Path, config_path: Path, device: torch.device) -> tuple[torch.nn.Module, dict]: checkpoint = torch.load(checkpoint_path, map_location="cpu") config = checkpoint.get("config") if config is None: config = json.loads(config_path.read_text(encoding="utf-8")) model = SegFaceHairModel( input_resolution=config["image_size"], model_name=config["model_name"], load_pretrained=False, freeze_backbone=config["freeze_backbone"], lora_rank=config["lora_rank"], lora_alpha=config["lora_alpha"], lora_dropout=config["lora_dropout"], lora_targets=config["lora_targets"], ) model.load_state_dict(checkpoint["model_state"], strict=False) model.to(device) model.eval() return model, config def preprocess(image: Image.Image, image_size: int) -> torch.Tensor: resized = TF.resize(image, [image_size, image_size], interpolation=InterpolationMode.BILINEAR) tensor = TF.to_tensor(resized) tensor = TF.normalize(tensor, IMAGENET_MEAN, IMAGENET_STD) return tensor.unsqueeze(0) def build_overlay(image: Image.Image, mask_u8: np.ndarray) -> Image.Image: image_np = np.asarray(image.convert("RGB"), dtype=np.uint8).copy() overlay = image_np.copy() overlay[mask_u8 > 127] = (overlay[mask_u8 > 127] * 0.4 + np.array([64, 255, 64]) * 0.6).astype(np.uint8) return Image.fromarray(overlay) def main() -> None: args = parse_args() checkpoint_path, config_path = resolve_artifacts(args) device = torch.device(args.device) model, config = load_model(checkpoint_path, config_path, device) threshold = args.threshold if args.threshold is not None else config.get("threshold", 0.5) image_path = Path(args.input) output_mask_path = Path(args.output_mask) output_mask_path.parent.mkdir(parents=True, exist_ok=True) image = Image.open(image_path).convert("RGB") original_size = image.size inputs = preprocess(image, int(config["image_size"])).to(device) with torch.no_grad(): logits = model(inputs)["hair_logits"] probs = torch.sigmoid(logits)[0, 0].cpu().numpy() mask_small = (probs >= threshold).astype(np.uint8) * 255 mask_image = Image.fromarray(mask_small, mode="L").resize(original_size, resample=Image.NEAREST) mask_image.save(output_mask_path) if args.output_overlay: output_overlay_path = Path(args.output_overlay) output_overlay_path.parent.mkdir(parents=True, exist_ok=True) overlay = build_overlay(image, np.asarray(mask_image, dtype=np.uint8)) overlay.save(output_overlay_path) print(f"Saved mask to {output_mask_path}") if args.output_overlay: print(f"Saved overlay to {args.output_overlay}") if __name__ == "__main__": main()