yichuan-huang commited on
Commit
8a7a521
·
1 Parent(s): 19d1ea6

Add CLI support for microscopy image enhancement and update README

Browse files
Files changed (3) hide show
  1. README.md +41 -1
  2. app.py +143 -72
  3. enhance_cli.py +369 -0
README.md CHANGED
@@ -14,6 +14,10 @@ license: apache-2.0
14
 
15
  An AI-powered microscopy image enhancement tool using the FLUX.2 model. This application provides intelligent image enhancement while preserving cellular structures and fine details.
16
 
 
 
 
 
17
  ## ✨ Features
18
 
19
  - **Batch Processing**: Process multiple images at once or entire archived folders
@@ -91,10 +95,46 @@ cd flux-image-enhance
91
  # Install dependencies
92
  pip install -r requirements.txt
93
 
94
- # Run the application
95
  python app.py
 
 
 
96
  ```
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ## 📋 Requirements
99
 
100
  - Python 3.8+
 
14
 
15
  An AI-powered microscopy image enhancement tool using the FLUX.2 model. This application provides intelligent image enhancement while preserving cellular structures and fine details.
16
 
17
+ **Available in two versions:**
18
+ - 🌐 **Web UI** (Gradio): Interactive web interface with drag-and-drop support
19
+ - ⌨️ **CLI** (Command Line): Batch processing tool for automation and scripting
20
+
21
  ## ✨ Features
22
 
23
  - **Batch Processing**: Process multiple images at once or entire archived folders
 
95
  # Install dependencies
96
  pip install -r requirements.txt
97
 
98
+ # Run the web UI
99
  python app.py
100
+
101
+ # Or use the CLI version
102
+ python enhance_cli.py -i input.jpg -o output/
103
  ```
104
 
105
+ ## ⌨️ CLI Usage
106
+
107
+ For batch processing and automation, use the command-line interface:
108
+
109
+ ```bash
110
+ # Basic usage
111
+ python enhance_cli.py -i input.jpg -o output/
112
+
113
+ # Process multiple images
114
+ python enhance_cli.py -i img1.jpg img2.png img3.tif -o output/
115
+
116
+ # Process entire directory (recursive)
117
+ python enhance_cli.py -i images_folder/ -o output/
118
+
119
+ # Custom parameters
120
+ python enhance_cli.py -i input.jpg -o output/ \
121
+ --guidance-scale 3.0 \
122
+ --steps 40 \
123
+ --prompt "enhance cellular structure"
124
+
125
+ # Quiet mode (minimal output)
126
+ python enhance_cli.py -i input.jpg -o output/ --quiet
127
+ ```
128
+
129
+ ### CLI Arguments
130
+
131
+ - `-i, --input`: Input path(s) - image files or directories (required)
132
+ - `-o, --output`: Output directory for enhanced images (required)
133
+ - `-p, --prompt`: Enhancement prompt (optional)
134
+ - `-g, --guidance-scale`: Guidance scale 1.0-5.0 (default: 2.0)
135
+ - `-s, --steps`: Inference steps 10-50 (default: 30)
136
+ - `-q, --quiet`: Quiet mode - minimal output
137
+
138
  ## 📋 Requirements
139
 
140
  - Python 3.8+
app.py CHANGED
@@ -1,21 +1,23 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- import spaces
 
4
 
5
  from diffusers import Flux2Pipeline
6
  from diffusers.utils import load_image
7
- from pathlib import Path
8
 
9
- from PIL import Image
10
- import os
11
- import zipfile
12
  import py7zr
13
- import tempfile
14
- import shutil
15
- import numpy as np
16
  from skimage.metrics import peak_signal_noise_ratio, structural_similarity
17
  from skimage.util import img_as_float
18
 
 
19
  # =========================
20
  # Config
21
  # =========================
@@ -36,8 +38,47 @@ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
36
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
37
  TORCH_DTYPE = torch.bfloat16
38
 
 
39
  # =========================
40
- # Global cached pipeline (created only inside GPU runtime)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # =========================
42
  _pipe = None
43
 
@@ -92,38 +133,43 @@ def find_images(directory: str):
92
 
93
  def _get_pipe():
94
  """
95
- Lazy-load the pipeline (ONLY call this inside a @spaces.GPU function).
 
 
96
  """
97
  global _pipe
98
 
