#!/usr/bin/env python3 from __future__ import annotations import inspect import os import re import sys import threading from pathlib import Path from typing import Any, Optional import gradio as gr from PIL import Image, ImageDraw, UnidentifiedImageError # HF Spaces 默认开启 SSR(GRADIO_SSR_MODE=True),易导致首屏布局/滚动异常;与 launch(ssr_mode=False) 对齐。 os.environ["GRADIO_SSR_MODE"] = "False" try: import spaces except ImportError: class _SpacesCompat: @staticmethod def GPU(*_args: Any, **_kwargs: Any): def decorator(func): return func return decorator spaces = _SpacesCompat() APP_DIR = Path(__file__).resolve().parent IMAGE_DIR = APP_DIR / "imgs" VALID_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} MODEL_ID_ENV = "MODEL_ID" HF_TOKEN_ENV = "HF_TOKEN" DEFAULT_MODEL_ID = "RichardChenZH/MedForge-Reasoner" DEFAULT_SYSTEM = ( "You are an expert in medical image forensics. Analyze the provided image to " "determine if it is a deepfake or authentic. First, perform a step-by-step " "examination of the image content, looking for artifacts, inconsistencies, or " "biological implausibilities. Use tags to articulate your reasoning " "process. If you identify manipulated regions, localize them using bounding " "boxes within your reasoning. Conclude your analysis with a final classification." ) DEFAULT_USER = "Is this image deepfake or real?" INFER_SIZE = (1024, 1024) STATUS_DETECTED = "Detected deepfake region" STATUS_DETECTED_NO_BBOX = "Detected deepfake region (no bbox)" STATUS_NO_REGION = "No deepfake region detected" STATUS_UNKNOWN = "Unknown result, manual review required" STATUS_INVALID_INPUT = "Invalid input" STATUS_SYSTEM_ERROR = "System error" STATUS_MODEL_ERROR = "Model error" GENERATION_TEMPERATURE = 0.0 _MODEL = None _PROCESSOR = None _TORCH = None _LOAD_ERROR: Optional[Exception] = None _MODEL_LOCK = threading.Lock() class BBoxError(RuntimeError): pass class InputError(ValueError): pass def _describe_import_error(exc: ImportError) -> str: """将 ImportError 映射为可操作的提示(保留原始信息,避免笼统归咎 torch/transformers)。""" name = getattr(exc, "name", None) msg = str(exc).strip() if "torchvision" in msg.lower(): return ( "需要安装 torchvision,且版本须与 torch 匹配(例如 torch 2.9.1 与 torchvision 0.24.1;" "见 requirements.txt)。" f" 原始异常: {msg}" ) if name: return ( f"无法导入 Python 模块「{name}」(该包未安装或其传递依赖缺失)。" f"原始异常: {msg}" ) return f"导入失败: {msg}" def _format_model_load_error(exc: BaseException, model_id: str) -> str: try: from huggingface_hub.errors import ( GatedRepoError, HfHubHTTPError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, ) except ImportError: GatedRepoError = HfHubHTTPError = LocalEntryNotFoundError = None RepositoryNotFoundError = RevisionNotFoundError = None if isinstance(exc, ImportError): return _describe_import_error(exc) if isinstance(exc, RuntimeError) and "accelerate" in str(exc).lower(): return ( "缺少 accelerate 依赖,但当前模型加载配置使用了 device_map=\"auto\"。" "请在 Space 依赖中安装 accelerate。" ) if RepositoryNotFoundError and isinstance(exc, RepositoryNotFoundError): return f"模型仓库不存在: {model_id}" if RevisionNotFoundError and isinstance(exc, RevisionNotFoundError): return f"模型版本不存在: {model_id}" if GatedRepoError and isinstance(exc, GatedRepoError): return f"模型仓库需要访问授权: {model_id}。请检查仓库权限或 HF_TOKEN。" if LocalEntryNotFoundError and isinstance(exc, LocalEntryNotFoundError): return ( f"模型下载失败: 无法从 Hugging Face Hub 获取 {model_id}," "请检查 Space 网络连通性。" ) if HfHubHTTPError and isinstance(exc, HfHubHTTPError): status_code = getattr(getattr(exc, "response", None), "status_code", None) if status_code == 401: return f"模型仓库无访问权限: {model_id}。请检查 HF_TOKEN。" if status_code == 403: return f"模型仓库被拒绝访问: {model_id}。请检查仓库权限或 HF_TOKEN。" if status_code == 404: return f"模型仓库不存在: {model_id}" return f"模型下载失败: Hub 返回 HTTP {status_code or 'unknown'}。" message = str(exc) if "Qwen3VLForConditionalGeneration" in message or "AutoModelForImageTextToText" in message: return ( "当前 transformers 版本不兼容,缺少 Qwen3VLForConditionalGeneration / " "AutoModelForImageTextToText。" ) if "trust_remote_code" in message: return "模型依赖 remote code 加载失败。" lowered_msg = message.lower() if "torchvision" in lowered_msg and ( "not found" in lowered_msg or "requires" in lowered_msg ): return ( "缺少 torchvision(须与 torch 版本匹配,例如 torch 2.9.1 对应 torchvision 0.24.1;" "已在 requirements.txt 中声明)。" f" 原文: {message}" ) if isinstance(exc, OSError): lowered = message.lower() if "401" in lowered or "403" in lowered: return f"模型仓库无访问权限: {model_id}。请检查 HF_TOKEN。" if "404" in lowered or "not found" in lowered: return f"模型仓库不存在: {model_id}" if "network" in lowered or "connection" in lowered or "offline" in lowered: return f"模型下载失败: 无法连接 Hugging Face Hub 获取 {model_id}。" return f"模型加载失败: {exc}" def list_example_paths() -> list[Path]: if not IMAGE_DIR.is_dir(): return [] paths = [] for path in sorted(IMAGE_DIR.iterdir()): if not path.is_file(): continue if path.suffix.lower() not in VALID_IMAGE_SUFFIXES: continue paths.append(path) return paths EXAMPLE_PATHS = list_example_paths() EXAMPLE_GALLERY_ITEMS = [(str(path), path.stem.replace("-", " ")) for path in EXAMPLE_PATHS] def _load_image_from_path(path: Path) -> Image.Image: try: image = Image.open(path) image.load() except UnidentifiedImageError as exc: raise ValueError(f"示例图无法识别: {path.name}") from exc return image.convert("RGB") def _normalize_input_image(image: Image.Image | str | Path | None) -> Image.Image: if image is None: raise InputError("请先选择示例图或上传图片。") if isinstance(image, Image.Image): return image.convert("RGB") if isinstance(image, (str, Path)): path = Path(image) if not path.is_file(): raise InputError("输入图片不存在。") return _load_image_from_path(path) raise InputError("不支持的图片输入类型。") def _resize_to_infer_size(image: Image.Image) -> Image.Image: try: resample = Image.Resampling.LANCZOS except AttributeError: resample = Image.LANCZOS return image.resize(INFER_SIZE, resample) def _extract_deepfake_bbox_1024(text: str) -> tuple[float, float, float, float]: pattern = ( r'<\|object_ref_start\|>"deepfake"<\|object_ref_end\|>' r"\s*<\|box_start\|>\s*" r'x1="([^"]+)"\s*y1="([^"]+)"\s*x2="([^"]+)"\s*y2="([^"]+)"' r"\s*<\|box_end\|>" ) match = re.search(pattern, text) if match is None: raise ValueError("未在模型输出中找到 deepfake 的 bbox 标注") return tuple(float(v) for v in match.groups()) def _scale_bbox_to_original( bbox_1024: tuple[float, float, float, float], width: int, height: int ) -> tuple[int, int, int, int]: x1, y1, x2, y2 = bbox_1024 sx = width / INFER_SIZE[0] sy = height / INFER_SIZE[1] ox1 = int(round(x1 * sx)) oy1 = int(round(y1 * sy)) ox2 = int(round(x2 * sx)) oy2 = int(round(y2 * sy)) if ox1 >= ox2 or oy1 >= oy2: raise BBoxError( f"反算后的 bbox 非法: ({ox1}, {oy1}, {ox2}, {oy2}),原始 bbox={bbox_1024}" ) ox1 = max(0, min(ox1, width - 1)) oy1 = max(0, min(oy1, height - 1)) ox2 = max(0, min(ox2, width - 1)) oy2 = max(0, min(oy2, height - 1)) return ox1, oy1, ox2, oy2 def _draw_bbox(image: Image.Image, bbox: tuple[int, int, int, int]) -> Image.Image: output = image.copy() draw = ImageDraw.Draw(output) draw.rectangle(bbox, outline=(255, 0, 0), width=4) x1, y1, _, _ = bbox label_top = max(0, y1 - 18) draw.text((x1, label_top), "deepfake", fill=(255, 0, 0)) return output def _draw_status_watermark( image: Image.Image, text: str, color: tuple[int, int, int], ) -> Image.Image: output = image.copy() draw = ImageDraw.Draw(output) x, y = 12, 12 pad_x, pad_y = 8, 6 try: left, top, right, bottom = draw.textbbox((x, y), text) box = ( left - pad_x, top - pad_y, right + pad_x, bottom + pad_y, ) draw.rectangle(box, fill=(0, 0, 0)) except AttributeError: pass draw.text((x, y), text, fill=color) return output def _normalize_last_sentence(text: str) -> str: cleaned = re.sub(r"(?:<\|im_end\|>\s*)+$", "", text.strip(), flags=re.IGNORECASE).strip() if not cleaned: return "" lines = [line.strip() for line in re.split(r"\n+", cleaned) if line.strip()] last_line = lines[-1] if lines else cleaned normalized = re.sub(r"[.!?。!?\s]+$", "", last_line).strip().lower() normalized = re.sub(r"\s+", " ", normalized) return normalized def analyze_model_output(text: str) -> dict[str, Any]: try: bbox = _extract_deepfake_bbox_1024(text) return {"label": "deepfake", "bbox": bbox, "reason": "bbox_detected"} except ValueError: pass last_sentence = _normalize_last_sentence(text) if not last_sentence: return {"label": "unknown", "bbox": None, "reason": "unparseable_verdict"} has_deepfake = re.search(r"\bdeepfake\b", last_sentence) is not None has_real = re.search(r"\breal\b", last_sentence) is not None if has_deepfake and has_real: return {"label": "unknown", "bbox": None, "reason": "conflict"} if has_deepfake: return {"label": "deepfake", "bbox": None, "reason": "verdict_parsed"} if has_real: return {"label": "real", "bbox": None, "reason": "verdict_parsed"} if "deepfake" in last_sentence and "real" in last_sentence: return {"label": "unknown", "bbox": None, "reason": "conflict"} return {"label": "unknown", "bbox": None, "reason": "unparseable_verdict"} def _get_model_id() -> str: model_id = os.getenv(MODEL_ID_ENV, "").strip() if model_id: return model_id return DEFAULT_MODEL_ID def _get_hf_token() -> Optional[str]: token = os.getenv(HF_TOKEN_ENV, "").strip() return token or None def ensure_model_loaded() -> tuple[Any, Any, Any]: global _MODEL, _PROCESSOR, _TORCH, _LOAD_ERROR if _MODEL is not None and _PROCESSOR is not None: return _MODEL, _PROCESSOR, _TORCH with _MODEL_LOCK: if _MODEL is not None and _PROCESSOR is not None: return _MODEL, _PROCESSOR, _TORCH if _LOAD_ERROR is not None: raise RuntimeError(str(_LOAD_ERROR)) from _LOAD_ERROR try: import torch except ImportError as exc: _LOAD_ERROR = exc raise RuntimeError(_describe_import_error(exc)) from exc try: import transformers from transformers import AutoProcessor except ImportError as exc: _LOAD_ERROR = exc raise RuntimeError(_describe_import_error(exc)) from exc if not transformers.__version__: _LOAD_ERROR = RuntimeError("transformers 未正确安装。") raise _LOAD_ERROR model_id = _get_model_id() token = _get_hf_token() try: import accelerate # noqa: F401 except ImportError as exc: error = RuntimeError( f"{_describe_import_error(exc)} " '(当前加载使用 device_map="auto",需安装 accelerate。)' ) _LOAD_ERROR = error raise error from exc if hasattr(transformers, "Qwen3VLForConditionalGeneration"): model_cls = transformers.Qwen3VLForConditionalGeneration elif hasattr(transformers, "AutoModelForImageTextToText"): model_cls = transformers.AutoModelForImageTextToText else: _LOAD_ERROR = RuntimeError( "当前 transformers 版本缺少 Qwen3VLForConditionalGeneration / " "AutoModelForImageTextToText。" ) raise _LOAD_ERROR load_kwargs: dict[str, Any] = { "device_map": "auto", "trust_remote_code": True, } if token is not None: load_kwargs["token"] = token from_pretrained_params = inspect.signature(model_cls.from_pretrained).parameters if "dtype" in from_pretrained_params: load_kwargs["dtype"] = torch.bfloat16 else: load_kwargs["torch_dtype"] = torch.bfloat16 try: model = model_cls.from_pretrained(model_id, **load_kwargs) processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, token=token, ) except Exception as exc: error = RuntimeError(_format_model_load_error(exc, model_id)) _LOAD_ERROR = error raise error from exc _MODEL = model _PROCESSOR = processor _TORCH = torch return _MODEL, _PROCESSOR, _TORCH @spaces.GPU(duration=120) def _generate_text( image_for_infer: Image.Image, system_prompt: str, user_prompt: str, max_tokens: int, ) -> str: model, processor, torch = ensure_model_loaded() messages = [ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, { "role": "user", "content": [ {"type": "image", "image": image_for_infer}, {"type": "text", "text": user_prompt}, ], }, ] model_inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) model_inputs = { key: value.to(model.device) if hasattr(value, "to") else value for key, value in model_inputs.items() } if getattr(model, "generation_config", None) is not None: model.generation_config.do_sample = False model.generation_config.temperature = None model.generation_config.top_p = None model.generation_config.top_k = None generate_kwargs: dict[str, Any] = { "max_new_tokens": max_tokens, "do_sample": False, } if GENERATION_TEMPERATURE > 0: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = GENERATION_TEMPERATURE with torch.inference_mode(): generated_ids = model.generate( **model_inputs, **generate_kwargs, ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs["input_ids"], generated_ids) ] return processor.batch_decode( generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False, )[0] def startup_status() -> str: return ( f"**Model Configured** \n" f"- Source: `{_get_model_id()}` \n" f"- Examples: `{len(EXAMPLE_PATHS)}` \n" f"- Runtime: `ZeroGPU lazy load`" ) def select_example(evt: gr.SelectData) -> Image.Image: index = evt.index if isinstance(index, tuple): index = index[0] if not isinstance(index, int) or index < 0 or index >= len(EXAMPLE_PATHS): raise ValueError("示例选择无效。") return _load_image_from_path(EXAMPLE_PATHS[index]) def run_inference( image: Image.Image | str | Path | None, system_prompt: str, user_prompt: str, max_tokens: int, print_raw_output_for_test: bool, ) -> tuple[Image.Image, str, str, str]: raw_text = "" try: original_image = _normalize_input_image(image) system_prompt = system_prompt.strip() user_prompt = user_prompt.strip() if not system_prompt: raise InputError("system prompt 不能为空。") if "" not in user_prompt: raise InputError("user prompt 必须包含 占位符。") prompt_text = user_prompt.replace("", "").strip() if not prompt_text: raise InputError("user prompt 去掉 后不能为空。") if max_tokens <= 0: raise InputError("max_tokens 必须大于 0。") image_for_infer = _resize_to_infer_size(original_image) text = _generate_text(image_for_infer, system_prompt, prompt_text, int(max_tokens)) raw_text = text if print_raw_output_for_test: sys.stdout.write("[TEST][RAW_MODEL_OUTPUT_BEGIN]\n") sys.stdout.write(text) if not text.endswith("\n"): sys.stdout.write("\n") sys.stdout.write("[TEST][RAW_MODEL_OUTPUT_END]\n") sys.stdout.flush() analysis = analyze_model_output(text) label = analysis["label"] bbox_1024 = analysis["bbox"] if label == "deepfake" and bbox_1024 is not None: bbox_original = _scale_bbox_to_original( bbox_1024, original_image.width, original_image.height, ) annotated = _draw_bbox(original_image, bbox_original) bbox_text = ( f"x1={bbox_original[0]}, y1={bbox_original[1]}, " f"x2={bbox_original[2]}, y2={bbox_original[3]}" ) return annotated, text, bbox_text, STATUS_DETECTED if label == "deepfake": annotated = _draw_status_watermark( original_image, "deepfake (no bbox)", (255, 0, 0), ) return annotated, text, "-", STATUS_DETECTED_NO_BBOX if label == "real": annotated = _draw_status_watermark( original_image, "real", (50, 205, 50), ) return annotated, text, "-", STATUS_NO_REGION annotated = _draw_status_watermark( original_image, "unknown", (255, 191, 0), ) return annotated, text, "-", STATUS_UNKNOWN except InputError as exc: return None, "", "-", f"{STATUS_INVALID_INPUT}: {exc}" except BBoxError as exc: return None, raw_text, "-", f"{STATUS_SYSTEM_ERROR}: {exc}" except ValueError as exc: return None, raw_text, "-", f"{STATUS_MODEL_ERROR}: {exc}" except RuntimeError as exc: return None, raw_text, "-", f"{STATUS_MODEL_ERROR}: {exc}" except Exception as exc: return None, raw_text, "-", f"{STATUS_SYSTEM_ERROR}: {exc}" def build_demo() -> gr.Blocks: theme = gr.themes.Soft( primary_hue="cyan", secondary_hue="slate", neutral_hue="slate", ) css = """ :root { --panel: linear-gradient(180deg, rgba(240,248,250,0.94), rgba(248,250,252,0.98)); --line: rgba(15, 23, 42, 0.08); --accent: #0f766e; --text: #10212b; --muted: #52606d; } body, .gradio-container { font-family: "IBM Plex Sans", "Avenir Next", "Segoe UI", sans-serif; background: radial-gradient(circle at top left, rgba(14,165,233,0.10), transparent 28%), radial-gradient(circle at top right, rgba(16,185,129,0.10), transparent 22%), linear-gradient(180deg, #eef5f7 0%, #f7fafb 100%); color: var(--text); } .app-shell { max-width: 1220px; margin: 0 auto; } .hero { padding: 24px 28px 18px; border: 1px solid var(--line); border-radius: 24px; background: linear-gradient(135deg, rgba(255,255,255,0.96), rgba(236,253,245,0.90)); box-shadow: 0 18px 48px rgba(15, 23, 42, 0.08); } .hero h1 { margin: 0; font-size: 2rem; letter-spacing: -0.03em; } .hero p { margin: 10px 0 0; color: var(--muted); line-height: 1.65; } .hero .input-notice { margin-top: 14px; padding: 12px 14px; border-left: 4px solid #d97706; border-radius: 0 12px 12px 0; background: rgba(254, 243, 199, 0.55); color: #422006; line-height: 1.6; font-size: 0.95rem; } .hero .input-notice strong { color: #78350f; } .hero-paper-promo { margin: 14px 0 0; padding: 14px 16px; border-radius: 14px; border: 1px solid rgba(15, 118, 110, 0.22); background: linear-gradient(135deg, rgba(240, 253, 250, 0.95), rgba(236, 254, 255, 0.88)); font-size: 1.05rem; color: var(--text); line-height: 1.55; } .hero-paper-promo .acl { font-weight: 700; color: #0f766e; letter-spacing: 0.02em; } .hero-repo-note { margin: 14px 0 0; font-size: 0.86rem; color: var(--muted); line-height: 1.5; } .hero-repo-note a { color: #0e7490; text-decoration: underline; text-underline-offset: 2px; } .card { border: 1px solid var(--line); border-radius: 22px; background: var(--panel); box-shadow: 0 10px 30px rgba(15, 23, 42, 0.06); } .gallery-note { color: var(--muted); font-size: 0.95rem; margin-top: -8px; margin-bottom: 6px; } .status-box { border: 1px solid rgba(15, 118, 110, 0.18); background: rgba(240, 253, 250, 0.72); border-radius: 18px; padding: 14px 16px; min-height: 5.5rem; } html, body { height: auto; min-height: 100%; } body { overflow-y: auto !important; } .gradio-container { min-height: 100vh; overflow-y: auto !important; } """ with gr.Blocks(title="Medical Deepfake Detector") as demo: with gr.Column(elem_classes="app-shell"): gr.HTML( """

