File size: 12,507 Bytes
dd1ad6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 | import gradio as gr
import torch
from diffusers import Flux2Pipeline
from diffusers.utils import load_image
from PIL import Image
import os
import zipfile
import py7zr
import tempfile
import shutil
from pathlib import Path
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util import img_as_float
import io
# Load Flux model
print("Loading Flux model...")
pipe = Flux2Pipeline.from_pretrained(
"diffusers/FLUX.2-dev-bnb-4bit", torch_dtype=torch.bfloat16, fix_mistral_regex=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
print("Model loaded successfully")
# Default enhancement prompt
DEFAULT_PROMPT = "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, preserve original morphological structure, maintain authentic texture patterns, minimal noise reduction while keeping fine details intact"
# Enhancement parameters
GUIDANCE_SCALE = 2.0
NUM_INFERENCE_STEPS = 30
# Supported image extensions
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
def calculate_psnr_ssim(original, enhanced):
"""Calculate PSNR and SSIM between original and enhanced images"""
# Convert images to float arrays
orig_float = img_as_float(np.array(original))
enhanced_float = img_as_float(np.array(enhanced))
# Ensure both images have the same shape
if orig_float.shape != enhanced_float.shape:
min_h = min(orig_float.shape[0], enhanced_float.shape[0])
min_w = min(orig_float.shape[1], enhanced_float.shape[1])
orig_float = orig_float[:min_h, :min_w]
enhanced_float = enhanced_float[:min_h, :min_w]
# Calculate PSNR
psnr = peak_signal_noise_ratio(orig_float, enhanced_float, data_range=1.0)
# Calculate SSIM
if len(orig_float.shape) == 3: # Color image
ssim = structural_similarity(
orig_float, enhanced_float, data_range=1.0, channel_axis=-1
)
else: # Grayscale image
ssim = structural_similarity(orig_float, enhanced_float, data_range=1.0)
return psnr, ssim
def extract_archive(archive_path, extract_to):
"""Extract zip or 7z archive"""
file_ext = Path(archive_path).suffix.lower()
if file_ext == ".zip":
with zipfile.ZipFile(archive_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
elif file_ext == ".7z":
with py7zr.SevenZipFile(archive_path, mode="r") as archive:
archive.extractall(path=extract_to)
else:
raise ValueError(f"Unsupported archive format: {file_ext}")
def find_images(directory):
"""Recursively find all images in directory"""
image_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if Path(file).suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(os.path.join(root, file))
return image_files
def enhance_single_image(image, prompt, guidance_scale, num_steps):
"""Enhance a single image"""
if isinstance(image, str):
input_image = load_image(image)
else:
input_image = image
enhanced_image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
).images[0]
return input_image, enhanced_image
def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
"""Process uploaded files (images or archives)"""
if not files:
return None, None, "Please upload at least one file."
if not prompt or prompt.strip() == "":
prompt = DEFAULT_PROMPT
# Create temporary directories
temp_dir = tempfile.mkdtemp()
output_dir = tempfile.mkdtemp()
try:
all_images = []
results = []
metrics_summary = []
progress(0, desc="Processing files...")
# Process each uploaded file
for file_obj in files:
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
file_ext = Path(file_path).suffix.lower()
# Check if it's an archive
if file_ext in [".zip", ".7z"]:
progress(0.1, desc=f"Extracting archive: {Path(file_path).name}...")
extract_dir = os.path.join(temp_dir, Path(file_path).stem)
extract_archive(file_path, extract_dir)
# Find all images in extracted directory
images = find_images(extract_dir)
for img_path in images:
# Get relative path to maintain directory structure
rel_path = os.path.relpath(img_path, extract_dir)
all_images.append((img_path, rel_path, extract_dir))
# Check if it's an image
elif file_ext in IMAGE_EXTENSIONS:
all_images.append((file_path, Path(file_path).name, None))
if not all_images:
return None, None, "No valid images found in uploaded files."
total_images = len(all_images)
progress(0.2, desc=f"Found {total_images} images. Starting enhancement...")
# Process each image
for idx, (img_path, rel_path, base_dir) in enumerate(all_images):
progress(
(0.2 + 0.7 * idx / total_images),
desc=f"Processing {idx + 1}/{total_images}: {Path(img_path).name}...",
)
# Enhance image
original, enhanced = enhance_single_image(
img_path, prompt, guidance_scale, num_steps
)
# Calculate metrics
psnr, ssim = calculate_psnr_ssim(original, enhanced)
# Prepare output path
if base_dir:
# For archive images, maintain directory structure
output_rel_path = rel_path
output_path = os.path.join(output_dir, output_rel_path)
else:
# For standalone images
output_path = os.path.join(output_dir, rel_path)
# Create output directory if needed
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Add _flux suffix to filename
output_name = Path(output_path).stem + "_flux" + Path(output_path).suffix
output_path = os.path.join(os.path.dirname(output_path), output_name)
# Save enhanced image
enhanced.save(output_path)
results.append(
{
"original": original,
"enhanced": enhanced,
"filename": rel_path,
"output_path": output_path,
"psnr": psnr,
"ssim": ssim,
}
)
metrics_summary.append(f"{rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}")
progress(0.9, desc="Creating output package...")
# Create output zip file
output_zip_path = os.path.join(
tempfile.gettempdir(), "enhanced_images_flux.zip"
)
with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(output_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, output_dir)
zipf.write(file_path, arcname)
# Calculate average metrics
avg_psnr = np.mean([r["psnr"] for r in results])
avg_ssim = np.mean([r["ssim"] for r in results])
# Create summary text
summary = f"✅ Processing completed!\n\n"
summary += f"Total images processed: {total_images}\n"
summary += f"Average PSNR: {avg_psnr:.2f} dB\n"
summary += f"Average SSIM: {avg_ssim:.4f}\n\n"
summary += "Individual metrics:\n"
summary += "\n".join(metrics_summary)
progress(1.0, desc="Done!")
# For display in gallery, show first few results
gallery_images = []
for result in results[:10]: # Show first 10 results
gallery_images.append(
(result["original"], f"Original: {result['filename']}")
)
gallery_images.append(
(
result["enhanced"],
f"Enhanced: {result['filename']}\nPSNR: {result['psnr']:.2f} dB, SSIM: {result['ssim']:.4f}",
)
)
return gallery_images, output_zip_path, summary
except Exception as e:
return None, None, f"Error during processing: {str(e)}"
finally:
# Cleanup temporary directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Create Gradio interface
with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo:
gr.Markdown(
"""
# 🔬 Flux Microscopy Image Enhancement
Upload microscopy images (individual files or compressed archives) for AI-powered enhancement.
**Supported formats:**
- Images: JPG, PNG, BMP, TIFF
- Archives: ZIP, 7Z (will process all images inside)
**Features:**
- Batch processing support
- Custom enhancement prompts
- Quality metrics (PSNR & SSIM)
- Download results as ZIP with `_flux` suffix
"""
)
with gr.Row():
with gr.Column(scale=1):
# File upload
file_input = gr.File(
label="Upload Images or Archive (ZIP/7Z)",
file_count="multiple",
file_types=["image", ".zip", ".7z"],
)
# Prompt input
prompt_input = gr.Textbox(
label="Enhancement Prompt",
placeholder="Enter custom prompt or leave empty for default",
value=DEFAULT_PROMPT,
lines=3,
)
# Parameter controls
gr.Markdown("### Enhancement Parameters")
guidance_scale_input = gr.Slider(
minimum=1.0,
maximum=5.0,
value=GUIDANCE_SCALE,
step=0.1,
label="Guidance Scale",
info="Controls enhancement strength (lower = more conservative, higher = more creative)",
)
num_steps_input = gr.Slider(
minimum=10,
maximum=50,
value=NUM_INFERENCE_STEPS,
step=1,
label="Inference Steps",
info="Number of processing steps (more steps = better quality but slower)",
)
# Process button
process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg")
# Example
gr.Markdown("### Example")
gr.Examples(
examples=[
[None, DEFAULT_PROMPT, GUIDANCE_SCALE, NUM_INFERENCE_STEPS],
],
inputs=[
file_input,
prompt_input,
guidance_scale_input,
num_steps_input,
],
)
with gr.Column(scale=2):
# Gallery for results
gallery_output = gr.Gallery(
label="Results Preview (Original vs Enhanced)",
columns=2,
rows=2,
height="auto",
object_fit="contain",
)
# Download button
download_output = gr.File(label="📥 Download All Enhanced Images (ZIP)")
# Metrics summary
summary_output = gr.Textbox(
label="Processing Summary & Metrics", lines=10, max_lines=20
)
# Process button click
process_btn.click(
fn=process_images,
inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input],
outputs=[gallery_output, download_output, summary_output],
)
gr.Markdown(
"""
---
### Default Parameters
- **Guidance Scale**: 2.0 (conservative for natural enhancement)
- **Inference Steps**: 30 (balanced quality and speed)
💡 You can adjust these parameters above to customize the enhancement process.
### Quality Metrics
- **PSNR** (Peak Signal-to-Noise Ratio): Higher is better (>30 dB is good)
- **SSIM** (Structural Similarity Index): Closer to 1.0 is better (>0.9 is excellent)
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)
|