99
  if _pipe is None:
100
  if not torch.cuda.is_available():
101
- # On ZeroGPU startup this is False; inside @spaces.GPU it should become True
102
  raise RuntimeError(
103
- "No GPU found in current runtime. "
104
- "On HF ZeroGPU you must only load the bnb-4bit model inside a @spaces.GPU function."
105
  )
106
 
107
  _pipe = Flux2Pipeline.from_pretrained(
108
  MODEL_ID,
109
  torch_dtype=TORCH_DTYPE,
110
- # removed fix_mistral_regex=True (it was ignored for Flux2Pipeline)
111
  )
112
  _pipe.to("cuda")
113
 
114
  return _pipe
115
 
116
 
117
- @spaces.GPU(duration=180)
118
  def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
119
  """
120
  Process uploaded files (images or archives) and return:
121
  - gallery preview (first 10 pairs)
122
- - path to output zip
 
123
  - summary text
124
  """
125
  if not files:
126
- return None, None, "Please upload at least one file."
 
 
 
 
 
127
 
128
  if not prompt or prompt.strip() == "":
129
  prompt = DEFAULT_PROMPT
@@ -131,9 +177,16 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
131
  guidance_scale = float(guidance_scale)
132
  num_steps = int(num_steps)
133
 
134
- # Temp dirs
135
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
136
- output_dir = tempfile.mkdtemp(prefix="flux_out_")
 
 
 
 
 
 
 
137
 
138
  try:
139
  progress(0.0, desc="Preparing files...")
@@ -144,6 +197,7 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
144
  file_ext = Path(file_path).suffix.lower()
145
 
146
  if file_ext in [".zip", ".7z"]:
 
147
  progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
148
  extract_dir = os.path.join(temp_dir, Path(file_path).stem)
149
  os.makedirs(extract_dir, exist_ok=True)
@@ -158,12 +212,17 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
158
  all_images.append((file_path, Path(file_path).name, None))
159
 
160
  if not all_images:
161
- return None, None, "No valid images found in uploaded files."
 
 
 
 
 
162
 
163
  total_images = len(all_images)
164
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
165
 
166
- # Load pipeline INSIDE GPU runtime
167
  pipe = _get_pipe()
168
 
169
  results = []
@@ -177,10 +236,8 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
177
  desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
178
  )
179
 
180
- # Load input image
181
  input_image = load_image(img_path)
182
 
183
- # Run inference
184
  enhanced_image = pipe(
185
  image=input_image,
186
  prompt=prompt,
@@ -188,14 +245,10 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
188
  num_inference_steps=num_steps,
189
  ).images[0]
190
 
191
- # Metrics
192
  psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
193
 
194
- # Output path (preserve structure if from archive)
195
- if base_dir:
196
- output_rel_path = rel_path
197
- else:
198
- output_rel_path = rel_path
199
 
200
  out_path = os.path.join(output_dir, output_rel_path)
201
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
@@ -216,35 +269,25 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
216
  "ssim": ssim,
217
  }
218
  )
 
219
  metrics_lines.append(
220
  f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
221
  )
222
 
223
- progress(0.92, desc="Packaging ZIP...")
224
-
225
- # Create output zip
226
- output_zip_path = os.path.join(
227
- tempfile.gettempdir(), "enhanced_images_flux.zip"
228
- )
229
- with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
230
- for root, _, fs in os.walk(output_dir):
231
- for f in fs:
232
- fp = os.path.join(root, f)
233
- arcname = os.path.relpath(fp, output_dir)
234
- zipf.write(fp, arcname)
235
-
236
  avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
237
  avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
238
 
239
  summary = (
240
  "✅ Processing completed!\n\n"
 
 
241
  f"Total images processed: {total_images}\n"
242
  f"Average PSNR: {avg_psnr:.2f} dB\n"
243
  f"Average SSIM: {avg_ssim:.4f}\n\n"
244
  "Individual metrics:\n" + "\n".join(metrics_lines)
245
  )
246
 
247
- # Build gallery preview (first 10 results -> 20 images: original+enhanced)
248
  gallery_images = []
249
  for r in results[:10]:
250
  gallery_images.append((r["original"], f"Original: {r['filename']}"))
@@ -255,18 +298,53 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
255
  )
256
  )
257
 
