cavargas10 commited on
Commit
6110d4e
Β·
verified Β·
1 Parent(s): 26d325e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -170
app.py CHANGED
@@ -1,46 +1,39 @@
1
- import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
- os.environ['SPCONV_ALGO'] = 'native'
7
- from typing import *
8
- import imageio
9
- from easydict import EasyDict as edict
10
  from PIL import Image, ImageOps
11
  from trellis.pipelines import TrellisImageTo3DPipeline
12
  from trellis.representations import Gaussian, MeshExtractResult
13
  from trellis.utils import render_utils, postprocessing_utils
14
- import random
15
- import torch
16
- import torchvision.transforms.functional as TF
17
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
18
- from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
19
- from controlnet_aux import PidiNetDetector, HEDdetector
20
- from diffusers.utils import load_image
21
- from huggingface_hub import HfApi
22
- from pathlib import Path
23
- import numpy as np
24
- import cv2
25
- from gradio_imageslider import ImageSlider
26
 
27
  style_list = [
28
  {
29
- "name": "3D Model",
30
- "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
31
- "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
32
- }
 
33
  ]
 
34
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
35
  STYLE_NAMES = list(styles.keys())
36
  DEFAULT_STYLE_NAME = "(No style)"
37
-
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
40
  os.makedirs(TMP_DIR, exist_ok=True)
41
 
42
  def reset_canvas():
43
- return gr.update(value={"background":Image.new("RGB", (512, 512), (255, 255, 255)), "layers":[Image.new("RGB", (512, 512), (255, 255, 255))], "composite":Image.new("RGB", (512, 512), (255, 255, 255))})
 
 
44
 
45
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
46
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
@@ -49,60 +42,37 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str
49
  def start_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
51
  os.makedirs(user_dir, exist_ok=True)
52
-
53
  def end_session(req: gr.Request):
54
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
55
  shutil.rmtree(user_dir)
56
 
57
  @spaces.GPU
58
- def preprocess_image(
59
- image: Image.Image,
60
- prompt: str = "",
61
- negative_prompt: str = "",
62
- image=image,
63
- num_inference_steps=num_steps,
64
- controlnet_conditioning_scale=controlnet_conditioning_scale,
65
- guidance_scale=guidance_scale,
66
- style_name: str = "",
67
- num_steps: int = 25,
68
- guidance_scale: float = 5,
69
- controlnet_conditioning_scale: float = 1.0,
70
- req: gr.Request = None # Agregamos el parΓ‘metro `req`
71
- ) -> Tuple[Image.Image, Image.Image]:
72
- # Crear un directorio ΓΊnico para el usuario basado en su sesiΓ³n
73
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
74
- os.makedirs(user_dir, exist_ok=True) # Asegurarse de que el directorio existe
75
-
76
- # Procesar las dimensiones de la imagen
77
- width, height = image['composite'].size
78
  ratio = np.sqrt(1024. * 1024. / (width * height))
79
  new_width, new_height = int(width * ratio), int(height * ratio)
80
  image = image['composite'].resize((new_width, new_height))
81
  image = ImageOps.invert(image)
82
-
83
- print("image:", type(image))
84
-
85
- # Aplicar estilo al prompt
86
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
87
-
88
- print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
89
-
90
- # Generar la imagen procesada usando el pipeline
91
  output = pipe_control(
92
  prompt=prompt,
93
  negative_prompt=negative_prompt,
 
 
 
 
94
  width=new_width,
95
- height=new_height
96
- ).images[0]
97
-
98
- # Guardar la imagen procesada en el directorio del usuario
99
- processed_image_path = os.path.join(user_dir, 'processed_image.png')
100
- output.save(processed_image_path)
101
-
102
- # Preprocesar la imagen para el siguiente paso (si es necesario)
103
  processed_image = pipeline.preprocess_image(output)
104
-
105
- return image, processed_image
106
 
107
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
108
  return {
@@ -119,7 +89,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
119
  'faces': mesh.faces.cpu().numpy(),
120
  },
121
  }
