yichuan-huang commited on
Commit
386592b
·
1 Parent(s): 8a7a521

Refactor HF Spaces detection and GPU handling; streamline process_images function for better clarity and maintainability

Browse files
Files changed (1) hide show
  1. app.py +60 -60
app.py CHANGED
@@ -32,49 +32,28 @@ NUM_INFERENCE_STEPS = 30
32
 
33
  IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
34
 
35
- # NOTE:
36
- # - This is the quantized 4-bit (bitsandbytes) model, which REQUIRES GPU at load time.
37
- # - On HF Spaces ZeroGPU, you MUST only load it inside a @spaces.GPU function.
38
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
39
  TORCH_DTYPE = torch.bfloat16
40
 
41
 
42
  # =========================
43
- # HF Spaces env detection + safe "spaces" decorator
44
  # =========================
45
- def _is_hf_space_env() -> bool:
46
- """
47
- Detect whether we're running on Hugging Face Spaces runtime.
48
- Common env vars present in Spaces:
49
- - SPACE_ID
50
- - HF_SPACE_ID
51
- - SYSTEM=spaces
52
- """
53
- return any(os.getenv(k) for k in ("SPACE_ID", "HF_SPACE_ID")) or (
54
- os.getenv("SYSTEM", "").lower() == "spaces"
55
- )
56
-
57
-
58
- IS_HF_SPACES = _is_hf_space_env()
59
-
60
  try:
61
- import spaces # available on HF Spaces
 
 
62
  except Exception:
63
  spaces = None
 
64
 
65
-
66
- def gpu_decorator(duration: int = 180):
67
- """
68
- If on HF Spaces and spaces.GPU exists -> use it.
69
- Else -> no-op decorator (local runs normally, using local GPU if available).
70
- """
71
- if IS_HF_SPACES and (spaces is not None) and hasattr(spaces, "GPU"):
72
- return spaces.GPU(duration=duration)
73
-
74
- def _noop(fn):
75
- return fn
76
-
77
- return _noop
78
 
79
 
80
  # =========================
@@ -134,15 +113,17 @@ def find_images(directory: str):
134
  def _get_pipe():
135
  """
136
  Lazy-load the pipeline.
137
- - On HF Spaces ZeroGPU: MUST be called inside a @spaces.GPU runtime (gpu_decorator handles this).
138
- - Locally: will use local GPU.
139
  """
140
  global _pipe
141
 
142
  if _pipe is None:
143
  if not torch.cuda.is_available():
144
  raise RuntimeError(
145
- "No CUDA GPU detected. This 4-bit bnb model requires GPU to load/run."
 
 
146
  )
147
 
148
  _pipe = Flux2Pipeline.from_pretrained(
@@ -154,14 +135,15 @@ def _get_pipe():
154
  return _pipe
155
 
156
 
157
- @gpu_decorator(duration=180)
158
- def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progress()):
 
159
  """
160
- Process uploaded files (images or archives) and return:
161
- - gallery preview (first 10 pairs)
162
- - files download (when only images uploaded)
163
- - zip download (when archive uploaded)
164
- - summary text
165
  """
166
  if not files:
167
  return (
@@ -177,12 +159,11 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
177
  guidance_scale = float(guidance_scale)
178
  num_steps = int(num_steps)
179
 
180
- # Temp for extraction / staging input
181
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
182
 
183
  # IMPORTANT:
184
- # Result files MUST remain on disk for Gradio download.
185
- # So we create a persistent temp dir and DO NOT delete it in finally.
186
  run_id = uuid.uuid4().hex[:10]
187
  output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
188
 
@@ -222,8 +203,7 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
222
  total_images = len(all_images)
223
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
224
 
225
- # Load pipeline (inside GPU runtime on Spaces; local GPU otherwise)
226
- pipe = _get_pipe()
227
 
228
  results = []
229
  metrics_lines = []
@@ -249,7 +229,6 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
249
 
250
  # Preserve structure if from archive
251
  output_rel_path = rel_path
252
-
253
  out_path = os.path.join(output_dir, output_rel_path)
254
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
255
 
@@ -287,7 +266,7 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
287
  "Individual metrics:\n" + "\n".join(metrics_lines)
288
  )
