| |
| """ |
| Flux Microscopy Image Enhancement - Command Line Interface |
| Process microscopy images with AI-powered enhancement using argparse for all parameters |
| """ |
|
|
| import argparse |
| import torch |
| from diffusers import Flux2Pipeline |
| from diffusers.utils import load_image |
| from pathlib import Path |
| from PIL import Image |
| import os |
| import shutil |
| import numpy as np |
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity |
| from skimage.util import img_as_float |
| import sys |
| from typing import List, Tuple |
|
|
| |
| |
| |
| DEFAULT_PROMPT = ( |
| "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, " |
| "preserve original morphological structure, maintain authentic texture patterns, " |
| "minimal noise reduction while keeping fine details intact" |
| ) |
|
|
| GUIDANCE_SCALE = 2.0 |
| NUM_INFERENCE_STEPS = 30 |
|
|
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"} |
|
|
| MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit" |
| TORCH_DTYPE = torch.bfloat16 |
|
|
| |
| |
| |
| _pipe = None |
|
|
|
|
| def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image): |
| """Calculate PSNR and SSIM between original and enhanced images.""" |
| orig_float = img_as_float(np.array(original)) |
| enh_float = img_as_float(np.array(enhanced)) |
|
|
| |
| if orig_float.shape != enh_float.shape: |
| min_h = min(orig_float.shape[0], enh_float.shape[0]) |
| min_w = min(orig_float.shape[1], enh_float.shape[1]) |
| orig_float = orig_float[:min_h, :min_w] |
| enh_float = enh_float[:min_h, :min_w] |
|
|
| psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0) |
|
|
| if orig_float.ndim == 3: |
| ssim = structural_similarity( |
| orig_float, enh_float, data_range=1.0, channel_axis=-1 |
| ) |
| else: |
| ssim = structural_similarity(orig_float, enh_float, data_range=1.0) |
|
|
| return float(psnr), float(ssim) |
|
|
|
|
| def find_images(directory: str) -> List[str]: |
| """Recursively find all images in a directory.""" |
| image_files = [] |
| for root, _, files in os.walk(directory): |
| for f in files: |
| if Path(f).suffix.lower() in IMAGE_EXTENSIONS: |
| image_files.append(os.path.join(root, f)) |
| return image_files |
|
|
|
|
| def _get_pipe(): |
| """ |
| Lazy-load the pipeline on local GPU. |
| """ |
| global _pipe |
|
|
| if _pipe is None: |
| if not torch.cuda.is_available(): |
| raise RuntimeError( |
| "No GPU found. This tool requires a CUDA-compatible GPU to run." |
| ) |
|
|
| print("Loading Flux model...") |
| _pipe = Flux2Pipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=TORCH_DTYPE, |
| ) |
| _pipe.to("cuda") |
| print(f"Model loaded successfully on GPU: {torch.cuda.get_device_name(0)}") |
|
|
| return _pipe |
|
|
|
|
| def process_images_cli( |
| input_paths: List[str], |
| output_dir: str, |
| prompt: str, |
| guidance_scale: float, |
| num_steps: int, |
| verbose: bool = True, |
| ) -> Tuple[int, List[dict]]: |
| """ |
| Process images from input paths and save to output directory. |
| |
| Args: |
| input_paths: List of file/directory paths (images or folders) |
| output_dir: Output directory for enhanced images |
| prompt: Enhancement prompt |
| guidance_scale: Guidance scale for inference |
| num_steps: Number of inference steps |
| verbose: Whether to print progress messages |
| |
| Returns: |
| Tuple of (total_images_processed, results_list) |
| """ |
| if not input_paths: |
| raise ValueError("No input files provided") |
|
|
| if not prompt or prompt.strip() == "": |
| prompt = DEFAULT_PROMPT |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| try: |
| if verbose: |
| print("=" * 60) |
| print("Flux Microscopy Image Enhancement - CLI") |
| print("=" * 60) |
| print(f"Output directory: {output_dir}") |
| print(f"Prompt: {prompt}") |
| print(f"Guidance scale: {guidance_scale}") |
| print(f"Inference steps: {num_steps}") |
| print("=" * 60) |
|
|
| all_images = [] |
|
|
| |
| for input_path in input_paths: |
| if not os.path.exists(input_path): |
| print(f"Warning: Path not found: {input_path}") |
| continue |
|
|
| |
| if os.path.isdir(input_path): |
| if verbose: |
| print(f"\n[Scanning] Directory: {input_path} ...") |
|
|
| images = find_images(input_path) |
| for img_path in images: |
| rel_path = os.path.relpath(img_path, input_path) |
| all_images.append((img_path, rel_path, input_path)) |
|
|
| if verbose: |
| print(f" Found {len(images)} images in directory") |
|
|
| |
| elif os.path.isfile(input_path): |
| file_ext = Path(input_path).suffix.lower() |
| if file_ext in IMAGE_EXTENSIONS: |
| all_images.append((input_path, Path(input_path).name, None)) |
| else: |
| print(f"Warning: Unsupported file format: {input_path}") |
| else: |
| print(f"Warning: Invalid path: {input_path}") |
|
|
| if not all_images: |
| raise ValueError("No valid images found in input files") |
|
|
| total_images = len(all_images) |
| if verbose: |
| print(f"\n[Processing] Total images to enhance: {total_images}") |
| print("-" * 60) |
|
|
| |
| pipe = _get_pipe() |
|
|
| results = [] |
| metrics_lines = [] |
|
|
| for idx, (img_path, rel_path, base_dir) in enumerate(all_images, 1): |
| if verbose: |
| print(f"\n[{idx}/{total_images}] Processing: {Path(img_path).name}") |
|
|
| |
| input_image = load_image(img_path) |
|
|
| |
| enhanced_image = pipe( |
| image=input_image, |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_steps, |
| ).images[0] |
|
|
| |
| psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image) |
|
|
| if verbose: |
| print(f" PSNR: {psnr:.2f} dB | SSIM: {ssim:.4f}") |
|
|
| |
| if base_dir: |
| output_rel_path = rel_path |
| else: |
| output_rel_path = rel_path |
|
|
| out_path = os.path.join(output_dir, output_rel_path) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
| |
| out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix |
| out_path = os.path.join(os.path.dirname(out_path), out_name) |
|
|
| |
| enhanced_image.save(out_path) |
|
|
| if verbose: |
| print(f" Saved to: {out_path}") |
|
|
| results.append( |
| { |
| "original_path": img_path, |
| "output_path": out_path, |
| "filename": output_rel_path, |
| "psnr": psnr, |
| "ssim": ssim, |
| } |
| ) |
|
|
| metrics_lines.append( |
| f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}" |
| ) |
|
|
| |
| if verbose: |
| avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0 |
| avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0 |
|
|
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print(f"Total images processed: {total_images}") |
| print(f"Average PSNR: {avg_psnr:.2f} dB") |
| print(f"Average SSIM: {avg_ssim:.4f}") |
| print("\nIndividual metrics:") |
| for line in metrics_lines: |
| print(f" {line}") |
| print("=" * 60) |
|
|
| return total_images, results |
|
|
| except Exception as e: |
| print(f"\nError during processing: {str(e)}", file=sys.stderr) |
| raise |
|
|
|
|
| def main(): |
| """Main entry point for CLI.""" |
| parser = argparse.ArgumentParser( |
| description="Flux Microscopy Image Enhancement - Command Line Interface", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Enhance a single image |
| python enhance_cli.py -i input.jpg -o output/ |
| |
| # Enhance multiple images with custom parameters |
| python enhance_cli.py -i image1.jpg image2.png -o output/ --guidance-scale 3.0 --steps 40 |
| |
| # Process all images in a directory |
| python enhance_cli.py -i images_folder/ -o output/ |
| |
| # Process with custom prompt |
| python enhance_cli.py -i input.jpg -o output/ --prompt "enhance cellular structure" |
| |
| # Quiet mode (less verbose output) |
| python enhance_cli.py -i input.jpg -o output/ --quiet |
| |
| Supported formats: |
| - Images: JPG, JPEG, PNG, BMP, TIFF, TIF |
| - Directories: Will recursively process all images inside |
| """, |
| ) |
|
|
| |
| parser.add_argument( |
| "-i", |
| "--input", |
| nargs="+", |
| required=True, |
| help="Input path(s) - image files or directories. Multiple paths supported.", |
| ) |
|
|
| parser.add_argument( |
| "-o", "--output", required=True, help="Output directory for enhanced images" |
| ) |
|
|
| |
| parser.add_argument( |
| "-p", |
| "--prompt", |
| default=DEFAULT_PROMPT, |
| help=f"Enhancement prompt (default: '{DEFAULT_PROMPT[:50]}...')", |
| ) |
|
|
| parser.add_argument( |
| "-g", |
| "--guidance-scale", |
| type=float, |
| default=GUIDANCE_SCALE, |
| help=f"Guidance scale (1.0-5.0, lower=conservative, higher=creative, default: {GUIDANCE_SCALE})", |
| ) |
|
|
| parser.add_argument( |
| "-s", |
| "--steps", |
| type=int, |
| default=NUM_INFERENCE_STEPS, |
| help=f"Number of inference steps (10-50, more=better quality but slower, default: {NUM_INFERENCE_STEPS})", |
| ) |
|
|
| parser.add_argument( |
| "-q", |
| "--quiet", |
| action="store_true", |
| help="Quiet mode - reduce output verbosity", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.guidance_scale < 1.0 or args.guidance_scale > 5.0: |
| parser.error("guidance-scale must be between 1.0 and 5.0") |
|
|
| if args.steps < 10 or args.steps > 50: |
| parser.error("steps must be between 10 and 50") |
|
|
| |
| try: |
| total, results = process_images_cli( |
| input_paths=args.input, |
| output_dir=args.output, |
| prompt=args.prompt, |
| guidance_scale=args.guidance_scale, |
| num_steps=args.steps, |
| verbose=not args.quiet, |
| ) |
|
|
| if not args.quiet: |
| print("\n✅ Enhancement completed successfully!") |
| print(f"📁 Output directory: {args.output}") |
|
|
| sys.exit(0) |
|
|
| except Exception as e: |
| print(f"\n❌ Error: {str(e)}", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|