258
- progress(1.0, desc="Done!")
259
- return gallery_images, output_zip_path, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  except Exception as e:
262
- return None, None, f"Error during processing: {str(e)}"
 
 
 
 
 
263
 
264
  finally:
265
- # Cleanup temp dirs
 
266
  if os.path.exists(temp_dir):
267
  shutil.rmtree(temp_dir, ignore_errors=True)
268
- if os.path.exists(output_dir):
269
- shutil.rmtree(output_dir, ignore_errors=True)
270
 
271
 
272
  # =========================
@@ -274,7 +352,7 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
274
  # =========================
275
  with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo:
276
  gr.Markdown(
277
- """
278
  # 🔬 Flux Microscopy Image Enhancement
279
 
280
  Upload microscopy images (individual files or compressed archives) for AI-powered enhancement.
@@ -283,11 +361,12 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
283
  - Images: JPG, PNG, BMP, TIFF
284
  - Archives: ZIP, 7Z (will process all images inside)
285
 
286
- **Features:**
287
- - Batch processing support
288
- - Custom enhancement prompts
289
- - Quality metrics (PSNR & SSIM)
290
- - Download results as ZIP with `_flux` suffix
 
291
  """
292
  )
293
 
@@ -328,19 +407,6 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
328
 
329
  process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg")
330
 
331
- gr.Markdown("### Example")
332
- gr.Examples(
333
- examples=[
334
- [None, DEFAULT_PROMPT, GUIDANCE_SCALE, NUM_INFERENCE_STEPS],
335
- ],
336
- inputs=[
337
- file_input,
338
- prompt_input,
339
- guidance_scale_input,
340
- num_steps_input,
341
- ],
342
- )
343
-
344
  with gr.Column(scale=2):
345
  gallery_output = gr.Gallery(
346
  label="Results Preview (Original vs Enhanced)",
@@ -350,7 +416,12 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
350
  object_fit="contain",
351
  )
352
 
353
- download_output = gr.File(label="📥 Download All Enhanced Images (ZIP)")
 
 
 
 
 
354
 
355
  summary_output = gr.Textbox(
356
  label="Processing Summary & Metrics",
@@ -361,7 +432,7 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
361
  process_btn.click(
362
  fn=process_images,
363
  inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input],
364
- outputs=[gallery_output, download_output, summary_output],
365
  )
366
 