289
 
290
- # Build gallery preview (first 10 results -> original+enhanced)
291
  gallery_images = []
292
  for r in results[:10]:
293
  gallery_images.append((r["original"], f"Original: {r['filename']}"))
@@ -298,9 +277,9 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
298
  )
299
  )
300
 
301
- # Decide download output:
302
- # - If user uploaded any archive -> provide ZIP
303
- # - Else -> provide enhanced files directly
304
  if has_archive:
305
  progress(0.92, desc="Packaging ZIP...")
306
 
@@ -341,10 +320,30 @@ def process_images(files, prompt, guidance_scale, num_steps, progress=gr.Progres
341
  )
342
 
343
  finally:
344
- # Cleanup ONLY input/extraction temp dir.
345
- # DO NOT delete output_dir because Gradio downloads need the files to remain.
346
  if os.path.exists(temp_dir):
347
  shutil.rmtree(temp_dir, ignore_errors=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
 
350
  # =========================
@@ -362,11 +361,12 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
362
  - Archives: ZIP, 7Z (will process all images inside)
363
 
364
  **Download behavior:**
365
- - If you upload **only images** → you can download the enhanced **image files directly** (`*_flux` suffix)
366
- - If you upload **a ZIP/7Z** → you can download **one ZIP** (images inside use `*_flux` suffix)
367
 
368
  **Runtime detection:**
369
  - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
 
370
  """
371
  )
372
 
@@ -443,11 +443,11 @@ Upload microscopy images (individual files or compressed archives) for AI-powere
443
  - **Inference Steps**: 30 (balanced quality and speed)
444
 
445
  ### Quality Metrics
446
- - **PSNR** (Peak Signal-to-Noise Ratio): Higher is better
447
- - **SSIM** (Structural Similarity Index): Closer to 1.0 is better
448
  """
449
  )
450
 
451
  if __name__ == "__main__":
452
- demo.queue() # recommended for Spaces
453
  demo.launch(share=False)
 
32
 
33
  IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
34
 
35
+ # Quantized 4-bit model (requires GPU)
 
 
36
  MODEL_ID = "diffusers/FLUX.2-dev-bnb-4bit"
37
  TORCH_DTYPE = torch.bfloat16
38
 
39
 
40
  # =========================
41
+ # HF Spaces detection (robust)
42
  # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
+ import spaces # Only available on Hugging Face Spaces
45
+
46
+ SPACES_AVAILABLE = True
47
  except Exception:
48
  spaces = None
49
+ SPACES_AVAILABLE = False
50
 
51
+ # Extra friendly label for UI (not used for logic)
52
+ IS_HF_SPACES = (
53
+ SPACES_AVAILABLE
54
+ or (os.getenv("SYSTEM", "").lower() == "spaces")
55
+ or bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))
56
+ )
 
 
 
 
 
 
 
57
 
58
 
59
  # =========================
 
113
  def _get_pipe():
114
  """
115
  Lazy-load the pipeline.
116
+ - On HF ZeroGPU: this MUST be called inside a @spaces.GPU function runtime.
117
+ - Locally: uses local CUDA GPU.
118
  """
119
  global _pipe
120
 
121
  if _pipe is None:
122
  if not torch.cuda.is_available():
123
  raise RuntimeError(
124
+ "CUDA GPU is not available in the current runtime. "
125
+ "This bnb-4bit model requires GPU. "
126
+ "On HF ZeroGPU, ensure the function is decorated with @spaces.GPU."
127
  )
128
 
