| import math |
| import random |
| from enum import Enum |
|
|
| import gradio as gr |
| import numpy as np |
| import safetensors.torch as sf |
| import torch |
| from diffusers import ( |
| AutoencoderKL, |
| DPMSolverMultistepScheduler, |
| StableDiffusionImg2ImgPipeline, |
| StableDiffusionPipeline, |
| UNet2DConditionModel, |
| ) |
| from diffusers.models.attention_processor import AttnProcessor2_0 |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
| from briarmbg import BriaRMBG |
|
|
| try: |
| import spaces |
| except ImportError: |
| class spaces: |
| @staticmethod |
| def GPU(duration=120): |
| def decorator(fn): |
| return fn |
|
|
| return decorator |
|
|
|
|
| BASE_MODEL = "stablediffusionapi/realistic-vision-v51" |
| ICLIGHT_REPO = "lllyasviel/ic-light" |
| MODEL_FILE = "iclight_sd15_fc.safetensors" |
| NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, cropped, worst quality" |
| ADDED_PROMPT = "best quality" |
|
|
| _ENGINE = None |
|
|
|
|
| class BGSource(Enum): |
| NONE = "None" |
| LEFT = "Left Light" |
| RIGHT = "Right Light" |
| TOP = "Top Light" |
| BOTTOM = "Bottom Light" |
|
|
|
|
| def ensure_rgb(image): |
| if image is None: |
| raise gr.Error("Upload an image first.") |
|
|
| if isinstance(image, Image.Image): |
| return np.array(image.convert("RGB")) |
|
|
| if image.ndim == 2: |
| image = np.stack([image, image, image], axis=-1) |
|
|
| if image.shape[-1] == 4: |
| image = np.array(Image.fromarray(image).convert("RGB")) |
|
|
| return image[:, :, :3].astype(np.uint8) |
|
|
|
|
| def resize_and_center_crop(image, target_width, target_height): |
| pil_image = Image.fromarray(image) |
| original_width, original_height = pil_image.size |
| scale_factor = max(target_width / original_width, target_height / original_height) |
| resized_width = int(round(original_width * scale_factor)) |
| resized_height = int(round(original_height * scale_factor)) |
| resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) |
| left = (resized_width - target_width) / 2 |
| top = (resized_height - target_height) / 2 |
| right = (resized_width + target_width) / 2 |
| bottom = (resized_height + target_height) / 2 |
| return np.array(resized_image.crop((left, top, right, bottom))) |
|
|
|
|
| def resize_without_crop(image, target_width, target_height): |
| return np.array(Image.fromarray(image).resize((target_width, target_height), Image.LANCZOS)) |
|
|
|
|
| def numpy2pytorch(imgs): |
| h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 |
| return h.movedim(-1, 1) |
|
|
|
|
| def pytorch2numpy(imgs): |
| results = [] |
| for x in imgs: |
| y = x.movedim(0, -1) |
| y = y * 127.5 + 127.5 |
| y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) |
| results.append(y) |
| return results |
|
|
|
|
| class ICLightEngine: |
| def __init__(self): |
| if not torch.cuda.is_available(): |
| raise gr.Error("IC-Light inference requires a CUDA GPU. On Hugging Face, enable ZeroGPU hardware.") |
|
|
| self.device = torch.device("cuda") |
| self.tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder") |
| self.vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae") |
| self.unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet") |
| self.rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4") |
|
|
| self._patch_unet_input() |
| self._load_iclight_weights() |
| self._move_to_gpu() |
| self._build_pipelines() |
|
|
| def _patch_unet_input(self): |
| with torch.no_grad(): |
| new_conv_in = torch.nn.Conv2d( |
| 8, |
| self.unet.conv_in.out_channels, |
| self.unet.conv_in.kernel_size, |
| self.unet.conv_in.stride, |
| self.unet.conv_in.padding, |
| ) |
| new_conv_in.weight.zero_() |
| new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) |
| new_conv_in.bias = self.unet.conv_in.bias |
| self.unet.conv_in = new_conv_in |
|
|
| unet_original_forward = self.unet.forward |
|
|
| def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): |
| c_concat = kwargs["cross_attention_kwargs"]["concat_conds"].to(sample) |
| c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) |
| new_sample = torch.cat([sample, c_concat], dim=1) |
| kwargs["cross_attention_kwargs"] = {} |
| return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) |
|
|
| self.unet.forward = hooked_unet_forward |
|
|
| def _load_iclight_weights(self): |
| model_path = hf_hub_download(ICLIGHT_REPO, MODEL_FILE) |
| sd_offset = sf.load_file(model_path, device="cpu") |
| sd_origin = self.unet.state_dict() |
| sd_merged = { |
| key: sd_origin[key] + sd_offset[key].to(dtype=sd_origin[key].dtype) |
| for key in sd_origin.keys() |
| } |
| self.unet.load_state_dict(sd_merged, strict=True) |
| del sd_offset, sd_origin, sd_merged |
|
|
| def _move_to_gpu(self): |
| self.text_encoder = self.text_encoder.to(device=self.device, dtype=torch.float16) |
| self.vae = self.vae.to(device=self.device, dtype=torch.bfloat16) |
| self.unet = self.unet.to(device=self.device, dtype=torch.float16) |
| self.rmbg = self.rmbg.to(device=self.device, dtype=torch.float32) |
| self.unet.set_attn_processor(AttnProcessor2_0()) |
| self.vae.set_attn_processor(AttnProcessor2_0()) |
|
|
| def _build_pipelines(self): |
| scheduler = DPMSolverMultistepScheduler( |
| num_train_timesteps=1000, |
| beta_start=0.00085, |
| beta_end=0.012, |
| algorithm_type="sde-dpmsolver++", |
| use_karras_sigmas=True, |
| steps_offset=1, |
| ) |
| pipe_kwargs = dict( |
| vae=self.vae, |
| text_encoder=self.text_encoder, |
| tokenizer=self.tokenizer, |
| unet=self.unet, |
| scheduler=scheduler, |
| safety_checker=None, |
| requires_safety_checker=False, |
| feature_extractor=None, |
| image_encoder=None, |
| ) |
| self.t2i_pipe = StableDiffusionPipeline(**pipe_kwargs) |
| self.i2i_pipe = StableDiffusionImg2ImgPipeline(**pipe_kwargs) |
|
|
| @torch.inference_mode() |
| def encode_prompt_inner(self, txt): |
| max_length = self.tokenizer.model_max_length |
| chunk_length = self.tokenizer.model_max_length - 2 |
| id_start = self.tokenizer.bos_token_id |
| id_end = self.tokenizer.eos_token_id |
| id_pad = id_end |
|
|
| def pad(x, p, i): |
| return x[:i] if len(x) >= i else x + [p] * (i - len(x)) |
|
|
| tokens = self.tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] |
| chunks = [ |
| [id_start] + tokens[i: i + chunk_length] + [id_end] |
| for i in range(0, len(tokens), chunk_length) |
| ] |
| chunks = [pad(chunk, id_pad, max_length) for chunk in chunks] |
|
|
| token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64) |
| return self.text_encoder(token_ids).last_hidden_state |
|
|
| @torch.inference_mode() |
| def encode_prompt_pair(self, positive_prompt, negative_prompt): |
| c = self.encode_prompt_inner(positive_prompt) |
| uc = self.encode_prompt_inner(negative_prompt) |
|
|
| c_len = float(len(c)) |
| uc_len = float(len(uc)) |
| max_count = max(c_len, uc_len) |
| c_repeat = int(math.ceil(max_count / c_len)) |
| uc_repeat = int(math.ceil(max_count / uc_len)) |
| max_chunk = max(len(c), len(uc)) |
|
|
| c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] |
| uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] |
|
|
| c = torch.cat([p[None, ...] for p in c], dim=1) |
| uc = torch.cat([p[None, ...] for p in uc], dim=1) |
|
|
| return c, uc |
|
|
| @torch.inference_mode() |
| def run_rmbg(self, img): |
| height, width, channels = img.shape |
| if channels != 3: |
| raise gr.Error("Input image must be RGB.") |
|
|
| k = (256.0 / float(height * width)) ** 0.5 |
| feed = resize_without_crop(img, int(64 * round(width * k)), int(64 * round(height * k))) |
| feed = numpy2pytorch([feed]).to(device=self.device, dtype=torch.float32) |
| alpha = self.rmbg(feed)[0][0] |
| alpha = torch.nn.functional.interpolate(alpha, size=(height, width), mode="bilinear") |
| alpha = alpha.movedim(1, -1)[0] |
| alpha = alpha.detach().float().cpu().numpy().clip(0, 1) |
| result = 127 + (img.astype(np.float32) - 127) * alpha |
| return result.clip(0, 255).astype(np.uint8) |
|
|
| def make_initial_background(self, bg_source, image_width, image_height): |
| bg_source = BGSource(bg_source) |
| if bg_source == BGSource.NONE: |
| return None |
| if bg_source == BGSource.LEFT: |
| gradient = np.linspace(255, 0, image_width) |
| image = np.tile(gradient, (image_height, 1)) |
| elif bg_source == BGSource.RIGHT: |
| gradient = np.linspace(0, 255, image_width) |
| image = np.tile(gradient, (image_height, 1)) |
| elif bg_source == BGSource.TOP: |
| gradient = np.linspace(255, 0, image_height)[:, None] |
| image = np.tile(gradient, (1, image_width)) |
| elif bg_source == BGSource.BOTTOM: |
| gradient = np.linspace(0, 255, image_height)[:, None] |
| image = np.tile(gradient, (1, image_width)) |
| else: |
| raise gr.Error("Invalid lighting preference.") |
|
|
| return np.stack((image,) * 3, axis=-1).astype(np.uint8) |
|
|
| @torch.inference_mode() |
| def relight( |
| self, |
| input_fg, |
| prompt, |
| image_width, |
| image_height, |
| num_samples, |
| seed, |
| steps, |
| cfg, |
| highres_scale, |
| highres_denoise, |
| lowres_denoise, |
| bg_source, |
| ): |
| input_fg = ensure_rgb(input_fg) |
| input_fg = self.run_rmbg(input_fg) |
| input_bg = self.make_initial_background(bg_source, image_width, image_height) |
|
|
| if seed is None or int(seed) < 0: |
| seed = random.randint(0, 2**31 - 1) |
|
|
| rng = torch.Generator(device=self.device).manual_seed(int(seed)) |
| fg = resize_and_center_crop(input_fg, image_width, image_height) |
|
|
| concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype) |
| concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor |
|
|
| conds, unconds = self.encode_prompt_pair( |
| positive_prompt=f"{prompt}, {ADDED_PROMPT}", |
| negative_prompt=NEGATIVE_PROMPT, |
| ) |
|
|
| if input_bg is None: |
| latents = self.t2i_pipe( |
| prompt_embeds=conds, |
| negative_prompt_embeds=unconds, |
| width=image_width, |
| height=image_height, |
| num_inference_steps=steps, |
| num_images_per_prompt=num_samples, |
| generator=rng, |
| output_type="latent", |
| guidance_scale=cfg, |
| cross_attention_kwargs={"concat_conds": concat_conds}, |
| ).images.to(self.vae.dtype) / self.vae.config.scaling_factor |
| else: |
| bg = resize_and_center_crop(input_bg, image_width, image_height) |
| bg_latent = numpy2pytorch([bg]).to(device=self.vae.device, dtype=self.vae.dtype) |
| bg_latent = self.vae.encode(bg_latent).latent_dist.mode() * self.vae.config.scaling_factor |
| latents = self.i2i_pipe( |
| image=bg_latent, |
| strength=lowres_denoise, |
| prompt_embeds=conds, |
| negative_prompt_embeds=unconds, |
| width=image_width, |
| height=image_height, |
| num_inference_steps=int(round(steps / lowres_denoise)), |
| num_images_per_prompt=num_samples, |
| generator=rng, |
| output_type="latent", |
| guidance_scale=cfg, |
| cross_attention_kwargs={"concat_conds": concat_conds}, |
| ).images.to(self.vae.dtype) / self.vae.config.scaling_factor |
|
|
| pixels = self.vae.decode(latents).sample |
| pixels = pytorch2numpy(pixels) |
| highres_width = int(round(image_width * highres_scale / 64.0) * 64) |
| highres_height = int(round(image_height * highres_scale / 64.0) * 64) |
| pixels = [ |
| resize_without_crop(image=p, target_width=highres_width, target_height=highres_height) |
| for p in pixels |
| ] |
|
|
| pixels = numpy2pytorch(pixels).to(device=self.vae.device, dtype=self.vae.dtype) |
| latents = self.vae.encode(pixels).latent_dist.mode() * self.vae.config.scaling_factor |
| latents = latents.to(device=self.unet.device, dtype=self.unet.dtype) |
|
|
| image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8 |
| fg = resize_and_center_crop(input_fg, image_width, image_height) |
| concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype) |
| concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor |
|
|
| latents = self.i2i_pipe( |
| image=latents, |
| strength=highres_denoise, |
| prompt_embeds=conds, |
| negative_prompt_embeds=unconds, |
| width=image_width, |
| height=image_height, |
| num_inference_steps=int(round(steps / highres_denoise)), |
| num_images_per_prompt=num_samples, |
| generator=rng, |
| output_type="latent", |
| guidance_scale=cfg, |
| cross_attention_kwargs={"concat_conds": concat_conds}, |
| ).images.to(self.vae.dtype) / self.vae.config.scaling_factor |
|
|
| pixels = self.vae.decode(latents).sample |
| return input_fg, pytorch2numpy(pixels) |
|
|
|
|
| def get_engine(): |
| global _ENGINE |
| if _ENGINE is None: |
| _ENGINE = ICLightEngine() |
| return _ENGINE |
|
|
|
|
| @spaces.GPU(duration=180) |
| def generate( |
| image, |
| prompt, |
| lighting, |
| width, |
| height, |
| samples, |
| seed, |
| steps, |
| cfg, |
| highres_scale, |
| highres_denoise, |
| lowres_denoise, |
| ): |
| if not prompt or not prompt.strip(): |
| raise gr.Error("Enter a prompt.") |
|
|
| engine = get_engine() |
| return engine.relight( |
| image, |
| prompt.strip(), |
| int(width), |
| int(height), |
| int(samples), |
| int(seed), |
| int(steps), |
| float(cfg), |
| float(highres_scale), |
| float(highres_denoise), |
| float(lowres_denoise), |
| lighting, |
| ) |
|
|
|
|
| quick_prompts = [ |
| ["beautiful woman, detailed face, sunshine from window"], |
| ["handsome man, detailed face, neon light, city"], |
| ["portrait, cinematic lighting"], |
| ["product photo, soft studio lighting"], |
| ["character art, dramatic light and shadow"], |
| ] |
|
|
|
|
| with gr.Blocks(title="IC-Light Relighting") as demo: |
| gr.Markdown("## IC-Light Relighting") |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(sources=["upload"], type="numpy", label="Image", height=440) |
| prompt = gr.Textbox(label="Prompt", value="portrait, cinematic lighting") |
| lighting = gr.Radio( |
| choices=[e.value for e in BGSource], |
| value=BGSource.NONE.value, |
| label="Lighting Preference", |
| ) |
| prompt_examples = gr.Dataset( |
| samples=quick_prompts, |
| label="Prompt Quick List", |
| components=[prompt], |
| samples_per_page=20, |
| ) |
| prompt_examples.click( |
| lambda x: x[0], |
| inputs=prompt_examples, |
| outputs=prompt, |
| show_progress=False, |
| queue=False, |
| ) |
| run_button = gr.Button("Relight", variant="primary") |
|
|
| with gr.Row(): |
| samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) |
| seed = gr.Number(label="Seed", value=12345, precision=0) |
|
|
| with gr.Row(): |
| width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64) |
| height = gr.Slider(label="Height", minimum=256, maximum=1024, value=640, step=64) |
|
|
| with gr.Accordion("Advanced", open=False): |
| steps = gr.Slider(label="Steps", minimum=1, maximum=80, value=25, step=1) |
| cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=16.0, value=2.0, step=0.1) |
| lowres_denoise = gr.Slider( |
| label="Lowres Denoise", |
| minimum=0.1, |
| maximum=1.0, |
| value=0.9, |
| step=0.01, |
| ) |
| highres_scale = gr.Slider( |
| label="Highres Scale", |
| minimum=1.0, |
| maximum=2.0, |
| value=1.5, |
| step=0.05, |
| ) |
| highres_denoise = gr.Slider( |
| label="Highres Denoise", |
| minimum=0.1, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| ) |
| with gr.Column(): |
| foreground = gr.Image(type="numpy", label="Preprocessed Foreground", height=360) |
| gallery = gr.Gallery(label="Outputs", height=720, object_fit="contain") |
|
|
| inputs = [ |
| input_image, |
| prompt, |
| lighting, |
| width, |
| height, |
| samples, |
| seed, |
| steps, |
| cfg, |
| highres_scale, |
| highres_denoise, |
| lowres_denoise, |
| ] |
| run_button.click(fn=generate, inputs=inputs, outputs=[foreground, gallery]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=20).launch(server_name="0.0.0.0") |
|
|