import math import random from enum import Enum import gradio as gr import numpy as np import safetensors.torch as sf import torch from diffusers import ( AutoencoderKL, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.models.attention_processor import AttnProcessor2_0 from huggingface_hub import hf_hub_download from PIL import Image from transformers import CLIPTextModel, CLIPTokenizer from briarmbg import BriaRMBG try: import spaces except ImportError: class spaces: @staticmethod def GPU(duration=120): def decorator(fn): return fn return decorator BASE_MODEL = "stablediffusionapi/realistic-vision-v51" ICLIGHT_REPO = "lllyasviel/ic-light" MODEL_FILE = "iclight_sd15_fc.safetensors" NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, cropped, worst quality" ADDED_PROMPT = "best quality" _ENGINE = None class BGSource(Enum): NONE = "None" LEFT = "Left Light" RIGHT = "Right Light" TOP = "Top Light" BOTTOM = "Bottom Light" def ensure_rgb(image): if image is None: raise gr.Error("Upload an image first.") if isinstance(image, Image.Image): return np.array(image.convert("RGB")) if image.ndim == 2: image = np.stack([image, image, image], axis=-1) if image.shape[-1] == 4: image = np.array(Image.fromarray(image).convert("RGB")) return image[:, :, :3].astype(np.uint8) def resize_and_center_crop(image, target_width, target_height): pil_image = Image.fromarray(image) original_width, original_height = pil_image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 return np.array(resized_image.crop((left, top, right, bottom))) def resize_without_crop(image, target_width, target_height): return np.array(Image.fromarray(image).resize((target_width, target_height), Image.LANCZOS)) def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 return h.movedim(-1, 1) def pytorch2numpy(imgs): results = [] for x in imgs: y = x.movedim(0, -1) y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) results.append(y) return results class ICLightEngine: def __init__(self): if not torch.cuda.is_available(): raise gr.Error("IC-Light inference requires a CUDA GPU. On Hugging Face, enable ZeroGPU hardware.") self.device = torch.device("cuda") self.tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder") self.vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae") self.unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet") self.rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4") self._patch_unet_input() self._load_iclight_weights() self._move_to_gpu() self._build_pipelines() def _patch_unet_input(self): with torch.no_grad(): new_conv_in = torch.nn.Conv2d( 8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding, ) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) new_conv_in.bias = self.unet.conv_in.bias self.unet.conv_in = new_conv_in unet_original_forward = self.unet.forward def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): c_concat = kwargs["cross_attention_kwargs"]["concat_conds"].to(sample) c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) new_sample = torch.cat([sample, c_concat], dim=1) kwargs["cross_attention_kwargs"] = {} return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) self.unet.forward = hooked_unet_forward def _load_iclight_weights(self): model_path = hf_hub_download(ICLIGHT_REPO, MODEL_FILE) sd_offset = sf.load_file(model_path, device="cpu") sd_origin = self.unet.state_dict() sd_merged = { key: sd_origin[key] + sd_offset[key].to(dtype=sd_origin[key].dtype) for key in sd_origin.keys() } self.unet.load_state_dict(sd_merged, strict=True) del sd_offset, sd_origin, sd_merged def _move_to_gpu(self): self.text_encoder = self.text_encoder.to(device=self.device, dtype=torch.float16) self.vae = self.vae.to(device=self.device, dtype=torch.bfloat16) self.unet = self.unet.to(device=self.device, dtype=torch.float16) self.rmbg = self.rmbg.to(device=self.device, dtype=torch.float32) self.unet.set_attn_processor(AttnProcessor2_0()) self.vae.set_attn_processor(AttnProcessor2_0()) def _build_pipelines(self): scheduler = DPMSolverMultistepScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True, steps_offset=1, ) pipe_kwargs = dict( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, feature_extractor=None, image_encoder=None, ) self.t2i_pipe = StableDiffusionPipeline(**pipe_kwargs) self.i2i_pipe = StableDiffusionImg2ImgPipeline(**pipe_kwargs) @torch.inference_mode() def encode_prompt_inner(self, txt): max_length = self.tokenizer.model_max_length chunk_length = self.tokenizer.model_max_length - 2 id_start = self.tokenizer.bos_token_id id_end = self.tokenizer.eos_token_id id_pad = id_end def pad(x, p, i): return x[:i] if len(x) >= i else x + [p] * (i - len(x)) tokens = self.tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] chunks = [ [id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length) ] chunks = [pad(chunk, id_pad, max_length) for chunk in chunks] token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64) return self.text_encoder(token_ids).last_hidden_state @torch.inference_mode() def encode_prompt_pair(self, positive_prompt, negative_prompt): c = self.encode_prompt_inner(positive_prompt) uc = self.encode_prompt_inner(negative_prompt) c_len = float(len(c)) uc_len = float(len(uc)) max_count = max(c_len, uc_len) c_repeat = int(math.ceil(max_count / c_len)) uc_repeat = int(math.ceil(max_count / uc_len)) max_chunk = max(len(c), len(uc)) c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] c = torch.cat([p[None, ...] for p in c], dim=1) uc = torch.cat([p[None, ...] for p in uc], dim=1) return c, uc @torch.inference_mode() def run_rmbg(self, img): height, width, channels = img.shape if channels != 3: raise gr.Error("Input image must be RGB.") k = (256.0 / float(height * width)) ** 0.5 feed = resize_without_crop(img, int(64 * round(width * k)), int(64 * round(height * k))) feed = numpy2pytorch([feed]).to(device=self.device, dtype=torch.float32) alpha = self.rmbg(feed)[0][0] alpha = torch.nn.functional.interpolate(alpha, size=(height, width), mode="bilinear") alpha = alpha.movedim(1, -1)[0] alpha = alpha.detach().float().cpu().numpy().clip(0, 1) result = 127 + (img.astype(np.float32) - 127) * alpha return result.clip(0, 255).astype(np.uint8) def make_initial_background(self, bg_source, image_width, image_height): bg_source = BGSource(bg_source) if bg_source == BGSource.NONE: return None if bg_source == BGSource.LEFT: gradient = np.linspace(255, 0, image_width) image = np.tile(gradient, (image_height, 1)) elif bg_source == BGSource.RIGHT: gradient = np.linspace(0, 255, image_width) image = np.tile(gradient, (image_height, 1)) elif bg_source == BGSource.TOP: gradient = np.linspace(255, 0, image_height)[:, None] image = np.tile(gradient, (1, image_width)) elif bg_source == BGSource.BOTTOM: gradient = np.linspace(0, 255, image_height)[:, None] image = np.tile(gradient, (1, image_width)) else: raise gr.Error("Invalid lighting preference.") return np.stack((image,) * 3, axis=-1).astype(np.uint8) @torch.inference_mode() def relight( self, input_fg, prompt, image_width, image_height, num_samples, seed, steps, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source, ): input_fg = ensure_rgb(input_fg) input_fg = self.run_rmbg(input_fg) input_bg = self.make_initial_background(bg_source, image_width, image_height) if seed is None or int(seed) < 0: seed = random.randint(0, 2**31 - 1) rng = torch.Generator(device=self.device).manual_seed(int(seed)) fg = resize_and_center_crop(input_fg, image_width, image_height) concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype) concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor conds, unconds = self.encode_prompt_pair( positive_prompt=f"{prompt}, {ADDED_PROMPT}", negative_prompt=NEGATIVE_PROMPT, ) if input_bg is None: latents = self.t2i_pipe( prompt_embeds=conds, negative_prompt_embeds=unconds, width=image_width, height=image_height, num_inference_steps=steps, num_images_per_prompt=num_samples, generator=rng, output_type="latent", guidance_scale=cfg, cross_attention_kwargs={"concat_conds": concat_conds}, ).images.to(self.vae.dtype) / self.vae.config.scaling_factor else: bg = resize_and_center_crop(input_bg, image_width, image_height) bg_latent = numpy2pytorch([bg]).to(device=self.vae.device, dtype=self.vae.dtype) bg_latent = self.vae.encode(bg_latent).latent_dist.mode() * self.vae.config.scaling_factor latents = self.i2i_pipe( image=bg_latent, strength=lowres_denoise, prompt_embeds=conds, negative_prompt_embeds=unconds, width=image_width, height=image_height, num_inference_steps=int(round(steps / lowres_denoise)), num_images_per_prompt=num_samples, generator=rng, output_type="latent", guidance_scale=cfg, cross_attention_kwargs={"concat_conds": concat_conds}, ).images.to(self.vae.dtype) / self.vae.config.scaling_factor pixels = self.vae.decode(latents).sample pixels = pytorch2numpy(pixels) highres_width = int(round(image_width * highres_scale / 64.0) * 64) highres_height = int(round(image_height * highres_scale / 64.0) * 64) pixels = [ resize_without_crop(image=p, target_width=highres_width, target_height=highres_height) for p in pixels ] pixels = numpy2pytorch(pixels).to(device=self.vae.device, dtype=self.vae.dtype) latents = self.vae.encode(pixels).latent_dist.mode() * self.vae.config.scaling_factor latents = latents.to(device=self.unet.device, dtype=self.unet.dtype) image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8 fg = resize_and_center_crop(input_fg, image_width, image_height) concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype) concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor latents = self.i2i_pipe( image=latents, strength=highres_denoise, prompt_embeds=conds, negative_prompt_embeds=unconds, width=image_width, height=image_height, num_inference_steps=int(round(steps / highres_denoise)), num_images_per_prompt=num_samples, generator=rng, output_type="latent", guidance_scale=cfg, cross_attention_kwargs={"concat_conds": concat_conds}, ).images.to(self.vae.dtype) / self.vae.config.scaling_factor pixels = self.vae.decode(latents).sample return input_fg, pytorch2numpy(pixels) def get_engine(): global _ENGINE if _ENGINE is None: _ENGINE = ICLightEngine() return _ENGINE @spaces.GPU(duration=180) def generate( image, prompt, lighting, width, height, samples, seed, steps, cfg, highres_scale, highres_denoise, lowres_denoise, ): if not prompt or not prompt.strip(): raise gr.Error("Enter a prompt.") engine = get_engine() return engine.relight( image, prompt.strip(), int(width), int(height), int(samples), int(seed), int(steps), float(cfg), float(highres_scale), float(highres_denoise), float(lowres_denoise), lighting, ) quick_prompts = [ ["beautiful woman, detailed face, sunshine from window"], ["handsome man, detailed face, neon light, city"], ["portrait, cinematic lighting"], ["product photo, soft studio lighting"], ["character art, dramatic light and shadow"], ] with gr.Blocks(title="IC-Light Relighting") as demo: gr.Markdown("## IC-Light Relighting") with gr.Row(): with gr.Column(): input_image = gr.Image(sources=["upload"], type="numpy", label="Image", height=440) prompt = gr.Textbox(label="Prompt", value="portrait, cinematic lighting") lighting = gr.Radio( choices=[e.value for e in BGSource], value=BGSource.NONE.value, label="Lighting Preference", ) prompt_examples = gr.Dataset( samples=quick_prompts, label="Prompt Quick List", components=[prompt], samples_per_page=20, ) prompt_examples.click( lambda x: x[0], inputs=prompt_examples, outputs=prompt, show_progress=False, queue=False, ) run_button = gr.Button("Relight", variant="primary") with gr.Row(): samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) seed = gr.Number(label="Seed", value=12345, precision=0) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64) height = gr.Slider(label="Height", minimum=256, maximum=1024, value=640, step=64) with gr.Accordion("Advanced", open=False): steps = gr.Slider(label="Steps", minimum=1, maximum=80, value=25, step=1) cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=16.0, value=2.0, step=0.1) lowres_denoise = gr.Slider( label="Lowres Denoise", minimum=0.1, maximum=1.0, value=0.9, step=0.01, ) highres_scale = gr.Slider( label="Highres Scale", minimum=1.0, maximum=2.0, value=1.5, step=0.05, ) highres_denoise = gr.Slider( label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01, ) with gr.Column(): foreground = gr.Image(type="numpy", label="Preprocessed Foreground", height=360) gallery = gr.Gallery(label="Outputs", height=720, object_fit="contain") inputs = [ input_image, prompt, lighting, width, height, samples, seed, steps, cfg, highres_scale, highres_denoise, lowres_denoise, ] run_button.click(fn=generate, inputs=inputs, outputs=[foreground, gallery]) if __name__ == "__main__": demo.queue(max_size=20).launch(server_name="0.0.0.0")