yichuan-huang commited on
Commit
889677a
·
1 Parent(s): d7faea4

Implement robust @spaces.GPU handling for ZeroGPU support; streamline HF Spaces detection and improve process_images function clarity

Browse files
Files changed (1) hide show
  1. app.py +60 -74
app.py CHANGED
@@ -18,6 +18,35 @@ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
18
  from skimage.util import img_as_float
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # =========================
22
  # Config
23
  # =========================
@@ -36,29 +65,7 @@ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
36
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
37
  TORCH_DTYPE = torch.bfloat16
38
 
39
-
40
- # =========================
41
- # HF Spaces detection (robust)
42
- # =========================
43
- try:
44
- import spaces # Only available on Hugging Face Spaces
45
-
46
- SPACES_AVAILABLE = True
47
- except Exception:
48
- spaces = None
49
- SPACES_AVAILABLE = False
50
-
51
- # Extra friendly label for UI (not used for logic)
52
- IS_HF_SPACES = (
53
- SPACES_AVAILABLE
54
- or (os.getenv("SYSTEM", "").lower() == "spaces")
55
- or bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))
56
- )
57
-
58
-
59
- # =========================
60
  # Global cached pipeline
61
- # =========================
62
  _pipe = None
63
 
64
 
@@ -89,7 +96,6 @@ def calculate_psnr_ssim(original: Image.Image, enhanced: Image.Image):
89
  def extract_archive(archive_path: str, extract_to: str):
90
  """Extract zip or 7z archive."""
91
  file_ext = Path(archive_path).suffix.lower()
92
-
93
  if file_ext == ".zip":
94
  with zipfile.ZipFile(archive_path, "r") as z:
95
  z.extractall(extract_to)
@@ -113,8 +119,9 @@ def find_images(directory: str):
113
  def _get_pipe():
114
  """
115
  Lazy-load the pipeline.
116
- - On HF ZeroGPU: this MUST be called inside a @spaces.GPU function runtime.
117
- - Locally: uses local CUDA GPU.
 
118
  """
119
  global _pipe
120
 
@@ -123,7 +130,7 @@ def _get_pipe():
123
  raise RuntimeError(
124
  "CUDA GPU is not available in the current runtime. "
125
  "This bnb-4bit model requires GPU. "
126
- "On HF ZeroGPU, ensure the function is decorated with @spaces.GPU."
127
  )
128
 
