#!/usr/bin/env python3 """ Flux Microscopy Image Enhancement - Command Line Interface Process microscopy images with AI-powered enhancement using argparse for all parameters """ import argparse import torch from diffusers import Flux2Pipeline from diffusers.utils import load_image from pathlib import Path from PIL import Image import os import shutil import numpy as np from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.util import img_as_float import sys from typing import List, Tuple # ========================= # Config # ========================= DEFAULT_PROMPT = ( "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, " "preserve original morphological structure, maintain authentic texture patterns, " "minimal noise reduction while keeping fine details intact" ) GUIDANCE_SCALE = 2.0 NUM_INFERENCE_STEPS = 30 IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"} MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit" TORCH_DTYPE = torch.bfloat16 # ========================= # Global cached pipeline # ========================= _pipe = None def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image): """Calculate PSNR and SSIM between original and enhanced images.""" orig_float = img_as_float(np.array(original)) enh_float = img_as_float(np.array(enhanced)) # Ensure same shape (crop to min overlap) if orig_float.shape != enh_float.shape: min_h = min(orig_float.shape[0], enh_float.shape[0]) min_w = min(orig_float.shape[1], enh_float.shape[1]) orig_float = orig_float[:min_h, :min_w] enh_float = enh_float[:min_h, :min_w] psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0) if orig_float.ndim == 3: ssim = structural_similarity( orig_float, enh_float, data_range=1.0, channel_axis=-1 ) else: ssim = structural_similarity(orig_float, enh_float, data_range=1.0) return float(psnr), float(ssim) def find_images(directory: str) -> List[str]: """Recursively find all images in a directory.""" image_files = [] for root, _, files in os.walk(directory): for f in files: if Path(f).suffix.lower() in IMAGE_EXTENSIONS: image_files.append(os.path.join(root, f)) return image_files def _get_pipe(): """ Lazy-load the pipeline on local GPU. """ global _pipe if _pipe is None: if not torch.cuda.is_available(): raise RuntimeError( "No GPU found. This tool requires a CUDA-compatible GPU to run." ) print("Loading Flux model...") _pipe = Flux2Pipeline.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, ) _pipe.to("cuda") print(f"Model loaded successfully on GPU: {torch.cuda.get_device_name(0)}") return _pipe def process_images_cli( input_paths: List[str], output_dir: str, prompt: str, guidance_scale: float, num_steps: int, verbose: bool = True, ) -> Tuple[int, List[dict]]: """ Process images from input paths and save to output directory. Args: input_paths: List of file/directory paths (images or folders) output_dir: Output directory for enhanced images prompt: Enhancement prompt guidance_scale: Guidance scale for inference num_steps: Number of inference steps verbose: Whether to print progress messages Returns: Tuple of (total_images_processed, results_list) """ if not input_paths: raise ValueError("No input files provided") if not prompt or prompt.strip() == "": prompt = DEFAULT_PROMPT # Create output directory os.makedirs(output_dir, exist_ok=True) try: if verbose: print("=" * 60) print("Flux Microscopy Image Enhancement - CLI") print("=" * 60) print(f"Output directory: {output_dir}") print(f"Prompt: {prompt}") print(f"Guidance scale: {guidance_scale}") print(f"Inference steps: {num_steps}") print("=" * 60) all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel) # Process each input path for input_path in input_paths: if not os.path.exists(input_path): print(f"Warning: Path not found: {input_path}") continue # Check if it's a directory if os.path.isdir(input_path): if verbose: print(f"\n[Scanning] Directory: {input_path} ...") images = find_images(input_path) for img_path in images: rel_path = os.path.relpath(img_path, input_path) all_images.append((img_path, rel_path, input_path)) if verbose: print(f" Found {len(images)} images in directory") # Check if it's an image file elif os.path.isfile(input_path): file_ext = Path(input_path).suffix.lower() if file_ext in IMAGE_EXTENSIONS: all_images.append((input_path, Path(input_path).name, None)) else: print(f"Warning: Unsupported file format: {input_path}") else: print(f"Warning: Invalid path: {input_path}") if not all_images: raise ValueError("No valid images found in input files") total_images = len(all_images) if verbose: print(f"\n[Processing] Total images to enhance: {total_images}") print("-" * 60) # Load pipeline pipe = _get_pipe() results = [] metrics_lines = [] for idx, (img_path, rel_path, base_dir) in enumerate(all_images, 1): if verbose: print(f"\n[{idx}/{total_images}] Processing: {Path(img_path).name}") # Load input image input_image = load_image(img_path) # Run inference enhanced_image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_steps, ).images[0] # Calculate metrics psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image) if verbose: print(f" PSNR: {psnr:.2f} dB | SSIM: {ssim:.4f}") # Determine output path (preserve structure if from directory) if base_dir: output_rel_path = rel_path else: output_rel_path = rel_path out_path = os.path.join(output_dir, output_rel_path) os.makedirs(os.path.dirname(out_path), exist_ok=True) # Add _flux suffix out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix out_path = os.path.join(os.path.dirname(out_path), out_name) # Save enhanced image enhanced_image.save(out_path) if verbose: print(f" Saved to: {out_path}") results.append( { "original_path": img_path, "output_path": out_path, "filename": output_rel_path, "psnr": psnr, "ssim": ssim, } ) metrics_lines.append( f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}" ) # Print summary if verbose: avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0 avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0 print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"Total images processed: {total_images}") print(f"Average PSNR: {avg_psnr:.2f} dB") print(f"Average SSIM: {avg_ssim:.4f}") print("\nIndividual metrics:") for line in metrics_lines: print(f" {line}") print("=" * 60) return total_images, results except Exception as e: print(f"\nError during processing: {str(e)}", file=sys.stderr) raise def main(): """Main entry point for CLI.""" parser = argparse.ArgumentParser( description="Flux Microscopy Image Enhancement - Command Line Interface", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Enhance a single image python enhance_cli.py -i input.jpg -o output/ # Enhance multiple images with custom parameters python enhance_cli.py -i image1.jpg image2.png -o output/ --guidance-scale 3.0 --steps 40 # Process all images in a directory python enhance_cli.py -i images_folder/ -o output/ # Process with custom prompt python enhance_cli.py -i input.jpg -o output/ --prompt "enhance cellular structure" # Quiet mode (less verbose output) python enhance_cli.py -i input.jpg -o output/ --quiet Supported formats: - Images: JPG, JPEG, PNG, BMP, TIFF, TIF - Directories: Will recursively process all images inside """, ) # Required arguments parser.add_argument( "-i", "--input", nargs="+", required=True, help="Input path(s) - image files or directories. Multiple paths supported.", ) parser.add_argument( "-o", "--output", required=True, help="Output directory for enhanced images" ) # Optional arguments parser.add_argument( "-p", "--prompt", default=DEFAULT_PROMPT, help=f"Enhancement prompt (default: '{DEFAULT_PROMPT[:50]}...')", ) parser.add_argument( "-g", "--guidance-scale", type=float, default=GUIDANCE_SCALE, help=f"Guidance scale (1.0-5.0, lower=conservative, higher=creative, default: {GUIDANCE_SCALE})", ) parser.add_argument( "-s", "--steps", type=int, default=NUM_INFERENCE_STEPS, help=f"Number of inference steps (10-50, more=better quality but slower, default: {NUM_INFERENCE_STEPS})", ) parser.add_argument( "-q", "--quiet", action="store_true", help="Quiet mode - reduce output verbosity", ) args = parser.parse_args() # Validate arguments if args.guidance_scale < 1.0 or args.guidance_scale > 5.0: parser.error("guidance-scale must be between 1.0 and 5.0") if args.steps < 10 or args.steps > 50: parser.error("steps must be between 10 and 50") # Process images try: total, results = process_images_cli( input_paths=args.input, output_dir=args.output, prompt=args.prompt, guidance_scale=args.guidance_scale, num_steps=args.steps, verbose=not args.quiet, ) if not args.quiet: print("\nāœ… Enhancement completed successfully!") print(f"šŸ“ Output directory: {args.output}") sys.exit(0) except Exception as e: print(f"\nāŒ Error: {str(e)}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()