Spaces:
Sleeping
Sleeping
| # app.py | |
| import os, json | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForSemanticSegmentation, | |
| SegformerImageProcessor, | |
| ) | |
| import gradio as gr | |
| # ===== Config ===== | |
| MODEL_ID = "Itbanque/fashion_segformer" | |
| PROCESSOR_ID = MODEL_ID | |
| # ===== Load processor ===== | |
| try: | |
| processor = SegformerImageProcessor.from_pretrained(PROCESSOR_ID) | |
| except Exception: | |
| # 兜底:没有 preprocessor_config.json 时,手动构造 | |
| processor = SegformerImageProcessor( | |
| size={"height": 512, "width": 512}, | |
| do_resize=True, | |
| do_normalize=True | |
| ) | |
| # ===== Load model ===== | |
| try: | |
| cfg = AutoConfig.from_pretrained(MODEL_ID) | |
| model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID, config=cfg) | |
| except Exception: | |
| # 兼容老的只存了权重的目录 | |
| model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if (device == "cuda") else torch.float32 | |
| model.to(device=device, dtype=dtype) | |
| model.eval() | |
| # ===== id2label / palette ===== | |
| id2label = getattr(model.config, "id2label", None) | |
| if isinstance(id2label, dict): | |
| id2label = {int(k): v for k, v in id2label.items()} | |
| else: | |
| id2label = {i: str(i) for i in range(model.config.num_labels)} | |
| NUM_CLASSES = len(id2label) | |
| def make_palette(n: int) -> np.ndarray: | |
| # 固定随机种子,稳定配色;0类设为黑色 | |
| rng = np.random.default_rng(0) | |
| colors = rng.integers(0, 255, size=(n, 3), dtype=np.uint8) | |
| colors[0] = np.array([0, 0, 0], dtype=np.uint8) | |
| return colors | |
| PALETTE = make_palette(NUM_CLASSES) | |
| # ===== Inference ===== | |
| def predict(pil_img: Image.Image, alpha: float = 0.5, show_overlay: bool = True): | |
| if pil_img is None: | |
| return None, None | |
| img = pil_img.convert("RGB") | |
| W, H = img.size | |
| inputs = processor(images=img, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(device, dtype=dtype) | |
| outputs = model(pixel_values=pixel_values) | |
| logits = outputs.logits # (1, C, h, w) | |
| # 上采样到原图尺寸再 argmax | |
| up = torch.nn.functional.interpolate( | |
| logits, size=(H, W), mode="bilinear", align_corners=False | |
| ) | |
| pred = up.argmax(dim=1)[0].to(torch.uint8).cpu().numpy() # (H, W) | |
| # 生成彩色掩码 & 叠加 | |
| color_mask = PALETTE[pred] # (H, W, 3) | |
| mask_pil = Image.fromarray(color_mask, mode="RGB") | |
| if show_overlay: | |
| overlay = Image.blend(img, mask_pil, float(alpha)) | |
| return overlay, mask_pil | |
| else: | |
| return None, mask_pil | |
| # ===== Gradio UI ===== | |
| with gr.Blocks(title="Fashion Segmentation") as demo: | |
| gr.Markdown("## Fashion Segmentation\nUpload an image and run SegFormer inference.") | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Upload image") | |
| with gr.Column(): | |
| alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha") | |
| show_overlay = gr.Checkbox(value=True, label="Show overlay") | |
| btn = gr.Button("Predict", variant="primary") | |
| with gr.Row(): | |
| out_overlay = gr.Image(label="Overlay", interactive=False) | |
| out_mask = gr.Image(label="Colored mask", interactive=False) | |
| def _run(image, a, show): | |
| return predict(image, alpha=a, show_overlay=show) | |
| btn.click(_run, inputs=[inp, alpha, show_overlay], outputs=[out_overlay, out_mask]) | |
| if __name__ == "__main__": | |
| # 本地运行:python app.py | |
| # HF Spaces:把本文件命名为 app.py 即可 | |
| demo.launch() |