122
-
123
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
124
  gs = Gaussian(
125
  aabb=state['gaussian']['aabb'],
@@ -134,12 +104,10 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
134
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
135
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
136
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
137
-
138
  mesh = edict(
139
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
140
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
141
  )
142
-
143
  return gs, mesh
144
 
145
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -176,13 +144,13 @@ def image_to_3d(
176
  video_path = os.path.join(user_dir, 'sample.mp4')
177
  imageio.mimsave(video_path, video, fps=15)
178
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
-
180
  torch.cuda.empty_cache()
181
  return state, video_path
182
 
183
  @spaces.GPU(duration=90)
184
  def extract_glb(
185
  state: dict,
 
186
  texture_size: int,
187
  req: gr.Request,
188
  ) -> Tuple[str, str]:
@@ -194,94 +162,68 @@ def extract_glb(
194
  torch.cuda.empty_cache()
195
  return glb_path, glb_path
196
 
197
- def reset_do_preprocess():
198
- return True
199
-
200
- @spaces.GPU
201
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
202
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
203
- gs, _ = unpack_state(state)
204
- gaussian_path = os.path.join(user_dir, 'sample.ply')
205
- gs.save_ply(gaussian_path)
206
- torch.cuda.empty_cache()
207
- return gaussian_path, gaussian_path
208
-
209
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
 
 
 
 
 
 
210
  with gr.Row():
211
  with gr.Column():
212
- with gr.Column():
213
- image_prompt = gr.ImageMask(label="Input sketch", type="pil", image_mode="RGB", height=512, value={"background":Image.new("RGB", (512, 512), (255, 255, 255)), "layers":[Image.new("RGB", (512, 512), (255, 255, 255))], "composite":Image.new("RGB", (512, 512), (255, 255, 255))})
214
- with gr.Row():
215
- sketch_btn = gr.Button("process sketch")
216
- generate_btn = gr.Button("Generate 3D")
217
- with gr.Row():
218
- prompt = gr.Textbox(label="Prompt")
219
- style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
220
-
221
- with gr.Accordion(label="Generation Settings", open=False):
222
- with gr.Tab(label="sketch-to-image generation"):
223
- negative_prompt = gr.Textbox(label="Negative prompt")
224
-
225
- num_steps = gr.Slider(
226
- label="Number of steps",
227
- minimum=1,
228
- maximum=20,
229
- step=1,
230
- value=8,
231
- )
232
- guidance_scale = gr.Slider(
233
- label="Guidance scale",
234
- minimum=0.1,
235
- maximum=10.0,
236
- step=0.1,
237
- value=5,
238
- )
239
- controlnet_conditioning_scale = gr.Slider(
240
- label="controlnet conditioning scale",
241
- minimum=0.5,
242
- maximum=5.0,
243
- step=0.01,
244
- value=0.85,
245
- )
246
- with gr.Tab(label="3D generation"):
247
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
248
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
249
- gr.Markdown("Stage 1: Sparse Structure Generation")
250
- with gr.Row():
251
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
252
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
253
- gr.Markdown("Stage 2: Structured Latent Generation")
254
  with gr.Row():
255
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
256
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
257
-
258
  with gr.Accordion(label="GLB Extraction Settings", open=False):
259
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
260
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
261
 
262
  with gr.Row():
263
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
264
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
265
-
266
  with gr.Column():
267
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
268
- image_prompt_processed = ImageSlider(label="processed sketch", interactive=False, type="pil", height=512)
269
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
270
-
271
- with gr.Row():
272
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
273
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
274
-
275
- do_preprocess = gr.State(True)
276
  output_buf = gr.State()
277
-
278
- #Example images at the bottom of the page
279
- with gr.Row(visible=False) as single_image_example:
280
  examples = gr.Examples(
281
- examples=[
282
- f'assets/example_image/{image}'
283
- for image in os.listdir("assets/example_image")
284
- ],
285
  inputs=[image_prompt],
286
  fn=preprocess_image,
287
  outputs=[image_prompt_processed],
@@ -289,84 +231,79 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
289
  examples_per_page=64,
290
  )
291
 
292
- # Handlers
293
  demo.load(start_session)
294
  demo.unload(end_session)
295
 
296
- image_prompt.clear(
297
- fn=reset_canvas,
298
- outputs = [image_prompt]
299
- )
300
 
301
  sketch_btn.click(
302
  get_seed,
 
303
  outputs=[seed],
304
  ).then(
305
  preprocess_image,
306
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
307
  outputs=[image_prompt_processed],
308
  )
309
-
310
  generate_btn.click(
311
  get_seed,
312
  inputs=[randomize_seed, seed],
 
 
 
313
  inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
314
  outputs=[output_buf, video_output],
315
  ).then(
316
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
317
- outputs=[extract_glb_btn, extract_gs_btn],
318
  )
319
-
320
  video_output.clear(
321
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
322
- outputs=[extract_glb_btn, extract_gs_btn],
323
  )
324
-
325
  extract_glb_btn.click(
326
  extract_glb,
327
  inputs=[output_buf, mesh_simplify, texture_size],
 
 
328
  lambda: gr.Button(interactive=True),
329
  outputs=[download_glb],
330
  )
331
 
332
- extract_gs_btn.click(
333
- extract_gaussian,
334
- inputs=[output_buf],
335
- outputs=[model_output, download_gs],
336
- ).then(
337
- lambda: gr.Button(interactive=True),
338
- outputs=[download_gs],
339
- )
340
-
341
  model_output.clear(
342
  lambda: gr.Button(interactive=False),
343
  outputs=[download_glb],
344
  )
345
-
346
- # Launch the Gradio app
347
  if __name__ == "__main__":
348
- pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
349
  pipeline.cuda()
350
-
351
  device = "cuda" if torch.cuda.is_available() else "cpu"
352
 
353
- #scribble controlnet
354
  controlnet = ControlNetModel.from_pretrained(
355
- "xinsir/controlnet-scribble-sdxl-1.0",
356
- torch_dtype=torch.float16
357
- )
 
358
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
359
 
360
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
 
 
361
  vae=vae,
362
  torch_dtype=torch.float16,
363
  )
364
-
365
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
366
  pipe_control.to(device)
367
 
368
  try:
369
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
370
  except:
371
  pass
 
372
  demo.launch()
 
1
+ import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
+ import numpy as np
7
+ import torch
 
 
8
  from PIL import Image, ImageOps
9
  from trellis.pipelines import TrellisImageTo3DPipeline
10
  from trellis.representations import Gaussian, MeshExtractResult
11
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
12
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
13
+ from diffusers import EulerAncestralDiscreteScheduler
14
+
15
+ os.environ['SPCONV_ALGO'] = 'native'
 
 
 
 
 
16
 
17
  style_list = [
18
  {
19
+ "name": "(No style)",
20
+ "prompt": "{prompt}",
21
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
22
+ },
23
+ # ... (resto de los estilos sin cambios)
24
  ]
25
+
26
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
27
  STYLE_NAMES = list(styles.keys())
28
  DEFAULT_STYLE_NAME = "(No style)"
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
31
  os.makedirs(TMP_DIR, exist_ok=True)
32
 
33
  def reset_canvas():
34
+ return gr.update(value={"background": Image.new("RGB", (512, 512), (255, 255, 255)),
35
+ "layers": [Image.new("RGB", (512, 512), (255, 255, 255))],
36
+ "composite": Image.new("RGB", (512, 512), (255, 255, 255))})
37
 
38
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
39
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
 
42
  def start_session(req: gr.Request):
43
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
44
  os.makedirs(user_dir, exist_ok=True)
45
+
46
  def end_session(req: gr.Request):
47
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
48
  shutil.rmtree(user_dir)
49
 
50
  @spaces.GPU
51
+ def preprocess_image(image: Image.Image,
52
+ prompt: str = "",
53
+ negative_prompt: str = "",
54
+ style_name: str = "",
55
+ num_steps: int = 25,
56
+ guidance_scale: float = 5,
57
+ controlnet_conditioning_scale: float = 1.0,
58
+ ) -> Image.Image:
59
+ width, height = image['composite'].size
 
 
 
 
 
 
 
 
 
 
 
60
  ratio = np.sqrt(1024. * 1024. / (width * height))
61
  new_width, new_height = int(width * ratio), int(height * ratio)
62
  image = image['composite'].resize((new_width, new_height))
63
  image = ImageOps.invert(image)
 
 
 
 
64
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
 
 
 
 
65
  output = pipe_control(
66
  prompt=prompt,
67
  negative_prompt=negative_prompt,
68
+ image=image,
69
+ num_inference_steps=num_steps,
70
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
71
+ guidance_scale=guidance_scale,
72
  width=new_width,
73
+ height=new_height).images[0]
 
 
 
 
 
 
 
74
  processed_image = pipeline.preprocess_image(output)
75
+ return (image, processed_image)
 
76
 
77
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
78
  return {
 
89
  'faces': mesh.faces.cpu().numpy(),
90
  },
91
  }
92
+
93
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
94
  gs = Gaussian(
95
  aabb=state['gaussian']['aabb'],
 
104
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
105
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
106
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
107
  mesh = edict(
108
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
109
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
110
  )
 
111
  return gs, mesh
112
 
113
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
144
  video_path = os.path.join(user_dir, 'sample.mp4')
145
  imageio.mimsave(video_path, video, fps=15)
146
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
147
  torch.cuda.empty_cache()
148
  return state, video_path
149
 
150
  @spaces.GPU(duration=90)
151
  def extract_glb(
152
  state: dict,
153
+ mesh_simplify: float,
154
  texture_size: int,
155
  req: gr.Request,
156
  ) -> Tuple[str, str]:
 
162
  torch.cuda.empty_cache()
163
  return glb_path, glb_path
164
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Blocks(delete_cache=(600, 600)) as demo:
166
+ gr.Markdown("""
167
+ ## Sketch to 3D with TRELLIS
168
+ 1. Fast sketch to image with SDXL Flash
169
+ 2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
170
+ ### 🎨 Draw or upload a sketch and click "Generate" to create a 3D asset πŸ’Ž
171
+ """)
172
+
173
  with gr.Row():
174
  with gr.Column():
175
+ image_prompt = gr.ImageMask(label="Input sketch", type="pil", image_mode="RGB", height=512,
176
+ value={"background": Image.new("RGB", (512, 512), (255, 255, 255)),
177
+ "layers": [Image.new("RGB", (512, 512), (255, 255, 255))],
178
+ "composite": Image.new("RGB", (512, 512), (255, 255, 255))})
179
+
180
+ with gr.Row():
181
+ sketch_btn = gr.Button("Process sketch")
182
+ generate_btn = gr.Button("Generate 3D")
183
+
184
+ with gr.Row():
185
+ prompt = gr.Textbox(label="Prompt")
186
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
187
+
188
+ with gr.Accordion(label="Generation Settings", open=False):
189
+ with gr.Tab(label="sketch-to-image generation"):
190
+ negative_prompt = gr.Textbox(label="Negative prompt")
191
+ num_steps = gr.Slider(1, 20, label="Number of steps", value=8, step=1)
192
+ guidance_scale = gr.Slider(0.1, 10.0, label="Guidance scale", value=5, step=0.1)
193
+ controlnet_conditioning_scale = gr.Slider(0.5, 5.0, label="ControlNet conditioning scale", value=0.85, step=0.01)
194
+
195
+ with gr.Tab(label="3D generation"):
196
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
197
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
198
+
199
+ gr.Markdown("Stage 1: Sparse Structure Generation")
200
+ with gr.Row():
201
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
202
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
203
+
204
+ gr.Markdown("Stage 2: Structured Latent Generation")
 
 
 
 
 
 
 
 
 
 
 
 
205
  with gr.Row():
206
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
207
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
208
+
209
  with gr.Accordion(label="GLB Extraction Settings", open=False):
210
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
211
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
212
 
213
  with gr.Row():
214
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
215
+
 
216
  with gr.Column():
217
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
218
+ image_prompt_processed = gr.Image(label="Processed sketch", interactive=False, type="pil", height=512)
219
+ model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
220
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
221
+
 
 
 
 
222
  output_buf = gr.State()
223
+
224
+ with gr.Row():
 
225
  examples = gr.Examples(
226
+ examples=[f'assets/example_image/{image}' for image in os.listdir("assets/example_image")],
 
 
 
227
  inputs=[image_prompt],
228
  fn=preprocess_image,
229
  outputs=[image_prompt_processed],
 
231
  examples_per_page=64,
232
  )
233
 
 
234
  demo.load(start_session)
235
  demo.unload(end_session)
236
 
237
+ image_prompt.clear(reset_canvas, outputs=[image_prompt])
 
 
 
238
 
239
  sketch_btn.click(
240
  get_seed,
241
+ inputs=[randomize_seed, seed],
242
  outputs=[seed],
243
  ).then(
244
  preprocess_image,
245
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
246
  outputs=[image_prompt_processed],
247
  )
248
+
249
  generate_btn.click(
250
  get_seed,
251
  inputs=[randomize_seed, seed],
252
+ outputs=[seed],
253
+ ).then(
254
+ image_to_3d,
255
  inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
256
  outputs=[output_buf, video_output],
257
  ).then(
258
+ lambda: gr.Button(interactive=True),
259
+ outputs=[extract_glb_btn],
260
  )
261
+
262
  video_output.clear(
263
+ lambda: gr.Button(interactive=False),
264
+ outputs=[extract_glb_btn],
265
  )
266
+
267
  extract_glb_btn.click(
268
  extract_glb,
269
  inputs=[output_buf, mesh_simplify, texture_size],
270
+ outputs=[model_output, download_glb],
271
+ ).then(
272
  lambda: gr.Button(interactive=True),
273
  outputs=[download_glb],
274
  )
275
 
 
 
 
 
 
 
 
 
 
276
  model_output.clear(
277
  lambda: gr.Button(interactive=False),
278
  outputs=[download_glb],
279
  )
280
+
 
281
  if __name__ == "__main__":
282
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
283
  pipeline.cuda()
284
+
285
  device = "cuda" if torch.cuda.is_available() else "cpu"
286
 
 
287
  controlnet = ControlNetModel.from_pretrained(
288
+ "xinsir/controlnet-scribble-sdxl-1.0",
289
+ torch_dtype=torch.float16
290
+ )
291
+
292
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
293
 
294
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
295
+ "sd-community/sdxl-flash",
296
+ controlnet=controlnet,
297
  vae=vae,
298
  torch_dtype=torch.float16,
299
  )
300
+
301
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
302
  pipe_control.to(device)
303
 
304
  try:
305
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
306
  except:
307
  pass
308
+
309
  demo.launch()