yichuan-huang's picture
refactor: align dummy spaces.GPU signature with official API
e13a657
raw
history blame
14.4 kB
import os
import sys
# =========================
# CRITICAL: Import spaces FIRST before any CUDA-related packages
# ZeroGPU requires spaces to be imported before torch/diffusers
# =========================
def _is_hf_spaces_env() -> bool:
"""Detect if running in Hugging Face Spaces environment"""
return (os.getenv("SYSTEM", "").lower() == "spaces") or bool(
os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID")
)
IS_HF_SPACES = _is_hf_spaces_env()
try:
import spaces
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
# Create dummy module for local environment
from types import ModuleType
class _SpacesDummy:
@staticmethod
def GPU(duration: int = 60, **kwargs):
"""No-op decorator for local runtime"""
def decorator(fn):
return fn
return decorator
spaces = ModuleType("spaces")
spaces.GPU = _SpacesDummy.GPU
sys.modules["spaces"] = spaces
# =========================
# Now import CUDA-related packages
# =========================
import zipfile
import tempfile
import shutil
import uuid
from pathlib import Path
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import Flux2Pipeline
from diffusers.utils import load_image
import py7zr
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util import img_as_float
# =========================
# Config
# =========================
DEFAULT_PROMPT = (
"enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, "
"preserve original morphological structure, maintain authentic texture patterns, "
"minimal noise reduction while keeping fine details intact"
)
GUIDANCE_SCALE = 2.0
NUM_INFERENCE_STEPS = 30
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
# Quantized 4-bit model (requires GPU)
MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
TORCH_DTYPE = torch.bfloat16
# Global cached pipeline
_pipe = None
def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image):
"""Calculate PSNR and SSIM between original and enhanced images."""
orig_float = img_as_float(np.array(original))
enh_float = img_as_float(np.array(enhanced))
# Ensure same shape (crop to min overlap)
if orig_float.shape != enh_float.shape:
min_h = min(orig_float.shape[0], enh_float.shape[0])
min_w = min(orig_float.shape[1], enh_float.shape[1])
orig_float = orig_float[:min_h, :min_w]
enh_float = enh_float[:min_h, :min_w]
psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0)
if orig_float.ndim == 3:
ssim = structural_similarity(
orig_float, enh_float, data_range=1.0, channel_axis=-1
)
else:
ssim = structural_similarity(orig_float, enh_float, data_range=1.0)
return float(psnr), float(ssim)
def extract_archive(archive_path: str, extract_to: str):
"""Extract zip or 7z archive."""
file_ext = Path(archive_path).suffix.lower()
if file_ext == ".zip":
with zipfile.ZipFile(archive_path, "r") as z:
z.extractall(extract_to)
elif file_ext == ".7z":
with py7zr.SevenZipFile(archive_path, mode="r") as a:
a.extractall(path=extract_to)
else:
raise ValueError(f"Unsupported archive format: {file_ext}")
def find_images(directory: str):
"""Recursively find all images in a directory."""
image_files = []
for root, _, files in os.walk(directory):
for f in files:
if Path(f).suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(os.path.join(root, f))
return image_files
def _get_pipe():
"""
Lazy-load the pipeline.
IMPORTANT:
- On HF ZeroGPU: must be called inside a @spaces.GPU-decorated function runtime.
- Locally: uses local CUDA.
"""
global _pipe
if _pipe is None:
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA GPU is not available in the current runtime. "
"This bnb-4bit model requires GPU. "
"On HF ZeroGPU, ensure inference is inside a @spaces.GPU function."
)
_pipe = Flux2Pipeline.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
)
_pipe.to("cuda")
return _pipe
def _process_images_impl(
files, prompt, guidance_scale, num_steps, progress=gr.Progress()
):
"""
Shared implementation.
Returns 4 outputs:
gallery, files_download, zip_download, summary
"""
if not files:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
"Please upload at least one file.",
)
if not prompt or prompt.strip() == "":
prompt = DEFAULT_PROMPT
guidance_scale = float(guidance_scale)
num_steps = int(num_steps)
# Temp for extraction/staging input
temp_dir = tempfile.mkdtemp(prefix="flux_in_")
# Output MUST remain for Gradio downloads
run_id = uuid.uuid4().hex[:10]
output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
has_archive = False
try:
progress(0.0, desc="Preparing files...")
all_images = [] # (img_path, rel_path)
for file_obj in files:
file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj)
file_ext = Path(file_path).suffix.lower()
if file_ext in [".zip", ".7z"]:
has_archive = True
progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
extract_dir = os.path.join(temp_dir, Path(file_path).stem)
os.makedirs(extract_dir, exist_ok=True)
extract_archive(file_path, extract_dir)
images = find_images(extract_dir)
for img_path in images:
rel_path = os.path.relpath(img_path, extract_dir)
all_images.append((img_path, rel_path))
elif file_ext in IMAGE_EXTENSIONS:
all_images.append((file_path, Path(file_path).name))
if not all_images:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
"No valid images found in uploaded files.",
)
total_images = len(all_images)
progress(0.10, desc=f"Found {total_images} images. Loading model...")
pipe = _get_pipe()
results = []
metrics_lines = []
progress(0.15, desc="Enhancing images...")
for idx, (img_path, rel_path) in enumerate(all_images):
progress(
0.15 + 0.75 * (idx / max(1, total_images)),
desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
)
input_image = load_image(img_path)
enhanced_image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
).images[0]
psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
out_path = os.path.join(output_dir, rel_path)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
# add _flux suffix
out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
out_path = os.path.join(os.path.dirname(out_path), out_name)
enhanced_image.save(out_path)
results.append(
{
"original": input_image,
"enhanced": enhanced_image,
"filename": rel_path,
"output_path": out_path,
"psnr": psnr,
"ssim": ssim,
}
)
metrics_lines.append(f"{rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}")
avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
summary = (
"✅ Processing completed!\n\n"
f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n"
f"Spaces module: {'Installed' if HAS_SPACES else 'Not installed'}\n"
f"GPU used: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}\n\n"
f"Total images processed: {total_images}\n"
f"Average PSNR: {avg_psnr:.2f} dB\n"
f"Average SSIM: {avg_ssim:.4f}\n\n"
"Individual metrics:\n" + "\n".join(metrics_lines)
)
# Gallery preview
gallery_images = []
for r in results[:10]:
gallery_images.append((r["original"], f"Original: {r['filename']}"))
gallery_images.append(
(
r["enhanced"],
f"Enhanced: {r['filename']}\nPSNR: {r['psnr']:.2f} dB, SSIM: {r['ssim']:.4f}",
)
)
if has_archive:
progress(0.92, desc="Packaging ZIP...")
output_zip_path = os.path.join(
tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip"
)
with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, fs in os.walk(output_dir):
for f in fs:
fp = os.path.join(root, f)
arcname = os.path.relpath(fp, output_dir)
zipf.write(fp, arcname)
progress(1.0, desc="Done!")
return (
gallery_images,
gr.update(value=None, visible=False),
gr.update(value=output_zip_path, visible=True),
summary,
)
else:
enhanced_paths = [r["output_path"] for r in results]
progress(1.0, desc="Done!")
return (
gallery_images,
gr.update(value=enhanced_paths, visible=True),
gr.update(value=None, visible=False),
summary,
)
except Exception as e:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
f"Error during processing: {str(e)}",
)
finally:
# Cleanup input temp only
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# DO NOT delete output_dir; needed for downloads
# =========================
# CRITICAL: Always define a @spaces.GPU function at top-level
# (ZeroGPU startup scanner will now ALWAYS find it)
# =========================
@spaces.GPU(duration=1500)
def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
# =========================
# Gradio UI
# =========================
with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo:
gr.Markdown(
f"""
# 🔬 Flux Microscopy Image Enhancement
Upload microscopy images (individual files or compressed archives) for AI-powered enhancement.
**Supported formats:**
- Images: JPG, PNG, BMP, TIFF
- Archives: ZIP, 7Z (will process all images inside)
**Download behavior:**
- Upload **only images** → download enhanced **image files directly** (`*_flux` suffix)
- Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix)
**Runtime info:**
- Environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
- Spaces module: **{"Installed" if HAS_SPACES else "Not installed (using dummy)"}**
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload Images or Archive (ZIP/7Z)",
file_count="multiple",
file_types=["image", ".zip", ".7z"],
)
prompt_input = gr.Textbox(
label="Enhancement Prompt",
placeholder="Enter custom prompt or leave empty for default",
value=DEFAULT_PROMPT,
lines=3,
)
gr.Markdown("### Enhancement Parameters")
guidance_scale_input = gr.Slider(
minimum=1.0,
maximum=5.0,
value=GUIDANCE_SCALE,
step=0.1,
label="Guidance Scale",
info="Lower = more conservative, higher = more creative",
)
num_steps_input = gr.Slider(
minimum=10,
maximum=50,
value=NUM_INFERENCE_STEPS,
step=1,
label="Inference Steps",
info="More steps = better quality but slower",
)
process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg")
with gr.Column(scale=2):
gallery_output = gr.Gallery(
label="Results Preview (Original vs Enhanced)",
columns=2,
rows=2,
height="auto",
object_fit="contain",
)
files_output = gr.Files(
label="📥 Download Enhanced Images (Files)", visible=False
)
zip_output = gr.File(
label="📥 Download Enhanced Images (ZIP)", visible=False
)
summary_output = gr.Textbox(
label="Processing Summary & Metrics",
lines=10,
max_lines=20,
)
process_btn.click(
fn=process_images,
inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input],
outputs=[gallery_output, files_output, zip_output, summary_output],
)
gr.Markdown(
"""
---
### Default Parameters
- **Guidance Scale**: 2.0 (conservative for natural enhancement)
- **Inference Steps**: 30 (balanced quality and speed)
### Quality Metrics
- **PSNR**: Higher is better
- **SSIM**: Closer to 1.0 is better
"""
)
if __name__ == "__main__":
demo.queue()
demo.launch(share=False)