cavargas10 commited on
Commit
a65fb48
Β·
verified Β·
1 Parent(s): ea3c5de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -51
app.py CHANGED
@@ -1,35 +1,60 @@
1
- import gradio as gr
2
  import spaces
 
 
3
  import os
4
  import shutil
5
- import random
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
9
  import numpy as np
10
  import imageio
11
- import cv2
12
- import torchvision.transforms.functional as TF
13
- from gradio_litmodel3d import LitModel3D
14
  from easydict import EasyDict as edict
15
  from PIL import Image, ImageOps
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
 
 
 
19
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
20
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
21
  from controlnet_aux import PidiNetDetector, HEDdetector
22
  from diffusers.utils import load_image
23
  from huggingface_hub import HfApi
24
  from pathlib import Path
 
 
 
 
 
 
25
  from gradio_imageslider import ImageSlider
26
 
27
  style_list = [
 
 
 
 
 
 
 
 
 
 
28
  {
29
  "name": "3D Model",
30
  "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
31
  "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
32
  },
 
 
 
 
 
33
  ]
34
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
35
  STYLE_NAMES = list(styles.keys())
@@ -50,36 +75,31 @@ def start_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
51
  os.makedirs(user_dir, exist_ok=True)
52
 
 
53
  def end_session(req: gr.Request):
54
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
55
  shutil.rmtree(user_dir)
56
 
57
  @spaces.GPU
58
- def preprocess_image(
59
- image: Image.Image,
60
- prompt: str = "",
61
- negative_prompt: str = "",
62
- style_name: str = "",
63
- num_steps: int = 25,
64
- guidance_scale: float = 5,
65
- controlnet_conditioning_scale: float = 1.0,
66
- req: gr.Request = None
67
- ) -> Tuple[Image.Image, Image.Image]:
68
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
69
- os.makedirs(user_dir, exist_ok=True)
70
-
71
- width, height = image['composite'].size
72
  ratio = np.sqrt(1024. * 1024. / (width * height))
73
  new_width, new_height = int(width * ratio), int(height * ratio)
74
  image = image['composite'].resize((new_width, new_height))
75
  image = ImageOps.invert(image)
76
 
77
- print("image:", type(image))
78
 
79
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
80
-
81
  print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
