ICLight / app.py
daKhosa
Replace exec stub with IC-Light app
2fc70fd
Raw
History Blame Contribute Delete
17.9 kB
import math
import random
from enum import Enum
import gradio as gr
import numpy as np
import safetensors.torch as sf
import torch
from diffusers import (
AutoencoderKL,
DPMSolverMultistepScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from briarmbg import BriaRMBG
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=120):
def decorator(fn):
return fn
return decorator
BASE_MODEL = "stablediffusionapi/realistic-vision-v51"
ICLIGHT_REPO = "lllyasviel/ic-light"
MODEL_FILE = "iclight_sd15_fc.safetensors"
NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, cropped, worst quality"
ADDED_PROMPT = "best quality"
_ENGINE = None
class BGSource(Enum):
NONE = "None"
LEFT = "Left Light"
RIGHT = "Right Light"
TOP = "Top Light"
BOTTOM = "Bottom Light"
def ensure_rgb(image):
if image is None:
raise gr.Error("Upload an image first.")
if isinstance(image, Image.Image):
return np.array(image.convert("RGB"))
if image.ndim == 2:
image = np.stack([image, image, image], axis=-1)
if image.shape[-1] == 4:
image = np.array(Image.fromarray(image).convert("RGB"))
return image[:, :, :3].astype(np.uint8)
def resize_and_center_crop(image, target_width, target_height):
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
return np.array(resized_image.crop((left, top, right, bottom)))
def resize_without_crop(image, target_width, target_height):
return np.array(Image.fromarray(image).resize((target_width, target_height), Image.LANCZOS))
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
return h.movedim(-1, 1)
def pytorch2numpy(imgs):
results = []
for x in imgs:
y = x.movedim(0, -1)
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
results.append(y)
return results
class ICLightEngine:
def __init__(self):
if not torch.cuda.is_available():
raise gr.Error("IC-Light inference requires a CUDA GPU. On Hugging Face, enable ZeroGPU hardware.")
self.device = torch.device("cuda")
self.tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder")
self.vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae")
self.unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet")
self.rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
self._patch_unet_input()
self._load_iclight_weights()
self._move_to_gpu()
self._build_pipelines()
def _patch_unet_input(self):
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(
8,
self.unet.conv_in.out_channels,
self.unet.conv_in.kernel_size,
self.unet.conv_in.stride,
self.unet.conv_in.padding,
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
new_conv_in.bias = self.unet.conv_in.bias
self.unet.conv_in = new_conv_in
unet_original_forward = self.unet.forward
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
c_concat = kwargs["cross_attention_kwargs"]["concat_conds"].to(sample)
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
new_sample = torch.cat([sample, c_concat], dim=1)
kwargs["cross_attention_kwargs"] = {}
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
self.unet.forward = hooked_unet_forward
def _load_iclight_weights(self):
model_path = hf_hub_download(ICLIGHT_REPO, MODEL_FILE)
sd_offset = sf.load_file(model_path, device="cpu")
sd_origin = self.unet.state_dict()
sd_merged = {
key: sd_origin[key] + sd_offset[key].to(dtype=sd_origin[key].dtype)
for key in sd_origin.keys()
}
self.unet.load_state_dict(sd_merged, strict=True)
del sd_offset, sd_origin, sd_merged
def _move_to_gpu(self):
self.text_encoder = self.text_encoder.to(device=self.device, dtype=torch.float16)
self.vae = self.vae.to(device=self.device, dtype=torch.bfloat16)
self.unet = self.unet.to(device=self.device, dtype=torch.float16)
self.rmbg = self.rmbg.to(device=self.device, dtype=torch.float32)
self.unet.set_attn_processor(AttnProcessor2_0())
self.vae.set_attn_processor(AttnProcessor2_0())
def _build_pipelines(self):
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
steps_offset=1,
)
pipe_kwargs = dict(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None,
)
self.t2i_pipe = StableDiffusionPipeline(**pipe_kwargs)
self.i2i_pipe = StableDiffusionImg2ImgPipeline(**pipe_kwargs)
@torch.inference_mode()
def encode_prompt_inner(self, txt):
max_length = self.tokenizer.model_max_length
chunk_length = self.tokenizer.model_max_length - 2
id_start = self.tokenizer.bos_token_id
id_end = self.tokenizer.eos_token_id
id_pad = id_end
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
tokens = self.tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
chunks = [
[id_start] + tokens[i: i + chunk_length] + [id_end]
for i in range(0, len(tokens), chunk_length)
]
chunks = [pad(chunk, id_pad, max_length) for chunk in chunks]
token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64)
return self.text_encoder(token_ids).last_hidden_state
@torch.inference_mode()
def encode_prompt_pair(self, positive_prompt, negative_prompt):
c = self.encode_prompt_inner(positive_prompt)
uc = self.encode_prompt_inner(negative_prompt)
c_len = float(len(c))
uc_len = float(len(uc))
max_count = max(c_len, uc_len)
c_repeat = int(math.ceil(max_count / c_len))
uc_repeat = int(math.ceil(max_count / uc_len))
max_chunk = max(len(c), len(uc))
c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
c = torch.cat([p[None, ...] for p in c], dim=1)
uc = torch.cat([p[None, ...] for p in uc], dim=1)
return c, uc
@torch.inference_mode()
def run_rmbg(self, img):
height, width, channels = img.shape
if channels != 3:
raise gr.Error("Input image must be RGB.")
k = (256.0 / float(height * width)) ** 0.5
feed = resize_without_crop(img, int(64 * round(width * k)), int(64 * round(height * k)))
feed = numpy2pytorch([feed]).to(device=self.device, dtype=torch.float32)
alpha = self.rmbg(feed)[0][0]
alpha = torch.nn.functional.interpolate(alpha, size=(height, width), mode="bilinear")
alpha = alpha.movedim(1, -1)[0]
alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
result = 127 + (img.astype(np.float32) - 127) * alpha
return result.clip(0, 255).astype(np.uint8)
def make_initial_background(self, bg_source, image_width, image_height):
bg_source = BGSource(bg_source)
if bg_source == BGSource.NONE:
return None
if bg_source == BGSource.LEFT:
gradient = np.linspace(255, 0, image_width)
image = np.tile(gradient, (image_height, 1))
elif bg_source == BGSource.RIGHT:
gradient = np.linspace(0, 255, image_width)
image = np.tile(gradient, (image_height, 1))
elif bg_source == BGSource.TOP:
gradient = np.linspace(255, 0, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
elif bg_source == BGSource.BOTTOM:
gradient = np.linspace(0, 255, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
else:
raise gr.Error("Invalid lighting preference.")
return np.stack((image,) * 3, axis=-1).astype(np.uint8)
@torch.inference_mode()
def relight(
self,
input_fg,
prompt,
image_width,
image_height,
num_samples,
seed,
steps,
cfg,
highres_scale,
highres_denoise,
lowres_denoise,
bg_source,
):
input_fg = ensure_rgb(input_fg)
input_fg = self.run_rmbg(input_fg)
input_bg = self.make_initial_background(bg_source, image_width, image_height)
if seed is None or int(seed) < 0:
seed = random.randint(0, 2**31 - 1)
rng = torch.Generator(device=self.device).manual_seed(int(seed))
fg = resize_and_center_crop(input_fg, image_width, image_height)
concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype)
concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor
conds, unconds = self.encode_prompt_pair(
positive_prompt=f"{prompt}, {ADDED_PROMPT}",
negative_prompt=NEGATIVE_PROMPT,
)
if input_bg is None:
latents = self.t2i_pipe(
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=steps,
num_images_per_prompt=num_samples,
generator=rng,
output_type="latent",
guidance_scale=cfg,
cross_attention_kwargs={"concat_conds": concat_conds},
).images.to(self.vae.dtype) / self.vae.config.scaling_factor
else:
bg = resize_and_center_crop(input_bg, image_width, image_height)
bg_latent = numpy2pytorch([bg]).to(device=self.vae.device, dtype=self.vae.dtype)
bg_latent = self.vae.encode(bg_latent).latent_dist.mode() * self.vae.config.scaling_factor
latents = self.i2i_pipe(
image=bg_latent,
strength=lowres_denoise,
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=int(round(steps / lowres_denoise)),
num_images_per_prompt=num_samples,
generator=rng,
output_type="latent",
guidance_scale=cfg,
cross_attention_kwargs={"concat_conds": concat_conds},
).images.to(self.vae.dtype) / self.vae.config.scaling_factor
pixels = self.vae.decode(latents).sample
pixels = pytorch2numpy(pixels)
highres_width = int(round(image_width * highres_scale / 64.0) * 64)
highres_height = int(round(image_height * highres_scale / 64.0) * 64)
pixels = [
resize_without_crop(image=p, target_width=highres_width, target_height=highres_height)
for p in pixels
]
pixels = numpy2pytorch(pixels).to(device=self.vae.device, dtype=self.vae.dtype)
latents = self.vae.encode(pixels).latent_dist.mode() * self.vae.config.scaling_factor
latents = latents.to(device=self.unet.device, dtype=self.unet.dtype)
image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
fg = resize_and_center_crop(input_fg, image_width, image_height)
concat_conds = numpy2pytorch([fg]).to(device=self.vae.device, dtype=self.vae.dtype)
concat_conds = self.vae.encode(concat_conds).latent_dist.mode() * self.vae.config.scaling_factor
latents = self.i2i_pipe(
image=latents,
strength=highres_denoise,
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=int(round(steps / highres_denoise)),
num_images_per_prompt=num_samples,
generator=rng,
output_type="latent",
guidance_scale=cfg,
cross_attention_kwargs={"concat_conds": concat_conds},
).images.to(self.vae.dtype) / self.vae.config.scaling_factor
pixels = self.vae.decode(latents).sample
return input_fg, pytorch2numpy(pixels)
def get_engine():
global _ENGINE
if _ENGINE is None:
_ENGINE = ICLightEngine()
return _ENGINE
@spaces.GPU(duration=180)
def generate(
image,
prompt,
lighting,
width,
height,
samples,
seed,
steps,
cfg,
highres_scale,
highres_denoise,
lowres_denoise,
):
if not prompt or not prompt.strip():
raise gr.Error("Enter a prompt.")
engine = get_engine()
return engine.relight(
image,
prompt.strip(),
int(width),
int(height),
int(samples),
int(seed),
int(steps),
float(cfg),
float(highres_scale),
float(highres_denoise),
float(lowres_denoise),
lighting,
)
quick_prompts = [
["beautiful woman, detailed face, sunshine from window"],
["handsome man, detailed face, neon light, city"],
["portrait, cinematic lighting"],
["product photo, soft studio lighting"],
["character art, dramatic light and shadow"],
]
with gr.Blocks(title="IC-Light Relighting") as demo:
gr.Markdown("## IC-Light Relighting")
with gr.Row():
with gr.Column():
input_image = gr.Image(sources=["upload"], type="numpy", label="Image", height=440)
prompt = gr.Textbox(label="Prompt", value="portrait, cinematic lighting")
lighting = gr.Radio(
choices=[e.value for e in BGSource],
value=BGSource.NONE.value,
label="Lighting Preference",
)
prompt_examples = gr.Dataset(
samples=quick_prompts,
label="Prompt Quick List",
components=[prompt],
samples_per_page=20,
)
prompt_examples.click(
lambda x: x[0],
inputs=prompt_examples,
outputs=prompt,
show_progress=False,
queue=False,
)
run_button = gr.Button("Relight", variant="primary")
with gr.Row():
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
seed = gr.Number(label="Seed", value=12345, precision=0)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=640, step=64)
with gr.Accordion("Advanced", open=False):
steps = gr.Slider(label="Steps", minimum=1, maximum=80, value=25, step=1)
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=16.0, value=2.0, step=0.1)
lowres_denoise = gr.Slider(
label="Lowres Denoise",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.01,
)
highres_scale = gr.Slider(
label="Highres Scale",
minimum=1.0,
maximum=2.0,
value=1.5,
step=0.05,
)
highres_denoise = gr.Slider(
label="Highres Denoise",
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.01,
)
with gr.Column():
foreground = gr.Image(type="numpy", label="Preprocessed Foreground", height=360)
gallery = gr.Gallery(label="Outputs", height=720, object_fit="contain")
inputs = [
input_image,
prompt,
lighting,
width,
height,
samples,
seed,
steps,
cfg,
highres_scale,
highres_denoise,
lowres_denoise,
]
run_button.click(fn=generate, inputs=inputs, outputs=[foreground, gallery])
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0")