siik commited on
Commit
7a4e54a
·
verified ·
1 Parent(s): ed430ba

Upload SegFace hair segmentation model bundle

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/epoch_010.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/training_curves.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ pipeline_tag: image-segmentation
4
+ tags:
5
+ - pytorch
6
+ - image-segmentation
7
+ - hair-segmentation
8
+ - segface
9
+ - korean-hairstyle
10
+ - custom-code
11
+ language:
12
+ - ko
13
+ ---
14
+
15
+ # SegFace_k-hair
16
+
17
+ Personal Hugging Face model repository for a custom PyTorch hair-only segmentation checkpoint trained on a K-Hairstyle based AIHub subset.
18
+
19
+ ## Model Summary
20
+
21
+ - Backbone: `swin_base`
22
+ - Input size: `512x512`
23
+ - Freeze backbone: `True`
24
+ - LoRA: `rank=8`, `alpha=16.0`, `dropout=0.05`
25
+ - Threshold used during validation: `0.5`
26
+ - Train / Val split used for this run: `50,000` / `5,000`
27
+
28
+ ## Validation Metrics
29
+
30
+ These are validation metrics from the training run. A separate independent hold-out test split has not been populated yet, so treat these as validation-only results.
31
+
32
+ | Metric | Value |
33
+ | --- | ---: |
34
+ | Best epoch | 7 |
35
+ | Val IoU | 0.9487 |
36
+ | Val Dice | 0.9736 |
37
+ | Val Precision | 0.9723 |
38
+ | Val Recall | 0.9751 |
39
+ | Epochs completed | 10 |
40
+ | Avg epoch time (sec) | 3546.45 |
41
+
42
+ ## Bundle Contents
43
+
44
+ - `best.pt`: inference checkpoint
45
+ - `config.json`: training-time model config
46
+ - `training_run_summary.json`: run summary and validation metrics
47
+ - `inference.py`: local / Hub inference example
48
+ - `requirements.txt`: minimal runtime dependencies
49
+ - `hair_mask_dataset/`, `models/`: custom model code required to load the checkpoint
50
+
51
+ ## Inference
52
+
53
+ Run locally from the root of this model bundle:
54
+
55
+ ```bash
56
+ python inference.py
57
+ --checkpoint best.pt
58
+ --input path/to/input.jpg
59
+ --output-mask output_mask.png
60
+ --output-overlay output_overlay.png
61
+ ```
62
+
63
+ You can also load directly from the Hugging Face Hub after uploading:
64
+
65
+ ```bash
66
+ python inference.py
67
+ --repo-id your-username/SegFace_k-hair
68
+ --input path/to/input.jpg
69
+ --output-mask output_mask.png
70
+ --output-overlay output_overlay.png
71
+ ```
72
+
73
+ ## Notes
74
+
75
+ - This repo contains custom code and a raw PyTorch checkpoint, not a Transformers-format model.
76
+ - Preprocessing expects RGB input, resize to `512`, ImageNet normalization, and sigmoid threshold `0.5`.
77
+ - Before making the repository public, verify whether your AIHub / K-Hairstyle data usage terms allow public redistribution of derived model weights.
78
+
79
+ ## Training Artifacts
80
+
81
+ ![Training Curve](assets/training_curves.png)
82
+
83
+ ![Preview](assets/epoch_010.png)
assets/epoch_010.png ADDED

Git LFS Details

  • SHA256: 5013944b2a0d2a7004783b1136fdeaa9fb68c431d132c98c2130a8d827764555
  • Pointer size: 131 Bytes
  • Size of remote file: 451 kB
assets/training_curves.png ADDED

