import os import zipfile import tempfile import shutil import uuid from pathlib import Path import gradio as gr import torch import numpy as np from PIL import Image from diffusers import Flux2Pipeline from diffusers.utils import load_image import py7zr from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.util import img_as_float # ========================= # Config # ========================= DEFAULT_PROMPT = ( "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, " "preserve original morphological structure, maintain authentic texture patterns, " "minimal noise reduction while keeping fine details intact" ) GUIDANCE_SCALE = 2.0 NUM_INFERENCE_STEPS = 30 IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"} # Quantized 4-bit model (requires GPU) MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit" TORCH_DTYPE = torch.bfloat16 # ========================= # HF Spaces detection (robust) # ========================= try: import spaces # Only available on Hugging Face Spaces SPACES_AVAILABLE = True except Exception: spaces = None SPACES_AVAILABLE = False # Extra friendly label for UI (not used for logic) IS_HF_SPACES = ( SPACES_AVAILABLE or (os.getenv("SYSTEM", "").lower() == "spaces") or bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID")) ) # ========================= # Global cached pipeline # ========================= _pipe = None def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image): """Calculate PSNR and SSIM between original and enhanced images.""" orig_float = img_as_float(np.array(original)) enh_float = img_as_float(np.array(enhanced)) # Ensure same shape (crop to min overlap) if orig_float.shape != enh_float.shape: min_h = min(orig_float.shape[0], enh_float.shape[0]) min_w = min(orig_float.shape[1], enh_float.shape[1]) orig_float = orig_float[:min_h, :min_w] enh_float = enh_float[:min_h, :min_w] psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0) if orig_float.ndim == 3: ssim = structural_similarity( orig_float, enh_float, data_range=1.0, channel_axis=-1 ) else: ssim = structural_similarity(orig_float, enh_float, data_range=1.0) return float(psnr), float(ssim) def extract_archive(archive_path: str, extract_to: str): """Extract zip or 7z archive.""" file_ext = Path(archive_path).suffix.lower() if file_ext == ".zip": with zipfile.ZipFile(archive_path, "r") as z: z.extractall(extract_to) elif file_ext == ".7z": with py7zr.SevenZipFile(archive_path, mode="r") as a: a.extractall(path=extract_to) else: raise ValueError(f"Unsupported archive format: {file_ext}") def find_images(directory: str): """Recursively find all images in a directory.""" image_files = [] for root, _, files in os.walk(directory): for f in files: if Path(f).suffix.lower() in IMAGE_EXTENSIONS: image_files.append(os.path.join(root, f)) return image_files def _get_pipe(): """ Lazy-load the pipeline. - On HF ZeroGPU: this MUST be called inside a @spaces.GPU function runtime. - Locally: uses local CUDA GPU. """ global _pipe if _pipe is None: if not torch.cuda.is_available(): raise RuntimeError( "CUDA GPU is not available in the current runtime. " "This bnb-4bit model requires GPU. " "On HF ZeroGPU, ensure the function is decorated with @spaces.GPU." ) _pipe = Flux2Pipeline.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, ) _pipe.to("cuda") return _pipe def _process_images_impl( files, prompt, guidance_scale, num_steps, progress=gr.Progress() ): """ Shared implementation used by both: - HF Spaces GPU wrapper - Local runtime Returns 4 outputs: gallery, files_download, zip_download, summary """ if not files: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "Please upload at least one file.", ) if not prompt or prompt.strip() == "": prompt = DEFAULT_PROMPT guidance_scale = float(guidance_scale) num_steps = int(num_steps) # Temp for extraction/staging input temp_dir = tempfile.mkdtemp(prefix="flux_in_") # IMPORTANT: # Output files MUST remain on disk for Gradio downloads. run_id = uuid.uuid4().hex[:10] output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_") has_archive = False try: progress(0.0, desc="Preparing files...") all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel) for file_obj in files: file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj) file_ext = Path(file_path).suffix.lower() if file_ext in [".zip", ".7z"]: has_archive = True progress(0.05, desc=f"Extracting: {Path(file_path).name} ...") extract_dir = os.path.join(temp_dir, Path(file_path).stem) os.makedirs(extract_dir, exist_ok=True) extract_archive(file_path, extract_dir) images = find_images(extract_dir) for img_path in images: rel_path = os.path.relpath(img_path, extract_dir) all_images.append((img_path, rel_path, extract_dir)) elif file_ext in IMAGE_EXTENSIONS: all_images.append((file_path, Path(file_path).name, None)) if not all_images: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), "No valid images found in uploaded files.", ) total_images = len(all_images) progress(0.10, desc=f"Found {total_images} images. Loading model...") pipe = _get_pipe() # IMPORTANT: must be inside GPU runtime on Spaces results = [] metrics_lines = [] progress(0.15, desc="Enhancing images...") for idx, (img_path, rel_path, base_dir) in enumerate(all_images): progress( 0.15 + 0.75 * (idx / max(1, total_images)), desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}", ) input_image = load_image(img_path) enhanced_image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_steps, ).images[0] psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image) # Preserve structure if from archive output_rel_path = rel_path out_path = os.path.join(output_dir, output_rel_path) os.makedirs(os.path.dirname(out_path), exist_ok=True) # Add _flux suffix out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix out_path = os.path.join(os.path.dirname(out_path), out_name) enhanced_image.save(out_path) results.append( { "original": input_image, "enhanced": enhanced_image, "filename": output_rel_path, "output_path": out_path, "psnr": psnr, "ssim": ssim, } ) metrics_lines.append( f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}" ) avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0 avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0 summary = ( "✅ Processing completed!\n\n" f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n" f"GPU available: {torch.cuda.is_available()}\n\n" f"Total images processed: {total_images}\n" f"Average PSNR: {avg_psnr:.2f} dB\n" f"Average SSIM: {avg_ssim:.4f}\n\n" "Individual metrics:\n" + "\n".join(metrics_lines) ) # Gallery preview gallery_images = [] for r in results[:10]: gallery_images.append((r["original"], f"Original: {r['filename']}")) gallery_images.append( ( r["enhanced"], f"Enhanced: {r['filename']}\nPSNR: {r['psnr']:.2f} dB, SSIM: {r['ssim']:.4f}", ) ) # Download behavior: # - If any archive uploaded -> zip # - Else -> direct files list if has_archive: progress(0.92, desc="Packaging ZIP...") output_zip_path = os.path.join( tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip" ) with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for root, _, fs in os.walk(output_dir): for f in fs: fp = os.path.join(root, f) arcname = os.path.relpath(fp, output_dir) zipf.write(fp, arcname) progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=None, visible=False), # files hidden gr.update(value=output_zip_path, visible=True), # zip shown summary, ) else: enhanced_paths = [r["output_path"] for r in results] progress(1.0, desc="Done!") return ( gallery_images, gr.update(value=enhanced_paths, visible=True), # files shown gr.update(value=None, visible=False), # zip hidden summary, ) except Exception as e: return ( None, gr.update(value=None, visible=False), gr.update(value=None, visible=False), f"Error during processing: {str(e)}", ) finally: # Clean ONLY input extraction temp. if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) # Do NOT delete output_dir; downloads need it to exist. # ========================= # IMPORTANT: Define a real @spaces.GPU function at import-time on Spaces # (This fixes: "No @spaces.GPU function detected during startup") # ========================= if SPACES_AVAILABLE: @spaces.GPU(duration=180) def process_images( files, prompt, guidance_scale, num_steps, progress=gr.Progress() ): return _process_images_impl(files, prompt, guidance_scale, num_steps, progress) else: def process_images( files, prompt, guidance_scale, num_steps, progress=gr.Progress() ): return _process_images_impl(files, prompt, guidance_scale, num_steps, progress) # ========================= # Gradio UI # ========================= with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo: gr.Markdown( f""" # 🔬 Flux Microscopy Image Enhancement Upload microscopy images (individual files or compressed archives) for AI-powered enhancement. **Supported formats:** - Images: JPG, PNG, BMP, TIFF - Archives: ZIP, 7Z (will process all images inside) **Download behavior:** - Upload **only images** → download enhanced **image files directly** (`*_flux` suffix) - Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix) **Runtime detection:** - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}** - spaces module available: **{SPACES_AVAILABLE}** """ ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload Images or Archive (ZIP/7Z)", file_count="multiple", file_types=["image", ".zip", ".7z"], ) prompt_input = gr.Textbox( label="Enhancement Prompt", placeholder="Enter custom prompt or leave empty for default", value=DEFAULT_PROMPT, lines=3, ) gr.Markdown("### Enhancement Parameters") guidance_scale_input = gr.Slider( minimum=1.0, maximum=5.0, value=GUIDANCE_SCALE, step=0.1, label="Guidance Scale", info="Lower = more conservative, higher = more creative", ) num_steps_input = gr.Slider( minimum=10, maximum=50, value=NUM_INFERENCE_STEPS, step=1, label="Inference Steps", info="More steps = better quality but slower", ) process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg") with gr.Column(scale=2): gallery_output = gr.Gallery( label="Results Preview (Original vs Enhanced)", columns=2, rows=2, height="auto", object_fit="contain", ) files_output = gr.Files( label="📥 Download Enhanced Images (Files)", visible=False ) zip_output = gr.File( label="📥 Download Enhanced Images (ZIP)", visible=False ) summary_output = gr.Textbox( label="Processing Summary & Metrics", lines=10, max_lines=20, ) process_btn.click( fn=process_images, inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input], outputs=[gallery_output, files_output, zip_output, summary_output], ) gr.Markdown( """ --- ### Default Parameters - **Guidance Scale**: 2.0 (conservative for natural enhancement) - **Inference Steps**: 30 (balanced quality and speed) ### Quality Metrics - **PSNR**: Higher is better - **SSIM**: Closer to 1.0 is better """ ) if __name__ == "__main__": demo.queue() demo.launch(share=False)