cavargas10 commited on
Commit
7ea9149
Β·
verified Β·
1 Parent(s): 6110d4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -71
app.py CHANGED
@@ -3,16 +3,20 @@ import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
- import numpy as np
 
7
  import torch
 
 
 
8
  from PIL import Image, ImageOps
9
  from trellis.pipelines import TrellisImageTo3DPipeline
10
- from trellis.representations import Gaussian, MeshExtractResult
11
  from trellis.utils import render_utils, postprocessing_utils
12
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
13
  from diffusers import EulerAncestralDiscreteScheduler
14
-
15
- os.environ['SPCONV_ALGO'] = 'native'
16
 
17
  style_list = [
18
  {
@@ -20,9 +24,52 @@ style_list = [
20
  "prompt": "{prompt}",
21
  "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
22
  },
23
- # ... (resto de los estilos sin cambios)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ]
25
-
26
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
27
  STYLE_NAMES = list(styles.keys())
28
  DEFAULT_STYLE_NAME = "(No style)"
@@ -31,9 +78,7 @@ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
31
  os.makedirs(TMP_DIR, exist_ok=True)
32
 
33
  def reset_canvas():
34
- return gr.update(value={"background": Image.new("RGB", (512, 512), (255, 255, 255)),
35
- "layers": [Image.new("RGB", (512, 512), (255, 255, 255))],
36
- "composite": Image.new("RGB", (512, 512), (255, 255, 255))})
37
 
38
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
39
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
@@ -74,41 +119,19 @@ def preprocess_image(image: Image.Image,
74
  processed_image = pipeline.preprocess_image(output)
75
  return (image, processed_image)
76
 
77
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
78
  return {
79
- 'gaussian': {
80
- **gs.init_params,
81
- '_xyz': gs._xyz.cpu().numpy(),
82
- '_features_dc': gs._features_dc.cpu().numpy(),
83
- '_scaling': gs._scaling.cpu().numpy(),
84
- '_rotation': gs._rotation.cpu().numpy(),
85
- '_opacity': gs._opacity.cpu().numpy(),
86
- },
87
  'mesh': {
88
  'vertices': mesh.vertices.cpu().numpy(),
89
  'faces': mesh.faces.cpu().numpy(),
90
  },
91
  }
92
 
93
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
94
- gs = Gaussian(
95
- aabb=state['gaussian']['aabb'],
96
- sh_degree=state['gaussian']['sh_degree'],
97
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
98
- scaling_bias=state['gaussian']['scaling_bias'],
99
- opacity_bias=state['gaussian']['opacity_bias'],
100
- scaling_activation=state['gaussian']['scaling_activation'],
101
- )
102
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
103
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
104
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
105
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
106
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
107
- mesh = edict(
108
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
109
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
110
  )
111
- return gs, mesh
112
 
113
  def get_seed(randomize_seed: bool, seed: int) -> int:
114
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@@ -127,7 +150,7 @@ def image_to_3d(
127
  outputs = pipeline.run(
128
  image[1],
129
  seed=seed,
130
- formats=["gaussian", "mesh"],
131
  preprocess_image=False,
132
  sparse_structure_sampler_params={
133
  "steps": ss_sampling_steps,
@@ -138,12 +161,10 @@ def image_to_3d(
138
  "cfg_strength": slat_guidance_strength,
139
  },
140
  )
141
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
142
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
143
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
144
  video_path = os.path.join(user_dir, 'sample.mp4')
145
  imageio.mimsave(video_path, video, fps=15)
146
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
147
  torch.cuda.empty_cache()
148
  return state, video_path
149
 
@@ -153,75 +174,68 @@ def extract_glb(
153
  mesh_simplify: float,
154
  texture_size: int,
155
  req: gr.Request,
156
- ) -> Tuple[str, str]:
157
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
158
- gs, mesh = unpack_state(state)
159
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
160
  glb_path = os.path.join(user_dir, 'sample.glb')
161
  glb.export(glb_path)
162
  torch.cuda.empty_cache()
163
- return glb_path, glb_path
 
 
 
164
 
165
  with gr.Blocks(delete_cache=(600, 600)) as demo:
166
  gr.Markdown("""
167
  ## Sketch to 3D with TRELLIS
168
  1. Fast sketch to image with SDXL Flash
169
  2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
170
- ### 🎨 Draw or upload a sketch and click "Generate" to create a 3D asset πŸ’Ž
171
  """)
172
-
173
  with gr.Row():
174
  with gr.Column():
175
- image_prompt = gr.ImageMask(label="Input sketch", type="pil", image_mode="RGB", height=512,
176
- value={"background": Image.new("RGB", (512, 512), (255, 255, 255)),
177
- "layers": [Image.new("RGB", (512, 512), (255, 255, 255))],
178
- "composite": Image.new("RGB", (512, 512), (255, 255, 255))})
179
-
180
  with gr.Row():
181
  sketch_btn = gr.Button("Process sketch")
182
  generate_btn = gr.Button("Generate 3D")
183
-
184
  with gr.Row():
185
  prompt = gr.Textbox(label="Prompt")
186
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
187
-
188
  with gr.Accordion(label="Generation Settings", open=False):
189
  with gr.Tab(label="sketch-to-image generation"):
190
  negative_prompt = gr.Textbox(label="Negative prompt")
191
  num_steps = gr.Slider(1, 20, label="Number of steps", value=8, step=1)
192
  guidance_scale = gr.Slider(0.1, 10.0, label="Guidance scale", value=5, step=0.1)
193
  controlnet_conditioning_scale = gr.Slider(0.5, 5.0, label="ControlNet conditioning scale", value=0.85, step=0.01)
194
-
195
  with gr.Tab(label="3D generation"):
196
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
197
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
198
-
199
  gr.Markdown("Stage 1: Sparse Structure Generation")
200
  with gr.Row():
201
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
202
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
203
-
204
  gr.Markdown("Stage 2: Structured Latent Generation")
205
  with gr.Row():
206
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
207
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
208
-
209
  with gr.Accordion(label="GLB Extraction Settings", open=False):
210
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
211
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
212
-
213
  with gr.Row():
214
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
215
-
 
216
  with gr.Column():
217
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
218
  image_prompt_processed = gr.Image(label="Processed sketch", interactive=False, type="pil", height=512)
219
  model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
220
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
221
-
222
  output_buf = gr.State()
223
-
224
- with gr.Row():
 
225
  examples = gr.Examples(
226
  examples=[f'assets/example_image/{image}' for image in os.listdir("assets/example_image")],
227
  inputs=[image_prompt],
@@ -233,9 +247,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
233
 
234
  demo.load(start_session)
235
  demo.unload(end_session)
236
-
237
  image_prompt.clear(reset_canvas, outputs=[image_prompt])
238
-
239
  sketch_btn.click(
240
  get_seed,
241
  inputs=[randomize_seed, seed],
@@ -245,7 +259,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
245
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
246
  outputs=[image_prompt_processed],
247
  )
248
-
249
  generate_btn.click(
250
  get_seed,
251
  inputs=[randomize_seed, seed],
@@ -258,12 +272,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
258
  lambda: gr.Button(interactive=True),
259
  outputs=[extract_glb_btn],
260
  )
261
-
262
  video_output.clear(
263
  lambda: gr.Button(interactive=False),
264
  outputs=[extract_glb_btn],
265
  )
266
-
267
  extract_glb_btn.click(
268
  extract_glb,
269
  inputs=[output_buf, mesh_simplify, texture_size],
@@ -272,7 +286,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
272
  lambda: gr.Button(interactive=True),
273
  outputs=[download_glb],
274
  )
275
-
276
  model_output.clear(
277
  lambda: gr.Button(interactive=False),
278
  outputs=[download_glb],
@@ -281,23 +295,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
281
  if __name__ == "__main__":
282
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
283
  pipeline.cuda()
284
-
285
  device = "cuda" if torch.cuda.is_available() else "cpu"
286
 
287
  controlnet = ControlNetModel.from_pretrained(
288
  "xinsir/controlnet-scribble-sdxl-1.0",
289
  torch_dtype=torch.float16
290
  )
291
-
292
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
293
-
294
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
295
  "sd-community/sdxl-flash",
296
  controlnet=controlnet,
297
  vae=vae,
298
  torch_dtype=torch.float16,
299
  )
300
-
301
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
302
  pipe_control.to(device)
303
 
 
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
+ os.environ['SPCONV_ALGO'] = 'native'
7
+ from typing import *
8
  import torch
9
+ import numpy as np
10
+ import imageio
11
+ from easydict import EasyDict as edict
12
  from PIL import Image, ImageOps
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.representations import MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
17
  from diffusers import EulerAncestralDiscreteScheduler
18
+ from huggingface_hub import HfApi
19
+ from pathlib import Path
20
 
21
  style_list = [
22
  {
 
24
  "prompt": "{prompt}",
25
  "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
26
  },
27
+ {
28
+ "name": "Cinematic",
29
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
30
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
31
+ },
32
+ {
33
+ "name": "3D Model",
34
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
35
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
36
+ },
37
+ {
38
+ "name": "Anime",
39
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
40
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
41
+ },
42
+ {
43
+ "name": "Digital Art",
44
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
45
+ "negative_prompt": "photo, photorealistic, realism, ugly",
46
+ },
47
+ {
48
+ "name": "Photographic",
49
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
50
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
51
+ },
52
+ {
53
+ "name": "Pixel art",
54
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
55
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
56
+ },
57
+ {
58
+ "name": "Fantasy art",
59
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
60
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
61
+ },
62
+ {
63
+ "name": "Neonpunk",
64
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
65
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
66
+ },
67
+ {
68
+ "name": "Manga",
69
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
70
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
71
+ },
72
  ]
 
73
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
74
  STYLE_NAMES = list(styles.keys())
75
  DEFAULT_STYLE_NAME = "(No style)"
 
78
  os.makedirs(TMP_DIR, exist_ok=True)
79
 
80
  def reset_canvas():
81
+ return gr.update(value={"background":Image.new("RGB", (512, 512), (255, 255, 255)), "layers":[Image.new("RGB", (512, 512), (255, 255, 255))], "composite":Image.new("RGB", (512, 512), (255, 255, 255))})
 
 
82
 
83
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
84
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
 
119
  processed_image = pipeline.preprocess_image(output)
120
  return (image, processed_image)
121
 
122
+ def pack_state(mesh: MeshExtractResult) -> dict:
123
  return {
 
 
 
 
 
 
 
 
124
  'mesh': {
125
  'vertices': mesh.vertices.cpu().numpy(),
126
  'faces': mesh.faces.cpu().numpy(),
127
  },
128
  }
129
 
130
+ def unpack_state(state: dict) -> edict:
131
+ return edict(
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
133
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
134
  )
 
135
 
136
  def get_seed(randomize_seed: bool, seed: int) -> int:
137
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
 
150
  outputs = pipeline.run(
151
  image[1],
152
  seed=seed,
153
+ formats=["mesh"],
154
  preprocess_image=False,
155
  sparse_structure_sampler_params={
156
  "steps": ss_sampling_steps,
 
161
  "cfg_strength": slat_guidance_strength,
162
  },
163
  )
164
+ video = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
 
 
165
  video_path = os.path.join(user_dir, 'sample.mp4')
166
  imageio.mimsave(video_path, video, fps=15)
167
+ state = pack_state(outputs['mesh'][0])
168
  torch.cuda.empty_cache()
169
  return state, video_path
170
 
 
174
  mesh_simplify: float,
175
  texture_size: int,
176
  req: gr.Request,
177
+ ) -> str:
178
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
179
+ mesh = unpack_state(state)
180
+ glb = postprocessing_utils.to_glb(mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
181
  glb_path = os.path.join(user_dir, 'sample.glb')
182
  glb.export(glb_path)
183
  torch.cuda.empty_cache()
184
+ return glb_path
185
+
186
+ def reset_do_preprocess():
187
+ return True
188
 
189
  with gr.Blocks(delete_cache=(600, 600)) as demo:
190
  gr.Markdown("""
191
  ## Sketch to 3D with TRELLIS
192
  1. Fast sketch to image with SDXL Flash
193
  2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
194
+ ### 🎨 Draw or upload a sketch and click "Generate" to create a 3D asset ✨
195
  """)
 
196
  with gr.Row():
197
  with gr.Column():
198
+ image_prompt = gr.ImageMask(label="Input sketch", type="pil", image_mode="RGB", height=512, value=reset_canvas())
 
 
 
 
199
  with gr.Row():
200
  sketch_btn = gr.Button("Process sketch")
201
  generate_btn = gr.Button("Generate 3D")
 
202
  with gr.Row():
203
  prompt = gr.Textbox(label="Prompt")
204
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
 
205
  with gr.Accordion(label="Generation Settings", open=False):
206
  with gr.Tab(label="sketch-to-image generation"):
207
  negative_prompt = gr.Textbox(label="Negative prompt")
208
  num_steps = gr.Slider(1, 20, label="Number of steps", value=8, step=1)
209
  guidance_scale = gr.Slider(0.1, 10.0, label="Guidance scale", value=5, step=0.1)
210
  controlnet_conditioning_scale = gr.Slider(0.5, 5.0, label="ControlNet conditioning scale", value=0.85, step=0.01)
 
211
  with gr.Tab(label="3D generation"):
212
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
213
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
214
  gr.Markdown("Stage 1: Sparse Structure Generation")
215
  with gr.Row():
216
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
217
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
218
  gr.Markdown("Stage 2: Structured Latent Generation")
219
  with gr.Row():
220
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
221
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
222
  with gr.Accordion(label="GLB Extraction Settings", open=False):
223
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
224
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
225
  with gr.Row():
226
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
227
+ gr.Markdown("")
228
+
229
  with gr.Column():
230
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
231
  image_prompt_processed = gr.Image(label="Processed sketch", interactive=False, type="pil", height=512)
232
  model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
233
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
234
+
235
  output_buf = gr.State()
236
+ do_preprocess = gr.State(True)
237
+
238
+ with gr.Row(visible=False) as single_image_example:
239
  examples = gr.Examples(
240
  examples=[f'assets/example_image/{image}' for image in os.listdir("assets/example_image")],
241
  inputs=[image_prompt],
 
247
 
248
  demo.load(start_session)
249
  demo.unload(end_session)
250
+
251
  image_prompt.clear(reset_canvas, outputs=[image_prompt])
252
+
253
  sketch_btn.click(
254
  get_seed,
255
  inputs=[randomize_seed, seed],
 
259
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
260
  outputs=[image_prompt_processed],
261
  )
262
+
263
  generate_btn.click(
264
  get_seed,
265
  inputs=[randomize_seed, seed],
 
272
  lambda: gr.Button(interactive=True),
273
  outputs=[extract_glb_btn],
274
  )
275
+
276
  video_output.clear(
277
  lambda: gr.Button(interactive=False),
278
  outputs=[extract_glb_btn],
279
  )
280
+
281
  extract_glb_btn.click(
282
  extract_glb,
283
  inputs=[output_buf, mesh_simplify, texture_size],
 
286
  lambda: gr.Button(interactive=True),
287
  outputs=[download_glb],
288
  )
289
+
290
  model_output.clear(
291
  lambda: gr.Button(interactive=False),
292
  outputs=[download_glb],
 
295
  if __name__ == "__main__":
296
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
297
  pipeline.cuda()
 
298
  device = "cuda" if torch.cuda.is_available() else "cpu"
299
 
300
  controlnet = ControlNetModel.from_pretrained(
301
  "xinsir/controlnet-scribble-sdxl-1.0",
302
  torch_dtype=torch.float16
303
  )
 
304
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
305
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
306
  "sd-community/sdxl-flash",
307
  controlnet=controlnet,
308
  vae=vae,
309
  torch_dtype=torch.float16,
310
  )
 
311
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
312
  pipe_control.to(device)
313