129
  _pipe = Flux2Pipeline.from_pretrained(
 
135
  return _pipe
136
 
137
 
138
+ def _process_images_impl(
139
+ files, prompt, guidance_scale, num_steps, progress=gr.Progress()
140
+ ):
141
  """
142
+ Shared implementation used by both:
143
+ - HF Spaces GPU wrapper
144
+ - Local runtime
145
+ Returns 4 outputs:
146
+ gallery, files_download, zip_download, summary
147
  """
148
  if not files:
149
  return (
 
159
  guidance_scale = float(guidance_scale)
160
  num_steps = int(num_steps)
161
 
162
+ # Temp for extraction/staging input
163
  temp_dir = tempfile.mkdtemp(prefix="flux_in_")
164
 
165
  # IMPORTANT:
166
+ # Output files MUST remain on disk for Gradio downloads.
 
167
  run_id = uuid.uuid4().hex[:10]
168
  output_dir = tempfile.mkdtemp(prefix=f"flux_results_{run_id}_")
169
 
 
203
  total_images = len(all_images)
204
  progress(0.10, desc=f"Found {total_images} images. Loading model...")
205
 
206
+ pipe = _get_pipe() # IMPORTANT: must be inside GPU runtime on Spaces
 
207
 
208
  results = []
209
  metrics_lines = []
 
229
 
230
  # Preserve structure if from archive
231
  output_rel_path = rel_path
 
232
  out_path = os.path.join(output_dir, output_rel_path)
233
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
234
 
 
266
  "Individual metrics:\n" + "\n".join(metrics_lines)
267
  )
268
 
269
+ # Gallery preview
270
  gallery_images = []
271
  for r in results[:10]:
272
  gallery_images.append((r["original"], f"Original: {r['filename']}"))
 
277
  )
278
  )
279
 
280
+ # Download behavior:
281
+ # - If any archive uploaded -> zip
282
+ # - Else -> direct files list
283
  if has_archive:
284
  progress(0.92, desc="Packaging ZIP...")
285
 
 
320
  )
321
 
322
  finally:
323
+ # Clean ONLY input extraction temp.
 
324
  if os.path.exists(temp_dir):
325
  shutil.rmtree(temp_dir, ignore_errors=True)
326
+ # Do NOT delete output_dir; downloads need it to exist.
327
+
328
+
329
+ # =========================
330
+ # IMPORTANT: Define a real @spaces.GPU function at import-time on Spaces
331
+ # (This fixes: "No @spaces.GPU function detected during startup")
332
+ # =========================
333
+ if SPACES_AVAILABLE:
334
+
335
+ @spaces.GPU(duration=180)
336
+ def process_images(
337
+ files, prompt, guidance_scale, num_steps, progress=gr.Progress()
338
+ ):
339
+ return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
340
+
341
+ else:
342
+
343
+ def process_images(
344
+ files, prompt, guidance_scale, num_steps, progress=gr.Progress()
345
+ ):
346
+ return _process_images_impl(files, prompt, guidance_scale, num_steps, progress)
347
 
348
 
349
  # =========================
 
361
  - Archives: ZIP, 7Z (will process all images inside)
362
 
363
  **Download behavior:**
364
+ - Upload **only images** → download enhanced **image files directly** (`*_flux` suffix)
365
+ - Upload **ZIP/7Z** → download **one ZIP** (images inside use `*_flux` suffix)
366
 
367
  **Runtime detection:**
368
  - Detected environment: **{"Hugging Face Spaces" if IS_HF_SPACES else "Local"}**
369
+ - spaces module available: **{SPACES_AVAILABLE}**
370
  """
371
  )
372
 
 
443
  - **Inference Steps**: 30 (balanced quality and speed)
444
 
445
  ### Quality Metrics
446
+ - **PSNR**: Higher is better
447
+ - **SSIM**: Closer to 1.0 is better
448
  """
449
  )
450
 
451
  if __name__ == "__main__":
452
+ demo.queue()
453
  demo.launch(share=False)