cavargas10 commited on
Commit
ea3c5de
Β·
verified Β·
1 Parent(s): 410cd67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -52
app.py CHANGED
@@ -1,23 +1,21 @@
1
- import gradio as gr
2
  import spaces
3
- from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
 
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
9
  import torch
10
- import torchvision.transforms.functional as TF
11
  import numpy as np
12
- import random
13
  import imageio
14
  import cv2
 
 
15
  from easydict import EasyDict as edict
16
  from PIL import Image, ImageOps
17
  from trellis.pipelines import TrellisImageTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
20
-
21
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
22
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
23
  from controlnet_aux import PidiNetDetector, HEDdetector
@@ -27,26 +25,16 @@ from pathlib import Path
27
  from gradio_imageslider import ImageSlider
28
 
29
  style_list = [
30
- {
31
- "name": "(No style)",
32
- "prompt": "{prompt}",
33
- "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
34
- },
35
- {
36
- "name": "Cinematic",
37
- "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
38
- "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
39
- },
40
  {
41
  "name": "3D Model",
42
  "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
43
  "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
44
  },
45
  ]
46
-
47
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
48
  STYLE_NAMES = list(styles.keys())
49
  DEFAULT_STYLE_NAME = "(No style)"
 
50
  MAX_SEED = np.iinfo(np.int32).max
51
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
52
  os.makedirs(TMP_DIR, exist_ok=True)
@@ -61,26 +49,37 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str
61
  def start_session(req: gr.Request):
62
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
63
  os.makedirs(user_dir, exist_ok=True)
64
-
65
  def end_session(req: gr.Request):
66
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
67
  shutil.rmtree(user_dir)
68
 
69
  @spaces.GPU
70
- def preprocess_image(image: Image.Image,
71
- prompt: str = "",
72
- negative_prompt: str = "",
73
- style_name: str = "",
74
- num_steps: int = 25,
75
- guidance_scale: float = 5,
76
- controlnet_conditioning_scale: float = 1.0,
77
- ) -> Image.Image:
78
- width, height = image['composite'].size
 
 
 
 
 
79
  ratio = np.sqrt(1024. * 1024. / (width * height))
80
  new_width, new_height = int(width * ratio), int(height * ratio)
81
  image = image['composite'].resize((new_width, new_height))
82
  image = ImageOps.invert(image)
 
 
 
83
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
 
 
 
84
  output = pipe_control(
85
  prompt=prompt,
86
  negative_prompt=negative_prompt,
@@ -89,9 +88,15 @@ def preprocess_image(image: Image.Image,
89
  controlnet_conditioning_scale=controlnet_conditioning_scale,
90
  guidance_scale=guidance_scale,
91
  width=new_width,
92
- height=new_height).images[0]
 
 
 
 
 
93
  processed_image = pipeline.preprocess_image(output)
94
- return (image, processed_image)
 
95
 
96
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
97
  return {
@@ -108,7 +113,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
108
  'faces': mesh.faces.cpu().numpy(),
109
  },
110
  }
