import os import sys # ========================= # CRITICAL: Import spaces FIRST before any CUDA-related packages # ZeroGPU requires spaces to be imported before torch/diffusers # ========================= def _is_hf_spaces_env() -> bool: """Detect if running in Hugging Face Spaces environment""" return (os.getenv("SYSTEM", "").lower() == "spaces") or bool( os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") ) IS_HF_SPACES = _is_hf_spaces_env() try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False # Create dummy module for local environment from types import ModuleType class _SpacesDummy: @staticmethod def GPU(duration: int = 180, **kwargs): """No-op decorator for local runtime""" def decorator(fn): return fn return decorator spaces = ModuleType("spaces") spaces.GPU = _SpacesDummy.GPU sys.modules["spaces"] = spaces # ========================= # Now import CUDA-related packages # ========================= import zipfile import tempfile import shutil import uuid from pathlib import Path import gradio as gr import torch import numpy as np from PIL import Image from diffusers import Flux2Pipeline from diffusers.utils import load_image import py7zr from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.util import img_as_float # ========================= # Config # ========================= DEFAULT_PROMPT = ( "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, " "preserve original morphological structure, maintain authentic texture patterns, " "minimal noise reduction while keeping fine details intact" ) GUIDANCE_SCALE = 2.0 NUM_INFERENCE_STEPS = 30 IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"} # Quantized 4-bit model (requires GPU) MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit" TORCH_DTYPE = torch.bfloat16 # Global cached pipeline _pipe = None def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image): """Calculate PSNR and SSIM between original and enhanced images.""" orig_float = img_as_float(np.array(original)) enh_float = img_as_float(np.array(enhanced)) # Ensure same shape (crop to min overlap) if orig_float.shape != enh_float.shape: min_h = min(orig_float.shape[0], enh_float.shape[0]) min_w = min(orig_float.shape[1], enh_float.shape[1]) orig_float = orig_float[:min_h, :min_w] enh_float = enh_float[:min_h, :min_w] psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0) if orig_float.ndim == 3: ssim = structural_similarity( orig_float, enh_float, data_range=1.0, channel_axis=-1 ) else: ssim = structural_similarity(orig_float, enh_float, data_range=1.0) return float(psnr), float(ssim) def extract_archive(archive_path: str, extract_to: str): """Extract zip or 7z archive.""" file_ext = Path(archive_path).suffix.lower() if file_ext == ".zip": with zipfile.ZipFile(archive_path, "r") as z: z.extractall(extract_to) elif file_ext == ".7z": with py7zr.SevenZipFile(archive_path, mode="r") as a: a.extractall(path=extract_to) else: raise ValueError(f"Unsupported archive format: {file_ext}") def find_images(directory: str): """Recursively find all images in a directory.""" image_files = [] for root, _, files in os.walk(directory): for f in files: if Path(f).suffix.lower() in IMAGE_EXTENSIONS: image_files.append(os.path.join(root, f)) return image_files def _get_pipe(): """ Lazy-load the pipeline. IMPORTANT: - On HF ZeroGPU: must be called inside a @spaces.GPU-decorated function runtime. - Locally: uses local CUDA. """ global _pipe if _pipe is None: if not torch.cuda.is_available(): raise RuntimeError( "CUDA GPU is not available in the current runtime. " "This bnb-4bit model requires GPU. " "On HF ZeroGPU, ensure inference is inside a @spaces.GPU function." ) _pipe = Flux2Pipeline.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, ) _pipe.to("cuda") return _pipe def _process_images_impl( files, prompt, guidance_scale, num_steps, progress=gr.Progress() ): """ Shared implementation. Returns 4 outputs: gallery, files_download, zip_download, summary """ if not files: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "Please upload at least one file.", ) if not prompt or prompt.strip() == "": prompt = DEFAULT_PROMPT guidance_scale = float(guidance_scale) num_steps = int(num_steps) # Temp for extraction/staging input temp_dir = tempfile.mkdtemp(prefix="flux_in_") # Output MUST remain for Gradio downloads run_id = uuid.uuid4().hex[:10] output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_") has_archive = False try: progress(0.0, desc="Preparing files...") all_images = [] # (img_path, rel_path) for file_obj in files: file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj) file_ext = Path(file_path).suffix.lower() if file_ext in [".zip", ".7z"]: has_archive = True progress(0.05, desc=f"Extracting: {Path(file_path).name} ...") extract_dir = os.path.join(temp_dir, Path(file_path).stem) os.makedirs(extract_dir, exist_ok=True) extract_archive(file_path, extract_dir) images = find_images(extract_dir) for img_path in images: rel_path = os.path.relpath(img_path, extract_dir) all_images.append((img_path, rel_path)) elif file_ext in IMAGE_EXTENSIONS: all_images.append((file_path, Path(file_path).name)) if not all_images: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "No valid images found in uploaded files.", ) total_images = len(all_images) progress(0.10, desc=f"Found {total_images} images. Loading model...") pipe = _get_pipe() results = [] metrics_lines = [] progress(0.15, desc="Enhancing images...") for idx, (img_path, rel_path) in enumerate(all_images): progress( 0.15 + 0.75 * (idx / max(1, total_images)), desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}", ) input_image = load_image(img_path) enhanced_image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_steps, ).images[0] psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image) out_path = os.path.join(output_dir, rel_path) os.makedirs(os.path.dirname(out_path), exist_ok=True) # add _flux suffix out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix out_path = os.path.join(os.path.dirname(out_path), out_name) enhanced_image.save(out_path) results.append( { "original": input_image, "enhanced": enhanced_image, "filename": rel_path, "output_path": out_path, "psnr": psnr, "ssim": ssim, } ) metrics_lines.append(f"{rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}") avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0 avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0 summary = ( "✅ Processing completed!\n\n" f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n" f"Spaces module: {'Installed' if HAS_SPACES else 'Not installed'}\n" f"GPU used: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}\n\n" f"Total images processed: {total_images}\n" f"Average PSNR: {avg_psnr:.2f} dB\n" f"Average SSIM: {avg_ssim:.4f}\n\n" "Individual metrics:\n" + "\n".join(metrics_lines) ) # Gallery preview gallery_images = [] for r in results[:10]: gallery_images.append((r["original"], f"Original: {r['filename']}")) gallery_images.append( ( r["enhanced"], f"Enhanced: {r['filename']}\nPSNR: {r['psnr']:.2f} dB, SSIM: {r['ssim']:.4f}", ) ) if has_archive: progress(0.92, desc="Packaging ZIP...") output_zip_path = os.path.join( tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip" ) with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for root, _, fs in os.walk(output_dir): for f in fs: fp = os.path.join(root, f) arcname = os.path.relpath(fp, output_dir) zipf.write(fp, arcname) progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=None, visible=False), gr.update(value=output_zip_path, visible=True), summary, ) else: enhanced_paths = [r["output_path"] for r in results] progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=enhanced_paths, visible=True), gr.update(value=None, visible=False), summary, ) except Exception as e: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), f"Error during processing: {str(e)}", ) finally: # Cleanup input temp only if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) # DO NOT delete output_dir; needed for downloads # ========================= # CRITICAL: Always define a @spaces.GPU function at top-level # (ZeroGPU startup scanner will now ALWAYS find it) # ========================= @spaces.GPU(duration=180) def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()): return _process_images_impl(files, prompt, guidance_scale, num_steps, progress) # ========================= # Gradio UI # ========================= with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo: gr.Markdown( f""" # 🔬 Flux Microscopy Image Enhancement Upload microscopy images (individual files or compressed archives) for AI-powered enhancement. **Supported formats:** - Images: JPG, PNG, BMP, TIFF - Archives: ZIP, 7Z (will process all images inside) **Download behavior:** - Upload **only images** → download enhanced **image files directly** (`*_flux` suffix) - Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix) **Runtime info:** - Environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}** - Spaces module: **{"Installed" if HAS_SPACES else "Not installed (using dummy)"}** """ ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload Images or Archive (ZIP/7Z)", file_count="multiple", file_types=["image", ".zip", ".7z"], ) prompt_input = gr.Textbox( label="Enhancement Prompt", placeholder="Enter custom prompt or leave empty for default", value=DEFAULT_PROMPT, lines=3, ) gr.Markdown("### Enhancement Parameters") guidance_scale_input = gr.Slider( minimum=1.0, maximum=5.0, value=GUIDANCE_SCALE, step=0.1, label="Guidance Scale", info="Lower = more conservative, higher = more creative", ) num_steps_input = gr.Slider( minimum=10, maximum=50, value=NUM_INFERENCE_STEPS, step=1, label="Inference Steps", info="More steps = better quality but slower", ) process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg") with gr.Column(scale=2): gallery_output = gr.Gallery( label="Results Preview (Original vs Enhanced)", columns=2, rows=2, height="auto", object_fit="contain", ) files_output = gr.Files( label="📥 Download Enhanced Images (Files)", visible=False ) zip_output = gr.File( label="📥 Download Enhanced Images (ZIP)", visible=False ) summary_output = gr.Textbox( label="Processing Summary & Metrics", lines=10, max_lines=20, ) process_btn.click( fn=process_images, inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input], outputs=[gallery_output, files_output, zip_output, summary_output], ) gr.Markdown( """ --- ### Default Parameters - **Guidance Scale**: 2.0 (conservative for natural enhancement) - **Inference Steps**: 30 (balanced quality and speed) ### Quality Metrics - **PSNR**: Higher is better - **SSIM**: Closer to 1.0 is better """ ) if __name__ == "__main__": demo.queue() demo.launch(share=False)