Git LFS Details

  • SHA256: 1706a4c7ae19bcfafd9516affb624ad166360d5a9cd0628d8dc4ddd147ceea44
  • Pointer size: 131 Bytes
  • Size of remote file: 200 kB
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6cb2cfb6fad666414dadfe1d61ff1a521b392c6319f7534baf341fcafa78fdb
3
+ size 417548126
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prepared_root": "/workspace/runpod_upload_ready/data/aihub_hairmask_hq_budget_50k",
3
+ "raw_root": "/workspace/runpod_upload_ready/data/aihub_korean_hairstyle_hq_raw",
4
+ "image_size": 512,
5
+ "model_name": "swin_base",
6
+ "run_dir": "/workspace/runpod_upload_ready/hair_mask_dataset/runs/segface_hair_budget_4090",
7
+ "epochs": 10,
8
+ "batch_size": 2,
9
+ "accumulation_steps": 2,
10
+ "num_workers": 6,
11
+ "lr": 0.0001,
12
+ "weight_decay": 0.0001,
13
+ "amp": true,
14
+ "threshold": 0.5,
15
+ "seed": 42,
16
+ "save_every": 1,
17
+ "freeze_backbone": true,
18
+ "lora_rank": 8,
19
+ "lora_alpha": 16.0,
20
+ "lora_dropout": 0.05,
21
+ "lora_targets": [
22
+ "attn.qkv",
23
+ "attn.proj",
24
+ "mlp.0",
25
+ "mlp.3"
26
+ ],
27
+ "compile_model": true,
28
+ "channels_last": true,
29
+ "trainable_params": 5804672,
30
+ "total_params": 92547896
31
+ }
hair_mask_dataset/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dataset import AIHubHairMaskDataset
2
+
3
+ __all__ = ["AIHubHairMaskDataset"]
hair_mask_dataset/segface_hair_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Dict, Iterable, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from huggingface_hub import hf_hub_download
10
+ from torchvision.models.swin_transformer import ShiftedWindowAttention
11
+
12
+ from models.segface.models.segface_celeb import SegFaceCeleb
13
+
14
+
15
+ HAIR_CLASS_INDEX = 14
16
+ DEFAULT_LORA_TARGETS = ("attn.qkv", "attn.proj", "mlp.0", "mlp.3")
17
+
18
+
19
+ def load_segface_pretrained(model: nn.Module) -> None:
20
+ ckpt_path = hf_hub_download(
21
+ repo_id="kartiknarayan/SegFace",
22
+ filename="swinb_celeba_512/model_299.pt",
23
+ )
24
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
25
+ state_dict = checkpoint["state_dict_backbone"] if "state_dict_backbone" in checkpoint else checkpoint
26
+ model.load_state_dict(state_dict, strict=False)
27
+
28
+
29
+ class LoRALinear(nn.Module):
30
+ def __init__(self, base: nn.Linear, *, rank: int, alpha: float, dropout: float) -> None:
31
+ super().__init__()
32
+ if rank <= 0:
33
+ raise ValueError("LoRA rank must be positive.")
34
+
35
+ self.base = base
36
+ self.rank = rank
37
+ self.scaling = alpha / rank
38
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
39
+ self.lora_down = nn.Linear(base.in_features, rank, bias=False)
40
+ self.lora_up = nn.Linear(rank, base.out_features, bias=False)
41
+
42
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
43
+ nn.init.zeros_(self.lora_up.weight)
44
+
45
+ self.base.weight.requires_grad = False
46
+ if self.base.bias is not None:
47
+ self.base.bias.requires_grad = False
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ base_out = self.base(x)
51
+ lora_out = self.lora_up(self.lora_down(self.dropout(x))) * self.scaling
52
+ return base_out + lora_out
53
+
54
+
55
+ def linear_bias(module: nn.Module) -> torch.Tensor | None:
56
+ if isinstance(module, LoRALinear):
57
+ return module.base.bias
58
+ if isinstance(module, nn.Linear):
59
+ return module.bias
60
+ raise TypeError(f"Unsupported linear module type: {type(module)!r}")
61
+
62
+
63
+ def linear_with_lora(module: nn.Module, x: torch.Tensor, *, bias_override: torch.Tensor | None = None) -> torch.Tensor:
64
+ if isinstance(module, LoRALinear):
65
+ base_bias = module.base.bias if bias_override is None else bias_override
66
+ base_out = F.linear(x, module.base.weight, base_bias)
67
+ lora_hidden = F.linear(module.dropout(x), module.lora_down.weight, None)
68
+ lora_out = F.linear(lora_hidden, module.lora_up.weight, None) * module.scaling
69
+ return base_out + lora_out
70
+ if isinstance(module, nn.Linear):
71
+ bias = module.bias if bias_override is None else bias_override
72
+ return F.linear(x, module.weight, bias)
73
+ raise TypeError(f"Unsupported linear module type: {type(module)!r}")
74
+
75
+
76
+ def shifted_window_attention_with_modules(
77
+ input: torch.Tensor,
78
+ qkv_module: nn.Module,
79
+ proj_module: nn.Module,
80
+ relative_position_bias: torch.Tensor,
81
+ window_size: list[int],
82
+ num_heads: int,
83
+ shift_size: list[int],
84
+ attention_dropout: float = 0.0,
85
+ dropout: float = 0.0,
86
+ logit_scale: torch.Tensor | None = None,
87
+ training: bool = True,
88
+ ) -> torch.Tensor:
89
+ B, H, W, C = input.shape
90
+ pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
91
+ pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
92
+ x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
93
+ _, pad_H, pad_W, _ = x.shape
94
+
95
+ shift_size = shift_size.copy()
96
+ if window_size[0] >= pad_H:
97
+ shift_size[0] = 0
98
+ if window_size[1] >= pad_W:
99
+ shift_size[1] = 0
100
+
101
+ if sum(shift_size) > 0:
102
+ x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
103
+
104
+ num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
105
+ x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
106
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C)
107
+
108
+ qkv_bias = linear_bias(qkv_module)
109
+ if logit_scale is not None and qkv_bias is not None:
110
+ qkv_bias = qkv_bias.clone()
111
+ length = qkv_bias.numel() // 3
112
+ qkv_bias[length : 2 * length].zero_()
113
+
114
+ qkv = linear_with_lora(qkv_module, x, bias_override=qkv_bias)
115
+ qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
116
+ q, k, v = qkv[0], qkv[1], qkv[2]
117
+
118
+ if logit_scale is not None:
119
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
120
+ logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
121
+ attn = attn * logit_scale
122
+ else:
123
+ q = q * (C // num_heads) ** -0.5
124
+ attn = q.matmul(k.transpose(-2, -1))
125
+
126
+ attn = attn + relative_position_bias
127
+
128
+ if sum(shift_size) > 0:
129
+ attn_mask = x.new_zeros((pad_H, pad_W))
130
+ h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
131
+ w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
132
+ count = 0
133
+ for h in h_slices:
134
+ for w in w_slices:
135
+ attn_mask[h[0] : h[1], w[0] : w[1]] = count
136
+ count += 1
137
+ attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
138
+ attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
139
+ attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
140
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
141
+ attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
142
+ attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
143
+ attn = attn.view(-1, num_heads, x.size(1), x.size(1))
144
+
145
+ attn = F.softmax(attn, dim=-1)
146
+ attn = F.dropout(attn, p=attention_dropout, training=training)
147
+
148
+ x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
149
+ x = linear_with_lora(proj_module, x)
150
+ x = F.dropout(x, p=dropout, training=training)
151
+
152
+ x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
153
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
154
+
155
+ if sum(shift_size) > 0:
156
+ x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
157
+
158
+ return x[:, :H, :W, :].contiguous()
159
+
160
+
161
+ def lora_compatible_swin_attention_forward(self: ShiftedWindowAttention, x: torch.Tensor) -> torch.Tensor:
162
+ relative_position_bias = self.get_relative_position_bias()
163
+ return shifted_window_attention_with_modules(
164
+ x,
165
+ self.qkv,
166
+ self.proj,
167
+ relative_position_bias,
168
+ self.window_size,
169
+ self.num_heads,
170
+ shift_size=self.shift_size,
171
+ attention_dropout=self.attention_dropout,
172
+ dropout=self.dropout,
173
+ logit_scale=getattr(self, "logit_scale", None),
174
+ training=self.training,
175
+ )
176
+
177
+
178
+ def patch_swin_attention_for_lora(module: nn.Module) -> int:
179
+ patched = 0
180
+ for child in module.modules():
181
+ if isinstance(child, ShiftedWindowAttention) and not getattr(child, "_lora_forward_patched", False):
182
+ child.forward = lora_compatible_swin_attention_forward.__get__(child, type(child))
183
+ child._lora_forward_patched = True
184
+ patched += 1
185
+ return patched
186
+
187
+
188
+ def freeze_module(module: nn.Module) -> None:
189
+ for param in module.parameters():
190
+ param.requires_grad = False
191
+
192
+
193
+ def apply_lora(
194
+ module: nn.Module,
195
+ *,
196
+ rank: int,
197
+ alpha: float,
198
+ dropout: float,
199
+ target_patterns: Iterable[str],
200
+ prefix: str = "",
201
+ ) -> int:
202
+ replaced = 0
203
+ patterns = tuple(pattern for pattern in target_patterns if pattern)
204
+
205
+ for child_name, child in list(module.named_children()):
206
+ full_name = f"{prefix}.{child_name}" if prefix else child_name
207
+ if isinstance(child, nn.Linear) and any(pattern in full_name for pattern in patterns):
208
+ setattr(module, child_name, LoRALinear(child, rank=rank, alpha=alpha, dropout=dropout))
209
+ replaced += 1
210
+ continue
211
+ replaced += apply_lora(
212
+ child,
213
+ rank=rank,
214
+ alpha=alpha,
215
+ dropout=dropout,
216
+ target_patterns=patterns,
217
+ prefix=full_name,
218
+ )
219
+
220
+ return replaced
221
+
222
+
223
+ class SegFaceHairModel(nn.Module):
224
+ def __init__(
225
+ self,
226
+ *,
227
+ input_resolution: int = 512,
228
+ model_name: str = "swin_base",
229
+ load_pretrained: bool = True,
230
+ freeze_backbone: bool = False,
231
+ lora_rank: int = 0,
232
+ lora_alpha: float = 16.0,
233
+ lora_dropout: float = 0.0,
234
+ lora_targets: Iterable[str] = DEFAULT_LORA_TARGETS,
235
+ ) -> None:
236
+ super().__init__()
237
+ self.segface = SegFaceCeleb(input_resolution=input_resolution, model=model_name)
238
+ if load_pretrained:
239
+ load_segface_pretrained(self.segface)
240
+
241
+ if freeze_backbone:
242
+ freeze_module(self.segface.backbone)
243
+
244
+ self.lora_target_patterns: Tuple[str, ...] = tuple(pattern for pattern in lora_targets if pattern)
245
+ self.lora_replaced = 0
246
+ self.swin_attention_patched = 0
247
+ if lora_rank > 0:
248
+ self.lora_replaced = apply_lora(
249
+ self.segface.backbone,
250
+ rank=lora_rank,
251
+ alpha=lora_alpha,
252
+ dropout=lora_dropout,
253
+ target_patterns=self.lora_target_patterns,
254
+ )
255
+ if model_name.startswith("swin"):
256
+ self.swin_attention_patched = patch_swin_attention_for_lora(self.segface.backbone)
257
+
258
+ def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
259
+ logits = self.segface(images, None, None)
260
+ hair_logits = logits[:, HAIR_CLASS_INDEX : HAIR_CLASS_INDEX + 1]
261
+ return {
262
+ "hair_logits": hair_logits,
263
+ "all_logits": logits,
264
+ }
inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms import functional as TF
13
+
14
+ from hair_mask_dataset.segface_hair_model import SegFaceHairModel
15
+
16
+
17
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
18
+ IMAGENET_STD = (0.229, 0.224, 0.225)
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ parser = argparse.ArgumentParser(description="Run hair segmentation inference.")
23
+ parser.add_argument("--input", required=True, help="Path to the input image.")
24
+ parser.add_argument("--output-mask", required=True, help="Where to save the predicted binary mask.")
25
+ parser.add_argument("--output-overlay", default="", help="Optional overlay output path.")
26
+ parser.add_argument("--checkpoint", default="best.pt", help="Local checkpoint path.")
27
+ parser.add_argument("--config", default="config.json", help="Local config path.")
28
+ parser.add_argument("--repo-id", default="", help="Optional Hugging Face repo id to download best.pt/config.json from.")
29
+ parser.add_argument("--revision", default="main", help="Hub revision to download from when using --repo-id.")
30
+ parser.add_argument("--threshold", type=float, default=None, help="Override sigmoid threshold.")
31
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Inference device.")
32
+ return parser.parse_args()
33
+
34
+
35
+ def resolve_artifacts(args: argparse.Namespace) -> tuple[Path, Path]:
36
+ if args.repo_id:
37
+ checkpoint_path = Path(
38
+ hf_hub_download(repo_id=args.repo_id, filename="best.pt", revision=args.revision)
39
+ )
40
+ config_path = Path(
41
+ hf_hub_download(repo_id=args.repo_id, filename="config.json", revision=args.revision)
42
+ )
43
+ return checkpoint_path, config_path
44
+ return Path(args.checkpoint), Path(args.config)
45
+
46
+
47
+ def load_model(checkpoint_path: Path, config_path: Path, device: torch.device) -> tuple[torch.nn.Module, dict]:
48
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
49
+ config = checkpoint.get("config")
50
+ if config is None:
51
+ config = json.loads(config_path.read_text(encoding="utf-8"))
52
+
53
+ model = SegFaceHairModel(
54
+ input_resolution=config["image_size"],
55
+ model_name=config["model_name"],
56
+ load_pretrained=False,
57
+ freeze_backbone=config["freeze_backbone"],
58
+ lora_rank=config["lora_rank"],
59
+ lora_alpha=config["lora_alpha"],
60
+ lora_dropout=config["lora_dropout"],
61
+ lora_targets=config["lora_targets"],
62
+ )
63
+ model.load_state_dict(checkpoint["model_state"], strict=False)
64
+ model.to(device)
65
+ model.eval()
66
+ return model, config
67
+
68
+
69
+ def preprocess(image: Image.Image, image_size: int) -> torch.Tensor:
70
+ resized = TF.resize(image, [image_size, image_size], interpolation=InterpolationMode.BILINEAR)
71
+ tensor = TF.to_tensor(resized)
72
+ tensor = TF.normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
73
+ return tensor.unsqueeze(0)
74
+
75
+
76
+ def build_overlay(image: Image.Image, mask_u8: np.ndarray) -> Image.Image:
77
+ image_np = np.asarray(image.convert("RGB"), dtype=np.uint8).copy()
78
+ overlay = image_np.copy()
79
+ overlay[mask_u8 > 127] = (overlay[mask_u8 > 127] * 0.4 + np.array([64, 255, 64]) * 0.6).astype(np.uint8)
80
+ return Image.fromarray(overlay)
81
+
82
+
83
+ def main() -> None:
84
+ args = parse_args()
85
+ checkpoint_path, config_path = resolve_artifacts(args)
86
+ device = torch.device(args.device)
87
+ model, config = load_model(checkpoint_path, config_path, device)
88
+ threshold = args.threshold if args.threshold is not None else config.get("threshold", 0.5)
89
+
90
+ image_path = Path(args.input)
91
+ output_mask_path = Path(args.output_mask)
92
+ output_mask_path.parent.mkdir(parents=True, exist_ok=True)
93
+
94
+ image = Image.open(image_path).convert("RGB")
95
+ original_size = image.size
96
+ inputs = preprocess(image, int(config["image_size"])).to(device)
97
+
98
+ with torch.no_grad():
99
+ logits = model(inputs)["hair_logits"]
100
+ probs = torch.sigmoid(logits)[0, 0].cpu().numpy()
101
+
102
+ mask_small = (probs >= threshold).astype(np.uint8) * 255
103
+ mask_image = Image.fromarray(mask_small, mode="L").resize(original_size, resample=Image.NEAREST)
104
+ mask_image.save(output_mask_path)
105
+
106
+ if args.output_overlay:
107
+ output_overlay_path = Path(args.output_overlay)
108
+ output_overlay_path.parent.mkdir(parents=True, exist_ok=True)
109
+ overlay = build_overlay(image, np.asarray(mask_image, dtype=np.uint8))
110
+ overlay.save(output_overlay_path)
111
+
112
+ print(f"Saved mask to {output_mask_path}")
113
+ if args.output_overlay:
114
+ print(f"Saved overlay to {args.output_overlay}")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
models/__init__.py ADDED
File without changes
models/segface/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import SegFaceLapa, SegFaceCeleb, SegFaceHelen
2
+
3
+
4
+ def get_model(backbone, input_resolution, model):
5
+ if backbone == "segface_lapa":
6
+ model = SegFaceLapa(input_resolution, model)
7
+ elif backbone == "segface_celeb":
8
+ model = SegFaceCeleb(input_resolution, model)
9
+ elif backbone == "segface_helen":
10
+ model = SegFaceHelen(input_resolution, model)
11
+ else:
12
+ raise ValueError("Backbone not implemented")
13
+ return model
models/segface/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .segface_lapa import SegFaceLapa
2
+ from .segface_celeb import SegFaceCeleb
3
+ from .segface_helen import SegFaceHelen
models/segface/models/segface_celeb.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+ from typing import Any, Optional, Tuple, Type
7
+ from torchvision.models import convnext_large, convnext_base, convnext_small, convnext_tiny, swin_b, swin_v2_b, swin_v2_s, swin_v2_t, mobilenet_v3_large, efficientnet_v2_m
8
+ import pdb
9
+ import numpy as np
10
+ import sys
11
+ import os
12
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
13
+ from models.segface.models.transformer import *
14
+ from models.segface.models.utils_models import *
15
+
16
+ class MLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_dim: int,
20
+ hidden_dim: int,
21
+ output_dim: int,
22
+ num_layers: int,
23
+ sigmoid_output: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+ h = [hidden_dim] * (num_layers - 1)
28
+ self.layers = nn.ModuleList(
29
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
30
+ )
31
+ self.sigmoid_output = sigmoid_output
32
+
33
+ def forward(self, x):
34
+ for i, layer in enumerate(self.layers):
35
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
36
+ if self.sigmoid_output:
37
+ x = F.sigmoid(x)
38
+ return x
39
+
40
+ class FaceDecoder(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ transformer_dim: 256,
45
+ transformer: nn.Module,
46
+ activation: Type[nn.Module] = nn.GELU,
47
+ ) -> None:
48
+
49
+ super().__init__()
50
+ self.transformer_dim = transformer_dim
51
+ self.transformer = transformer
52
+
53
+ self.background_token = nn.Embedding(1, transformer_dim)
54
+ self.neck_token = nn.Embedding(1, transformer_dim)
55
+ self.face_token = nn.Embedding(1, transformer_dim)
56
+ self.cloth_token = nn.Embedding(1, transformer_dim)
57
+ self.rightear_token = nn.Embedding(1, transformer_dim)
58
+ self.leftear_token = nn.Embedding(1, transformer_dim)
59
+ self.rightbro_token = nn.Embedding(1, transformer_dim)
60
+ self.leftbro_token = nn.Embedding(1, transformer_dim)
61
+ self.righteye_token = nn.Embedding(1, transformer_dim)
62
+ self.lefteye_token = nn.Embedding(1, transformer_dim)
63
+ self.nose_token = nn.Embedding(1, transformer_dim)
64
+ self.innermouth_token = nn.Embedding(1, transformer_dim)
65
+ self.lowerlip_token = nn.Embedding(1, transformer_dim)
66
+ self.upperlip_token = nn.Embedding(1, transformer_dim)
67
+ self.hair_token = nn.Embedding(1, transformer_dim)
68
+ self.glass_token = nn.Embedding(1, transformer_dim)
69
+ self.hat_token = nn.Embedding(1, transformer_dim)
70
+ self.earring_token = nn.Embedding(1, transformer_dim)
71
+ self.necklace_token = nn.Embedding(1, transformer_dim)
72
+
73
+
74
+ self.output_upscaling = nn.Sequential(
75
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
76
+ LayerNorm2d(transformer_dim // 4),
77
+ activation(),
78
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
79
+ activation(),
80
+ )
81
+
82
+ self.output_hypernetwork_mlps = MLP(
83
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
84
+ )
85
+
86
+ def forward(
87
+ self,
88
+ image_embeddings: torch.Tensor,
89
+ image_pe: torch.Tensor,
90
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ '''
92
+ image_embeddings - torch.Size([1, 256, 128, 128])
93
+ image_pe - torch.Size([1, 256, 128, 128])
94
+ '''
95
+ output_tokens = torch.cat([
96
+ self.background_token.weight, self.neck_token.weight, self.face_token.weight, self.cloth_token.weight,
97
+ self.rightear_token.weight, self.leftear_token.weight, self.rightbro_token.weight, self.leftbro_token.weight,
98
+ self.righteye_token.weight, self.lefteye_token.weight, self.nose_token.weight, self.innermouth_token.weight,
99
+ self.lowerlip_token.weight, self.upperlip_token.weight, self.hair_token.weight, self.glass_token.weight,
100
+ self.hat_token.weight, self.earring_token.weight, self.necklace_token.weight], dim=0)
101
+
102
+ tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1) ##### torch.Size([4, 11, 256])
103
+
104
+ src = image_embeddings ##### torch.Size([4, 256, 128, 128])
105
+ pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
106
+ b, c, h, w = src.shape
107
+
108
+ # Run the transformer
109
+ hs, src = self.transformer(src, pos_src, tokens) ####### hs - torch.Size([BS, 11, 256]), src - torch.Size([BS, 16348, 256])
110
+ mask_token_out = hs[:, :, :]
111
+
112
+ src = src.transpose(1, 2).view(b, c, h, w) ##### torch.Size([4, 256, 128, 128])
113
+ upscaled_embedding = self.output_upscaling(src) ##### torch.Size([4, 32, 512, 512])
114
+ hyper_in = self.output_hypernetwork_mlps(mask_token_out) ##### torch.Size([1, 11, 32])
115
+ b, c, h, w = upscaled_embedding.shape
116
+ seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) ##### torch.Size([1, 11, 512, 512])
117
+
118
+ return seg_output
119
+
120
+
121
+ class PositionEmbeddingRandom(nn.Module):
122
+ """
123
+ Positional encoding using random spatial frequencies.
124
+ """
125
+
126
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
127
+ super().__init__()
128
+ if scale is None or scale <= 0.0:
129
+ scale = 1.0
130
+ self.register_buffer(
131
+ "positional_encoding_gaussian_matrix",
132
+ scale * torch.randn((2, num_pos_feats)),
133
+ )
134
+
135
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
136
+ """Positionally encode points that are normalized to [0,1]."""
137
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
138
+ coords = 2 * coords - 1
139
+ coords = coords @ self.positional_encoding_gaussian_matrix
140
+ coords = 2 * np.pi * coords
141
+ # outputs d_1 x ... x d_n x C shape
142
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
143
+
144
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
145
+ """Generate positional encoding for a grid of the specified size."""
146
+ h, w = size
147
+ device: Any = self.positional_encoding_gaussian_matrix.device
148
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
149
+ y_embed = grid.cumsum(dim=0) - 0.5
150
+ x_embed = grid.cumsum(dim=1) - 0.5
151
+ y_embed = y_embed / h
152
+ x_embed = x_embed / w
153
+
154
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
155
+ return pe.permute(2, 0, 1) # C x H x W
156
+
157
+ def forward_with_coords(
158
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
159
+ ) -> torch.Tensor:
160
+ """Positionally encode points that are not normalized to [0,1]."""
161
+ coords = coords_input.clone()
162
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
163
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
164
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
165
+
166
+
167
+ class SegfaceMLP(nn.Module):
168
+ """
169
+ Linear Embedding.
170
+ """
171
+
172
+ def __init__(self, input_dim):
173
+ super().__init__()
174
+ self.proj = nn.Linear(input_dim, 256)
175
+
176
+ def forward(self, hidden_states: torch.Tensor):
177
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
178
+ hidden_states = self.proj(hidden_states)
179
+ return hidden_states
180
+
181
+ class SegFaceCeleb(nn.Module):
182
+ def __init__(self, input_resolution, model):
183
+ super(SegFaceCeleb, self).__init__()
184
+ self.input_resolution = input_resolution
185
+ self.model = model
186
+
187
+ if self.model == "swin_base":
188
+ swin_v2 = swin_b(weights='IMAGENET1K_V1')
189
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
190
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
191
+ self.multi_scale_features = []
192
+
193
+ if self.model == "swinv2_base":
194
+ swin_v2 = swin_v2_b(weights='IMAGENET1K_V1')
195
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
196
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
197
+ self.multi_scale_features = []
198
+
199
+ if self.model == "swinv2_small":
200
+ swin_v2 = swin_v2_s(weights='IMAGENET1K_V1')
201
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
202
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
203
+ self.multi_scale_features = []
204
+
205
+ if self.model == "swinv2_tiny":
206
+ swin_v2 = swin_v2_t(weights='IMAGENET1K_V1')
207
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
208
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
209
+ self.multi_scale_features = []
210
+
211
+ if self.model == "convnext_base":
212
+ convnext = convnext_base(pretrained=True)
213
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
214
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
215
+ self.multi_scale_features = []
216
+
217
+ if self.model == "convnext_small":
218
+ convnext = convnext_small(pretrained=True)
219
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
220
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
221
+ self.multi_scale_features = []
222
+
223
+ if self.model == "convnext_tiny":
224
+ convnext = convnext_small(pretrained=True)
225
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
226
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
227
+ self.multi_scale_features = []
228
+
229
+ if self.model == "resnet":
230
+ resnet101 = models.resnet101(pretrained=True)
231
+ self.backbone = torch.nn.Sequential(*(list(resnet101.children())[:-1]))
232
+ self.target_layer_names = ['4', '5', '6', '7']
233
+ self.multi_scale_features = []
234
+
235
+ if self.model == "mobilenet":
236
+ mobilenet = mobilenet_v3_large(pretrained=True).features
237
+ self.backbone = mobilenet
238
+ self.target_layer_names = ['3', '6', '12', '16']
239
+ self.multi_scale_features = []
240
+
241
+ if self.model == "efficientnet":
242
+ efficientnet = efficientnet_v2_m(pretrained=True).features
243
+ self.backbone = efficientnet
244
+ self.target_layer_names = ['2', '3', '5', '8']
245
+ self.multi_scale_features = []
246
+
247
+ embed_dim = 1024
248
+ out_chans = 256
249
+
250
+ self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
251
+
252
+ for name, module in self.backbone.named_modules():
253
+ if name in self.target_layer_names:
254
+ module.register_forward_hook(self.save_features_hook(name))
255
+
256
+ self.face_decoder = FaceDecoder(
257
+ transformer_dim=256,
258
+ transformer=TwoWayTransformer(
259
+ depth=2,
260
+ embedding_dim=256,
261
+ mlp_dim=2048,
262
+ num_heads=8,
263
+ ))
264
+
265
+ num_encoder_blocks = 4
266
+ if self.model in ["swin_base", "swinv2_base", "convnext_base"]:
267
+ hidden_sizes = [128, 256, 512, 1024] ### Swin Base and ConvNext Base
268
+ if self.model in ["resnet"]:
269
+ hidden_sizes = [256, 512, 1024, 2048] ### ResNet
270
+ if self.model in ["swinv2_small", "swinv2_tiny", "convnext_small", "convnext_tiny"]:
271
+ hidden_sizes = [96, 192, 384, 768] ### Swin Small/Tiny and ConvNext Small/Tiny
272
+ if self.model in ["mobilenet"]:
273
+ hidden_sizes = [24, 40, 112, 960] ### MobileNet
274
+ if self.model in ["efficientnet"]:
275
+ hidden_sizes = [48, 80, 176, 1280] ### EfficientNet
276
+ decoder_hidden_size = 256
277
+
278
+ mlps = []
279
+ for i in range(num_encoder_blocks):
280
+ mlp = SegfaceMLP(input_dim=hidden_sizes[i])
281
+ mlps.append(mlp)
282
+ self.linear_c = nn.ModuleList(mlps)
283
+
284
+ # The following 3 layers implement the ConvModule of the original implementation
285
+ self.linear_fuse = nn.Conv2d(
286
+ in_channels=decoder_hidden_size * num_encoder_blocks,
287
+ out_channels=decoder_hidden_size,
288
+ kernel_size=1,
289
+ bias=False,
290
+ )
291
+
292
+
293
+ def save_features_hook(self, name):
294
+ def hook(module, input, output):
295
+ if self.model in ["swin_base", "swinv2_base", "swinv2_small", "swinv2_tiny"]:
296
+ self.multi_scale_features.append(output.permute(0,3,1,2).contiguous()) ### Swin, Swinv2
297
+ if self.model in ["convnext_base", "convnext_small", "convnext_tiny", "mobilenet", "efficientnet"]:
298
+ self.multi_scale_features.append(output) ### ConvNext, ResNet, EfficientNet, MobileNet
299
+ return hook
300
+
301
+ def forward(self, x, labels, dataset):
302
+ self.multi_scale_features.clear()
303
+
304
+ _,_,h,w = x.shape
305
+ features = self.backbone(x).squeeze()
306
+
307
+ batch_size = self.multi_scale_features[-1].shape[0]
308
+ all_hidden_states = ()
309
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
310
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
311
+ encoder_hidden_state = mlp(encoder_hidden_state)
312
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
313
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
314
+ # upsample
315
+ encoder_hidden_state = nn.functional.interpolate(
316
+ encoder_hidden_state, size=self.multi_scale_features[0].size()[2:], mode="bilinear", align_corners=False
317
+ )
318
+ all_hidden_states += (encoder_hidden_state,)
319
+
320
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) #### torch.Size([BS, 256, 128, 128])
321
+ image_pe = self.pe_layer((fused_states.shape[2], fused_states.shape[3])).unsqueeze(0)
322
+ seg_output = self.face_decoder(
323
+ image_embeddings=fused_states,
324
+ image_pe=image_pe
325
+ )
326
+
327
+ return seg_output
328
+
329
+ if __name__ == "__main__":
330
+ input_resolution = 512
331
+ model_name = "swin_base"
332
+ model = SegFaceCeleb(input_resolution, model_name)
333
+
334
+ batch_size = 4
335
+ num_channels = 3
336
+ height = 512
337
+ width = 512
338
+
339
+ x = torch.randn(batch_size, num_channels, height, width)
340
+
341
+ labels = {
342
+ "lnm_seg": torch.randn(batch_size, 5, 2)
343
+ }
344
+
345
+ dataset = torch.tensor([0,0,0,0])
346
+
347
+ seg_output = model(x, labels, dataset)
348
+ print("Segmentation Output Shape:", seg_output.shape)
models/segface/models/segface_helen.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+ from typing import Any, Optional, Tuple, Type
7
+ from torchvision.models import convnext_large, convnext_base, convnext_small, convnext_tiny, swin_b, swin_v2_b, swin_v2_s, swin_v2_t, mobilenet_v3_large, efficientnet_v2_m
8
+ import pdb
9
+ import numpy as np
10
+ import sys
11
+ import os
12
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
13
+ from models.segface.models.transformer import *
14
+ from models.segface.models.utils_models import *
15
+
16
+ class MLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_dim: int,
20
+ hidden_dim: int,
21
+ output_dim: int,
22
+ num_layers: int,
23
+ sigmoid_output: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+ h = [hidden_dim] * (num_layers - 1)
28
+ self.layers = nn.ModuleList(
29
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
30
+ )
31
+ self.sigmoid_output = sigmoid_output
32
+
33
+ def forward(self, x):
34
+ for i, layer in enumerate(self.layers):
35
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
36
+ if self.sigmoid_output:
37
+ x = F.sigmoid(x)
38
+ return x
39
+
40
+ class FaceDecoder(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ transformer_dim: 256,
45
+ transformer: nn.Module,
46
+ activation: Type[nn.Module] = nn.GELU,
47
+ ) -> None:
48
+
49
+ super().__init__()
50
+ self.transformer_dim = transformer_dim
51
+ self.transformer = transformer
52
+
53
+ self.background_token = nn.Embedding(1, transformer_dim)
54
+ self.face_token = nn.Embedding(1, transformer_dim)
55
+ self.leftbro_token = nn.Embedding(1, transformer_dim)
56
+ self.rightbro_token = nn.Embedding(1, transformer_dim)
57
+ self.lefteye_token = nn.Embedding(1, transformer_dim)
58
+ self.righteye_token = nn.Embedding(1, transformer_dim)
59
+ self.nose_token = nn.Embedding(1, transformer_dim)
60
+ self.upperlip_token = nn.Embedding(1, transformer_dim)
61
+ self.innermouth_token = nn.Embedding(1, transformer_dim)
62
+ self.lowerlip_token = nn.Embedding(1, transformer_dim)
63
+ self.hair_token = nn.Embedding(1, transformer_dim)
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
67
+ LayerNorm2d(transformer_dim // 4),
68
+ activation(),
69
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
70
+ activation(),
71
+ )
72
+
73
+ self.output_hypernetwork_mlps = MLP(
74
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ image_embeddings: torch.Tensor,
80
+ image_pe: torch.Tensor,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ '''
83
+ image_embeddings - torch.Size([1, 256, 128, 128])
84
+ image_pe - torch.Size([1, 256, 128, 128])
85
+ '''
86
+ output_tokens = torch.cat([
87
+ self.background_token.weight, self.face_token.weight, self.leftbro_token.weight, self.rightbro_token.weight,
88
+ self.lefteye_token.weight, self.righteye_token.weight, self.nose_token.weight, self.upperlip_token.weight,
89
+ self.innermouth_token.weight, self.lowerlip_token.weight, self.hair_token.weight], dim=0)
90
+
91
+ tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1) ##### torch.Size([4, 11, 256])
92
+
93
+ src = image_embeddings ##### torch.Size([4, 256, 128, 128])
94
+ pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
95
+ b, c, h, w = src.shape
96
+
97
+ # Run the transformer
98
+ hs, src = self.transformer(src, pos_src, tokens) ####### hs - torch.Size([BS, 11, 256]), src - torch.Size([BS, 16348, 256])
99
+ mask_token_out = hs[:, :, :]
100
+
101
+ src = src.transpose(1, 2).view(b, c, h, w) ##### torch.Size([4, 256, 128, 128])
102
+ upscaled_embedding = self.output_upscaling(src) ##### torch.Size([4, 32, 512, 512])
103
+ hyper_in = self.output_hypernetwork_mlps(mask_token_out) ##### torch.Size([1, 11, 32])
104
+ b, c, h, w = upscaled_embedding.shape
105
+ seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) ##### torch.Size([1, 11, 512, 512])
106
+
107
+ return seg_output
108
+
109
+
110
+
111
+ class PositionEmbeddingRandom(nn.Module):
112
+ """
113
+ Positional encoding using random spatial frequencies.
114
+ """
115
+
116
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
117
+ super().__init__()
118
+ if scale is None or scale <= 0.0:
119
+ scale = 1.0
120
+ self.register_buffer(
121
+ "positional_encoding_gaussian_matrix",
122
+ scale * torch.randn((2, num_pos_feats)),
123
+ )
124
+
125
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
126
+ """Positionally encode points that are normalized to [0,1]."""
127
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
128
+ coords = 2 * coords - 1
129
+ coords = coords @ self.positional_encoding_gaussian_matrix
130
+ coords = 2 * np.pi * coords
131
+ # outputs d_1 x ... x d_n x C shape
132
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
133
+
134
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
135
+ """Generate positional encoding for a grid of the specified size."""
136
+ h, w = size
137
+ device: Any = self.positional_encoding_gaussian_matrix.device
138
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
139
+ y_embed = grid.cumsum(dim=0) - 0.5
140
+ x_embed = grid.cumsum(dim=1) - 0.5
141
+ y_embed = y_embed / h
142
+ x_embed = x_embed / w
143
+
144
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
145
+ return pe.permute(2, 0, 1) # C x H x W
146
+
147
+ def forward_with_coords(
148
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
149
+ ) -> torch.Tensor:
150
+ """Positionally encode points that are not normalized to [0,1]."""
151
+ coords = coords_input.clone()
152
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
153
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
154
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
155
+
156
+
157
+ class SegfaceMLP(nn.Module):
158
+ """
159
+ Linear Embedding.
160
+ """
161
+ def __init__(self, input_dim):
162
+ super().__init__()
163
+ self.proj = nn.Linear(input_dim, 256)
164
+
165
+ def forward(self, hidden_states: torch.Tensor):
166
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
167
+ hidden_states = self.proj(hidden_states)
168
+ return hidden_states
169
+
170
+ class SegFaceHelen(nn.Module):
171
+ def __init__(self, input_resolution, model):
172
+ super(SegFaceHelen, self).__init__()
173
+ self.input_resolution = input_resolution
174
+ self.model = model
175
+
176
+ if self.model == "swin_base":
177
+ swin_v2 = swin_b(weights='IMAGENET1K_V1')
178
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
179
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
180
+ self.multi_scale_features = []
181
+
182
+ if self.model == "swinv2_base":
183
+ swin_v2 = swin_v2_b(weights='IMAGENET1K_V1')
184
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
185
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
186
+ self.multi_scale_features = []
187
+
188
+ if self.model == "swinv2_small":
189
+ swin_v2 = swin_v2_s(weights='IMAGENET1K_V1')
190
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
191
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
192
+ self.multi_scale_features = []
193
+
194
+ if self.model == "swinv2_tiny":
195
+ swin_v2 = swin_v2_t(weights='IMAGENET1K_V1')
196
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
197
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
198
+ self.multi_scale_features = []
199
+
200
+ if self.model == "convnext_base":
201
+ convnext = convnext_base(pretrained=True)
202
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
203
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
204
+ self.multi_scale_features = []
205
+
206
+ if self.model == "convnext_small":
207
+ convnext = convnext_small(pretrained=True)
208
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
209
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
210
+ self.multi_scale_features = []
211
+
212
+ if self.model == "convnext_tiny":
213
+ convnext = convnext_small(pretrained=True)
214
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
215
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
216
+ self.multi_scale_features = []
217
+
218
+ if self.model == "resnet":
219
+ resnet101 = models.resnet101(pretrained=True)
220
+ self.backbone = torch.nn.Sequential(*(list(resnet101.children())[:-1]))
221
+ self.target_layer_names = ['4', '5', '6', '7']
222
+ self.multi_scale_features = []
223
+
224
+ if self.model == "mobilenet":
225
+ mobilenet = mobilenet_v3_large(pretrained=True).features
226
+ self.backbone = mobilenet
227
+ self.target_layer_names = ['3', '6', '12', '16']
228
+ self.multi_scale_features = []
229
+
230
+ if self.model == "efficientnet":
231
+ efficientnet = efficientnet_v2_m(pretrained=True).features
232
+ self.backbone = efficientnet
233
+ self.target_layer_names = ['2', '3', '5', '8']
234
+ self.multi_scale_features = []
235
+
236
+ embed_dim = 1024
237
+ out_chans = 256
238
+
239
+ self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
240
+ self.get_matrix_fn = functools.partial(get_face_align_matrix, target_shape=(self.input_resolution, self.input_resolution), target_face_scale=1.0)
241
+ self.warp_fn = functools.partial(make_tanh_warp_grid, warp_factor=0.8, warped_shape=(self.input_resolution, self.input_resolution))
242
+ self.inv_warp_fn = functools.partial(make_inverted_tanh_warp_grid, warp_factor=0.8, warped_shape=(self.input_resolution, self.input_resolution))
243
+
244
+ for name, module in self.backbone.named_modules():
245
+ if name in self.target_layer_names:
246
+ module.register_forward_hook(self.save_features_hook(name))
247
+
248
+ self.face_decoder = FaceDecoder(
249
+ transformer_dim=256,
250
+ transformer=TwoWayTransformer(
251
+ depth=2,
252
+ embedding_dim=256,
253
+ mlp_dim=2048,
254
+ num_heads=8,
255
+ ))
256
+
257
+ num_encoder_blocks = 4
258
+ if self.model in ["swin_base", "swinv2_base", "convnext_base"]:
259
+ hidden_sizes = [128, 256, 512, 1024] ### Swin Base and ConvNext Base
260
+ if self.model in ["resnet"]:
261
+ hidden_sizes = [256, 512, 1024, 2048] ### ResNet
262
+ if self.model in ["swinv2_small", "swinv2_tiny", "convnext_small", "convnext_tiny"]:
263
+ hidden_sizes = [96, 192, 384, 768] ### Swin Small/Tiny and ConvNext Small/Tiny
264
+ if self.model in ["mobilenet"]:
265
+ hidden_sizes = [24, 40, 112, 960] ### MobileNet
266
+ if self.model in ["efficientnet"]:
267
+ hidden_sizes = [48, 80, 176, 1280] ### EfficientNet
268
+ decoder_hidden_size = 256
269
+
270
+ mlps = []
271
+ for i in range(num_encoder_blocks):
272
+ mlp = SegfaceMLP(input_dim=hidden_sizes[i])
273
+ mlps.append(mlp)
274
+ self.linear_c = nn.ModuleList(mlps)
275
+
276
+ # The following 3 layers implement the ConvModule of the original implementation
277
+ self.linear_fuse = nn.Conv2d(
278
+ in_channels=decoder_hidden_size * num_encoder_blocks,
279
+ out_channels=decoder_hidden_size,
280
+ kernel_size=1,
281
+ bias=False,
282
+ )
283
+
284
+
285
+ def save_features_hook(self, name):
286
+ def hook(module, input, output):
287
+ if self.model in ["swin_base", "swinv2_base", "swinv2_small", "swinv2_tiny"]:
288
+ self.multi_scale_features.append(output.permute(0,3,1,2).contiguous()) ### Swin, Swinv2
289
+ if self.model in ["convnext_base", "convnext_small", "convnext_tiny", "mobilenet", "efficientnet"]:
290
+ self.multi_scale_features.append(output) ### ConvNext, ResNet, EfficientNet, MobileNet
291
+ return hook
292
+
293
+ def forward(self, x, labels, dataset):
294
+ self.multi_scale_features.clear()
295
+
296
+ _,_,h,w = x.shape
297
+ features = self.backbone(x).squeeze()
298
+
299
+ batch_size = self.multi_scale_features[-1].shape[0]
300
+ all_hidden_states = ()
301
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
302
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
303
+ encoder_hidden_state = mlp(encoder_hidden_state)
304
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
305
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
306
+ # upsample
307
+ encoder_hidden_state = nn.functional.interpolate(
308
+ encoder_hidden_state, size=self.multi_scale_features[0].size()[2:], mode="bilinear", align_corners=False
309
+ )
310
+ all_hidden_states += (encoder_hidden_state,)
311
+
312
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) #### torch.Size([BS, 256, 128, 128])
313
+ image_pe = self.pe_layer((fused_states.shape[2], fused_states.shape[3])).unsqueeze(0)
314
+ seg_output = self.face_decoder(
315
+ image_embeddings=fused_states,
316
+ image_pe=image_pe
317
+ )
318
+
319
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) #### torch.Size([bs, 256, 128, 128])
320
+ image_pe = self.pe_layer((fused_states.shape[2], fused_states.shape[3])).unsqueeze(0)
321
+ seg_output = self.face_decoder(
322
+ image_embeddings=fused_states,
323
+ image_pe=image_pe
324
+ )
325
+
326
+ return seg_output
327
+
328
+ if __name__ == "__main__":
329
+ model_name = "swin_base"
330
+ input_resolution = 512
331
+ model = SegFaceHelen(input_resolution, model_name)
332
+
333
+ batch_size = 4
334
+ num_channels = 3
335
+ height = input_resolution
336
+ width = input_resolution
337
+
338
+ x = torch.randn(batch_size, num_channels, height, width)
339
+
340
+ labels = {
341
+ "lnm_seg": torch.randn(batch_size, 5, 2)
342
+ }
343
+
344
+ dataset = torch.tensor([2,2,2,2])
345
+
346
+ seg_output = model(x, labels, dataset)
347
+ print("Segmentation Output Shape:", seg_output.shape)
models/segface/models/segface_lapa.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+ from typing import Any, Optional, Tuple, Type
7
+ from torchvision.models import convnext_large, convnext_base, convnext_small, convnext_tiny, swin_b, swin_v2_b, swin_v2_s, swin_v2_t, mobilenet_v3_large, efficientnet_v2_m
8
+ import pdb
9
+ import numpy as np
10
+ import sys
11
+ import os
12
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
13
+ from models.segface.models.transformer import *
14
+ from models.segface.models.utils_models import *
15
+
16
+ class MLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_dim: int,
20
+ hidden_dim: int,
21
+ output_dim: int,
22
+ num_layers: int,
23
+ sigmoid_output: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+ h = [hidden_dim] * (num_layers - 1)
28
+ self.layers = nn.ModuleList(
29
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
30
+ )
31
+ self.sigmoid_output = sigmoid_output
32
+
33
+ def forward(self, x):
34
+ for i, layer in enumerate(self.layers):
35
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
36
+ if self.sigmoid_output:
37
+ x = F.sigmoid(x)
38
+ return x
39
+
40
+ class FaceDecoder(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ transformer_dim: 256,
45
+ transformer: nn.Module,
46
+ activation: Type[nn.Module] = nn.GELU,
47
+ ) -> None:
48
+
49
+ super().__init__()
50
+ self.transformer_dim = transformer_dim
51
+ self.transformer = transformer
52
+
53
+ self.background_token = nn.Embedding(1, transformer_dim)
54
+ self.face_token = nn.Embedding(1, transformer_dim)
55
+ self.leftbro_token = nn.Embedding(1, transformer_dim)
56
+ self.rightbro_token = nn.Embedding(1, transformer_dim)
57
+ self.lefteye_token = nn.Embedding(1, transformer_dim)
58
+ self.righteye_token = nn.Embedding(1, transformer_dim)
59
+ self.nose_token = nn.Embedding(1, transformer_dim)
60
+ self.upperlip_token = nn.Embedding(1, transformer_dim)
61
+ self.innermouth_token = nn.Embedding(1, transformer_dim)
62
+ self.lowerlip_token = nn.Embedding(1, transformer_dim)
63
+ self.hair_token = nn.Embedding(1, transformer_dim)
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
67
+ LayerNorm2d(transformer_dim // 4),
68
+ activation(),
69
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
70
+ activation(),
71
+ )
72
+
73
+ self.output_hypernetwork_mlps = MLP(
74
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ image_embeddings: torch.Tensor,
80
+ image_pe: torch.Tensor,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ '''
83
+ image_embeddings - torch.Size([1, 256, 128, 128])
84
+ image_pe - torch.Size([1, 256, 128, 128])
85
+ '''
86
+ output_tokens = torch.cat([self.background_token.weight, self.face_token.weight, self.leftbro_token.weight, self.rightbro_token.weight, self.lefteye_token.weight, \
87
+ self.righteye_token.weight, self.nose_token.weight, self.upperlip_token.weight, self.innermouth_token.weight,self.lowerlip_token.weight, self.hair_token.weight], dim=0)
88
+
89
+ tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1) ##### torch.Size([4, 11, 256])
90
+
91
+ src = image_embeddings ##### torch.Size([4, 256, 128, 128])
92
+ pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
93
+ b, c, h, w = src.shape
94
+
95
+ # Run the transformer
96
+ hs, src = self.transformer(src, pos_src, tokens) ####### hs - torch.Size([BS, 11, 256]), src - torch.Size([BS, 16348, 256])
97
+ mask_token_out = hs[:, :, :]
98
+
99
+ src = src.transpose(1, 2).view(b, c, h, w) ##### torch.Size([4, 256, 128, 128])
100
+ upscaled_embedding = self.output_upscaling(src) ##### torch.Size([4, 32, 512, 512])
101
+ hyper_in = self.output_hypernetwork_mlps(mask_token_out) ##### torch.Size([1, 11, 32])
102
+ b, c, h, w = upscaled_embedding.shape
103
+ seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) ##### torch.Size([1, 11, 512, 512])
104
+
105
+ return seg_output
106
+
107
+
108
+
109
+ class PositionEmbeddingRandom(nn.Module):
110
+ """
111
+ Positional encoding using random spatial frequencies.
112
+ """
113
+
114
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
115
+ super().__init__()
116
+ if scale is None or scale <= 0.0:
117
+ scale = 1.0
118
+ self.register_buffer(
119
+ "positional_encoding_gaussian_matrix",
120
+ scale * torch.randn((2, num_pos_feats)),
121
+ )
122
+
123
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
124
+ """Positionally encode points that are normalized to [0,1]."""
125
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
126
+ coords = 2 * coords - 1
127
+ coords = coords @ self.positional_encoding_gaussian_matrix
128
+ coords = 2 * np.pi * coords
129
+ # outputs d_1 x ... x d_n x C shape
130
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
131
+
132
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
133
+ """Generate positional encoding for a grid of the specified size."""
134
+ h, w = size
135
+ device: Any = self.positional_encoding_gaussian_matrix.device
136
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
137
+ y_embed = grid.cumsum(dim=0) - 0.5
138
+ x_embed = grid.cumsum(dim=1) - 0.5
139
+ y_embed = y_embed / h
140
+ x_embed = x_embed / w
141
+
142
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
143
+ return pe.permute(2, 0, 1) # C x H x W
144
+
145
+ def forward_with_coords(
146
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
147
+ ) -> torch.Tensor:
148
+ """Positionally encode points that are not normalized to [0,1]."""
149
+ coords = coords_input.clone()
150
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
151
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
152
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
153
+
154
+
155
+ class SegfaceMLP(nn.Module):
156
+ """
157
+ Linear Embedding.
158
+ """
159
+
160
+ def __init__(self, input_dim):
161
+ super().__init__()
162
+ self.proj = nn.Linear(input_dim, 256)
163
+
164
+ def forward(self, hidden_states: torch.Tensor):
165
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
166
+ hidden_states = self.proj(hidden_states)
167
+ return hidden_states
168
+
169
+ class SegFaceLapa(nn.Module):
170
+ def __init__(self, input_resolution, model):
171
+ super(SegFaceLapa, self).__init__()
172
+ self.input_resolution = input_resolution
173
+ self.model = model
174
+
175
+ if self.model == "swin_base":
176
+ swin_v2 = swin_b(weights='IMAGENET1K_V1')
177
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
178
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
179
+ self.multi_scale_features = []
180
+
181
+ if self.model == "swinv2_base":
182
+ swin_v2 = swin_v2_b(weights='IMAGENET1K_V1')
183
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
184
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
185
+ self.multi_scale_features = []
186
+
187
+ if self.model == "swinv2_small":
188
+ swin_v2 = swin_v2_s(weights='IMAGENET1K_V1')
189
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
190
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
191
+ self.multi_scale_features = []
192
+
193
+ if self.model == "swinv2_tiny":
194
+ swin_v2 = swin_v2_t(weights='IMAGENET1K_V1')
195
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
196
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
197
+ self.multi_scale_features = []
198
+
199
+ if self.model == "convnext_base":
200
+ convnext = convnext_base(pretrained=True)
201
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
202
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
203
+ self.multi_scale_features = []
204
+
205
+ if self.model == "convnext_small":
206
+ convnext = convnext_small(pretrained=True)
207
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
208
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
209
+ self.multi_scale_features = []
210
+
211
+ if self.model == "convnext_tiny":
212
+ convnext = convnext_small(pretrained=True)
213
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
214
+ self.target_layer_names = ['0.1', '0.3', '0.5', '0.7']
215
+ self.multi_scale_features = []
216
+
217
+ if self.model == "resnet":
218
+ resnet101 = models.resnet101(pretrained=True)
219
+ self.backbone = torch.nn.Sequential(*(list(resnet101.children())[:-1]))
220
+ self.target_layer_names = ['4', '5', '6', '7']
221
+ self.multi_scale_features = []
222
+
223
+ if self.model == "mobilenet":
224
+ mobilenet = mobilenet_v3_large(pretrained=True).features
225
+ self.backbone = mobilenet
226
+ self.target_layer_names = ['3', '6', '12', '16']
227
+ self.multi_scale_features = []
228
+
229
+ if self.model == "efficientnet":
230
+ efficientnet = efficientnet_v2_m(pretrained=True).features
231
+ self.backbone = efficientnet
232
+ self.target_layer_names = ['2', '3', '5', '8']
233
+ self.multi_scale_features = []
234
+
235
+ embed_dim = 1024
236
+ out_chans = 256
237
+
238
+ self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
239
+ self.get_matrix_fn = functools.partial(get_face_align_matrix, target_shape=(self.input_resolution, self.input_resolution), target_face_scale=1.0)
240
+ self.warp_fn = functools.partial(make_tanh_warp_grid, warp_factor=0.8, warped_shape=(self.input_resolution, self.input_resolution))
241
+ self.inv_warp_fn = functools.partial(make_inverted_tanh_warp_grid, warp_factor=0.8, warped_shape=(self.input_resolution, self.input_resolution))
242
+
243
+ for name, module in self.backbone.named_modules():
244
+ if name in self.target_layer_names:
245
+ module.register_forward_hook(self.save_features_hook(name))
246
+
247
+ self.face_decoder = FaceDecoder(
248
+ transformer_dim=256,
249
+ transformer=TwoWayTransformer(
250
+ depth=2,
251
+ embedding_dim=256,
252
+ mlp_dim=2048,
253
+ num_heads=8,
254
+ ))
255
+
256
+ num_encoder_blocks = 4
257
+ if self.model in ["swin_base", "swinv2_base", "convnext_base"]:
258
+ hidden_sizes = [128, 256, 512, 1024] ### Swin Base and ConvNext Base
259
+ if self.model in ["resnet"]:
260
+ hidden_sizes = [256, 512, 1024, 2048] ### ResNet
261
+ if self.model in ["swinv2_small", "swinv2_tiny", "convnext_small", "convnext_tiny"]:
262
+ hidden_sizes = [96, 192, 384, 768] ### Swin Small/Tiny and ConvNext Small/Tiny
263
+ if self.model in ["mobilenet"]:
264
+ hidden_sizes = [24, 40, 112, 960] ### MobileNet
265
+ if self.model in ["efficientnet"]:
266
+ hidden_sizes = [48, 80, 176, 1280] ### EfficientNet
267
+ decoder_hidden_size = 256
268
+
269
+ mlps = []
270
+ for i in range(num_encoder_blocks):
271
+ mlp = SegfaceMLP(input_dim=hidden_sizes[i])
272
+ mlps.append(mlp)
273
+ self.linear_c = nn.ModuleList(mlps)
274
+
275
+ # The following 3 layers implement the ConvModule of the original implementation
276
+ self.linear_fuse = nn.Conv2d(
277
+ in_channels=decoder_hidden_size * num_encoder_blocks,
278
+ out_channels=decoder_hidden_size,
279
+ kernel_size=1,
280
+ bias=False,
281
+ )
282
+
283
+
284
+ def save_features_hook(self, name):
285
+ def hook(module, input, output):
286
+ if self.model in ["swin_base", "swinv2_base", "swinv2_small", "swinv2_tiny"]:
287
+ self.multi_scale_features.append(output.permute(0,3,1,2).contiguous()) ### Swin, Swinv2
288
+ if self.model in ["convnext_base", "convnext_small", "convnext_tiny", "mobilenet", "efficientnet"]:
289
+ self.multi_scale_features.append(output) ### ConvNext, ResNet, EfficientNet, MobileNet
290
+ return hook
291
+
292
+ def forward(self, x, labels, dataset):
293
+ self.multi_scale_features.clear()
294
+
295
+ _,_,h,w = x.shape
296
+ mask = dataset == 1
297
+ x_seg = x[mask] #### torch.Size([4, 3, 512, 512])
298
+ lnd = labels["lnm_seg"][mask] #### torch.Size([4, 5, 2])
299
+ matrix = self.get_matrix_fn(lnd) #### torch.Size([4, 3, 3])
300
+ grid = self.warp_fn(matrix=matrix, orig_shape = (h, w)) #### torch.Size([4, 512, 512, 2])
301
+ inv_grid = self.inv_warp_fn(matrix=matrix, orig_shape = (h, w)) #### torch.Size([4, 512, 512, 2])
302
+ w_x_seg = F.grid_sample(x_seg, grid, mode='bilinear', align_corners=False) #### torch.Size([4, 3, 512, 512])
303
+ x[mask] = w_x_seg
304
+ features = self.backbone(x).squeeze()
305
+
306
+ batch_size = self.multi_scale_features[-1].shape[0]
307
+ all_hidden_states = ()
308
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
309
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
310
+ encoder_hidden_state = mlp(encoder_hidden_state)
311
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
312
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
313
+ # upsample
314
+ encoder_hidden_state = nn.functional.interpolate(
315
+ encoder_hidden_state, size=self.multi_scale_features[0].size()[2:], mode="bilinear", align_corners=False
316
+ )
317
+ all_hidden_states += (encoder_hidden_state,)
318
+
319
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) #### torch.Size([bs, 256, 128, 128])
320
+ image_pe = self.pe_layer((fused_states.shape[2], fused_states.shape[3])).unsqueeze(0)
321
+ seg_output = self.face_decoder(
322
+ image_embeddings=fused_states,
323
+ image_pe=image_pe
324
+ )
325
+
326
+ segmentation_indices = dataset == 1
327
+ seg_lapa_in = seg_output[mask]
328
+ seg_lapa = F.grid_sample(seg_lapa_in, inv_grid, mode='bilinear', align_corners=False)
329
+ seg_output[mask] = seg_lapa
330
+ seg_output = seg_output[segmentation_indices]
331
+
332
+ return seg_output
333
+
334
+ if __name__ == "__main__":
335
+ input_resolution = 512
336
+ model_name = "swin_base"
337
+ model = SegFaceLapa(input_resolution, model_name)
338
+
339
+ batch_size = 4
340
+ num_channels = 3
341
+ height = input_resolution
342
+ width = input_resolution
343
+
344
+ x = torch.randn(batch_size, num_channels, height, width)
345
+
346
+ labels = {
347
+ "lnm_seg": torch.randn(batch_size, 5, 2)
348
+ }
349
+
350
+ dataset = torch.tensor([0,0,1,1])
351
+
352
+ seg_output = model(x, labels, dataset)
353
+ print("Segmentation Output Shape:", seg_output.shape)
models/segface/models/transformer.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+
14
+ class MLPBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embedding_dim: int,
18
+ mlp_dim: int,
19
+ act: Type[nn.Module] = nn.GELU,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
23
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
24
+ self.act = act()
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return self.lin2(self.act(self.lin1(x)))
28
+
29
+
30
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
45
+
46
+
47
+ class TwoWayTransformer(nn.Module):
48
+ def __init__(
49
+ self,
50
+ depth: int,
51
+ embedding_dim: int,
52
+ num_heads: int,
53
+ mlp_dim: int,
54
+ activation: Type[nn.Module] = nn.ReLU,
55
+ attention_downsample_rate: int = 2,
56
+ ) -> None:
57
+ """
58
+ A transformer decoder that attends to an input image using
59
+ queries whose positional embedding is supplied.
60
+
61
+ Args:
62
+ depth (int): number of layers in the transformer
63
+ embedding_dim (int): the channel dimension for the input embeddings
64
+ num_heads (int): the number of heads for multihead attention. Must
65
+ divide embedding_dim
66
+ mlp_dim (int): the channel dimension internal to the MLP block
67
+ activation (nn.Module): the activation to use in the MLP block
68
+ """
69
+ super().__init__()
70
+ self.depth = depth
71
+ self.embedding_dim = embedding_dim
72
+ self.num_heads = num_heads
73
+ self.mlp_dim = mlp_dim
74
+ self.layers = nn.ModuleList()
75
+
76
+ for i in range(depth):
77
+ self.layers.append(
78
+ TwoWayAttentionBlock(
79
+ embedding_dim=embedding_dim,
80
+ num_heads=num_heads,
81
+ mlp_dim=mlp_dim,
82
+ activation=activation,
83
+ attention_downsample_rate=attention_downsample_rate,
84
+ skip_first_layer_pe=(i == 0),
85
+ )
86
+ )
87
+
88
+ self.final_attn_token_to_image = Attention(
89
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
90
+ )
91
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
92
+
93
+ def forward(
94
+ self,
95
+ image_embedding: Tensor,
96
+ image_pe: Tensor,
97
+ point_embedding: Tensor,
98
+ ) -> Tuple[Tensor, Tensor]:
99
+ """
100
+ Args:
101
+ image_embedding (torch.Tensor): image to attend to. Should be shape
102
+ B x embedding_dim x h x w for any h and w.
103
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
104
+ have the same shape as image_embedding.
105
+ point_embedding (torch.Tensor): the embedding to add to the query points.
106
+ Must have shape B x N_points x embedding_dim for any N_points.
107
+
108
+ Returns:
109
+ torch.Tensor: the processed point_embedding
110
+ torch.Tensor: the processed image_embedding
111
+ """
112
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
113
+ bs, c, h, w = image_embedding.shape
114
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
115
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
116
+
117
+ # Prepare queries
118
+ queries = point_embedding
119
+ keys = image_embedding
120
+
121
+ # Apply transformer blocks and final layernorm
122
+ for layer in self.layers:
123
+ queries, keys = layer(
124
+ queries=queries,
125
+ keys=keys,
126
+ query_pe=point_embedding,
127
+ key_pe=image_pe,
128
+ )
129
+
130
+ # Apply the final attention layer from the points to the image
131
+ q = queries + point_embedding
132
+ k = keys + image_pe
133
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
134
+ queries = queries + attn_out
135
+ queries = self.norm_final_attn(queries)
136
+
137
+ return queries, keys
138
+
139
+
140
+ class TwoWayAttentionBlock(nn.Module):
141
+ def __init__(
142
+ self,
143
+ embedding_dim: int,
144
+ num_heads: int,
145
+ mlp_dim: int = 2048,
146
+ activation: Type[nn.Module] = nn.ReLU,
147
+ attention_downsample_rate: int = 2,
148
+ skip_first_layer_pe: bool = False,
149
+ ) -> None:
150
+ """
151
+ A transformer block with four layers: (1) self-attention of sparse
152
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
153
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
154
+ inputs.
155
+
156
+ Arguments:
157
+ embedding_dim (int): the channel dimension of the embeddings
158
+ num_heads (int): the number of heads in the attention layers
159
+ mlp_dim (int): the hidden dimension of the mlp block
160
+ activation (nn.Module): the activation of the mlp block
161
+ skip_first_layer_pe (bool): skip the PE on the first layer
162
+ """
163
+ super().__init__()
164
+ self.self_attn = Attention(embedding_dim, num_heads)
165
+ self.norm1 = nn.LayerNorm(embedding_dim)
166
+
167
+ self.cross_attn_token_to_image = Attention(
168
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
169
+ )
170
+ self.norm2 = nn.LayerNorm(embedding_dim)
171
+
172
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
173
+ self.norm3 = nn.LayerNorm(embedding_dim)
174
+
175
+ self.norm4 = nn.LayerNorm(embedding_dim)
176
+ self.cross_attn_image_to_token = Attention(
177
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
178
+ )
179
+
180
+ self.skip_first_layer_pe = skip_first_layer_pe
181
+
182
+ def forward(
183
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
184
+ ) -> Tuple[Tensor, Tensor]:
185
+ # Self attention block
186
+ if self.skip_first_layer_pe:
187
+ queries = self.self_attn(q=queries, k=queries, v=queries)
188
+ else:
189
+ q = queries + query_pe
190
+ attn_out = self.self_attn(q=q, k=q, v=queries)
191
+ queries = queries + attn_out
192
+ queries = self.norm1(queries)
193
+
194
+ # Cross attention block, tokens attending to image embedding
195
+ q = queries + query_pe
196
+ k = keys + key_pe
197
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
198
+ queries = queries + attn_out
199
+ queries = self.norm2(queries)
200
+
201
+ # MLP block
202
+ mlp_out = self.mlp(queries)
203
+ queries = queries + mlp_out
204
+ queries = self.norm3(queries)
205
+
206
+ # Cross attention block, image embedding attending to tokens
207
+ q = queries + query_pe
208
+ k = keys + key_pe
209
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
210
+ keys = keys + attn_out
211
+ keys = self.norm4(keys)
212
+
213
+ return queries, keys
214
+
215
+
216
+ class Attention(nn.Module):
217
+ """
218
+ An attention layer that allows for downscaling the size of the embedding
219
+ after projection to queries, keys, and values.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ embedding_dim: int,
225
+ num_heads: int,
226
+ downsample_rate: int = 1,
227
+ ) -> None:
228
+ super().__init__()
229
+ self.embedding_dim = embedding_dim
230
+ self.internal_dim = embedding_dim // downsample_rate
231
+ self.num_heads = num_heads
232
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
233
+
234
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
235
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
236
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
237
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
238
+
239
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
240
+ b, n, c = x.shape
241
+ x = x.reshape(b, n, num_heads, c // num_heads)
242
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
243
+
244
+ def _recombine_heads(self, x: Tensor) -> Tensor:
245
+ b, n_heads, n_tokens, c_per_head = x.shape
246
+ x = x.transpose(1, 2)
247
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
248
+
249
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
250
+ # Input projections
251
+ q = self.q_proj(q)
252
+ k = self.k_proj(k)
253
+ v = self.v_proj(v)
254
+
255
+ # Separate into heads
256
+ q = self._separate_heads(q, self.num_heads)
257
+ k = self._separate_heads(k, self.num_heads)
258
+ v = self._separate_heads(v, self.num_heads)
259
+
260
+ # Attention
261
+ _, _, _, c_per_head = q.shape
262
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
263
+ attn = attn / math.sqrt(c_per_head)
264
+ attn = torch.softmax(attn, dim=-1)
265
+
266
+ # Get output
267
+ out = attn @ v
268
+ out = self._recombine_heads(out)
269
+ out = self.out_proj(out)
270
+
271
+ return out
models/segface/models/utils_models.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Callable, Tuple, Optional
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import functools
5
+ import numpy as np
6
+
7
+ @functools.lru_cache(maxsize=128)
8
+ def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
9
+ yy, xx = torch.meshgrid(torch.arange(h).float(),
10
+ torch.arange(w).float(),
11
+ indexing='ij')
12
+ return yy, xx
13
+
14
+
15
+ def _forge_grid(batch_size: int, device: torch.device,
16
+ output_shape: Tuple[int, int],
17
+ fn: Callable[[torch.Tensor], torch.Tensor]
18
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ """ Forge transform maps with a given function `fn`.
20
+
21
+ Args:
22
+ output_shape (tuple): (b, h, w, ...).
23
+ fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
24
+ a bxnx2 array and outputs the transformed bxnx2 array. Both input
25
+ and output store (x, y) coordinates.
26
+
27
+ Note:
28
+ both input and output arrays of `fn` should store (y, x) coordinates.
29
+
30
+ Returns:
31
+ Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
32
+ pixel (y, x) or coordinate (x, y),
33
+ `(X[y, x], Y[y, x]) = fn([x, y])`
34
+ """
35
+ h, w, *_ = output_shape
36
+ yy, xx = _meshgrid(h, w) # h x w
37
+ yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
38
+ xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
39
+
40
+ in_xxyy = torch.stack(
41
+ [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
42
+ out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
43
+ return out_xxyy.reshape(batch_size, h, w, 2)
44
+
45
+ def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
46
+ warp_factor: float, warped_shape: Tuple[int, int]):
47
+ """ Inverted tanh-warp function.
48
+
49
+ Args:
50
+ coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
51
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
52
+ from the original image to the aligned yet not-warped image.
53
+ warp_factor (float): The warp factor.
54
+ 0 means linear transform, 1 means full tanh warp.
55
+ warped_shape (tuple): [height, width].
56
+
57
+ Returns:
58
+ torch.Tensor: b x n x 2 (x, y). The original coordinates.
59
+ """
60
+ h, w, *_ = warped_shape
61
+ # h -= 1
62
+ # w -= 1
63
+
64
+ w_h = torch.tensor([[w, h]]).to(coords)
65
+
66
+ if warp_factor > 0:
67
+ # normalize coordinates to [-1, +1]
68
+ coords = coords / w_h * 2 - 1
69
+
70
+ nl_part1 = coords > 1.0 - warp_factor
71
+ nl_part2 = coords < -1.0 + warp_factor
72
+
73
+ ret_nl_part1 = _safe_arctanh(
74
+ (coords - 1.0 + warp_factor) /
75
+ warp_factor) * warp_factor + \
76
+ 1.0 - warp_factor
77
+ ret_nl_part2 = _safe_arctanh(
78
+ (coords + 1.0 - warp_factor) /
79
+ warp_factor) * warp_factor - \
80
+ 1.0 + warp_factor
81
+
82
+ coords = torch.where(nl_part1, ret_nl_part1,
83
+ torch.where(nl_part2, ret_nl_part2, coords))
84
+
85
+ # denormalize
86
+ coords = (coords + 1) / 2 * w_h
87
+
88
+ coords_homo = torch.cat(
89
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
90
+
91
+ inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
92
+ # inv_matrix = np.linalg.inv(matrix)
93
+ coords_homo = torch.bmm(
94
+ coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
95
+ return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
96
+
97
+
98
+ def tanh_warp_transform(
99
+ coords: torch.Tensor, matrix: torch.Tensor,
100
+ warp_factor: float, warped_shape: Tuple[int, int]):
101
+ """ Tanh-warp function.
102
+
103
+ Args:
104
+ coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
105
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
106
+ from the original image to the aligned yet not-warped image.
107
+ warp_factor (float): The warp factor.
108
+ 0 means linear transform, 1 means full tanh warp.
109
+ warped_shape (tuple): [height, width].
110
+
111
+ Returns:
112
+ torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
113
+ """
114
+ h, w, *_ = warped_shape
115
+ # h -= 1
116
+ # w -= 1
117
+ w_h = torch.tensor([[w, h]]).to(coords)
118
+
119
+ coords_homo = torch.cat(
120
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
121
+
122
+ coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
123
+ coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
124
+
125
+ if warp_factor > 0:
126
+ # normalize coordinates to [-1, +1]
127
+ coords = coords / w_h * 2 - 1
128
+
129
+ nl_part1 = coords > 1.0 - warp_factor
130
+ nl_part2 = coords < -1.0 + warp_factor
131
+
132
+ ret_nl_part1 = torch.tanh(
133
+ (coords - 1.0 + warp_factor) /
134
+ warp_factor) * warp_factor + \
135
+ 1.0 - warp_factor
136
+ ret_nl_part2 = torch.tanh(
137
+ (coords + 1.0 - warp_factor) /
138
+ warp_factor) * warp_factor - \
139
+ 1.0 + warp_factor
140
+
141
+ coords = torch.where(nl_part1, ret_nl_part1,
142
+ torch.where(nl_part2, ret_nl_part2, coords))
143
+
144
+ # denormalize
145
+ coords = (coords + 1) / 2 * w_h
146
+
147
+ return coords
148
+
149
+ def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
150
+ warped_shape: Tuple[int, int],
151
+ orig_shape: Tuple[int, int]):
152
+ """
153
+ Args:
154
+ matrix: bx3x3 matrix.
155
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
156
+ `warp_factor=0.0` represents a cropping.
157
+ warped_shape: The target image shape to transform to.
158
+
159
+ Returns:
160
+ torch.Tensor: b x h x w x 2 (x, y).
161
+ """
162
+ orig_h, orig_w, *_ = orig_shape
163
+ w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
164
+ return _forge_grid(
165
+ matrix.size(0), matrix.device,
166
+ warped_shape,
167
+ functools.partial(inverted_tanh_warp_transform,
168
+ matrix=matrix,
169
+ warp_factor=warp_factor,
170
+ warped_shape=warped_shape)) / w_h*2-1
171
+
172
+
173
+ def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
174
+ warped_shape: Tuple[int, int],
175
+ orig_shape: Tuple[int, int]):
176
+ """
177
+ Args:
178
+ matrix: bx3x3 matrix.
179
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
180
+ `warp_factor=0.0` represents a cropping.
181
+ warped_shape: The target image shape to transform to.
182
+ orig_shape: The original image shape that is transformed from.
183
+
184
+ Returns:
185
+ torch.Tensor: b x h x w x 2 (x, y).
186
+ """
187
+ h, w, *_ = warped_shape
188
+ w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
189
+ return _forge_grid(
190
+ matrix.size(0), matrix.device,
191
+ orig_shape,
192
+ functools.partial(tanh_warp_transform,
193
+ matrix=matrix,
194
+ warp_factor=warp_factor,
195
+ warped_shape=warped_shape)) / w_h * 2-1
196
+
197
+ def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
198
+ return torch.clamp(x, -1+eps, 1-eps).arctanh()
199
+
200
+ def get_similarity_transform_matrix(
201
+ from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
202
+ """
203
+ Args:
204
+ from_pts, to_pts: b x n x 2
205
+
206
+ Returns:
207
+ torch.Tensor: b x 3 x 3
208
+ """
209
+ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
210
+ mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
211
+
212
+ a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
213
+ c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
214
+
215
+ to_delta = to_pts - mto
216
+ from_delta = from_pts - mfrom
217
+ c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
218
+ :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
219
+
220
+ a = c1 / a1
221
+ b = c2 / a1
222
+ dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
223
+ dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
224
+
225
+ ones_pl = torch.ones_like(a1)
226
+ zeros_pl = torch.zeros_like(a1)
227
+
228
+ return torch.stack([
229
+ a, b, dx,
230
+ -b, a, dy,
231
+ zeros_pl, zeros_pl, ones_pl,
232
+ ], dim=-1).reshape(-1, 3, 3)
233
+
234
+
235
+ @functools.lru_cache()
236
+ def _standard_face_pts():
237
+ pts = torch.tensor([
238
+ 196.0, 226.0,
239
+ 316.0, 226.0,
240
+ 256.0, 286.0,
241
+ 220.0, 360.4,
242
+ 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
243
+ return torch.reshape(pts, (5, 2))
244
+
245
+
246
+ def get_face_align_matrix(
247
+ face_pts: torch.Tensor, target_shape: Tuple[int, int],
248
+ target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
249
+ target_pts: Optional[torch.Tensor] = None):
250
+
251
+ if target_pts is None:
252
+ with torch.no_grad():
253
+ std_pts = _standard_face_pts().to(face_pts) # [-1 1]
254
+ h, w, *_ = target_shape
255
+ target_pts = (std_pts * target_face_scale + 1) * \
256
+ torch.tensor([w-1, h-1]).to(face_pts) / 2.0
257
+ if offset_xy is not None:
258
+ target_pts[:, 0] += offset_xy[0]
259
+ target_pts[:, 1] += offset_xy[1]
260
+ else:
261
+ target_pts = target_pts.to(face_pts)
262
+
263
+ if target_pts.dim() == 2:
264
+ target_pts = target_pts.unsqueeze(0)
265
+ if target_pts.size(0) == 1:
266
+ target_pts = target_pts.broadcast_to(face_pts.shape)
267
+
268
+ assert target_pts.shape == face_pts.shape
269
+
270
+ return get_similarity_transform_matrix(face_pts, target_pts)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.4
2
+ torchvision>=0.19
3
+ numpy>=1.26
4
+ Pillow>=10.0
5
+ huggingface_hub>=0.30
training_run_summary.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_dir": "hair_mask_dataset/runs/segface_hair_budget_4090",
3
+ "model_name": "swin_base",
4
+ "prepared_root": "/workspace/runpod_upload_ready/data/aihub_hairmask_hq_budget_50k",
5
+ "raw_root": "/workspace/runpod_upload_ready/data/aihub_korean_hairstyle_hq_raw",
6
+ "epochs_completed": 10,
7
+ "best_epoch": 7,
8
+ "best_val_iou": 0.9486894006725745,
9
+ "best_val_dice": 0.9735556454363521,
10
+ "best_val_precision": 0.9723250788834037,
11
+ "best_val_recall": 0.9751487422222148,
12
+ "last_epoch": 10,
13
+ "last_train_loss": 0.028126267597526313,
14
+ "last_val_loss": 0.028694584750384094,
15
+ "last_val_iou": 0.9486362328742441,
16
+ "last_val_dice": 0.9735264129781702,
17
+ "last_val_precision": 0.9721686440444964,
18
+ "last_val_recall": 0.9752568952148782,
19
+ "avg_epoch_sec": 3546.4520416259766,
20
+ "train_count": 50000,
21
+ "val_count": 5000,
22
+ "test_count": 0,
23
+ "checkpoint_files": [
24
+ "best.pt",
25
+ "epoch_001.pt",
26
+ "epoch_002.pt",
27
+ "epoch_003.pt",
28
+ "epoch_004.pt",
29
+ "epoch_005.pt",
30
+ "epoch_006.pt",
31
+ "epoch_007.pt",
32
+ "epoch_008.pt",
33
+ "epoch_009.pt",
34
+ "epoch_010.pt",
35
+ "last.pt"
36
+ ],
37
+ "plot_path": "hair_mask_dataset/runs/segface_hair_budget_4090/plots/training_curves.png",
38
+ "latest_preview_path": "hair_mask_dataset/runs/segface_hair_budget_4090/previews/epoch_010.png",
39
+ "submit_date": "2026-03-17",
40
+ "github_url": "https://github.com/skn-ai22-251029/SKN22-Final-1Team-AI",
41
+ "team_members": [
42
+ "이병재",
43
+ "장완식",
44
+ "최정환",
45
+ "문승준"
46
+ ]
47
+ }