import os import zipfile import tempfile import shutil import uuid from pathlib import Path import gradio as gr import torch import numpy as np from PIL import Image from diffusers import Flux2Pipeline from diffusers.utils import load_image import py7zr from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.util import img_as_float # ========================= # Config # ========================= DEFAULT_PROMPT = ( "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, " "preserve original morphological structure, maintain authentic texture patterns, " "minimal noise reduction while keeping fine details intact" ) GUIDANCE_SCALE = 2.0 NUM_INFERENCE_STEPS = 30 IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"} # NOTE: # - This is the quantized 4-bit (bitsandbytes) model, which REQUIRES GPU at load time. # - On HF Spaces ZeroGPU, you MUST only load it inside a @spaces.GPU function. MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit" TORCH_DTYPE = torch.bfloat16 # ========================= # HF Spaces env detection + safe "spaces" decorator # ========================= def _is_hf_space_env() -> bool: """ Detect whether we're running on Hugging Face Spaces runtime. Common env vars present in Spaces: - SPACE_ID - HF_SPACE_ID - SYSTEM=spaces """ return any(os.getenv(k) for k in ("SPACE_ID", "HF_SPACE_ID")) or ( os.getenv("SYSTEM", "").lower() == "spaces" ) IS_HF_SPACES = _is_hf_space_env() try: import spaces # available on HF Spaces except Exception: spaces = None def gpu_decorator(duration: int = 180): """ If on HF Spaces and spaces.GPU exists -> use it. Else -> no-op decorator (local runs normally, using local GPU if available). """ if IS_HF_SPACES and (spaces is not None) and hasattr(spaces, "GPU"): return spaces.GPU(duration=duration) def _noop(fn): return fn return _noop # ========================= # Global cached pipeline # ========================= _pipe = None def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image): """Calculate PSNR and SSIM between original and enhanced images.""" orig_float = img_as_float(np.array(original)) enh_float = img_as_float(np.array(enhanced)) # Ensure same shape (crop to min overlap) if orig_float.shape != enh_float.shape: min_h = min(orig_float.shape[0], enh_float.shape[0]) min_w = min(orig_float.shape[1], enh_float.shape[1]) orig_float = orig_float[:min_h, :min_w] enh_float = enh_float[:min_h, :min_w] psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0) if orig_float.ndim == 3: ssim = structural_similarity( orig_float, enh_float, data_range=1.0, channel_axis=-1 ) else: ssim = structural_similarity(orig_float, enh_float, data_range=1.0) return float(psnr), float(ssim) def extract_archive(archive_path: str, extract_to: str): """Extract zip or 7z archive.""" file_ext = Path(archive_path).suffix.lower() if file_ext == ".zip": with zipfile.ZipFile(archive_path, "r") as z: z.extractall(extract_to) elif file_ext == ".7z": with py7zr.SevenZipFile(archive_path, mode="r") as a: a.extractall(path=extract_to) else: raise ValueError(f"Unsupported archive format: {file_ext}") def find_images(directory: str): """Recursively find all images in a directory.""" image_files = [] for root, _, files in os.walk(directory): for f in files: if Path(f).suffix.lower() in IMAGE_EXTENSIONS: image_files.append(os.path.join(root, f)) return image_files def _get_pipe(): """ Lazy-load the pipeline. - On HF Spaces ZeroGPU: MUST be called inside a @spaces.GPU runtime (gpu_decorator handles this). - Locally: will use local GPU. """ global _pipe if _pipe is None: if not torch.cuda.is_available(): raise RuntimeError( "No CUDA GPU detected. This 4-bit bnb model requires GPU to load/run." ) _pipe = Flux2Pipeline.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, ) _pipe.to("cuda") return _pipe @gpu_decorator(duration=180) def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()): """ Process uploaded files (images or archives) and return: - gallery preview (first 10 pairs) - files download (when only images uploaded) - zip download (when archive uploaded) - summary text """ if not files: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "Please upload at least one file.", ) if not prompt or prompt.strip() == "": prompt = DEFAULT_PROMPT guidance_scale = float(guidance_scale) num_steps = int(num_steps) # Temp for extraction / staging input temp_dir = tempfile.mkdtemp(prefix="flux_in_") # IMPORTANT: # Result files MUST remain on disk for Gradio download. # So we create a persistent temp dir and DO NOT delete it in finally. run_id = uuid.uuid4().hex[:10] output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_") has_archive = False try: progress(0.0, desc="Preparing files...") all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel) for file_obj in files: file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj) file_ext = Path(file_path).suffix.lower() if file_ext in [".zip", ".7z"]: has_archive = True progress(0.05, desc=f"Extracting: {Path(file_path).name} ...") extract_dir = os.path.join(temp_dir, Path(file_path).stem) os.makedirs(extract_dir, exist_ok=True) extract_archive(file_path, extract_dir) images = find_images(extract_dir) for img_path in images: rel_path = os.path.relpath(img_path, extract_dir) all_images.append((img_path, rel_path, extract_dir)) elif file_ext in IMAGE_EXTENSIONS: all_images.append((file_path, Path(file_path).name, None)) if not all_images: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "No valid images found in uploaded files.", ) total_images = len(all_images) progress(0.10, desc=f"Found {total_images} images. Loading model...") # Load pipeline (inside GPU runtime on Spaces; local GPU otherwise) pipe = _get_pipe() results = [] metrics_lines = [] progress(0.15, desc="Enhancing images...") for idx, (img_path, rel_path, base_dir) in enumerate(all_images): progress( 0.15 + 0.75 * (idx / max(1, total_images)), desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}", ) input_image = load_image(img_path) enhanced_image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_steps, ).images[0] psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image) # Preserve structure if from archive output_rel_path = rel_path out_path = os.path.join(output_dir, output_rel_path) os.makedirs(os.path.dirname(out_path), exist_ok=True) # Add _flux suffix out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix out_path = os.path.join(os.path.dirname(out_path), out_name) enhanced_image.save(out_path) results.append( { "original": input_image, "enhanced": enhanced_image, "filename": output_rel_path, "output_path": out_path, "psnr": psnr, "ssim": ssim, } ) metrics_lines.append( f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}" ) avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0 avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0 summary = ( "✅ Processing completed!\n\n" f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n" f"GPU available: {torch.cuda.is_available()}\n\n" f"Total images processed: {total_images}\n" f"Average PSNR: {avg_psnr:.2f} dB\n" f"Average SSIM: {avg_ssim:.4f}\n\n" "Individual metrics:\n" + "\n".join(metrics_lines) ) # Build gallery preview (first 10 results -> original+enhanced) gallery_images = [] for r in results[:10]: gallery_images.append((r["original"], f"Original: {r['filename']}")) gallery_images.append( ( r["enhanced"], f"Enhanced: {r['filename']}\nPSNR: {r['psnr']:.2f} dB, SSIM: {r['ssim']:.4f}", ) ) # Decide download output: # - If user uploaded any archive -> provide ZIP # - Else -> provide enhanced files directly if has_archive: progress(0.92, desc="Packaging ZIP...") output_zip_path = os.path.join( tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip" ) with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for root, _, fs in os.walk(output_dir): for f in fs: fp = os.path.join(root, f) arcname = os.path.relpath(fp, output_dir) zipf.write(fp, arcname) progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=None, visible=False), # files hidden gr.update(value=output_zip_path, visible=True), # zip shown summary, ) else: enhanced_paths = [r["output_path"] for r in results] progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=enhanced_paths, visible=True), # files shown gr.update(value=None, visible=False), # zip hidden summary, ) except Exception as e: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), f"Error during processing: {str(e)}", ) finally: # Cleanup ONLY input/extraction temp dir. # DO NOT delete output_dir because Gradio downloads need the files to remain. if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) # ========================= # Gradio UI # ========================= with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo: gr.Markdown( f""" # 🔬 Flux Microscopy Image Enhancement Upload microscopy images (individual files or compressed archives) for AI-powered enhancement. **Supported formats:** - Images: JPG, PNG, BMP, TIFF - Archives: ZIP, 7Z (will process all images inside) **Download behavior:** - If you upload **only images** → you can download the enhanced **image files directly** (`*_flux` suffix) - If you upload **a ZIP/7Z** → you can download **one ZIP** (images inside use `*_flux` suffix) **Runtime detection:** - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}** """ ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload Images or Archive (ZIP/7Z)", file_count="multiple", file_types=["image", ".zip", ".7z"], ) prompt_input = gr.Textbox( label="Enhancement Prompt", placeholder="Enter custom prompt or leave empty for default", value=DEFAULT_PROMPT, lines=3, ) gr.Markdown("### Enhancement Parameters") guidance_scale_input = gr.Slider( minimum=1.0, maximum=5.0, value=GUIDANCE_SCALE, step=0.1, label="Guidance Scale", info="Lower = more conservative, higher = more creative", ) num_steps_input = gr.Slider( minimum=10, maximum=50, value=NUM_INFERENCE_STEPS, step=1, label="Inference Steps", info="More steps = better quality but slower", ) process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg") with gr.Column(scale=2): gallery_output = gr.Gallery( label="Results Preview (Original vs Enhanced)", columns=2, rows=2, height="auto", object_fit="contain", ) files_output = gr.Files( label="📥 Download Enhanced Images (Files)", visible=False ) zip_output = gr.File( label="📥 Download Enhanced Images (ZIP)", visible=False ) summary_output = gr.Textbox( label="Processing Summary & Metrics", lines=10, max_lines=20, ) process_btn.click( fn=process_images, inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input], outputs=[gallery_output, files_output, zip_output, summary_output], ) gr.Markdown( """ --- ### Default Parameters - **Guidance Scale**: 2.0 (conservative for natural enhancement) - **Inference Steps**: 30 (balanced quality and speed) ### Quality Metrics - **PSNR** (Peak Signal-to-Noise Ratio): Higher is better - **SSIM** (Structural Similarity Index): Closer to 1.0 is better """ ) if __name__ == "__main__": demo.queue() # recommended for Spaces demo.launch(share=False)