yichuan-huang's picture
Add CLI support for microscopy image enhancement and update README
8a7a521
raw
history blame
14.7 kB
import os
import zipfile
import tempfile
import shutil
import uuid
from pathlib import Path
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import Flux2Pipeline
from diffusers.utils import load_image
import py7zr
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util import img_as_float
# =========================
# Config
# =========================
DEFAULT_PROMPT = (
"enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, "
"preserve original morphological structure, maintain authentic texture patterns, "
"minimal noise reduction while keeping fine details intact"
)
GUIDANCE_SCALE = 2.0
NUM_INFERENCE_STEPS = 30
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
# NOTE:
# - This is the quantized 4-bit (bitsandbytes) model, which REQUIRES GPU at load time.
# - On HF Spaces ZeroGPU, you MUST only load it inside a @spaces.GPU function.
MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
TORCH_DTYPE = torch.bfloat16
# =========================
# HF Spaces env detection + safe "spaces" decorator
# =========================
def _is_hf_space_env() -> bool:
"""
Detect whether we're running on Hugging Face Spaces runtime.
Common env vars present in Spaces:
- SPACE_ID
- HF_SPACE_ID
- SYSTEM=spaces
"""
return any(os.getenv(k) for k in ("SPACE_ID", "HF_SPACE_ID")) or (
os.getenv("SYSTEM", "").lower() == "spaces"
)
IS_HF_SPACES = _is_hf_space_env()
try:
import spaces # available on HF Spaces
except Exception:
spaces = None
def gpu_decorator(duration: int = 180):
"""
If on HF Spaces and spaces.GPU exists -> use it.
Else -> no-op decorator (local runs normally, using local GPU if available).
"""
if IS_HF_SPACES and (spaces is not None) and hasattr(spaces, "GPU"):
return spaces.GPU(duration=duration)
def _noop(fn):
return fn
return _noop
# =========================
# Global cached pipeline
# =========================
_pipe = None
def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image):
"""Calculate PSNR and SSIM between original and enhanced images."""
orig_float = img_as_float(np.array(original))
enh_float = img_as_float(np.array(enhanced))
# Ensure same shape (crop to min overlap)
if orig_float.shape != enh_float.shape:
min_h = min(orig_float.shape[0], enh_float.shape[0])
min_w = min(orig_float.shape[1], enh_float.shape[1])
orig_float = orig_float[:min_h, :min_w]
enh_float = enh_float[:min_h, :min_w]
psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0)
if orig_float.ndim == 3:
ssim = structural_similarity(
orig_float, enh_float, data_range=1.0, channel_axis=-1
)
else:
ssim = structural_similarity(orig_float, enh_float, data_range=1.0)
return float(psnr), float(ssim)
def extract_archive(archive_path: str, extract_to: str):
"""Extract zip or 7z archive."""
file_ext = Path(archive_path).suffix.lower()
if file_ext == ".zip":
with zipfile.ZipFile(archive_path, "r") as z:
z.extractall(extract_to)
elif file_ext == ".7z":
with py7zr.SevenZipFile(archive_path, mode="r") as a:
a.extractall(path=extract_to)
else:
raise ValueError(f"Unsupported archive format: {file_ext}")
def find_images(directory: str):
"""Recursively find all images in a directory."""
image_files = []
for root, _, files in os.walk(directory):
for f in files:
if Path(f).suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(os.path.join(root, f))
return image_files
def _get_pipe():
"""
Lazy-load the pipeline.
- On HF Spaces ZeroGPU: MUST be called inside a @spaces.GPU runtime (gpu_decorator handles this).
- Locally: will use local GPU.
"""
global _pipe
if _pipe is None:
if not torch.cuda.is_available():
raise RuntimeError(
"No CUDA GPU detected. This 4-bit bnb model requires GPU to load/run."
)
_pipe = Flux2Pipeline.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
)
_pipe.to("cuda")
return _pipe
@gpu_decorator(duration=180)
def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
"""
Process uploaded files (images or archives) and return:
- gallery preview (first 10 pairs)
- files download (when only images uploaded)
- zip download (when archive uploaded)
- summary text
"""
if not files:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
"Please upload at least one file.",
)
if not prompt or prompt.strip() == "":
prompt = DEFAULT_PROMPT
guidance_scale = float(guidance_scale)
num_steps = int(num_steps)
# Temp for extraction / staging input
temp_dir = tempfile.mkdtemp(prefix="flux_in_")
# IMPORTANT:
# Result files MUST remain on disk for Gradio download.
# So we create a persistent temp dir and DO NOT delete it in finally.
run_id = uuid.uuid4().hex[:10]
output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
has_archive = False
try:
progress(0.0, desc="Preparing files...")
all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel)
for file_obj in files:
file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj)
file_ext = Path(file_path).suffix.lower()
if file_ext in [".zip", ".7z"]:
has_archive = True
progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
extract_dir = os.path.join(temp_dir, Path(file_path).stem)
os.makedirs(extract_dir, exist_ok=True)
extract_archive(file_path, extract_dir)
images = find_images(extract_dir)
for img_path in images:
rel_path = os.path.relpath(img_path, extract_dir)
all_images.append((img_path, rel_path, extract_dir))
elif file_ext in IMAGE_EXTENSIONS:
all_images.append((file_path, Path(file_path).name, None))
if not all_images:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
"No valid images found in uploaded files.",
)
total_images = len(all_images)
progress(0.10, desc=f"Found {total_images} images. Loading model...")
# Load pipeline (inside GPU runtime on Spaces; local GPU otherwise)
pipe = _get_pipe()
results = []
metrics_lines = []
progress(0.15, desc="Enhancing images...")
for idx, (img_path, rel_path, base_dir) in enumerate(all_images):
progress(
0.15 + 0.75 * (idx / max(1, total_images)),
desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
)
input_image = load_image(img_path)
enhanced_image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
).images[0]
psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
# Preserve structure if from archive
output_rel_path = rel_path
out_path = os.path.join(output_dir, output_rel_path)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
# Add _flux suffix
out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
out_path = os.path.join(os.path.dirname(out_path), out_name)
enhanced_image.save(out_path)
results.append(
{
"original": input_image,
"enhanced": enhanced_image,
"filename": output_rel_path,
"output_path": out_path,
"psnr": psnr,
"ssim": ssim,
}
)
metrics_lines.append(
f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
)
avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
summary = (
"✅ Processing completed!\n\n"
f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n"
f"GPU available: {torch.cuda.is_available()}\n\n"
f"Total images processed: {total_images}\n"
f"Average PSNR: {avg_psnr:.2f} dB\n"
f"Average SSIM: {avg_ssim:.4f}\n\n"
"Individual metrics:\n" + "\n".join(metrics_lines)
)
# Build gallery preview (first 10 results -> original+enhanced)
gallery_images = []
for r in results[:10]:
gallery_images.append((r["original"], f"Original: {r['filename']}"))
gallery_images.append(
(
r["enhanced"],
f"Enhanced: {r['filename']}\nPSNR: {r['psnr']:.2f} dB, SSIM: {r['ssim']:.4f}",
)
)
# Decide download output:
# - If user uploaded any archive -> provide ZIP
# - Else -> provide enhanced files directly
if has_archive:
progress(0.92, desc="Packaging ZIP...")
output_zip_path = os.path.join(
tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip"
)
with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, fs in os.walk(output_dir):
for f in fs:
fp = os.path.join(root, f)
arcname = os.path.relpath(fp, output_dir)
zipf.write(fp, arcname)
progress(1.0, desc="Done!")
return (
gallery_images,
gr.update(value=None, visible=False), # files hidden
gr.update(value=output_zip_path, visible=True), # zip shown
summary,
)
else:
enhanced_paths = [r["output_path"] for r in results]
progress(1.0, desc="Done!")
return (
gallery_images,
gr.update(value=enhanced_paths, visible=True), # files shown
gr.update(value=None, visible=False), # zip hidden
summary,
)
except Exception as e:
return (
None,
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
f"Error during processing: {str(e)}",
)
finally:
# Cleanup ONLY input/extraction temp dir.
# DO NOT delete output_dir because Gradio downloads need the files to remain.
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# =========================
# Gradio UI
# =========================
with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo:
gr.Markdown(
f"""
# 🔬 Flux Microscopy Image Enhancement
Upload microscopy images (individual files or compressed archives) for AI-powered enhancement.
**Supported formats:**
- Images: JPG, PNG, BMP, TIFF
- Archives: ZIP, 7Z (will process all images inside)
**Download behavior:**
- If you upload **only images** → you can download the enhanced **image files directly** (`*_flux` suffix)
- If you upload **a ZIP/7Z** → you can download **one ZIP** (images inside use `*_flux` suffix)
**Runtime detection:**
- Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload Images or Archive (ZIP/7Z)",
file_count="multiple",
file_types=["image", ".zip", ".7z"],
)
prompt_input = gr.Textbox(
label="Enhancement Prompt",
placeholder="Enter custom prompt or leave empty for default",
value=DEFAULT_PROMPT,
lines=3,
)
gr.Markdown("### Enhancement Parameters")
guidance_scale_input = gr.Slider(
minimum=1.0,
maximum=5.0,
value=GUIDANCE_SCALE,
step=0.1,
label="Guidance Scale",
info="Lower = more conservative, higher = more creative",
)
num_steps_input = gr.Slider(
minimum=10,
maximum=50,
value=NUM_INFERENCE_STEPS,
step=1,
label="Inference Steps",
info="More steps = better quality but slower",
)
process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg")
with gr.Column(scale=2):
gallery_output = gr.Gallery(
label="Results Preview (Original vs Enhanced)",
columns=2,
rows=2,
height="auto",
object_fit="contain",
)
files_output = gr.Files(
label="📥 Download Enhanced Images (Files)", visible=False
)
zip_output = gr.File(
label="📥 Download Enhanced Images (ZIP)", visible=False
)
summary_output = gr.Textbox(
label="Processing Summary & Metrics",
lines=10,
max_lines=20,
)
process_btn.click(
fn=process_images,
inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input],
outputs=[gallery_output, files_output, zip_output, summary_output],
)
gr.Markdown(
"""
---
### Default Parameters
- **Guidance Scale**: 2.0 (conservative for natural enhancement)
- **Inference Steps**: 30 (balanced quality and speed)
### Quality Metrics
- **PSNR** (Peak Signal-to-Noise Ratio): Higher is better
- **SSIM** (Structural Similarity Index): Closer to 1.0 is better
"""
)
if __name__ == "__main__":
demo.queue() # recommended for Spaces
demo.launch(share=False)