# app.py import os, json import numpy as np from PIL import Image import torch from transformers import ( AutoConfig, AutoModelForSemanticSegmentation, SegformerImageProcessor, ) import gradio as gr # ===== Config ===== MODEL_ID = "Itbanque/fashion_segformer" PROCESSOR_ID = MODEL_ID # ===== Load processor ===== try: processor = SegformerImageProcessor.from_pretrained(PROCESSOR_ID) except Exception: # 兜底:没有 preprocessor_config.json 时,手动构造 processor = SegformerImageProcessor( size={"height": 512, "width": 512}, do_resize=True, do_normalize=True ) # ===== Load model ===== try: cfg = AutoConfig.from_pretrained(MODEL_ID) model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID, config=cfg) except Exception: # 兼容老的只存了权重的目录 model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if (device == "cuda") else torch.float32 model.to(device=device, dtype=dtype) model.eval() # ===== id2label / palette ===== id2label = getattr(model.config, "id2label", None) if isinstance(id2label, dict): id2label = {int(k): v for k, v in id2label.items()} else: id2label = {i: str(i) for i in range(model.config.num_labels)} NUM_CLASSES = len(id2label) def make_palette(n: int) -> np.ndarray: # 固定随机种子,稳定配色;0类设为黑色 rng = np.random.default_rng(0) colors = rng.integers(0, 255, size=(n, 3), dtype=np.uint8) colors[0] = np.array([0, 0, 0], dtype=np.uint8) return colors PALETTE = make_palette(NUM_CLASSES) # ===== Inference ===== @torch.no_grad() def predict(pil_img: Image.Image, alpha: float = 0.5, show_overlay: bool = True): if pil_img is None: return None, None img = pil_img.convert("RGB") W, H = img.size inputs = processor(images=img, return_tensors="pt") pixel_values = inputs["pixel_values"].to(device, dtype=dtype) outputs = model(pixel_values=pixel_values) logits = outputs.logits # (1, C, h, w) # 上采样到原图尺寸再 argmax up = torch.nn.functional.interpolate( logits, size=(H, W), mode="bilinear", align_corners=False ) pred = up.argmax(dim=1)[0].to(torch.uint8).cpu().numpy() # (H, W) # 生成彩色掩码 & 叠加 color_mask = PALETTE[pred] # (H, W, 3) mask_pil = Image.fromarray(color_mask, mode="RGB") if show_overlay: overlay = Image.blend(img, mask_pil, float(alpha)) return overlay, mask_pil else: return None, mask_pil # ===== Gradio UI ===== with gr.Blocks(title="Fashion Segmentation") as demo: gr.Markdown("## Fashion Segmentation\nUpload an image and run SegFormer inference.") with gr.Row(): inp = gr.Image(type="pil", label="Upload image") with gr.Column(): alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha") show_overlay = gr.Checkbox(value=True, label="Show overlay") btn = gr.Button("Predict", variant="primary") with gr.Row(): out_overlay = gr.Image(label="Overlay", interactive=False) out_mask = gr.Image(label="Colored mask", interactive=False) def _run(image, a, show): return predict(image, alpha=a, show_overlay=show) btn.click(_run, inputs=[inp, alpha, show_overlay], outputs=[out_overlay, out_mask]) if __name__ == "__main__": # 本地运行:python app.py # HF Spaces:把本文件命名为 app.py 即可 demo.launch()