367
  gr.Markdown(
@@ -372,8 +443,8 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
372
  - **Inference Steps**: 30 (balanced quality and speed)
373
 
374
  ### Quality Metrics
375
- - **PSNR** (Peak Signal-to-Noise Ratio): Higher is better (>30 dB is good)
376
- - **SSIM** (Structural Similarity Index): Closer to 1.0 is better (>0.9 is excellent)
377
  """
378
  )
379
 
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ import shutil
5
+ import uuid
6
+ from pathlib import Path
7
+
8
  import gradio as gr
9
  import torch
10
+ import numpy as np
11
+ from PIL import Image
12
 
13
  from diffusers import Flux2Pipeline
14
  from diffusers.utils import load_image
 
15
 
 
 
 
16
  import py7zr
 
 
 
17
  from skimage.metrics import peak_signal_noise_ratio, structural_similarity
18
  from skimage.util import img_as_float
19
 
20
+
21
  # =========================
22
  # Config
23
  # =========================
 
38
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
39
  TORCH_DTYPE = torch.bfloat16
40
 
41
+
42
  # =========================
43
+ # HF Spaces env detection + safe "spaces" decorator
44
+ # =========================
45
+ def _is_hf_space_env() -> bool:
46
+ """
47
+ Detect whether we're running on Hugging Face Spaces runtime.
48
+ Common env vars present in Spaces:
49
+ - SPACE_ID
50
+ - HF_SPACE_ID
51
+ - SYSTEM=spaces
52
+ """
53
+ return any(os.getenv(k) for k in ("SPACE_ID", "HF_SPACE_ID")) or (
54
+ os.getenv("SYSTEM", "").lower() == "spaces"
55
+ )
56
+
57
+
58
+ IS_HF_SPACES = _is_hf_space_env()
59
+
60
+ try:
61
+ import spaces # available on HF Spaces
62
+ except Exception:
63
+ spaces = None
64
+
65
+
66
+ def gpu_decorator(duration: int = 180):
67
+ """
68
+ If on HF Spaces and spaces.GPU exists -> use it.
69
+ Else -> no-op decorator (local runs normally, using local GPU if available).
70
+ """
71
+ if IS_HF_SPACES and (spaces is not None) and hasattr(spaces, "GPU"):
72
+ return spaces.GPU(duration=duration)
73
+
74
+ def _noop(fn):
75
+ return fn
76
+
77
+ return _noop
78
+
79
+
80
+ # =========================
81
+ # Global cached pipeline
82
  # =========================
83
  _pipe = None
84
 
 
133
 
134
  def _get_pipe():
135
  """
136
+ Lazy-load the pipeline.
137
+ - On HF Spaces ZeroGPU: MUST be called inside a @spaces.GPU runtime (gpu_decorator handles this).
138
+ - Locally: will use local GPU.
139
  """
140
  global _pipe
141
 
142
  if _pipe is None:
143
  if not torch.cuda.is_available():
 
144
  raise RuntimeError(
145
+ "No CUDA GPU detected. This 4-bit bnb model requires GPU to load/run."
 
146
  )
147
 
148
  _pipe = Flux2Pipeline.from_pretrained(
149
  MODEL_ID,
150
  torch_dtype=TORCH_DTYPE,
 
151
  )
152
  _pipe.to("cuda")
153
 
154
  return _pipe
155
 
156
 
157
+ @gpu_decorator(duration=180)
158
  def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
159
  """
160
  Process uploaded files (images or archives) and return:
161
  - gallery preview (first 10 pairs)
162
+ - files download (when only images uploaded)
163
+ - zip download (when archive uploaded)
164
  - summary text
165
  """
166
  if not files:
167
+ return (
168
+ None,
169
+ gr.update(value=None, visible=False),
170
+ gr.update(value=None, visible=False),
171
+ "Please upload at least one file.",
172
+ )
173
 
174
  if not prompt or prompt.strip() == "":
175
  prompt = DEFAULT_PROMPT
 
177
  guidance_scale = float(guidance_scale)
178
  num_steps = int(num_steps)
179
 
180
+ # Temp for extraction / staging input
181
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
182
+
183
+ # IMPORTANT:
184
+ # Result files MUST remain on disk for Gradio download.
185
+ # So we create a persistent temp dir and DO NOT delete it in finally.
186
+ run_id = uuid.uuid4().hex[:10]
187
+ output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
188
+
189
+ has_archive = False
190
 
191
  try:
192
  progress(0.0, desc="Preparing files...")
 
197
  file_ext = Path(file_path).suffix.lower()
198
 
199
  if file_ext in [".zip", ".7z"]:
200
+ has_archive = True
201
  progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
202
  extract_dir = os.path.join(temp_dir, Path(file_path).stem)
203
  os.makedirs(extract_dir, exist_ok=True)
 
212
  all_images.append((file_path, Path(file_path).name, None))
213
 
214
  if not all_images:
215
+ return (
216
+ None,
217
+ gr.update(value=None, visible=False),
218
+ gr.update(value=None, visible=False),
219
+ "No valid images found in uploaded files.",
220
+ )
221
 
222
  total_images = len(all_images)
223
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
224
 
225
+ # Load pipeline (inside GPU runtime on Spaces; local GPU otherwise)
226
  pipe = _get_pipe()
227
 
228
  results = []
 
236
  desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
237
  )
238
 
 
239
  input_image = load_image(img_path)
240
 
 
241
  enhanced_image = pipe(
242
  image=input_image,
243
  prompt=prompt,
 
245
  num_inference_steps=num_steps,
246
  ).images[0]
247
 
 
248
  psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
249
 
250
+ # Preserve structure if from archive
251
+ output_rel_path = rel_path
 
 
 
252
 
253
  out_path = os.path.join(output_dir, output_rel_path)
254
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
 
269
  "ssim": ssim,
270
  }
271
  )
272
+
273
  metrics_lines.append(
274
  f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
275
  )
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
278
  avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
279
 
280
  summary = (
281
  "✅ Processing completed!\n\n"
282
+ f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n"
283
+ f"GPU available: {torch.cuda.is_available()}\n\n"
284
  f"Total images processed: {total_images}\n"
285
  f"Average PSNR: {avg_psnr:.2f} dB\n"
286
  f"Average SSIM: {avg_ssim:.4f}\n\n"
287
  "Individual metrics:\n" + "\n".join(metrics_lines)
288
  )
289
 
290
+ # Build gallery preview (first 10 results -> original+enhanced)
291
  gallery_images = []
292
  for r in results[:10]:
293
  gallery_images.append((r["original"], f"Original: {r['filename']}"))
 
298
  )
299
  )
300
 
301
+ # Decide download output:
302
+ # - If user uploaded any archive -> provide ZIP
303
+ # - Else -> provide enhanced files directly
304
+ if has_archive:
305
+ progress(0.92, desc="Packaging ZIP...")
306
+
307
+ output_zip_path = os.path.join(
308
+ tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip"
309
+ )
310
+
311
+ with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
312
+ for root, _, fs in os.walk(output_dir):
313
+ for f in fs:
314
+ fp = os.path.join(root, f)
315
+ arcname = os.path.relpath(fp, output_dir)
316
+ zipf.write(fp, arcname)
317
+
318
+ progress(1.0, desc="Done!")
319
+ return (
320
+ gallery_images,
321
+ gr.update(value=None, visible=False), # files hidden
322
+ gr.update(value=output_zip_path, visible=True), # zip shown
323
+ summary,
324
+ )
325
+ else:
326
+ enhanced_paths = [r["output_path"] for r in results]
327
+ progress(1.0, desc="Done!")
328
+ return (
329
+ gallery_images,
330
+ gr.update(value=enhanced_paths, visible=True), # files shown
331
+ gr.update(value=None, visible=False), # zip hidden
332
+ summary,
333
+ )
334
 
335
  except Exception as e:
336
+ return (
337
+ None,
338
+ gr.update(value=None, visible=False),
339
+ gr.update(value=None, visible=False),
340
+ f"Error during processing: {str(e)}",
341
+ )
342
 
343
  finally:
344
+ # Cleanup ONLY input/extraction temp dir.
345
+ # DO NOT delete output_dir because Gradio downloads need the files to remain.
346
  if os.path.exists(temp_dir):
347
  shutil.rmtree(temp_dir, ignore_errors=True)
 
 
348
 
349
 
350
  # =========================
 
352
  # =========================
353
  with gr.Blocks(title="Flux Microscopy Image Enhancement") as demo:
354
  gr.Markdown(
355
+ f"""
356
  # 🔬 Flux Microscopy Image Enhancement
357
 
358
  Upload microscopy images (individual files or compressed archives) for AI-powered enhancement.
 
361
  - Images: JPG, PNG, BMP, TIFF
362
  - Archives: ZIP, 7Z (will process all images inside)
363
 
364
+ **Download behavior:**
365
+ - If you upload **only images** → you can download the enhanced **image files directly** (`*_flux` suffix)
366
+ - If you upload **a ZIP/7Z** → you can download **one ZIP** (images inside use `*_flux` suffix)
367
+
368
+ **Runtime detection:**
369
+ - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
370
  """
371
  )
372
 
 
407
 
408
  process_btn = gr.Button("🚀 Enhance Images", variant="primary", size="lg")
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  with gr.Column(scale=2):
411
  gallery_output = gr.Gallery(
412
  label="Results Preview (Original vs Enhanced)",
 
416
  object_fit="contain",
417
  )
418
 
419
+ files_output = gr.Files(
420
+ label="📥 Download Enhanced Images (Files)", visible=False
421
+ )
422
+ zip_output = gr.File(
423
+ label="📥 Download Enhanced Images (ZIP)", visible=False
424
+ )
425
 
426
  summary_output = gr.Textbox(
427
  label="Processing Summary & Metrics",
 
432
  process_btn.click(
433
  fn=process_images,
434
  inputs=[file_input, prompt_input, guidance_scale_input, num_steps_input],
435
+ outputs=[gallery_output, files_output, zip_output, summary_output],
436
  )
437
 
438
  gr.Markdown(
 
443
  - **Inference Steps**: 30 (balanced quality and speed)
444
 
445
  ### Quality Metrics
446
+ - **PSNR** (Peak Signal-to-Noise Ratio): Higher is better
447
+ - **SSIM** (Structural Similarity Index): Closer to 1.0 is better
448
  """
449
  )