111
-
112
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
113
  gs = Gaussian(
114
  aabb=state['gaussian']['aabb'],
@@ -123,10 +128,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
123
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
124
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
125
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
126
  mesh = edict(
127
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
128
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
129
  )
 
130
  return gs, mesh
131
 
132
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -143,11 +150,10 @@ def image_to_3d(
143
  req: gr.Request,
144
  ) -> Tuple[dict, str]:
145
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
146
- os.makedirs(user_dir, exist_ok=True)
147
  outputs = pipeline.run(
148
  image[1],
149
  seed=seed,
150
- formats=["mesh"],
151
  preprocess_image=False,
152
  sparse_structure_sampler_params={
153
  "steps": ss_sampling_steps,
@@ -158,7 +164,9 @@ def image_to_3d(
158
  "cfg_strength": slat_guidance_strength,
159
  },
160
  )
161
- video = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
 
 
162
  video_path = os.path.join(user_dir, 'sample.mp4')
163
  imageio.mimsave(video_path, video, fps=15)
164
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
@@ -182,14 +190,17 @@ def extract_glb(
182
 
183
  def reset_do_preprocess():
184
  return True
 
 
 
 
 
 
 
 
 
185
 
186
  with gr.Blocks(delete_cache=(600, 600)) as demo:
187
- gr.Markdown("""
188
- ## Sketch to 3D with TRELLIS
189
- 1. Fast sketch to image with SDXL Flash, using [@xinsir](https://huggingface.co/xinsir) [scribble sdxl controlnet](https://huggingface.co/xinsir/controlnet-scribble-sdxl-1.0) and [sdxl flash](https://huggingface.co/sd-community/sdxl-flash)
190
- 2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
191
- ### ð ¨ð ï¸ draw or upload a sketch and click "Generate" to create a 3D asset Ò ¨
192
- """)
193
  with gr.Row():
194
  with gr.Column():
195
  with gr.Column():
@@ -200,9 +211,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
200
  with gr.Row():
201
  prompt = gr.Textbox(label="Prompt")
202
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
 
203
  with gr.Accordion(label="Generation Settings", open=False):
204
  with gr.Tab(label="sketch-to-image generation"):
205
  negative_prompt = gr.Textbox(label="Negative prompt")
 
206
  num_steps = gr.Slider(
207
  label="Number of steps",
208
  minimum=1,
@@ -232,32 +245,54 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
232
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
233
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
234
  gr.Markdown("Stage 2: Structured Latent Generation")
235
- with gr.Row():
236
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
237
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
238
  with gr.Accordion(label="GLB Extraction Settings", open=False):
239
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
240
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
241
  with gr.Row():
242
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
243
  gr.Markdown("""
244
- *NOTE: GLB file can be downloaded after extraction.*
245
  """)
 
246
  with gr.Column():
247
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
248
  image_prompt_processed = ImageSlider(label="processed sketch", interactive=False, type="pil", height=512)
249
- model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
 
250
  with gr.Row():
251
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
252
-
 
 
253
  output_buf = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  demo.load(start_session)
255
  demo.unload(end_session)
256
-
257
  image_prompt.clear(
258
  fn=reset_canvas,
259
  outputs = [image_prompt]
260
  )
 
261
  sketch_btn.click(
262
  get_seed,
263
  inputs=[randomize_seed, seed],
@@ -277,13 +312,15 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
277
  inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
278
  outputs=[output_buf, video_output],
279
  ).then(
280
- lambda: gr.Button(interactive=True),
281
- outputs=[extract_glb_btn],
282
  )
 
283
  video_output.clear(
284
- lambda: gr.Button(interactive=False),
285
- outputs=[extract_glb_btn],
286
  )
 
287
  extract_glb_btn.click(
288
  extract_glb,
289
  inputs=[output_buf, mesh_simplify, texture_size],
@@ -292,21 +329,35 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
292
  lambda: gr.Button(interactive=True),
293
  outputs=[download_glb],
294
  )
 
 
 
 
 
 
 
 
 
 
295
  model_output.clear(
296
  lambda: gr.Button(interactive=False),
297
  outputs=[download_glb],
298
  )
299
-
 
300
  if __name__ == "__main__":
301
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
302
  pipeline.cuda()
 
303
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
304
  #scribble controlnet
305
  controlnet = ControlNetModel.from_pretrained(
306
  "xinsir/controlnet-scribble-sdxl-1.0",
307
  torch_dtype=torch.float16
308
  )
309
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
310
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
311
  "sd-community/sdxl-flash",
312
  controlnet=controlnet,
@@ -315,8 +366,9 @@ if __name__ == "__main__":
315
  )
316
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
317
  pipe_control.to(device)
 
318
  try:
319
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
320
  except:
321
  pass
322
- demo.launch()
 
1
+ import gradio as gr
2
  import spaces
 
 
3
  import os
4
  import shutil
5
+ import random
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
 
9
  import numpy as np
 
10
  import imageio
11
  import cv2
12
+ import torchvision.transforms.functional as TF
13
+ from gradio_litmodel3d import LitModel3D
14
  from easydict import EasyDict as edict
15
  from PIL import Image, ImageOps
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
 
19
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
20
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
21
  from controlnet_aux import PidiNetDetector, HEDdetector
 
25
  from gradio_imageslider import ImageSlider
26
 
27
  style_list = [
 
 
 
 
 
 
 
 
 
 
28
  {
29
  "name": "3D Model",
30
  "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
31
  "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
32
  },
33
  ]
 
34
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
35
  STYLE_NAMES = list(styles.keys())
36
  DEFAULT_STYLE_NAME = "(No style)"
37
+
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
40
  os.makedirs(TMP_DIR, exist_ok=True)
 
49
  def start_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
51
  os.makedirs(user_dir, exist_ok=True)
52
+
53
  def end_session(req: gr.Request):
54
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
55
  shutil.rmtree(user_dir)
56
 
57
  @spaces.GPU
58
+ def preprocess_image(
59
+ image: Image.Image,
60
+ prompt: str = "",
61
+ negative_prompt: str = "",
62
+ style_name: str = "",
63
+ num_steps: int = 25,
64
+ guidance_scale: float = 5,
65
+ controlnet_conditioning_scale: float = 1.0,
66
+ req: gr.Request = None
67
+ ) -> Tuple[Image.Image, Image.Image]:
68
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
69
+ os.makedirs(user_dir, exist_ok=True)
70
+
71
+ width, height = image['composite'].size
72
  ratio = np.sqrt(1024. * 1024. / (width * height))
73
  new_width, new_height = int(width * ratio), int(height * ratio)
74
  image = image['composite'].resize((new_width, new_height))
75
  image = ImageOps.invert(image)
76
+
77
+ print("image:", type(image))
78
+
79
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
80
+
81
+ print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
82
+
83
  output = pipe_control(
84
  prompt=prompt,
85
  negative_prompt=negative_prompt,
 
88
  controlnet_conditioning_scale=controlnet_conditioning_scale,
89
  guidance_scale=guidance_scale,
90
  width=new_width,
91
+ height=new_height
92
+ ).images[0]
93
+
94
+ processed_image_path = os.path.join(user_dir, 'processed_image.png')
95
+ output.save(processed_image_path)
96
+
97
  processed_image = pipeline.preprocess_image(output)
98
+
99
+ return image, processed_image
100
 
101
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
102
  return {
 
113
  'faces': mesh.faces.cpu().numpy(),
114
  },
115
  }
116
+
117
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
118
  gs = Gaussian(
119
  aabb=state['gaussian']['aabb'],
 
128
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
129
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
130
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
131
+
132
  mesh = edict(
133
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
134
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
135
  )
136
+
137
  return gs, mesh
138
 
139
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
150
  req: gr.Request,
151
  ) -> Tuple[dict, str]:
152
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
153
  outputs = pipeline.run(
154
  image[1],
155
  seed=seed,
156
+ formats=["gaussian", "mesh"],
157
  preprocess_image=False,
158
  sparse_structure_sampler_params={
159
  "steps": ss_sampling_steps,
 
164
  "cfg_strength": slat_guidance_strength,
165
  },
166
  )
167
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
168
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
169
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
170
  video_path = os.path.join(user_dir, 'sample.mp4')
171
  imageio.mimsave(video_path, video, fps=15)
172
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
190
 
191
  def reset_do_preprocess():
192
  return True
193
+
194
+ @spaces.GPU
195
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
196
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
197
+ gs, _ = unpack_state(state)
198
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
199
+ gs.save_ply(gaussian_path)
200
+ torch.cuda.empty_cache()
201
+ return gaussian_path, gaussian_path
202
 
203
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
 
 
 
 
 
204
  with gr.Row():
205
  with gr.Column():
206
  with gr.Column():
 
211
  with gr.Row():
212
  prompt = gr.Textbox(label="Prompt")
213
  style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
214
+
215
  with gr.Accordion(label="Generation Settings", open=False):
216
  with gr.Tab(label="sketch-to-image generation"):
217
  negative_prompt = gr.Textbox(label="Negative prompt")
218
+
219
  num_steps = gr.Slider(
220
  label="Number of steps",
221
  minimum=1,
 
245
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
246
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
247
  gr.Markdown("Stage 2: Structured Latent Generation")
248
+ with gr.Row():
249
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
250
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
251
+
252
  with gr.Accordion(label="GLB Extraction Settings", open=False):
253
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
254
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
255
+
256
  with gr.Row():
257
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
258
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
259
  gr.Markdown("""
260
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
261
  """)
262
+
263
  with gr.Column():
264
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
265
  image_prompt_processed = ImageSlider(label="processed sketch", interactive=False, type="pil", height=512)
266
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
267
+
268
  with gr.Row():
269
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
270
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
271
+
272
+ do_preprocess = gr.State(True)
273
  output_buf = gr.State()
274
+
275
+ with gr.Row(visible=False) as single_image_example:
276
+ examples = gr.Examples(
277
+ examples=[
278
+ f'assets/example_image/{image}'
279
+ for image in os.listdir("assets/example_image")
280
+ ],
281
+ inputs=[image_prompt],
282
+ fn=preprocess_image,
283
+ outputs=[image_prompt_processed],
284
+ run_on_click=True,
285
+ examples_per_page=64,
286
+ )
287
+
288
  demo.load(start_session)
289
  demo.unload(end_session)
290
+
291
  image_prompt.clear(
292
  fn=reset_canvas,
293
  outputs = [image_prompt]
294
  )
295
+
296
  sketch_btn.click(
297
  get_seed,
298
  inputs=[randomize_seed, seed],
 
312
  inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
313
  outputs=[output_buf, video_output],
314
  ).then(
315
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
316
+ outputs=[extract_glb_btn, extract_gs_btn],
317
  )
318
+
319
  video_output.clear(
320
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
321
+ outputs=[extract_glb_btn, extract_gs_btn],
322
  )
323
+
324
  extract_glb_btn.click(
325
  extract_glb,
326
  inputs=[output_buf, mesh_simplify, texture_size],
 
329
  lambda: gr.Button(interactive=True),
330
  outputs=[download_glb],
331
  )
332
+
333
+ extract_gs_btn.click(
334
+ extract_gaussian,
335
+ inputs=[output_buf],
336
+ outputs=[model_output, download_gs],
337
+ ).then(
338
+ lambda: gr.Button(interactive=True),
339
+ outputs=[download_gs],
340
+ )
341
+
342
  model_output.clear(
343
  lambda: gr.Button(interactive=False),
344
  outputs=[download_glb],
345
  )
346
+
347
+ # Launch the Gradio app
348
  if __name__ == "__main__":
349
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
350
  pipeline.cuda()
351
+
352
  device = "cuda" if torch.cuda.is_available() else "cpu"
353
+
354
  #scribble controlnet
355
  controlnet = ControlNetModel.from_pretrained(
356
  "xinsir/controlnet-scribble-sdxl-1.0",
357
  torch_dtype=torch.float16
358
  )
359
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
360
+
361
  pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
362
  "sd-community/sdxl-flash",
363
  controlnet=controlnet,
 
366
  )
367
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
368
  pipe_control.to(device)
369
+
370
  try:
371
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
372
  except:
373
  pass
374
+ demo.launch(show_error=True)