daKhosa commited on
Commit
2fc70fd
·
1 Parent(s): 3901f61

Replace exec stub with IC-Light app

Browse files
Files changed (4) hide show
  1. README.md +28 -1
  2. app.py +497 -1
  3. briarmbg.py +462 -0
  4. requirements.txt +3 -1
README.md CHANGED
@@ -8,6 +8,33 @@ sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  license: other
11
+ suggested_hardware: zero-a10g
12
  ---
13
 
14
+ # IC-Light Relighting
15
+
16
+ This Space replaces the previous `EXEC` environment-variable stub with a real Gradio app and a ZeroGPU-decorated inference function.
17
+
18
+ The implementation is based on the public `lllyasviel/ic-light` SD1.5 foreground-conditioned IC-Light model:
19
+
20
+ - code reference: https://github.com/lllyasviel/IC-Light
21
+ - weights: https://huggingface.co/lllyasviel/ic-light
22
+ - model file used here: `iclight_sd15_fc.safetensors`
23
+
24
+ ## V2-Vary provenance
25
+
26
+ `lllyasviel` announced IC-Light V2-Vary in GitHub discussion #109 as an alternative model for stronger illumination variations:
27
+
28
+ https://github.com/lllyasviel/IC-Light/discussions/109
29
+
30
+ The linked official Space is:
31
+
32
+ https://huggingface.co/spaces/lllyasviel/iclight-v2-vary
33
+
34
+ As of this repo update, that Space's public git tree contains only:
35
+
36
+ ```python
37
+ import os; exec(os.getenv('EXEC'))
38
+ ```
39
+
40
+ The public `lllyasviel/ic-light` model repository only exposes the SD1.5 IC-Light weights, not Flux/V2-Vary weights. Because the V2-Vary app source and weights are not public in the Space git tree or the public model repo, this Space uses the available upstream IC-Light integration rather than preserving the unsafe hidden-exec stub.
app.py CHANGED
@@ -1 +1,497 @@
1
- import os; exec(os.getenv('EXEC'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from enum import Enum
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import safetensors.torch as sf
8
+ import torch
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ DPMSolverMultistepScheduler,
12
+ StableDiffusionImg2ImgPipeline,
13
+ StableDiffusionPipeline,
14
+ UNet2DConditionModel,
15
+ )
16
+ from diffusers.models.attention_processor import AttnProcessor2_0
17
+ from huggingface_hub import hf_hub_download
18
+ from PIL import Image
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+
21
+ from briarmbg import BriaRMBG
22
+
23
+ try:
24
+ import spaces
25
+ except ImportError:
26
+ class spaces:
27
+ @staticmethod
28
+ def GPU(duration=120):
29
+ def decorator(fn):
30
+ return fn
31
+
32
+ return decorator
33
+
34
+
35
+ BASE_MODEL = "stablediffusionapi/realistic-vision-v51"
36
+ ICLIGHT_REPO = "lllyasviel/ic-light"
37
+ MODEL_FILE = "iclight_sd15_fc.safetensors"
38
+ NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, cropped, worst quality"
39
+ ADDED_PROMPT = "best quality"
40
+
41
+ _ENGINE = None
42
+
43
+
44
+ class BGSource(Enum):
45
+ NONE = "None"
46
+ LEFT = "Left Light"
47
+ RIGHT = "Right Light"
48
+ TOP = "Top Light"
49
+ BOTTOM = "Bottom Light"
50
+
51
+
52
+ def ensure_rgb(image):
53
+ if image is None:
54
+ raise gr.Error("Upload an image first.")
55
+
56
+ if isinstance(image, Image.Image):
57
+ return np.array(image.convert("RGB"))
58
+
59
+ if image.ndim == 2:
60
+ image = np.stack([image, image, image], axis=-1)
61
+
62
+ if image.shape[-1] == 4:
63
+ image = np.array(Image.fromarray(image).convert("RGB"))
64
+
65
+ return image[:, :, :3].astype(np.uint8)
66
+
67
+
68
+ def resize_and_center_crop(image, target_width, target_height):
69
+ pil_image = Image.fromarray(image)
70
+ original_width, original_height = pil_image.size
71
+ scale_factor = max(target_width / original_width, target_height / original_height)
72
+ resized_width = int(round(original_width * scale_factor))
73
+ resized_height = int(round(original_height * scale_factor))
74
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
75
+ left = (resized_width - target_width) / 2
76
+ top = (resized_height - target_height) / 2
77
+ right = (resized_width + target_width) / 2
78
+ bottom = (resized_height + target_height) / 2
79
+ return np.array(resized_image.crop((left, top, right, bottom)))
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ return np.array(Image.fromarray(image).resize((target_width, target_height), Image.LANCZOS))
84
+
85
+
86
+ def numpy2pytorch(imgs):
87
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
88
+ return h.movedim(-1, 1)
89
+
90
+
91
+ def pytorch2numpy(imgs):
92
+ results = []
93
+ for x in imgs:
94
+ y = x.movedim(0, -1)
95
+ y = y * 127.5 + 127.5
96
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
97
+ results.append(y)
98
+ return results
99
+
100
+
101
+ class ICLightEngine:
102
+ def __init__(self):
103
+ if not torch.cuda.is_available():
104
+ raise gr.Error("IC-Light inference requires a CUDA GPU. On Hugging Face, enable ZeroGPU hardware.")
105
+
106
+ self.device = torch.device("cuda")
107
+ self.tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
108
+ self.text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder")
109
+ self.vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae")
110
+ self.unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet")
111
+ self.rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
112
+
113
+ self._patch_unet_input()
114
+ self._load_iclight_weights()
115
+ self._move_to_gpu()
116
+ self._build_pipelines()
117
+
118
+ def _patch_unet_input(self):
119
+ with torch.no_grad():
120
+ new_conv_in = torch.nn.Conv2d(
121
+ 8,
122
+ self.unet.conv_in.out_channels,
123
+ self.unet.conv_in.kernel_size,
124
+ self.unet.conv_in.stride,
125
+ self.unet.conv_in.padding,
126
+ )
127
+ new_conv_in.weight.zero_()
128
+ new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
129
+ new_conv_in.bias = self.unet.conv_in.bias
130
+ self.unet.conv_in = new_conv_in
131
+
132
+ unet_original_forward = self.unet.forward
133
+
134
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
135
+ c_concat = kwargs["cross_attention_kwargs"]["concat_conds"].to(sample)
136
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
137
+ new_sample = torch.cat([sample, c_concat], dim=1)
138
+ kwargs["cross_attention_kwargs"] = {}
139
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
140
+
141
+ self.unet.forward = hooked_unet_forward
142
+
143
+ def _load_iclight_weights(self):
144
+ model_path = hf_hub_download(ICLIGHT_REPO, MODEL_FILE)
145
+ sd_offset = sf.load_file(model_path, device="cpu")
146
+ sd_origin = self.unet.state_dict()
147
+ sd_merged = {
148
+ key: sd_origin[key] + sd_offset[key].to(dtype=sd_origin[key].dtype)
149
+ for key in sd_origin.keys()
150
+ }
151
+ self.unet.load_state_dict(sd_merged, strict=True)
152
+ del sd_offset, sd_origin, sd_merged
153
+
154
+ def _move_to_gpu(self):
155
+ self.text_encoder = self.text_encoder.to(device=self.device, dtype=torch.float16)
156
+ self.vae = self.vae.to(device=self.device, dtype=torch.bfloat16)
157
+ self.unet = self.unet.to(device=self.device, dtype=torch.float16)
158
+ self.rmbg = self.rmbg.to(device=self.device, dtype=torch.float32)
159
+ self.unet.set_attn_processor(AttnProcessor2_0())
160
+ self.vae.set_attn_processor(AttnProcessor2_0())
161
+
162
+ def _build_pipelines(self):
163
+ scheduler = DPMSolverMultistepScheduler(
164
+ num_train_timesteps=1000,
165
+ beta_start=0.00085,
166
+ beta_end=0.012,
167
+ algorithm_type="sde-dpmsolver++",
168
+ use_karras_sigmas=True,
169
+ steps_offset=1,
170
+ )
171
+ pipe_kwargs = dict(
172
+ vae=self.vae,
173
+ text_encoder=self.text_encoder,
174
+ tokenizer=self.tokenizer,
175
+ unet=self.unet,
176
+ scheduler=scheduler,
177
+ safety_checker=None,
178
+ requires_safety_checker=False,
179
+ feature_extractor=None,
180
+ image_encoder=None,
181
+ )
182
+ self.t2i_pipe = StableDiffusionPipeline(**pipe_kwargs)
183
+ self.i2i_pipe = StableDiffusionImg2ImgPipeline(**pipe_kwargs)
184
+
185
+ @torch.inference_mode()
186
+ def encode_prompt_inner(self, txt):
187
+ max_length = self.tokenizer.model_max_length
188
+ chunk_length = self.tokenizer.model_max_length - 2
189
+ id_start = self.tokenizer.bos_token_id
190
+ id_end = self.tokenizer.eos_token_id
191
+ id_pad = id_end
192
+
193
+ def pad(x, p, i):
194
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
195
+
196
+ tokens = self.tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
197
+ chunks = [
198
+ [id_start] + tokens[i: i + chunk_length] + [id_end]
199
+ for i in range(0, len(tokens), chunk_length)
200
+ ]
201
+ chunks = [pad(chunk, id_pad, max_length) for chunk in chunks]
202
+
203
+ token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64)
204
+ return self.text_encoder(token_ids).last_hidden_state
205
+
206
+ @torch.inference_mode()
207
+ def encode_prompt_pair(self, positive_prompt, negative_prompt):
208
+ c = self.encode_prompt_inner(positive_prompt)
209
+ uc = self.encode_prompt_inner(negative_prompt)
210
+
211
+ c_len = float(len(c))
212
+ uc_len = float(len(uc))
213
+ max_count = max(c_len, uc_len)
214
+ c_repeat = int(math.ceil(max_count / c_len))
215
+ uc_repeat = int(math.ceil(max_count / uc_len))
216
+ max_chunk = max(len(c), len(uc))
217
+
218
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
219
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
220
+
221
+ c = torch.cat([p[None, ...] for p in c], dim=1)
222
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
223
+
224
+ return c, uc
225
+
226
+ @torch.inference_mode()
227
+ def run_rmbg(self, img):
228
+ height, width, channels = img.shape
229
+ if channels != 3:
230
+ raise gr.Error("Input image must be RGB.")
231
+
232
+ k = (256.0 / float(height * width)) ** 0.5
233
+ feed = resize_without_crop(img, int(64 * round(width * k)), int(64 * round(height * k)))
234
+ feed = numpy2pytorch([feed]).to(device=self.device, dtype=torch.float32)
235
+ alpha = self.rmbg(feed)[0][0]
236
+ alpha = torch.nn.functional.interpolate(alpha, size=(height, width), mode="bilinear")
237
+ alpha = alpha.movedim(1, -1)[0]
238
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
239
+ result = 127 + (img.astype(np.float32) - 127) * alpha
240
+ return result.clip(0, 255).astype(np.uint8)
241
+
242
+ def make_initial_background(self, bg_source, image_width, image_height):
243
+ bg_source = BGSource(bg_source)
244
+ if bg_source == BGSource.NONE:
245
+ return None
246
+ if bg_source == BGSource.LEFT:
247
+ gradient = np.linspace(255, 0, image_width)
248
+ image = np.tile(gradient, (image_height, 1))
249
+ elif bg_source == BGSource.RIGHT:
250
+ gradient = np.linspace(0, 255, image_width)
251
+ image = np.tile(gradient, (image_height, 1))
252
+ elif bg_source == BGSource.TOP:
253
+ gradient = np.linspace(255, 0, image_height)[:, None]
254
+ image = np.tile(gradient, (1, image_width))
255
+ elif bg_source == BGSource.BOTTOM:
256
+ gradient = np.linspace(0, 255, image_height)[:, None]
257
+ image = np.tile(gradient, (1, image_width))
258
+ else:
259
+ raise gr.Error("Invalid lighting preference.")
260
+
261
+ return np.stack((image,) * 3, axis=-1).astype(np.uint8)
262
+
263
+ @torch.inference_mode()
264
+ def relight(
265
+ self,
266
+ input_fg,
267
+ prompt,
268
+ image_width,
269
+ image_height,
270
+ num_samples,
271
+ seed,
272
+ steps,
273
+ cfg,
274
+ highres_scale,
275
+ highres_denoise,
276
+ lowres_denoise,
277
+ bg_source,
278
+ ):
279
+ input_fg = ensure_rgb(input_fg)
280
+ input_fg = self.run_rmbg(input_fg)
281
+ input_bg = self.make_initial_background(bg_source, image_width, image_height)
282
+
283
+ if seed is None or int(seed) < 0:
284
+ seed = random.randint(0, 2**31 - 1)
285
+
286
+ rng = torch.Generator(device=self.device).manual_seed(int(seed))
287
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
288
+
289
+ concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype)
290
+ concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor
291
+
292
+ conds, unconds = self.encode_prompt_pair(
293
+ positive_prompt=f"{prompt}, {ADDED_PROMPT}",
294
+ negative_prompt=NEGATIVE_PROMPT,
295
+ )
296
+
297
+ if input_bg is None:
298
+ latents = self.t2i_pipe(
299
+ prompt_embeds=conds,
300
+ negative_prompt_embeds=unconds,
301
+ width=image_width,
302
+ height=image_height,
303
+ num_inference_steps=steps,
304
+ num_images_per_prompt=num_samples,
305
+ generator=rng,
306
+ output_type="latent",
307
+ guidance_scale=cfg,
308
+ cross_attention_kwargs={"concat_conds": concat_conds},
309
+ ).images.to(self.vae.dtype) / self.vae.config.scaling_factor
310
+ else:
311
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
312
+ bg_latent = numpy2pytorch([bg]).to(device=self.vae.device, dtype=self.vae.dtype)
313
+ bg_latent = self.vae.encode(bg_latent).latent_dist.mode() * self.vae.config.scaling_factor
314
+ latents = self.i2i_pipe(
315
+ image=bg_latent,
316
+ strength=lowres_denoise,
317
+ prompt_embeds=conds,
318
+ negative_prompt_embeds=unconds,
319
+ width=image_width,
320
+ height=image_height,
321
+ num_inference_steps=int(round(steps / lowres_denoise)),
322
+ num_images_per_prompt=num_samples,
323
+ generator=rng,
324
+ output_type="latent",
325
+ guidance_scale=cfg,
326
+ cross_attention_kwargs={"concat_conds": concat_conds},
327
+ ).images.to(self.vae.dtype) / self.vae.config.scaling_factor
328
+
329
+ pixels = self.vae.decode(latents).sample
330
+ pixels = pytorch2numpy(pixels)
331
+ highres_width = int(round(image_width * highres_scale / 64.0) * 64)
332
+ highres_height = int(round(image_height * highres_scale / 64.0) * 64)
333
+ pixels = [
334
+ resize_without_crop(image=p, target_width=highres_width, target_height=highres_height)
335
+ for p in pixels
336
+ ]
337
+
338
+ pixels = numpy2pytorch(pixels).to(device=self.vae.device, dtype=self.vae.dtype)
339
+ latents = self.vae.encode(pixels).latent_dist.mode() * self.vae.config.scaling_factor
340
+ latents = latents.to(device=self.unet.device, dtype=self.unet.dtype)
341
+
342
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
343
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
344
+ concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype)
345
+ concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor
346
+
347
+ latents = self.i2i_pipe(
348
+ image=latents,
349
+ strength=highres_denoise,
350
+ prompt_embeds=conds,
351
+ negative_prompt_embeds=unconds,
352
+ width=image_width,
353
+ height=image_height,
354
+ num_inference_steps=int(round(steps / highres_denoise)),
355
+ num_images_per_prompt=num_samples,
356
+ generator=rng,
357
+ output_type="latent",
358
+ guidance_scale=cfg,
359
+ cross_attention_kwargs={"concat_conds": concat_conds},
360
+ ).images.to(self.vae.dtype) / self.vae.config.scaling_factor
361
+
362
+ pixels = self.vae.decode(latents).sample
363
+ return input_fg, pytorch2numpy(pixels)
364
+
365
+
366
+ def get_engine():
367
+ global _ENGINE
368
+ if _ENGINE is None:
369
+ _ENGINE = ICLightEngine()
370
+ return _ENGINE
371
+
372
+
373
+ @spaces.GPU(duration=180)
374
+ def generate(
375
+ image,
376
+ prompt,
377
+ lighting,
378
+ width,
379
+ height,
380
+ samples,
381
+ seed,
382
+ steps,
383
+ cfg,
384
+ highres_scale,
385
+ highres_denoise,
386
+ lowres_denoise,
387
+ ):
388
+ if not prompt or not prompt.strip():
389
+ raise gr.Error("Enter a prompt.")
390
+
391
+ engine = get_engine()
392
+ return engine.relight(
393
+ image,
394
+ prompt.strip(),
395
+ int(width),
396
+ int(height),
397
+ int(samples),
398
+ int(seed),
399
+ int(steps),
400
+ float(cfg),
401
+ float(highres_scale),
402
+ float(highres_denoise),
403
+ float(lowres_denoise),
404
+ lighting,
405
+ )
406
+
407
+
408
+ quick_prompts = [
409
+ ["beautiful woman, detailed face, sunshine from window"],
410
+ ["handsome man, detailed face, neon light, city"],
411
+ ["portrait, cinematic lighting"],
412
+ ["product photo, soft studio lighting"],
413
+ ["character art, dramatic light and shadow"],
414
+ ]
415
+
416
+
417
+ with gr.Blocks(title="IC-Light Relighting") as demo:
418
+ gr.Markdown("## IC-Light Relighting")
419
+ with gr.Row():
420
+ with gr.Column():
421
+ input_image = gr.Image(sources=["upload"], type="numpy", label="Image", height=440)
422
+ prompt = gr.Textbox(label="Prompt", value="portrait, cinematic lighting")
423
+ lighting = gr.Radio(
424
+ choices=[e.value for e in BGSource],
425
+ value=BGSource.NONE.value,
426
+ label="Lighting Preference",
427
+ )
428
+ prompt_examples = gr.Dataset(
429
+ samples=quick_prompts,
430
+ label="Prompt Quick List",
431
+ components=[prompt],
432
+ samples_per_page=20,
433
+ )
434
+ prompt_examples.click(
435
+ lambda x: x[0],
436
+ inputs=prompt_examples,
437
+ outputs=prompt,
438
+ show_progress=False,
439
+ queue=False,
440
+ )
441
+ run_button = gr.Button("Relight", variant="primary")
442
+
443
+ with gr.Row():
444
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
445
+ seed = gr.Number(label="Seed", value=12345, precision=0)
446
+
447
+ with gr.Row():
448
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
449
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, value=640, step=64)
450
+
451
+ with gr.Accordion("Advanced", open=False):
452
+ steps = gr.Slider(label="Steps", minimum=1, maximum=80, value=25, step=1)
453
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=16.0, value=2.0, step=0.1)
454
+ lowres_denoise = gr.Slider(
455
+ label="Lowres Denoise",
456
+ minimum=0.1,
457
+ maximum=1.0,
458
+ value=0.9,
459
+ step=0.01,
460
+ )
461
+ highres_scale = gr.Slider(
462
+ label="Highres Scale",
463
+ minimum=1.0,
464
+ maximum=2.0,
465
+ value=1.5,
466
+ step=0.05,
467
+ )
468
+ highres_denoise = gr.Slider(
469
+ label="Highres Denoise",
470
+ minimum=0.1,
471
+ maximum=1.0,
472
+ value=0.5,
473
+ step=0.01,
474
+ )
475
+ with gr.Column():
476
+ foreground = gr.Image(type="numpy", label="Preprocessed Foreground", height=360)
477
+ gallery = gr.Gallery(label="Outputs", height=720, object_fit="contain")
478
+
479
+ inputs = [
480
+ input_image,
481
+ prompt,
482
+ lighting,
483
+ width,
484
+ height,
485
+ samples,
486
+ seed,
487
+ steps,
488
+ cfg,
489
+ highres_scale,
490
+ highres_denoise,
491
+ lowres_denoise,
492
+ ]
493
+ run_button.click(fn=generate, inputs=inputs, outputs=[foreground, gallery])
494
+
495
+
496
+ if __name__ == "__main__":
497
+ demo.queue(max_size=20).launch(server_name="0.0.0.0")
briarmbg.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RMBG1.4 (diffusers implementation)
2
+ # Found on huggingface space of several projects
3
+ # Not sure which project is the source of this file
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+
11
+ class REBNCONV(nn.Module):
12
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
13
+ super(REBNCONV, self).__init__()
14
+
15
+ self.conv_s1 = nn.Conv2d(
16
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
17
+ )
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+
25
+ return xout
26
+
27
+
28
+ def _upsample_like(src, tar):
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+ return src
31
+
32
+
33
+ ### RSU-7 ###
34
+ class RSU7(nn.Module):
35
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
36
+ super(RSU7, self).__init__()
37
+
38
+ self.in_ch = in_ch
39
+ self.mid_ch = mid_ch
40
+ self.out_ch = out_ch
41
+
42
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
43
+
44
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
45
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
+
47
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
48
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
58
+
59
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
60
+
61
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
62
+
63
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
69
+
70
+ def forward(self, x):
71
+ b, c, h, w = x.shape
72
+
73
+ hx = x
74
+ hxin = self.rebnconvin(hx)
75
+
76
+ hx1 = self.rebnconv1(hxin)
77
+ hx = self.pool1(hx1)
78
+
79
+ hx2 = self.rebnconv2(hx)
80
+ hx = self.pool2(hx2)
81
+
82
+ hx3 = self.rebnconv3(hx)
83
+ hx = self.pool3(hx3)
84
+
85
+ hx4 = self.rebnconv4(hx)
86
+ hx = self.pool4(hx4)
87
+
88
+ hx5 = self.rebnconv5(hx)
89
+ hx = self.pool5(hx5)
90
+
91
+ hx6 = self.rebnconv6(hx)
92
+
93
+ hx7 = self.rebnconv7(hx6)
94
+
95
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
96
+ hx6dup = _upsample_like(hx6d, hx5)
97
+
98
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
99
+ hx5dup = _upsample_like(hx5d, hx4)
100
+
101
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
102
+ hx4dup = _upsample_like(hx4d, hx3)
103
+
104
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
105
+ hx3dup = _upsample_like(hx3d, hx2)
106
+
107
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
108
+ hx2dup = _upsample_like(hx2d, hx1)
109
+
110
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
166
+ hx5dup = _upsample_like(hx5d, hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
169
+ hx4dup = _upsample_like(hx4d, hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
172
+ hx3dup = _upsample_like(hx3d, hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
175
+ hx2dup = _upsample_like(hx2d, hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
178
+
179
+ return hx1d + hxin
180
+
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5, self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
206
+
207
+ def forward(self, x):
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
226
+ hx4dup = _upsample_like(hx4d, hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
229
+ hx3dup = _upsample_like(hx3d, hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
232
+ hx2dup = _upsample_like(hx2d, hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
235
+
236
+ return hx1d + hxin
237
+
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4, self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+
273
+ hx4 = self.rebnconv4(hx3)
274
+
275
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
276
+ hx3dup = _upsample_like(hx3d, hx2)
277
+
278
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
279
+ hx2dup = _upsample_like(hx2d, hx1)
280
+
281
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
282
+
283
+ return hx1d + hxin
284
+
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
289
+ super(RSU4F, self).__init__()
290
+
291
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
292
+
293
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
294
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
295
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
296
+
297
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
298
+
299
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
300
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
301
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
302
+
303
+ def forward(self, x):
304
+ hx = x
305
+
306
+ hxin = self.rebnconvin(hx)
307
+
308
+ hx1 = self.rebnconv1(hxin)
309
+ hx2 = self.rebnconv2(hx1)
310
+ hx3 = self.rebnconv3(hx2)
311
+
312
+ hx4 = self.rebnconv4(hx3)
313
+
314
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
315
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
316
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
317
+
318
+ return hx1d + hxin
319
+
320
+
321
+ class myrebnconv(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1,
331
+ ):
332
+ super(myrebnconv, self).__init__()
333
+
334
+ self.conv = nn.Conv2d(
335
+ in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups,
342
+ )
343
+ self.bn = nn.BatchNorm2d(out_ch)
344
+ self.rl = nn.ReLU(inplace=True)
345
+
346
+ def forward(self, x):
347
+ return self.rl(self.bn(self.conv(x)))
348
+
349
+
350
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
351
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
352
+ super(BriaRMBG, self).__init__()
353
+ in_ch = config["in_ch"]
354
+ out_ch = config["out_ch"]
355
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
356
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64, 32, 64)
359
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64, 32, 128)
362
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128, 64, 256)
365
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256, 128, 512)
368
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512, 256, 512)
371
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512, 256, 512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024, 256, 512)
377
+ self.stage4d = RSU4(1024, 128, 256)
378
+ self.stage3d = RSU5(512, 64, 128)
379
+ self.stage2d = RSU6(256, 32, 64)
380
+ self.stage1d = RSU7(128, 16, 64)
381
+
382
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
383
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
384
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
385
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
386
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
387
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self, x):
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ # hx = self.pool_in(hxin)
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+ d1 = _upsample_like(d1, x)
439
+
440
+ d2 = self.side2(hx2d)
441
+ d2 = _upsample_like(d2, x)
442
+
443
+ d3 = self.side3(hx3d)
444
+ d3 = _upsample_like(d3, x)
445
+
446
+ d4 = self.side4(hx4d)
447
+ d4 = _upsample_like(d4, x)
448
+
449
+ d5 = self.side5(hx5d)
450
+ d5 = _upsample_like(d5, x)
451
+
452
+ d6 = self.side6(hx6)
453
+ d6 = _upsample_like(d6, x)
454
+
455
+ return [
456
+ F.sigmoid(d1),
457
+ F.sigmoid(d2),
458
+ F.sigmoid(d3),
459
+ F.sigmoid(d4),
460
+ F.sigmoid(d5),
461
+ F.sigmoid(d6),
462
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
requirements.txt CHANGED
@@ -10,5 +10,7 @@ safetensors
10
  pillow
11
  einops
12
  peft
13
- pyzipper
 
 
14
  python-multipart==0.0.12
 
10
  pillow
11
  einops
12
  peft
13
+ protobuf==3.20.*
14
+ huggingface_hub
15
+ spaces
16
  python-multipart==0.0.12