Medical Deepfake Detector

MedForge — presented at ACL 2026 Main. This demo is part of our work on medical-image deepfake detection; thanks for trying it.

Upload one image or choose an example below. The app keeps the original inference flow, parses the model-generated deepfake box, and renders the annotated result directly on the original image.

Input quality: Use full-resolution medical scan exports (e.g. from PACS, DICOM viewers, or your clinical workflow)—not thumbnails saved from websites. Website thumbnails compress and downsample the image; in practice this degrades detection reliability. File format: upload JPEG (.jpg / .jpeg) or PNG exports from your workflow.

Open-source code and resources for the paper: github.com/richardChenzhihui/ACL2026-MedForge.

""" ) status = gr.Markdown(value=startup_status(), elem_classes="status-box") with gr.Row(): with gr.Column(scale=5, elem_classes="card"): image_input = gr.Image( type="pil", label="Input Image", image_mode="RGB", height=420, ) run_button = gr.Button("Run Detection", variant="primary", size="lg") with gr.Accordion("Advanced Settings", open=False): system_prompt = gr.Textbox( value=DEFAULT_SYSTEM, label="System Prompt", lines=6, ) user_prompt = gr.Textbox( value=DEFAULT_USER, label="User Prompt", lines=2, ) max_tokens = gr.Slider( minimum=128, maximum=4096, step=64, value=2048, label="Max Tokens", ) print_raw_output_for_test = gr.Checkbox( value=False, label="测试模式:打印模型原始输出到终端日志", ) with gr.Column(scale=5, elem_classes="card"): annotated_image = gr.Image( type="pil", label="Annotated Result", image_mode="RGB", height=420, ) bbox_text = gr.Textbox(label="BBox", interactive=False) raw_response = gr.Textbox( label="Raw Model Output", lines=14, interactive=False, ) gr.Markdown("### Example Gallery") gr.Markdown( "Click any example to load it into the input panel. Uploading a new image will replace the current selection.", elem_classes="gallery-note", ) examples_gallery = gr.Gallery( value=EXAMPLE_GALLERY_ITEMS, label="Examples", columns=3, rows=1, height=260, allow_preview=True, preview=False, object_fit="cover", ) run_button.click( fn=run_inference, inputs=[ image_input, system_prompt, user_prompt, max_tokens, print_raw_output_for_test, ], outputs=[annotated_image, raw_response, bbox_text, status], show_progress="full", ) examples_gallery.select( fn=select_example, outputs=image_input, show_progress="hidden", ) demo.theme = theme demo.css = css return demo demo = build_demo() demo.queue() if __name__ == "__main__": demo.launch(theme=demo.theme, css=demo.css, ssr_mode=False)