yichuan-huang's picture
Add CLI support for microscopy image enhancement and update README
8a7a521
raw
history blame
11.4 kB
#!/usr/bin/env python3
"""
Flux Microscopy Image Enhancement - Command Line Interface
Process microscopy images with AI-powered enhancement using argparse for all parameters
"""
import argparse
import torch
from diffusers import Flux2Pipeline
from diffusers.utils import load_image
from pathlib import Path
from PIL import Image
import os
import shutil
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util import img_as_float
import sys
from typing import List, Tuple
# =========================
# Config
# =========================
DEFAULT_PROMPT = (
"enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, "
"preserve original morphological structure, maintain authentic texture patterns, "
"minimal noise reduction while keeping fine details intact"
)
GUIDANCE_SCALE = 2.0
NUM_INFERENCE_STEPS = 30
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
TORCH_DTYPE = torch.bfloat16
# =========================
# Global cached pipeline
# =========================
_pipe = None
def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image):
"""Calculate PSNR and SSIM between original and enhanced images."""
orig_float = img_as_float(np.array(original))
enh_float = img_as_float(np.array(enhanced))
# Ensure same shape (crop to min overlap)
if orig_float.shape != enh_float.shape:
min_h = min(orig_float.shape[0], enh_float.shape[0])
min_w = min(orig_float.shape[1], enh_float.shape[1])
orig_float = orig_float[:min_h, :min_w]
enh_float = enh_float[:min_h, :min_w]
psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0)
if orig_float.ndim == 3:
ssim = structural_similarity(
orig_float, enh_float, data_range=1.0, channel_axis=-1
)
else:
ssim = structural_similarity(orig_float, enh_float, data_range=1.0)
return float(psnr), float(ssim)
def find_images(directory: str) -> List[str]:
"""Recursively find all images in a directory."""
image_files = []
for root, _, files in os.walk(directory):
for f in files:
if Path(f).suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(os.path.join(root, f))
return image_files
def _get_pipe():
"""
Lazy-load the pipeline on local GPU.
"""
global _pipe
if _pipe is None:
if not torch.cuda.is_available():
raise RuntimeError(
"No GPU found. This tool requires a CUDA-compatible GPU to run."
)
print("Loading Flux model...")
_pipe = Flux2Pipeline.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
)
_pipe.to("cuda")
print(f"Model loaded successfully on GPU: {torch.cuda.get_device_name(0)}")
return _pipe
def process_images_cli(
input_paths: List[str],
output_dir: str,
prompt: str,
guidance_scale: float,
num_steps: int,
verbose: bool = True,
) -> Tuple[int, List[dict]]:
"""
Process images from input paths and save to output directory.
Args:
input_paths: List of file/directory paths (images or folders)
output_dir: Output directory for enhanced images
prompt: Enhancement prompt
guidance_scale: Guidance scale for inference
num_steps: Number of inference steps
verbose: Whether to print progress messages
Returns:
Tuple of (total_images_processed, results_list)
"""
if not input_paths:
raise ValueError("No input files provided")
if not prompt or prompt.strip() == "":
prompt = DEFAULT_PROMPT
# Create output directory
os.makedirs(output_dir, exist_ok=True)
try:
if verbose:
print("=" * 60)
print("Flux Microscopy Image Enhancement - CLI")
print("=" * 60)
print(f"Output directory: {output_dir}")
print(f"Prompt: {prompt}")
print(f"Guidance scale: {guidance_scale}")
print(f"Inference steps: {num_steps}")
print("=" * 60)
all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel)
# Process each input path
for input_path in input_paths:
if not os.path.exists(input_path):
print(f"Warning: Path not found: {input_path}")
continue
# Check if it's a directory
if os.path.isdir(input_path):
if verbose:
print(f"\n[Scanning] Directory: {input_path} ...")
images = find_images(input_path)
for img_path in images:
rel_path = os.path.relpath(img_path, input_path)
all_images.append((img_path, rel_path, input_path))
if verbose:
print(f" Found {len(images)} images in directory")
# Check if it's an image file
elif os.path.isfile(input_path):
file_ext = Path(input_path).suffix.lower()
if file_ext in IMAGE_EXTENSIONS:
all_images.append((input_path, Path(input_path).name, None))
else:
print(f"Warning: Unsupported file format: {input_path}")
else:
print(f"Warning: Invalid path: {input_path}")
if not all_images:
raise ValueError("No valid images found in input files")
total_images = len(all_images)
if verbose:
print(f"\n[Processing] Total images to enhance: {total_images}")
print("-" * 60)
# Load pipeline
pipe = _get_pipe()
results = []
metrics_lines = []
for idx, (img_path, rel_path, base_dir) in enumerate(all_images, 1):
if verbose:
print(f"\n[{idx}/{total_images}] Processing: {Path(img_path).name}")
# Load input image
input_image = load_image(img_path)
# Run inference
enhanced_image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
).images[0]
# Calculate metrics
psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
if verbose:
print(f" PSNR: {psnr:.2f} dB | SSIM: {ssim:.4f}")
# Determine output path (preserve structure if from directory)
if base_dir:
output_rel_path = rel_path
else:
output_rel_path = rel_path
out_path = os.path.join(output_dir, output_rel_path)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
# Add _flux suffix
out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
out_path = os.path.join(os.path.dirname(out_path), out_name)
# Save enhanced image
enhanced_image.save(out_path)
if verbose:
print(f" Saved to: {out_path}")
results.append(
{
"original_path": img_path,
"output_path": out_path,
"filename": output_rel_path,
"psnr": psnr,
"ssim": ssim,
}
)
metrics_lines.append(
f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
)
# Print summary
if verbose:
avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total images processed: {total_images}")
print(f"Average PSNR: {avg_psnr:.2f} dB")
print(f"Average SSIM: {avg_ssim:.4f}")
print("\nIndividual metrics:")
for line in metrics_lines:
print(f" {line}")
print("=" * 60)
return total_images, results
except Exception as e:
print(f"\nError during processing: {str(e)}", file=sys.stderr)
raise
def main():
"""Main entry point for CLI."""
parser = argparse.ArgumentParser(
description="Flux Microscopy Image Enhancement - Command Line Interface",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Enhance a single image
python enhance_cli.py -i input.jpg -o output/
# Enhance multiple images with custom parameters
python enhance_cli.py -i image1.jpg image2.png -o output/ --guidance-scale 3.0 --steps 40
# Process all images in a directory
python enhance_cli.py -i images_folder/ -o output/
# Process with custom prompt
python enhance_cli.py -i input.jpg -o output/ --prompt "enhance cellular structure"
# Quiet mode (less verbose output)
python enhance_cli.py -i input.jpg -o output/ --quiet
Supported formats:
- Images: JPG, JPEG, PNG, BMP, TIFF, TIF
- Directories: Will recursively process all images inside
""",
)
# Required arguments
parser.add_argument(
"-i",
"--input",
nargs="+",
required=True,
help="Input path(s) - image files or directories. Multiple paths supported.",
)
parser.add_argument(
"-o", "--output", required=True, help="Output directory for enhanced images"
)
# Optional arguments
parser.add_argument(
"-p",
"--prompt",
default=DEFAULT_PROMPT,
help=f"Enhancement prompt (default: '{DEFAULT_PROMPT[:50]}...')",
)
parser.add_argument(
"-g",
"--guidance-scale",
type=float,
default=GUIDANCE_SCALE,
help=f"Guidance scale (1.0-5.0, lower=conservative, higher=creative, default: {GUIDANCE_SCALE})",
)
parser.add_argument(
"-s",
"--steps",
type=int,
default=NUM_INFERENCE_STEPS,
help=f"Number of inference steps (10-50, more=better quality but slower, default: {NUM_INFERENCE_STEPS})",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="Quiet mode - reduce output verbosity",
)
args = parser.parse_args()
# Validate arguments
if args.guidance_scale < 1.0 or args.guidance_scale > 5.0:
parser.error("guidance-scale must be between 1.0 and 5.0")
if args.steps < 10 or args.steps > 50:
parser.error("steps must be between 10 and 50")
# Process images
try:
total, results = process_images_cli(
input_paths=args.input,
output_dir=args.output,
prompt=args.prompt,
guidance_scale=args.guidance_scale,
num_steps=args.steps,
verbose=not args.quiet,
)
if not args.quiet:
print("\n✅ Enhancement completed successfully!")
print(f"📁 Output directory: {args.output}")
sys.exit(0)
except Exception as e:
print(f"\n❌ Error: {str(e)}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()