Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -34,13 +34,18 @@ import gradio as gr
|
|
| 34 |
|
| 35 |
SD_INPAINT_ID = "runwayml/stable-diffusion-inpainting"
|
| 36 |
CONTROLNET_ID = "lllyasviel/control_v11f1p_sd15_depth" # -> hr16/ControlNet-HandRefiner-pruned for best
|
|
|
|
|
|
|
| 37 |
MESHGRAPHORMER_ID = "hr16/ControlNet-HandRefiner-pruned"
|
| 38 |
MAX_SIDE = 768
|
|
|
|
| 39 |
DEFAULT_PROMPT = "a detailed, anatomically correct hand with five fingers, natural proportions, same art style and lighting"
|
| 40 |
NEG = "extra fingers, fused fingers, missing fingers, deformed, mutated, blurry, low quality"
|
|
|
|
| 41 |
|
| 42 |
_PIPE = None
|
| 43 |
_MESH = None
|
|
|
|
| 44 |
|
| 45 |
def _load():
|
| 46 |
"""Load on CPU at import time. Models are moved to GPU inside the @spaces.GPU call,
|
|
@@ -74,6 +79,32 @@ try:
|
|
| 74 |
except Exception as _e:
|
| 75 |
print("[load] preload deferred:", _e, flush=True)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def _fit(img):
|
| 78 |
w, h = img.size
|
| 79 |
s = min(1.0, MAX_SIDE / max(w, h))
|
|
@@ -126,19 +157,66 @@ def fix_hands(image, mask_layers, prompt, strength):
|
|
| 126 |
print("[fix] ERROR:\n" + traceback.format_exc(), flush=True)
|
| 127 |
raise gr.Error(f"Fix failed: {e}")
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
demo.queue().launch()
|
|
|
|
| 34 |
|
| 35 |
SD_INPAINT_ID = "runwayml/stable-diffusion-inpainting"
|
| 36 |
CONTROLNET_ID = "lllyasviel/control_v11f1p_sd15_depth" # -> hr16/ControlNet-HandRefiner-pruned for best
|
| 37 |
+
TILE_CN_ID = "lllyasviel/control_v11f1e_sd15_tile" # detail-regeneration ControlNet
|
| 38 |
+
SD_BASE_ID = "runwayml/stable-diffusion-v1-5" # base SD for img2img detail pass
|
| 39 |
MESHGRAPHORMER_ID = "hr16/ControlNet-HandRefiner-pruned"
|
| 40 |
MAX_SIDE = 768
|
| 41 |
+
DETAIL_MAX_SIDE = 1280 # detail pass can work larger since it's tiled-friendly
|
| 42 |
DEFAULT_PROMPT = "a detailed, anatomically correct hand with five fingers, natural proportions, same art style and lighting"
|
| 43 |
NEG = "extra fingers, fused fingers, missing fingers, deformed, mutated, blurry, low quality"
|
| 44 |
+
DETAIL_NEG = "blurry, soft, out of focus, jpeg artifacts, low quality, smudged, messy lines"
|
| 45 |
|
| 46 |
_PIPE = None
|
| 47 |
_MESH = None
|
| 48 |
+
_DETAIL = None
|
| 49 |
|
| 50 |
def _load():
|
| 51 |
"""Load on CPU at import time. Models are moved to GPU inside the @spaces.GPU call,
|
|
|
|
| 79 |
except Exception as _e:
|
| 80 |
print("[load] preload deferred:", _e, flush=True)
|
| 81 |
|
| 82 |
+
def _load_detail():
|
| 83 |
+
"""Tile-ControlNet img2img pipeline for detail/lineart recovery. Loaded lazily on CPU."""
|
| 84 |
+
global _DETAIL
|
| 85 |
+
if _DETAIL is not None:
|
| 86 |
+
return
|
| 87 |
+
import time
|
| 88 |
+
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
|
| 89 |
+
t0 = time.time()
|
| 90 |
+
print("[load] detail pipeline (tile CN) on CPU…", flush=True)
|
| 91 |
+
tile = ControlNetModel.from_pretrained(TILE_CN_ID, torch_dtype=torch.float16)
|
| 92 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 93 |
+
SD_BASE_ID, controlnet=tile, torch_dtype=torch.float16, safety_checker=None
|
| 94 |
+
)
|
| 95 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 96 |
+
try: pipe.enable_attention_slicing()
|
| 97 |
+
except Exception as e: print("[load] attn-slicing skip:", e, flush=True)
|
| 98 |
+
try: pipe.enable_vae_tiling()
|
| 99 |
+
except Exception as e: print("[load] vae-tiling skip:", e, flush=True)
|
| 100 |
+
_DETAIL = pipe
|
| 101 |
+
print(f"[load] detail pipeline ready ({time.time()-t0:.0f}s)", flush=True)
|
| 102 |
+
|
| 103 |
+
def _fit_to(img, max_side):
|
| 104 |
+
w, h = img.size
|
| 105 |
+
s = min(1.0, max_side / max(w, h))
|
| 106 |
+
return img.resize((max(8, int(round(w*s/8))*8), max(8, int(round(h*s/8))*8)), Image.LANCZOS), (w, h)
|
| 107 |
+
|
| 108 |
def _fit(img):
|
| 109 |
w, h = img.size
|
| 110 |
s = min(1.0, MAX_SIDE / max(w, h))
|
|
|
|
| 157 |
print("[fix] ERROR:\n" + traceback.format_exc(), flush=True)
|
| 158 |
raise gr.Error(f"Fix failed: {e}")
|
| 159 |
|
| 160 |
+
@spaces.GPU(duration=120)
|
| 161 |
+
def detail_pass(image, strength, scale):
|
| 162 |
+
"""Detail/lineart recovery via Tile-ControlNet img2img at low denoise.
|
| 163 |
+
Regenerates real detail and clean lines while preserving composition + style.
|
| 164 |
+
No prompt is used (per ControlNet-tile guidance) so it can't redraw the subject."""
|
| 165 |
+
import time, traceback
|
| 166 |
+
if image is None:
|
| 167 |
+
raise gr.Error("Upload an image first.")
|
| 168 |
+
try:
|
| 169 |
+
t0 = time.time()
|
| 170 |
+
_load_detail()
|
| 171 |
+
_DETAIL.to("cuda")
|
| 172 |
+
src = image["background"] if isinstance(image, dict) else image
|
| 173 |
+
src = src.convert("RGB")
|
| 174 |
+
# optionally enlarge first (Lanczos) — the model then fills in real detail at the higher res
|
| 175 |
+
scale = float(scale)
|
| 176 |
+
if scale > 1.01:
|
| 177 |
+
src = src.resize((int(src.width*scale), int(src.height*scale)), Image.LANCZOS)
|
| 178 |
+
work, (ow, oh) = _fit_to(src, DETAIL_MAX_SIDE)
|
| 179 |
+
print(f"[detail] working at {work.size}, denoise={strength}", flush=True)
|
| 180 |
+
# tile controlnet uses the image itself as the control signal
|
| 181 |
+
out = _DETAIL(
|
| 182 |
+
prompt="", negative_prompt=DETAIL_NEG,
|
| 183 |
+
image=work, control_image=work,
|
| 184 |
+
num_inference_steps=30, strength=float(strength),
|
| 185 |
+
guidance_scale=6.0, controlnet_conditioning_scale=1.0,
|
| 186 |
+
).images[0]
|
| 187 |
+
if out.size != (ow, oh):
|
| 188 |
+
out = out.resize((ow, oh), Image.LANCZOS)
|
| 189 |
+
print(f"[detail] done, total {time.time()-t0:.0f}s", flush=True)
|
| 190 |
+
return out
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print("[detail] ERROR:\n" + traceback.format_exc(), flush=True)
|
| 193 |
+
raise gr.Error(f"Detail pass failed: {e}")
|
| 194 |
+
|
| 195 |
+
with gr.Blocks(title="DARKROOM", theme=gr.themes.Base()) as demo:
|
| 196 |
+
gr.Markdown("## 🎨 DARKROOM\nAI-art repair on GPU. **Fix hands** regenerates malformed hands "
|
| 197 |
+
"with correct geometry. **Add detail** uses Tile-ControlNet img2img to recover real "
|
| 198 |
+
"sharpness and clean lineart while keeping your original style.")
|
| 199 |
+
with gr.Tab("Fix hands"):
|
| 200 |
+
with gr.Row():
|
| 201 |
+
with gr.Column():
|
| 202 |
+
inp = gr.ImageMask(type="pil", label="Image (optionally paint over the bad hand)")
|
| 203 |
+
prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=2)
|
| 204 |
+
strength = gr.Slider(0.3, 1.0, value=0.75, step=0.05, label="Fix strength (denoise)")
|
| 205 |
+
btn = gr.Button("Fix hands", variant="primary")
|
| 206 |
+
with gr.Column():
|
| 207 |
+
out = gr.Image(type="pil", label="Result")
|
| 208 |
+
btn.click(fix_hands, inputs=[inp, inp, prompt, strength], outputs=out, api_name="fix_hands")
|
| 209 |
+
with gr.Tab("Add detail"):
|
| 210 |
+
with gr.Row():
|
| 211 |
+
with gr.Column():
|
| 212 |
+
dinp = gr.Image(type="pil", label="Image to sharpen / add detail")
|
| 213 |
+
dstrength = gr.Slider(0.15, 0.6, value=0.3, step=0.05,
|
| 214 |
+
label="Detail strength (low = safe & on-style, high = more new detail / more drift)")
|
| 215 |
+
dscale = gr.Slider(1.0, 2.0, value=1.0, step=0.5, label="Enlarge first (×)")
|
| 216 |
+
dbtn = gr.Button("Add detail", variant="primary")
|
| 217 |
+
with gr.Column():
|
| 218 |
+
dout = gr.Image(type="pil", label="Result")
|
| 219 |
+
dbtn.click(detail_pass, inputs=[dinp, dstrength, dscale], outputs=dout, api_name="detail_pass")
|
| 220 |
|
| 221 |
if __name__ == "__main__":
|
| 222 |
demo.queue().launch()
|