82
-
83
  output = pipe_control(
84
  prompt=prompt,
85
  negative_prompt=negative_prompt,
@@ -88,15 +108,15 @@ def preprocess_image(
88
  controlnet_conditioning_scale=controlnet_conditioning_scale,
89
  guidance_scale=guidance_scale,
90
  width=new_width,
91
- height=new_height
92
- ).images[0]
93
-
94
- processed_image_path = os.path.join(user_dir, 'processed_image.png')
95
- output.save(processed_image_path)
96
-
97
  processed_image = pipeline.preprocess_image(output)
 
98
 
99
- return image, processed_image
 
 
 
100
 
101
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
102
  return {
@@ -142,28 +162,48 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
142
  @spaces.GPU
143
  def image_to_3d(
144
  image: Image.Image,
 
 
145
  seed: int,
146
  ss_guidance_strength: float,
147
  ss_sampling_steps: int,
148
  slat_guidance_strength: float,
149
  slat_sampling_steps: int,
 
150
  req: gr.Request,
151
  ) -> Tuple[dict, str]:
152
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
153
- outputs = pipeline.run(
154
- image[1],
155
- seed=seed,
156
- formats=["gaussian", "mesh"],
157
- preprocess_image=False,
158
- sparse_structure_sampler_params={
159
- "steps": ss_sampling_steps,
160
- "cfg_strength": ss_guidance_strength,
161
- },
162
- slat_sampler_params={
163
- "steps": slat_sampling_steps,
164
- "cfg_strength": slat_guidance_strength,
165
- },
166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
168
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
169
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -200,7 +240,38 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
200
  torch.cuda.empty_cache()
201
  return gaussian_path, gaussian_path
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
 
 
 
 
 
 
204
  with gr.Row():
205
  with gr.Column():
206
  with gr.Column():
@@ -245,10 +316,14 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
245
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
246
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
247
  gr.Markdown("Stage 2: Structured Latent Generation")
248
- with gr.Row():
249
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
250
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
251
-
 
 
 
 
252
  with gr.Accordion(label="GLB Extraction Settings", open=False):
253
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
254
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
@@ -269,6 +344,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
269
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
270
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
271
 
 
272
  do_preprocess = gr.State(True)
273
  output_buf = gr.State()
274
 
@@ -284,10 +360,23 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
284
  run_on_click=True,
285
  examples_per_page=64,
286
  )
 
 
 
 
 
 
 
 
 
287
 
288
  demo.load(start_session)
289
  demo.unload(end_session)
290
 
 
 
 
 
291
  image_prompt.clear(
292
  fn=reset_canvas,
293
  outputs = [image_prompt]
@@ -302,6 +391,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
302
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
303
  outputs=[image_prompt_processed],
304
  )
 
 
 
 
 
305
 
306
  generate_btn.click(
307
  get_seed,
@@ -309,7 +403,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
309
  outputs=[seed],
310
  ).then(
311
  image_to_3d,
312
- inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
313
  outputs=[output_buf, video_output],
314
  ).then(
315
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
@@ -343,15 +437,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
343
  lambda: gr.Button(interactive=False),
344
  outputs=[download_glb],
345
  )
346
-
347
- # Launch the Gradio app
348
  if __name__ == "__main__":
349
- pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
350
  pipeline.cuda()
351
 
352
  device = "cuda" if torch.cuda.is_available() else "cpu"
353
 
354
- #scribble controlnet
355
  controlnet = ControlNetModel.from_pretrained(
356
  "xinsir/controlnet-scribble-sdxl-1.0",
357
  torch_dtype=torch.float16
@@ -371,4 +463,4 @@ if __name__ == "__main__":
371
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
372
  except:
373
  pass
374
- demo.launch(show_error=True)
 
1
+ import gradio as gr
2
  import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
  import os
6
  import shutil
 
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
9
  import torch
10
  import numpy as np
11
  import imageio
 
 
 
12
  from easydict import EasyDict as edict
13
  from PIL import Image, ImageOps
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
+
18
+ import os
19
+ import random
20
+ import torch
21
+ import torchvision.transforms.functional as TF
22
+
23
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
24
  from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
25
  from controlnet_aux import PidiNetDetector, HEDdetector
26
  from diffusers.utils import load_image
27
  from huggingface_hub import HfApi
28
  from pathlib import Path
29
+ from PIL import Image, ImageOps
30
+ import torch
31
+ import numpy as np
32
+ import cv2
33
+ import os
34
+ import random
35
  from gradio_imageslider import ImageSlider
36
 
37
  style_list = [
38
+ {
39
+ "name": "(No style)",
40
+ "prompt": "{prompt}",
41
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
42
+ },
43
+ {
44
+ "name": "Cinematic",
45
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
46
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
47
+ },
48
  {
49
  "name": "3D Model",
50
  "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
51
  "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
52
  },
53
+ {
54
+ "name": "Anime",
55
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
56
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
57
+ },
58
  ]
59
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
60
  STYLE_NAMES = list(styles.keys())
 
75
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
76
  os.makedirs(user_dir, exist_ok=True)
77
 
78
+
79
  def end_session(req: gr.Request):
80
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
81
  shutil.rmtree(user_dir)
82
 
83
  @spaces.GPU
84
+ def preprocess_image(image: Image.Image,
85
+ prompt: str = "",
86
+ negative_prompt: str = "",
87
+ style_name: str = "",
88
+ num_steps: int = 25,
89
+ guidance_scale: float = 5,
90
+ controlnet_conditioning_scale: float = 1.0,
91
+ ) -> Image.Image:
92
+ width, height = image['composite'].size
 
 
 
 
 
93
  ratio = np.sqrt(1024. * 1024. / (width * height))
94
  new_width, new_height = int(width * ratio), int(height * ratio)
95
  image = image['composite'].resize((new_width, new_height))
96
  image = ImageOps.invert(image)
97
 
98
+ print("image:",type(image))
99
 
100
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
101
+
102
  print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
 
103
  output = pipe_control(
104
  prompt=prompt,
105
  negative_prompt=negative_prompt,
 
108
  controlnet_conditioning_scale=controlnet_conditioning_scale,
109
  guidance_scale=guidance_scale,
110
  width=new_width,
111
+ height=new_height).images[0]
112
+
 
 
 
 
113
  processed_image = pipeline.preprocess_image(output)
114
+ return (image, processed_image)
115
 
116
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
117
+ images = [image[0] for image in images]
118
+ processed_images = [pipeline.preprocess_image(image) for image in images]
119
+ return processed_images
120
 
121
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
122
  return {
 
162
  @spaces.GPU
163
  def image_to_3d(
164
  image: Image.Image,
165
+ multiimages: List[Tuple[Image.Image, str]],
166
+ is_multiimage: bool,
167
  seed: int,
168
  ss_guidance_strength: float,
169
  ss_sampling_steps: int,
170
  slat_guidance_strength: float,
171
  slat_sampling_steps: int,
172
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
173
  req: gr.Request,
174
  ) -> Tuple[dict, str]:
175
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
176
+ if not is_multiimage:
177
+ outputs = pipeline.run(
178
+ image[1],
179
+ seed=seed,
180
+ formats=["gaussian", "mesh"],
181
+ preprocess_image=False,
182
+ sparse_structure_sampler_params={
183
+ "steps": ss_sampling_steps,
184
+ "cfg_strength": ss_guidance_strength,
185
+ },
186
+ slat_sampler_params={
187
+ "steps": slat_sampling_steps,
188
+ "cfg_strength": slat_guidance_strength,
189
+ },
190
+ )
191
+ else:
192
+ outputs = pipeline.run_multi_image(
193
+ [image[0] for image in multiimages],
194
+ seed=seed,
195
+ formats=["gaussian", "mesh"],
196
+ preprocess_image=False,
197
+ sparse_structure_sampler_params={
198
+ "steps": ss_sampling_steps,
199
+ "cfg_strength": ss_guidance_strength,
200
+ },
201
+ slat_sampler_params={
202
+ "steps": slat_sampling_steps,
203
+ "cfg_strength": slat_guidance_strength,
204
+ },
205
+ mode=multiimage_algo,
206
+ )
207
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
208
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
209
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
240
  torch.cuda.empty_cache()
241
  return gaussian_path, gaussian_path
242
 
243
+ def prepare_multi_example() -> List[Image.Image]:
244
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
245
+ images = []
246
+ for case in multi_case:
247
+ _images = []
248
+ for i in range(1, 4):
249
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
250
+ W, H = img.size
251
+ img = img.resize((int(W / H * 512), 512))
252
+ _images.append(np.array(img))
253
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
254
+ return images
255
+
256
+ def split_image(image: Image.Image) -> List[Image.Image]:
257
+ image = np.array(image)
258
+ alpha = image[..., 3]
259
+ alpha = np.any(alpha>0, axis=0)
260
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
261
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
262
+ images = []
263
+ for s, e in zip(start_pos, end_pos):
264
+ images.append(Image.fromarray(image[:, s:e+1]))
265
+ return [preprocess_image(image) for image in images]
266
+
267
  with gr.Blocks(delete_cache=(600, 600)) as demo:
268
+ gr.Markdown("""
269
+ ## Sketch to 3D with TRELLIS
270
+ 1. Fast sketch to image with SDXL Flash, using [@xinsir](https://huggingface.co/xinsir) [scribble sdxl controlnet](https://huggingface.co/xinsir/controlnet-scribble-sdxl-1.0) and [sdxl flash](https://huggingface.co/sd-community/sdxl-flash)
271
+ 2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
272
+ ### πŸŽ¨πŸ–ŒοΈ draw or upload a sketch and click "Generate" to create a 3D asset ✨
273
+
274
+ """)
275
  with gr.Row():
276
  with gr.Column():
277
  with gr.Column():
 
316
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
317
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
318
  gr.Markdown("Stage 2: Structured Latent Generation")
319
+ with gr.Row():
320
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
321
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
322
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
323
+
324
+ with gr.Tab(label="Multiple Images", id=1, visible=False) as multiimage_input_tab:
325
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
326
+
327
  with gr.Accordion(label="GLB Extraction Settings", open=False):
328
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
329
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
344
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
345
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
346
 
347
+ is_multiimage = gr.State(False)
348
  do_preprocess = gr.State(True)
349
  output_buf = gr.State()
350
 
 
360
  run_on_click=True,
361
  examples_per_page=64,
362
  )
363
+ with gr.Row(visible=False) as multiimage_example:
364
+ examples_multi = gr.Examples(
365
+ examples=prepare_multi_example(),
366
+ inputs=[image_prompt],
367
+ fn=split_image,
368
+ outputs=[multiimage_prompt],
369
+ run_on_click=True,
370
+ examples_per_page=8,
371
+ )
372
 
373
  demo.load(start_session)
374
  demo.unload(end_session)
375
 
376
+ multiimage_input_tab.select(
377
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
378
+ outputs=[is_multiimage, single_image_example, multiimage_example]
379
+ )
380
  image_prompt.clear(
381
  fn=reset_canvas,
382
  outputs = [image_prompt]
 
391
  inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
392
  outputs=[image_prompt_processed],
393
  )
394
+ multiimage_prompt.upload(
395
+ preprocess_images,
396
+ inputs=[multiimage_prompt],
397
+ outputs=[multiimage_prompt],
398
+ )
399
 
400
  generate_btn.click(
401
  get_seed,
 
403
  outputs=[seed],
404
  ).then(
405
  image_to_3d,
406
+ inputs=[image_prompt_processed, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
407
  outputs=[output_buf, video_output],
408
  ).then(
409
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
 
437
  lambda: gr.Button(interactive=False),
438
  outputs=[download_glb],
439
  )
440
+
 
441
  if __name__ == "__main__":
442
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
443
  pipeline.cuda()
444
 
445
  device = "cuda" if torch.cuda.is_available() else "cpu"
446
 
 
447
  controlnet = ControlNetModel.from_pretrained(
448
  "xinsir/controlnet-scribble-sdxl-1.0",
449
  torch_dtype=torch.float16
 
463
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
464
  except:
465
  pass
466
+ demo.launch()