450
 
enhance_cli.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flux Microscopy Image Enhancement - Command Line Interface
4
+ Process microscopy images with AI-powered enhancement using argparse for all parameters
5
+ """
6
+
7
+ import argparse
8
+ import torch
9
+ from diffusers import Flux2Pipeline
10
+ from diffusers.utils import load_image
11
+ from pathlib import Path
12
+ from PIL import Image
13
+ import os
14
+ import shutil
15
+ import numpy as np
16
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
17
+ from skimage.util import img_as_float
18
+ import sys
19
+ from typing import List, Tuple
20
+
21
+ # =========================
22
+ # Config
23
+ # =========================
24
+ DEFAULT_PROMPT = (
25
+ "enhance microscopy image with subtle improvements, gently increase cellular boundary clarity, "
26
+ "preserve original morphological structure, maintain authentic texture patterns, "
27
+ "minimal noise reduction while keeping fine details intact"
28
+ )
29
+
30
+ GUIDANCE_SCALE = 2.0
31
+ NUM_INFERENCE_STEPS = 30
32
+
33
+ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
34
+
35
+ MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
36
+ TORCH_DTYPE = torch.bfloat16
37
+
38
+ # =========================
39
+ # Global cached pipeline
40
+ # =========================
41
+ _pipe = None
42
+
43
+
44
+ def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image):
45
+ """Calculate PSNR and SSIM between original and enhanced images."""
46
+ orig_float = img_as_float(np.array(original))
47
+ enh_float = img_as_float(np.array(enhanced))
48
+
49
+ # Ensure same shape (crop to min overlap)
50
+ if orig_float.shape != enh_float.shape:
51
+ min_h = min(orig_float.shape[0], enh_float.shape[0])
52
+ min_w = min(orig_float.shape[1], enh_float.shape[1])
53
+ orig_float = orig_float[:min_h, :min_w]
54
+ enh_float = enh_float[:min_h, :min_w]
55
+
56
+ psnr = peak_signal_noise_ratio(orig_float, enh_float, data_range=1.0)
57
+
58
+ if orig_float.ndim == 3:
59
+ ssim = structural_similarity(
60
+ orig_float, enh_float, data_range=1.0, channel_axis=-1
61
+ )
62
+ else:
63
+ ssim = structural_similarity(orig_float, enh_float, data_range=1.0)
64
+
65
+ return float(psnr), float(ssim)
66
+
67
+
68
+ def find_images(directory: str) -> List[str]:
69
+ """Recursively find all images in a directory."""
70
+ image_files = []
71
+ for root, _, files in os.walk(directory):
72
+ for f in files:
73
+ if Path(f).suffix.lower() in IMAGE_EXTENSIONS:
74
+ image_files.append(os.path.join(root, f))
75
+ return image_files
76
+
77
+
78
+ def _get_pipe():
79
+ """
80
+ Lazy-load the pipeline on local GPU.
81
+ """
82
+ global _pipe
83
+
84
+ if _pipe is None:
85
+ if not torch.cuda.is_available():
86
+ raise RuntimeError(
87
+ "No GPU found. This tool requires a CUDA-compatible GPU to run."
88
+ )
89
+
90
+ print("Loading Flux model...")
91
+ _pipe = Flux2Pipeline.from_pretrained(
92
+ MODEL_ID,
93
+ torch_dtype=TORCH_DTYPE,
94
+ )
95
+ _pipe.to("cuda")
96
+ print(f"Model loaded successfully on GPU: {torch.cuda.get_device_name(0)}")
97
+
98
+ return _pipe
99
+
100
+
101
+ def process_images_cli(
102
+ input_paths: List[str],
103
+ output_dir: str,
104
+ prompt: str,
105
+ guidance_scale: float,
106
+ num_steps: int,
107
+ verbose: bool = True,
108
+ ) -> Tuple[int, List[dict]]:
109
+ """
110
+ Process images from input paths and save to output directory.
111
+
112
+ Args:
113
+ input_paths: List of file/directory paths (images or folders)
114
+ output_dir: Output directory for enhanced images
115
+ prompt: Enhancement prompt
116
+ guidance_scale: Guidance scale for inference
117
+ num_steps: Number of inference steps
118
+ verbose: Whether to print progress messages
119
+
120
+ Returns:
121
+ Tuple of (total_images_processed, results_list)
122
+ """
123
+ if not input_paths:
124
+ raise ValueError("No input files provided")
125
+
126
+ if not prompt or prompt.strip() == "":
127
+ prompt = DEFAULT_PROMPT
128
+
129
+ # Create output directory
130
+ os.makedirs(output_dir, exist_ok=True)
131
+
132
+ try:
133
+ if verbose:
134
+ print("=" * 60)
135
+ print("Flux Microscopy Image Enhancement - CLI")
136
+ print("=" * 60)
137
+ print(f"Output directory: {output_dir}")
138
+ print(f"Prompt: {prompt}")
139
+ print(f"Guidance scale: {guidance_scale}")
140
+ print(f"Inference steps: {num_steps}")
141
+ print("=" * 60)
142
+
143
+ all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel)
144
+
145
+ # Process each input path
146
+ for input_path in input_paths:
147
+ if not os.path.exists(input_path):
148
+ print(f"Warning: Path not found: {input_path}")
149
+ continue
150
+
151
+ # Check if it's a directory
152
+ if os.path.isdir(input_path):
153
+ if verbose:
154
+ print(f"\n[Scanning] Directory: {input_path} ...")
155
+
156
+ images = find_images(input_path)
157
+ for img_path in images:
158
+ rel_path = os.path.relpath(img_path, input_path)
159
+ all_images.append((img_path, rel_path, input_path))
160
+
161
+ if verbose:
162
+ print(f" Found {len(images)} images in directory")
163
+
164
+ # Check if it's an image file
165
+ elif os.path.isfile(input_path):
166
+ file_ext = Path(input_path).suffix.lower()
167
+ if file_ext in IMAGE_EXTENSIONS:
168
+ all_images.append((input_path, Path(input_path).name, None))
169
+ else:
170
+ print(f"Warning: Unsupported file format: {input_path}")
171
+ else:
172
+ print(f"Warning: Invalid path: {input_path}")
173
+
174
+ if not all_images:
175
+ raise ValueError("No valid images found in input files")
176
+
177
+ total_images = len(all_images)
178
+ if verbose:
179
+ print(f"\n[Processing] Total images to enhance: {total_images}")
180
+ print("-" * 60)
181
+
182
+ # Load pipeline
183
+ pipe = _get_pipe()
184
+
185
+ results = []
186
+ metrics_lines = []
187
+
188
+ for idx, (img_path, rel_path, base_dir) in enumerate(all_images, 1):
189
+ if verbose:
190
+ print(f"\n[{idx}/{total_images}] Processing: {Path(img_path).name}")
191
+
192
+ # Load input image
193
+ input_image = load_image(img_path)
194
+
195
+ # Run inference
196
+ enhanced_image = pipe(
197
+ image=input_image,
198
+ prompt=prompt,
199
+ guidance_scale=guidance_scale,
200
+ num_inference_steps=num_steps,
201
+ ).images[0]
202
+
203
+ # Calculate metrics
204
+ psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
205
+
206
+ if verbose:
207
+ print(f" PSNR: {psnr:.2f} dB | SSIM: {ssim:.4f}")
208
+
209
+ # Determine output path (preserve structure if from directory)
210
+ if base_dir:
211
+ output_rel_path = rel_path
212
+ else:
213
+ output_rel_path = rel_path
214
+
215
+ out_path = os.path.join(output_dir, output_rel_path)
216
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
217
+
218
+ # Add _flux suffix
219
+ out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
220
+ out_path = os.path.join(os.path.dirname(out_path), out_name)
221
+
222
+ # Save enhanced image
223
+ enhanced_image.save(out_path)
224
+
225
+ if verbose:
226
+ print(f" Saved to: {out_path}")
227
+
228
+ results.append(
229
+ {
230
+ "original_path": img_path,
231
+ "output_path": out_path,
232
+ "filename": output_rel_path,
233
+ "psnr": psnr,
234
+ "ssim": ssim,
235
+ }
236
+ )
237
+
238
+ metrics_lines.append(
239
+ f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
240
+ )
241
+
242
+ # Print summary
243
+ if verbose:
244
+ avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
245
+ avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
246
+
247
+ print("\n" + "=" * 60)
248
+ print("SUMMARY")
249
+ print("=" * 60)
250
+ print(f"Total images processed: {total_images}")
251
+ print(f"Average PSNR: {avg_psnr:.2f} dB")
252
+ print(f"Average SSIM: {avg_ssim:.4f}")
253
+ print("\nIndividual metrics:")
254
+ for line in metrics_lines:
255
+ print(f" {line}")
256
+ print("=" * 60)
257
+
258
+ return total_images, results
259
+
260
+ except Exception as e:
261
+ print(f"\nError during processing: {str(e)}", file=sys.stderr)
262
+ raise
263
+
264
+
265
+ def main():
266
+ """Main entry point for CLI."""
267
+ parser = argparse.ArgumentParser(
268
+ description="Flux Microscopy Image Enhancement - Command Line Interface",
269
+ formatter_class=argparse.RawDescriptionHelpFormatter,
270
+ epilog="""
271
+ Examples:
272
+ # Enhance a single image
273
+ python enhance_cli.py -i input.jpg -o output/
274
+
275
+ # Enhance multiple images with custom parameters
276
+ python enhance_cli.py -i image1.jpg image2.png -o output/ --guidance-scale 3.0 --steps 40
277
+
278
+ # Process all images in a directory
279
+ python enhance_cli.py -i images_folder/ -o output/
280
+
281
+ # Process with custom prompt
282
+ python enhance_cli.py -i input.jpg -o output/ --prompt "enhance cellular structure"
283
+
284
+ # Quiet mode (less verbose output)
285
+ python enhance_cli.py -i input.jpg -o output/ --quiet
286
+
287
+ Supported formats:
288
+ - Images: JPG, JPEG, PNG, BMP, TIFF, TIF
289
+ - Directories: Will recursively process all images inside
290
+ """,
291
+ )
292
+
293
+ # Required arguments
294
+ parser.add_argument(
295
+ "-i",
296
+ "--input",
297
+ nargs="+",
298
+ required=True,
299
+ help="Input path(s) - image files or directories. Multiple paths supported.",
300
+ )
301
+
302
+ parser.add_argument(
303
+ "-o", "--output", required=True, help="Output directory for enhanced images"
304
+ )
305
+
306
+ # Optional arguments
307
+ parser.add_argument(
308
+ "-p",
309
+ "--prompt",
310
+ default=DEFAULT_PROMPT,
311
+ help=f"Enhancement prompt (default: '{DEFAULT_PROMPT[:50]}...')",
312
+ )
313
+
314
+ parser.add_argument(
315
+ "-g",
316
+ "--guidance-scale",
317
+ type=float,
318
+ default=GUIDANCE_SCALE,
319
+ help=f"Guidance scale (1.0-5.0, lower=conservative, higher=creative, default: {GUIDANCE_SCALE})",
320
+ )
321
+
322
+ parser.add_argument(
323
+ "-s",
324
+ "--steps",
325
+ type=int,
326
+ default=NUM_INFERENCE_STEPS,
327
+ help=f"Number of inference steps (10-50, more=better quality but slower, default: {NUM_INFERENCE_STEPS})",
328
+ )
329
+
330
+ parser.add_argument(
331
+ "-q",
332
+ "--quiet",
333
+ action="store_true",
334
+ help="Quiet mode - reduce output verbosity",
335
+ )
336
+
337
+ args = parser.parse_args()
338
+
339
+ # Validate arguments
340
+ if args.guidance_scale < 1.0 or args.guidance_scale > 5.0:
341
+ parser.error("guidance-scale must be between 1.0 and 5.0")
342
+
343
+ if args.steps < 10 or args.steps > 50:
344
+ parser.error("steps must be between 10 and 50")
345
+
346
+ # Process images
347
+ try:
348
+ total, results = process_images_cli(
349
+ input_paths=args.input,
350
+ output_dir=args.output,
351
+ prompt=args.prompt,
352
+ guidance_scale=args.guidance_scale,
353
+ num_steps=args.steps,
354
+ verbose=not args.quiet,
355
+ )
356
+
357
+ if not args.quiet:
358
+ print("\n✅ Enhancement completed successfully!")
359
+ print(f"📁 Output directory: {args.output}")
360
+
361
+ sys.exit(0)
362
+
363
+ except Exception as e:
364
+ print(f"\n❌ Error: {str(e)}", file=sys.stderr)
365
+ sys.exit(1)
366
+
367
+
368
+ if __name__ == "__main__":
369
+ main()