129
  _pipe = Flux2Pipeline.from_pretrained(
@@ -139,9 +146,7 @@ def _process_images_impl(
139
  files, prompt, guidance_scale, num_steps, progress=gr.Progress()
140
  ):
141
  """
142
- Shared implementation used by both:
143
- - HF Spaces GPU wrapper
144
- - Local runtime
145
  Returns 4 outputs:
146
  gallery, files_download, zip_download, summary
147
  """
@@ -162,8 +167,7 @@ def _process_images_impl(
162
  # Temp for extraction/staging input
163
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
164
 
165
- # IMPORTANT:
166
- # Output files MUST remain on disk for Gradio downloads.
167
  run_id = uuid.uuid4().hex[:10]
168
  output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
169
 
@@ -172,7 +176,7 @@ def _process_images_impl(
172
  try:
173
  progress(0.0, desc="Preparing files...")
174
 
175
- all_images = [] # list of tuples: (img_path, rel_path, base_dir_for_rel)
176
  for file_obj in files:
177
  file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj)
178
  file_ext = Path(file_path).suffix.lower()
@@ -180,6 +184,7 @@ def _process_images_impl(
180
  if file_ext in [".zip", ".7z"]:
181
  has_archive = True
182
  progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
 
183
  extract_dir = os.path.join(temp_dir, Path(file_path).stem)
184
  os.makedirs(extract_dir, exist_ok=True)
185
  extract_archive(file_path, extract_dir)
@@ -187,10 +192,10 @@ def _process_images_impl(
187
  images = find_images(extract_dir)
188
  for img_path in images:
189
  rel_path = os.path.relpath(img_path, extract_dir)
190
- all_images.append((img_path, rel_path, extract_dir))
191
 
192
  elif file_ext in IMAGE_EXTENSIONS:
193
- all_images.append((file_path, Path(file_path).name, None))
194
 
195
  if not all_images:
196
  return (
@@ -203,14 +208,14 @@ def _process_images_impl(
203
  total_images = len(all_images)
204
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
205
 
206
- pipe = _get_pipe() # IMPORTANT: must be inside GPU runtime on Spaces
207
 
208
  results = []
209
  metrics_lines = []
210
 
211
  progress(0.15, desc="Enhancing images...")
212
 
213
- for idx, (img_path, rel_path, base_dir) in enumerate(all_images):
214
  progress(
215
  0.15 + 0.75 * (idx / max(1, total_images)),
216
  desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
@@ -227,12 +232,10 @@ def _process_images_impl(
227
 
228
  psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
229
 
230
- # Preserve structure if from archive
231
- output_rel_path = rel_path
232
- out_path = os.path.join(output_dir, output_rel_path)
233
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
234
 
235
- # Add _flux suffix
236
  out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
237
  out_path = os.path.join(os.path.dirname(out_path), out_name)
238
 
@@ -242,23 +245,21 @@ def _process_images_impl(
242
  {
243
  "original": input_image,
244
  "enhanced": enhanced_image,
245
- "filename": output_rel_path,
246
  "output_path": out_path,
247
  "psnr": psnr,
248
  "ssim": ssim,
249
  }
250
  )
251
 
252
- metrics_lines.append(
253
- f"{output_rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}"
254
- )
255
 
256
  avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
257
  avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
258
 
259
  summary = (
260
  "✅ Processing completed!\n\n"
261
- f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local'}\n"
262
  f"GPU available: {torch.cuda.is_available()}\n\n"
263
  f"Total images processed: {total_images}\n"
264
  f"Average PSNR: {avg_psnr:.2f} dB\n"
@@ -277,16 +278,12 @@ def _process_images_impl(
277
  )
278
  )
279
 
280
- # Download behavior:
281
- # - If any archive uploaded -> zip
282
- # - Else -> direct files list
283
  if has_archive:
284
  progress(0.92, desc="Packaging ZIP...")
285
 
286
  output_zip_path = os.path.join(
287
  tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip"
288
  )
289
-
290
  with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
291
  for root, _, fs in os.walk(output_dir):
292
  for f in fs:
@@ -297,8 +294,8 @@ def _process_images_impl(
297
  progress(1.0, desc="Done!")
298
  return (
299
  gallery_images,
300
- gr.update(value=None, visible=False), # files hidden
301
- gr.update(value=output_zip_path, visible=True), # zip shown
302
  summary,
303
  )
304
  else:
@@ -306,8 +303,8 @@ def _process_images_impl(
306
  progress(1.0, desc="Done!")
307
  return (
308
  gallery_images,
309
- gr.update(value=enhanced_paths, visible=True), # files shown
310
- gr.update(value=None, visible=False), # zip hidden
311
  summary,
312
  )
313
 
@@ -320,30 +317,19 @@ def _process_images_impl(
320
  )
321
 
322
  finally:
323
- # Clean ONLY input extraction temp.
324
  if os.path.exists(temp_dir):
325
  shutil.rmtree(temp_dir, ignore_errors=True)
326
- # Do NOT delete output_dir; downloads need it to exist.
327
 
328
 
329
  # =========================
330
- # IMPORTANT: Define a real @spaces.GPU function at import-time on Spaces
331
- # (This fixes: "No @spaces.GPU function detected during startup")
332
  # =========================
333
- if SPACES_AVAILABLE:
334
-
335
- @spaces.GPU(duration=180)
336
- def process_images(
337
- files, prompt, guidance_scale, num_steps, progress=gr.Progress()
338
- ):
339
- return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
340
-
341
- else:
342
-
343
- def process_images(
344
- files, prompt, guidance_scale, num_steps, progress=gr.Progress()
345
- ):
346
- return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
347
 
348
 
349
  # =========================
@@ -364,9 +350,9 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
364
  - Upload **only images** → download enhanced **image files directly** (`*_flux` suffix)
365
  - Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix)
366
 
367
- **Runtime detection:**
368
- - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
369
- - spaces module available: **{SPACES_AVAILABLE}**
370
  """
371
  )
372
 
 
18
  from skimage.util import img_as_float
19
 
20
 
21
+ # =========================
22
+ # Make @spaces.GPU ALWAYS exist (critical for ZeroGPU startup scan)
23
+ # =========================
24
+ try:
25
+ import spaces # Hugging Face Spaces provides this
26
+ except Exception:
27
+
28
+ class _DummySpaces:
29
+ @staticmethod
30
+ def GPU(duration: int = 180, **kwargs):
31
+ # No-op decorator for local runtime
32
+ def _decorator(fn):
33
+ return fn
34
+
35
+ return _decorator
36
+
37
+ spaces = _DummySpaces()
38
+
39
+
40
+ def _is_hf_spaces_env() -> bool:
41
+ # Best-effort env detection (informational only)
42
+ return (os.getenv("SYSTEM", "").lower() == "spaces") or bool(
43
+ os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID")
44
+ )
45
+
46
+
47
+ IS_HF_SPACES = _is_hf_spaces_env()
48
+
49
+
50
  # =========================
51
  # Config
52
  # =========================
 
65
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
66
  TORCH_DTYPE = torch.bfloat16
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Global cached pipeline
 
69
  _pipe = None
70
 
71
 
 
96
  def extract_archive(archive_path: str, extract_to: str):
97
  """Extract zip or 7z archive."""
98
  file_ext = Path(archive_path).suffix.lower()
 
99
  if file_ext == ".zip":
100
  with zipfile.ZipFile(archive_path, "r") as z:
101
  z.extractall(extract_to)
 
119
  def _get_pipe():
120
  """
121
  Lazy-load the pipeline.
122
+ IMPORTANT:
123
+ - On HF ZeroGPU: must be called inside a @spaces.GPU-decorated function runtime.
124
+ - Locally: uses local CUDA.
125
  """
126
  global _pipe
127
 
 
130
  raise RuntimeError(
131
  "CUDA GPU is not available in the current runtime. "
132
  "This bnb-4bit model requires GPU. "
133
+ "On HF ZeroGPU, ensure inference is inside a @spaces.GPU function."
134
  )
135
 
136
  _pipe = Flux2Pipeline.from_pretrained(
 
146
  files, prompt, guidance_scale, num_steps, progress=gr.Progress()
147
  ):
148
  """
149
+ Shared implementation.
 
 
150
  Returns 4 outputs:
151
  gallery, files_download, zip_download, summary
152
  """
 
167
  # Temp for extraction/staging input
168
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
169
 
170
+ # Output MUST remain for Gradio downloads
 
171
  run_id = uuid.uuid4().hex[:10]
172
  output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
173
 
 
176
  try:
177
  progress(0.0, desc="Preparing files...")
178
 
179
+ all_images = [] # (img_path, rel_path)
180
  for file_obj in files:
181
  file_path = file_obj.name if hasattr(file_obj, "name") else str(file_obj)
182
  file_ext = Path(file_path).suffix.lower()
 
184
  if file_ext in [".zip", ".7z"]:
185
  has_archive = True
186
  progress(0.05, desc=f"Extracting: {Path(file_path).name} ...")
187
+
188
  extract_dir = os.path.join(temp_dir, Path(file_path).stem)
189
  os.makedirs(extract_dir, exist_ok=True)
190
  extract_archive(file_path, extract_dir)
 
192
  images = find_images(extract_dir)
193
  for img_path in images:
194
  rel_path = os.path.relpath(img_path, extract_dir)
195
+ all_images.append((img_path, rel_path))
196
 
197
  elif file_ext in IMAGE_EXTENSIONS:
198
+ all_images.append((file_path, Path(file_path).name))
199
 
200
  if not all_images:
201
  return (
 
208
  total_images = len(all_images)
209
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
210
 
211
+ pipe = _get_pipe()
212
 
213
  results = []
214
  metrics_lines = []
215
 
216
  progress(0.15, desc="Enhancing images...")
217
 
218
+ for idx, (img_path, rel_path) in enumerate(all_images):
219
  progress(
220
  0.15 + 0.75 * (idx / max(1, total_images)),
221
  desc=f"Processing {idx+1}/{total_images}: {Path(img_path).name}",
 
232
 
233
  psnr, ssim = calculate_psnr_ssim(input_image, enhanced_image)
234
 
235
+ out_path = os.path.join(output_dir, rel_path)
 
 
236
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
237
 
238
+ # add _flux suffix
239
  out_name = Path(out_path).stem + "_flux" + Path(out_path).suffix
240
  out_path = os.path.join(os.path.dirname(out_path), out_name)
241
 
 
245
  {
246
  "original": input_image,
247
  "enhanced": enhanced_image,
248
+ "filename": rel_path,
249
  "output_path": out_path,
250
  "psnr": psnr,
251
  "ssim": ssim,
252
  }
253
  )
254
 
255
+ metrics_lines.append(f"{rel_path}: PSNR={psnr:.2f} dB, SSIM={ssim:.4f}")
 
 
256
 
257
  avg_psnr = float(np.mean([r["psnr"] for r in results])) if results else 0.0
258
  avg_ssim = float(np.mean([r["ssim"] for r in results])) if results else 0.0
259
 
260
  summary = (
261
  "✅ Processing completed!\n\n"
262
+ f"Environment: {'Hugging Face Spaces' if IS_HF_SPACES else 'Local/Unknown'}\n"
263
  f"GPU available: {torch.cuda.is_available()}\n\n"
264
  f"Total images processed: {total_images}\n"
265
  f"Average PSNR: {avg_psnr:.2f} dB\n"
 
278
  )
279
  )
280
 
 
 
 
281
  if has_archive:
282
  progress(0.92, desc="Packaging ZIP...")
283
 
284
  output_zip_path = os.path.join(
285
  tempfile.gettempdir(), f"enhanced_images_flux_{run_id}.zip"
286
  )
 
287
  with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
288
  for root, _, fs in os.walk(output_dir):
289
  for f in fs:
 
294
  progress(1.0, desc="Done!")
295
  return (
296
  gallery_images,
297
+ gr.update(value=None, visible=False),
298
+ gr.update(value=output_zip_path, visible=True),
299
  summary,
300
  )
301
  else:
 
303
  progress(1.0, desc="Done!")
304
  return (
305
  gallery_images,
306
+ gr.update(value=enhanced_paths, visible=True),
307
+ gr.update(value=None, visible=False),
308
  summary,
309
  )
310
 
 
317
  )
318
 
319
  finally:
320
+ # Cleanup input temp only
321
  if os.path.exists(temp_dir):
322
  shutil.rmtree(temp_dir, ignore_errors=True)
323
+ # DO NOT delete output_dir; needed for downloads
324
 
325
 
326
  # =========================
327
+ # CRITICAL: Always define a @spaces.GPU function at top-level
328
+ # (ZeroGPU startup scanner will now ALWAYS find it)
329
  # =========================
330
+ @spaces.GPU(duration=180)
331
+ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
332
+ return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
 
 
 
 
 
 
 
 
 
 
 
333
 
334
 
335
  # =========================
 
350
  - Upload **only images** → download enhanced **image files directly** (`*_flux` suffix)
351
  - Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix)
352
 
353
+ **Runtime info (informational):**
354
+ - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local/Unknown"}**
355
+ - CUDA visible at startup: **{torch.cuda.is_available()}**
356
  """
357
  )
358