cavargas10 commited on
Commit
b05b1ea
·
verified ·
1 Parent(s): 70975ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -150
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
@@ -14,26 +13,11 @@ from PIL import Image, ImageOps
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
-
18
- import os
19
- import random
20
  import torch
21
  import torchvision.transforms.functional as TF
22
-
23
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
24
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
25
- from controlnet_aux import PidiNetDetector, HEDdetector
26
- from diffusers.utils import load_image
27
- from huggingface_hub import HfApi
28
  from pathlib import Path
29
- from PIL import Image, ImageOps
30
- import torch
31
- import numpy as np
32
- import cv2
33
- import os
34
- import random
35
- from gradio_imageslider import ImageSlider
36
-
37
  style_list = [
38
  {
39
  "name": "(No style)",
@@ -59,7 +43,6 @@ style_list = [
59
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
60
  STYLE_NAMES = list(styles.keys())
61
  DEFAULT_STYLE_NAME = "(No style)"
62
-
63
  MAX_SEED = np.iinfo(np.int32).max
64
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
65
  os.makedirs(TMP_DIR, exist_ok=True)
@@ -74,7 +57,7 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str
74
  def start_session(req: gr.Request):
75
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
76
  os.makedirs(user_dir, exist_ok=True)
77
-
78
  def end_session(req: gr.Request):
79
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
80
  shutil.rmtree(user_dir)
@@ -93,11 +76,8 @@ def preprocess_image(image: Image.Image,
93
  new_width, new_height = int(width * ratio), int(height * ratio)
94
  image = image['composite'].resize((new_width, new_height))
95
  image = ImageOps.invert(image)
96
-
97
  print("image:",type(image))
98
-
99
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
100
-
101
  print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
102
  output = pipe_control(
103
  prompt=prompt,
@@ -108,15 +88,9 @@ def preprocess_image(image: Image.Image,
108
  guidance_scale=guidance_scale,
109
  width=new_width,
110
  height=new_height).images[0]
111
-
112
  processed_image = pipeline.preprocess_image(output)
113
  return (image, processed_image)
114
 
115
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
116
- images = [image[0] for image in images]
117
- processed_images = [pipeline.preprocess_image(image) for image in images]
118
- return processed_images
119
-
120
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
121
  return {
122
  'gaussian': {
@@ -132,7 +106,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
132
  'faces': mesh.faces.cpu().numpy(),
133
  },
134
  }
135
-
136
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
137
  gs = Gaussian(
138
  aabb=state['gaussian']['aabb'],
@@ -147,12 +121,10 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
147
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
148
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
149
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
150
-
151
  mesh = edict(
152
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
153
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
154
  )
155
-
156
  return gs, mesh
157
 
158
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -161,48 +133,28 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
161
  @spaces.GPU
162
  def image_to_3d(
163
  image: Image.Image,
164
- multiimages: List[Tuple[Image.Image, str]],
165
- is_multiimage: bool,
166
  seed: int,
167
  ss_guidance_strength: float,
168
  ss_sampling_steps: int,
169
  slat_guidance_strength: float,
170
  slat_sampling_steps: int,
171
- multiimage_algo: Literal["multidiffusion", "stochastic"],
172
  req: gr.Request,
173
  ) -> Tuple[dict, str]:
174
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
175
- if not is_multiimage:
176
- outputs = pipeline.run(
177
- image[1],
178
- seed=seed,
179
- formats=["gaussian", "mesh"],
180
- preprocess_image=False,
181
- sparse_structure_sampler_params={
182
- "steps": ss_sampling_steps,
183
- "cfg_strength": ss_guidance_strength,
184
- },
185
- slat_sampler_params={
186
- "steps": slat_sampling_steps,
187
- "cfg_strength": slat_guidance_strength,
188
- },
189
- )
190
- else:
191
- outputs = pipeline.run_multi_image(
192
- [image[0] for image in multiimages],
193
- seed=seed,
194
- formats=["gaussian", "mesh"],
195
- preprocess_image=False,
196
- sparse_structure_sampler_params={
197
- "steps": ss_sampling_steps,
198
- "cfg_strength": ss_guidance_strength,
199
- },
200
- slat_sampler_params={
201
- "steps": slat_sampling_steps,
202
- "cfg_strength": slat_guidance_strength,
203
- },
204
- mode=multiimage_algo,
205
- )
206
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
207
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
208
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -227,10 +179,6 @@ def extract_glb(
227
  torch.cuda.empty_cache()
228
  return glb_path, glb_path
229
 
230
- def reset_do_preprocess():
231
- return True
232
-
233
- @spaces.GPU
234
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
235
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
236
  gs, _ = unpack_state(state)
@@ -239,30 +187,6 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
239
  torch.cuda.empty_cache()
240
  return gaussian_path, gaussian_path
241
 
242
- def prepare_multi_example() -> List[Image.Image]:
243
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
244
- images = []
245
- for case in multi_case:
246
- _images = []
247
- for i in range(1, 4):
248
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
249
- W, H = img.size
250
- img = img.resize((int(W / H * 512), 512))
251
- _images.append(np.array(img))
252
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
253
- return images
254
-
255
- def split_image(image: Image.Image) -> List[Image.Image]:
256
- image = np.array(image)
257
- alpha = image[..., 3]
258
- alpha = np.any(alpha>0, axis=0)
259
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
260
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
261
- images = []
262
- for s, e in zip(start_pos, end_pos):
263
- images.append(Image.fromarray(image[:, s:e+1]))
264
- return [preprocess_image(image) for image in images]
265
-
266
  with gr.Blocks(delete_cache=(600, 600)) as demo:
267
  gr.Markdown("""
268
  # UTPL - Conversión de Boceto a objetos 3D usando IA
@@ -271,7 +195,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
271
  **Base técnica:** Adaptación de TRELLIS (herramienta de código abierto para generación 3D)
272
  **Propósito educativo:** Demostraciones académicas e Investigación en modelado 3D automático
273
  """)
274
-
275
  with gr.Row():
276
  with gr.Column():
277
  with gr.Column():
@@ -282,11 +205,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
282
  with gr.Row():
283
  prompt = gr.Textbox(label="Prompt")
284
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
285
-
286
  with gr.Accordion(label="Generation Settings", open=False):
287
  with gr.Tab(label="sketch-to-image generation"):
288
  negative_prompt = gr.Textbox(label="Negative prompt")
289
-
290
  num_steps = gr.Slider(
291
  label="Number of steps",
292
  minimum=1,
@@ -319,71 +240,31 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
319
  with gr.Row():
320
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
321
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
322
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
323
-
324
- with gr.Tab(label="Multiple Images", id=1, visible=False) as multiimage_input_tab:
325
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
326
- gr.Markdown("""
327
- Input different views of the object in separate images.
328
-
329
- *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
330
- """)
331
-
332
  with gr.Accordion(label="GLB Extraction Settings", open=False):
333
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
334
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
335
-
336
  with gr.Row():
337
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
338
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
339
-
340
  with gr.Column():
341
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
342
- image_prompt_processed = ImageSlider(label="processed sketch", interactive=False, type="pil", height=512)
343
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
344
-
345
  with gr.Row():
346
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
347
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
348
-
349
- is_multiimage = gr.State(False)
350
  do_preprocess = gr.State(True)
351
  output_buf = gr.State()
352
 
353
- with gr.Row(visible=False) as single_image_example:
354
- examples = gr.Examples(
355
- examples=[
356
- f'assets/example_image/{image}'
357
- for image in os.listdir("assets/example_image")
358
- ],
359
- inputs=[image_prompt],
360
- fn=preprocess_image,
361
- outputs=[image_prompt_processed],
362
- run_on_click=True,
363
- examples_per_page=64,
364
- )
365
- with gr.Row(visible=False) as multiimage_example:
366
- examples_multi = gr.Examples(
367
- examples=prepare_multi_example(),
368
- inputs=[image_prompt],
369
- fn=split_image,
370
- outputs=[multiimage_prompt],
371
- run_on_click=True,
372
- examples_per_page=8,
373
- )
374
-
375
  demo.load(start_session)
376
  demo.unload(end_session)
377
-
378
- multiimage_input_tab.select(
379
- lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
380
- outputs=[is_multiimage, single_image_example, multiimage_example]
381
- )
382
  image_prompt.clear(
383
  fn=reset_canvas,
384
  outputs = [image_prompt]
385
  )
386
-
387
  sketch_btn.click(
388
  get_seed,
389
  inputs=[randomize_seed, seed],
@@ -393,11 +274,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
393
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
394
  outputs=[image_prompt_processed],
395
  )
396
- multiimage_prompt.upload(
397
- preprocess_images,
398
- inputs=[multiimage_prompt],
399
- outputs=[multiimage_prompt],
400
- )
401
 
402
  generate_btn.click(
403
  get_seed,
@@ -405,7 +281,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
405
  outputs=[seed],
406
  ).then(
407
  image_to_3d,
408
- inputs=[image_prompt_processed, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
409
  outputs=[output_buf, video_output],
410
  ).then(
411
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
@@ -425,7 +301,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
425
  lambda: gr.Button(interactive=True),
426
  outputs=[download_glb],
427
  )
428
-
429
  extract_gs_btn.click(
430
  extract_gaussian,
431
  inputs=[output_buf],
@@ -443,16 +319,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
443
  if __name__ == "__main__":
444
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
445
  pipeline.cuda()
446
-
447
  device = "cuda" if torch.cuda.is_available() else "cpu"
448
-
449
  #scribble controlnet
450
  controlnet = ControlNetModel.from_pretrained(
451
  "xinsir/controlnet-scribble-sdxl-1.0",
452
  torch_dtype=torch.float16
453
  )
454
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
455
-
456
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
457
  "sd-community/sdxl-flash",
458
  controlnet=controlnet,
@@ -461,7 +334,6 @@ if __name__ == "__main__":
461
  )
462
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
463
  pipe_control.to(device)
464
-
465
  try:
466
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
467
  except:
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  import shutil
6
  os.environ['SPCONV_ALGO'] = 'native'
 
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
16
  import torch
17
  import torchvision.transforms.functional as TF
 
18
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
19
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
 
 
 
20
  from pathlib import Path
 
 
 
 
 
 
 
 
21
  style_list = [
22
  {
23
  "name": "(No style)",
 
43
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
44
  STYLE_NAMES = list(styles.keys())
45
  DEFAULT_STYLE_NAME = "(No style)"
 
46
  MAX_SEED = np.iinfo(np.int32).max
47
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
48
  os.makedirs(TMP_DIR, exist_ok=True)
 
57
  def start_session(req: gr.Request):
58
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
59
  os.makedirs(user_dir, exist_ok=True)
60
+
61
  def end_session(req: gr.Request):
62
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
63
  shutil.rmtree(user_dir)
 
76
  new_width, new_height = int(width * ratio), int(height * ratio)
77
  image = image['composite'].resize((new_width, new_height))
78
  image = ImageOps.invert(image)
 
79
  print("image:",type(image))
 
80
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
 
81
  print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
82
  output = pipe_control(
83
  prompt=prompt,
 
88
  guidance_scale=guidance_scale,
89
  width=new_width,
90
  height=new_height).images[0]
 
91
  processed_image = pipeline.preprocess_image(output)
92
  return (image, processed_image)
93
 
 
 
 
 
 
94
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
95
  return {
96
  'gaussian': {
 
106
  'faces': mesh.faces.cpu().numpy(),
107
  },
108
  }
109
+
110
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
111
  gs = Gaussian(
112
  aabb=state['gaussian']['aabb'],
 
121
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
122
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
123
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
124
  mesh = edict(
125
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
126
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
127
  )
 
128
  return gs, mesh
129
 
130
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
133
  @spaces.GPU
134
  def image_to_3d(
135
  image: Image.Image,
 
 
136
  seed: int,
137
  ss_guidance_strength: float,
138
  ss_sampling_steps: int,
139
  slat_guidance_strength: float,
140
  slat_sampling_steps: int,
 
141
  req: gr.Request,
142
  ) -> Tuple[dict, str]:
143
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
144
+ outputs = pipeline.run(
145
+ image[1],
146
+ seed=seed,
147
+ formats=["gaussian", "mesh"],
148
+ preprocess_image=False,
149
+ sparse_structure_sampler_params={
150
+ "steps": ss_sampling_steps,
151
+ "cfg_strength": ss_guidance_strength,
152
+ },
153
+ slat_sampler_params={
154
+ "steps": slat_sampling_steps,
155
+ "cfg_strength": slat_guidance_strength,
156
+ },
157
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
159
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
160
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
179
  torch.cuda.empty_cache()
180
  return glb_path, glb_path
181
 
 
 
 
 
182
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
183
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
184
  gs, _ = unpack_state(state)
 
187
  torch.cuda.empty_cache()
188
  return gaussian_path, gaussian_path
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  with gr.Blocks(delete_cache=(600, 600)) as demo:
191
  gr.Markdown("""
192
  # UTPL - Conversión de Boceto a objetos 3D usando IA
 
195
  **Base técnica:** Adaptación de TRELLIS (herramienta de código abierto para generación 3D)
196
  **Propósito educativo:** Demostraciones académicas e Investigación en modelado 3D automático
197
  """)
 
198
  with gr.Row():
199
  with gr.Column():
200
  with gr.Column():
 
205
  with gr.Row():
206
  prompt = gr.Textbox(label="Prompt")
207
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
 
208
  with gr.Accordion(label="Generation Settings", open=False):
209
  with gr.Tab(label="sketch-to-image generation"):
210
  negative_prompt = gr.Textbox(label="Negative prompt")
 
211
  num_steps = gr.Slider(
212
  label="Number of steps",
213
  minimum=1,
 
240
  with gr.Row():
241
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
242
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
 
 
 
 
 
 
243
  with gr.Accordion(label="GLB Extraction Settings", open=False):
244
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
245
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
246
  with gr.Row():
247
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
248
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
249
  with gr.Column():
250
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
251
+ image_prompt_processed = gr.Image(label="processed sketch", interactive=False, type="pil", height=512)
252
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
 
253
  with gr.Row():
254
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
255
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
256
+
 
257
  do_preprocess = gr.State(True)
258
  output_buf = gr.State()
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  demo.load(start_session)
261
  demo.unload(end_session)
262
+
 
 
 
 
263
  image_prompt.clear(
264
  fn=reset_canvas,
265
  outputs = [image_prompt]
266
  )
267
+
268
  sketch_btn.click(
269
  get_seed,
270
  inputs=[randomize_seed, seed],
 
274
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
275
  outputs=[image_prompt_processed],
276
  )
 
 
 
 
 
277
 
278
  generate_btn.click(
279
  get_seed,
 
281
  outputs=[seed],
282
  ).then(
283
  image_to_3d,
284
+ inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
285
  outputs=[output_buf, video_output],
286
  ).then(
287
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
 
301
  lambda: gr.Button(interactive=True),
302
  outputs=[download_glb],
303
  )
304
+
305
  extract_gs_btn.click(
306
  extract_gaussian,
307
  inputs=[output_buf],
 
319
  if __name__ == "__main__":
320
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
321
  pipeline.cuda()
 
322
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
323
  #scribble controlnet
324
  controlnet = ControlNetModel.from_pretrained(
325
  "xinsir/controlnet-scribble-sdxl-1.0",
326
  torch_dtype=torch.float16
327
  )
328
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
329
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
330
  "sd-community/sdxl-flash",
331
  controlnet=controlnet,
 
334
  )
335
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
336
  pipe_control.to(device)
 
337
  try:
338
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
339
  except: