cavargas10 commited on
Commit
6980cac
Β·
verified Β·
1 Parent(s): ed6baf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -43
app.py CHANGED
@@ -3,19 +3,46 @@ import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
- import torch
7
  import numpy as np
 
8
  import imageio
9
  from easydict import EasyDict as edict
10
- from PIL import Image
11
  from trellis.pipelines import TrellisImageTo3DPipeline
12
  from trellis.representations import Gaussian, MeshExtractResult
13
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
17
  os.makedirs(TMP_DIR, exist_ok=True)
18
 
 
 
 
 
19
  def start_session(req: gr.Request):
20
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
21
  os.makedirs(user_dir, exist_ok=True)
@@ -24,8 +51,35 @@ def end_session(req: gr.Request):
24
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
25
  shutil.rmtree(user_dir, ignore_errors=True)
26
 
27
- def preprocess_image(image: Image.Image) -> Image.Image:
28
- return pipeline.preprocess_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
31
  return {
@@ -45,12 +99,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
45
 
46
  def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
47
  gs = Gaussian(**state['gaussian'])
48
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
49
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
50
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
51
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
52
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
53
-
54
  mesh = edict(
55
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
56
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
@@ -61,65 +109,169 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
61
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
62
 
63
  @spaces.GPU
64
- def image_to_3d(image: Image.Image, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, req: gr.Request) -> Tuple[dict, str]:
 
 
 
 
 
 
 
 
 
65
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
66
  outputs = pipeline.run(
67
  image,
68
  seed=seed,
69
  formats=["gaussian", "mesh"],
70
  preprocess_image=False,
71
- sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
72
- slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength},
 
 
 
 
 
 
73
  )
74
 
75
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
76
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
77
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
78
  video_path = os.path.join(user_dir, 'sample.mp4')
79
  imageio.mimsave(video_path, video, fps=15)
80
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
81
- torch.cuda.empty_cache()
82
- return state, video_path
83
 
84
  @spaces.GPU(duration=90)
85
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
86
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
87
  gs, mesh = unpack_state(state)
88
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
 
 
 
 
 
 
 
89
  glb_path = os.path.join(user_dir, 'sample.glb')
90
  glb.export(glb_path)
91
- torch.cuda.empty_cache()
92
- return glb_path, glb_path
93
 
94
  with gr.Blocks() as demo:
95
- gr.Markdown("""
96
- # ConversiΓ³n de ImΓ‘gen a 3D
97
- """)
98
-
99
- image_prompt = gr.Image(label="Input Image", type="pil")
100
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0)
101
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
102
- generate_btn = gr.Button("Generate 3D Asset")
103
- video_output = gr.Video(label="3D Preview")
104
- model_output = LitModel3D(label="3D Model Viewer")
105
- extract_glb_btn = gr.Button("Export GLB")
106
- download_glb = gr.DownloadButton(label="Download GLB")
107
- output_buf = gr.State()
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  generate_btn.click(
110
- get_seed, inputs=[randomize_seed, seed], outputs=[seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  ).then(
112
- image_to_3d, inputs=[image_prompt, seed, 7.5, 12, 3.0, 12], outputs=[output_buf, video_output]
 
 
 
 
 
 
 
 
 
113
  ).then(
114
- lambda: gr.Button(interactive=True), outputs=[extract_glb_btn]
 
115
  )
116
-
117
- extract_glb_btn.click(
118
- extract_glb, inputs=[output_buf, 0.95, 1024], outputs=[model_output, download_glb]
 
 
119
  ).then(
120
- lambda: gr.Button(interactive=True), outputs=[download_glb]
 
121
  )
122
 
123
  if __name__ == "__main__":
124
- pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS").cuda()
125
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
 
6
  import numpy as np
7
+ import torch
8
  import imageio
9
  from easydict import EasyDict as edict
10
+ from PIL import Image, ImageOps
11
  from trellis.pipelines import TrellisImageTo3DPipeline
12
  from trellis.representations import Gaussian, MeshExtractResult
13
  from trellis.utils import render_utils, postprocessing_utils
14
+ from typing import List, Tuple, Literal
15
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
16
+ from diffusers import EulerAncestralDiscreteScheduler
17
+ from controlnet_aux import PidiNetDetector, HEDdetector
18
+
19
+ os.environ['SPCONV_ALGO'] = 'native'
20
+
21
+ style_list = [
22
+ {
23
+ "name": "(No style)",
24
+ "prompt": "{prompt}",
25
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
26
+ },
27
+ {
28
+ "name": "Cinematic",
29
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
30
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
31
+ },
32
+ # ... (otros estilos de la lista original)
33
+ ]
34
 
35
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
36
+ STYLE_NAMES = list(styles.keys())
37
+ DEFAULT_STYLE_NAME = "(No style)"
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
40
  os.makedirs(TMP_DIR, exist_ok=True)
41
 
42
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
43
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
44
+ return p.replace("{prompt}", positive), n + negative
45
+
46
  def start_session(req: gr.Request):
47
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
48
  os.makedirs(user_dir, exist_ok=True)
 
51
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
52
  shutil.rmtree(user_dir, ignore_errors=True)
53
 
54
+ @spaces.GPU
55
+ def preprocess_image(image: Image.Image,
56
+ prompt: str = "",
57
+ negative_prompt: str = "",
58
+ style_name: str = DEFAULT_STYLE_NAME,
59
+ num_steps: int = 25,
60
+ guidance_scale: float = 5,
61
+ controlnet_conditioning_scale: float = 1.0) -> Image.Image:
62
+
63
+ width, height = image.size
64
+ ratio = np.sqrt(1024 * 1024 / (width * height))
65
+ new_size = (int(width * ratio), int(height * ratio))
66
+ image = image.resize(new_size)
67
+ image = ImageOps.invert(image.convert("L")).convert("RGB")
68
+
69
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
70
+
71
+ output = pipe_control(
72
+ prompt=prompt,
73
+ negative_prompt=negative_prompt,
74
+ image=image,
75
+ num_inference_steps=num_steps,
76
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
77
+ guidance_scale=guidance_scale,
78
+ width=new_size[0],
79
+ height=new_size[1]
80
+ ).images[0]
81
+
82
+ return pipeline.preprocess_image(output)
83
 
84
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
85
  return {
 
99
 
100
  def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
101
  gs = Gaussian(**state['gaussian'])
 
 
 
 
 
 
102
  mesh = edict(
103
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
104
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
 
109
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
110
 
111
  @spaces.GPU
112
+ def image_to_3d(
113
+ image: Image.Image,
114
+ seed: int,
115
+ ss_guidance_strength: float,
116
+ ss_sampling_steps: int,
117
+ slat_guidance_strength: float,
118
+ slat_sampling_steps: int,
119
+ req: gr.Request,
120
+ ) -> Tuple[dict, str]:
121
+
122
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
123
  outputs = pipeline.run(
124
  image,
125
  seed=seed,
126
  formats=["gaussian", "mesh"],
127
  preprocess_image=False,
128
+ sparse_structure_sampler_params={
129
+ "steps": ss_sampling_steps,
130
+ "cfg_strength": ss_guidance_strength,
131
+ },
132
+ slat_sampler_params={
133
+ "steps": slat_sampling_steps,
134
+ "cfg_strength": slat_guidance_strength,
135
+ },
136
  )
137
 
138
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
139
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
140
+ video = [np.concatenate([v, g], axis=1) for v, g in zip(video, video_geo)]
141
+
142
  video_path = os.path.join(user_dir, 'sample.mp4')
143
  imageio.mimsave(video_path, video, fps=15)
144
+
145
+ return pack_state(outputs['gaussian'][0], outputs['mesh'][0]), video_path
 
146
 
147
  @spaces.GPU(duration=90)
148
+ def extract_glb(
149
+ state: dict,
150
+ mesh_simplify: float,
151
+ texture_size: int,
152
+ req: gr.Request,
153
+ ) -> str:
154
+
155
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
156
  gs, mesh = unpack_state(state)
157
+
158
+ glb = postprocessing_utils.to_glb(
159
+ gs,
160
+ mesh,
161
+ simplify=mesh_simplify,
162
+ texture_size=texture_size,
163
+ verbose=False
164
+ )
165
+
166
  glb_path = os.path.join(user_dir, 'sample.glb')
167
  glb.export(glb_path)
168
+ return glb_path
 
169
 
170
  with gr.Blocks() as demo:
171
+ gr.Markdown("# Sketch-to-3D con TRELLIS")
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ with gr.Row():
174
+ with gr.Column():
175
+ image_prompt = gr.Image(
176
+ label="Boceto",
177
+ type="pil",
178
+ image_mode="RGBA",
179
+ height=512,
180
+ tool="sketch"
181
+ )
182
+
183
+ with gr.Accordion("Ajustes de GeneraciΓ³n", open=False):
184
+ prompt = gr.Textbox(label="Prompt")
185
+ style = gr.Dropdown(label="Estilo", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
186
+ negative_prompt = gr.Textbox(label="Negative Prompt")
187
+
188
+ with gr.Group():
189
+ gr.Markdown("#### Etapa 1: Estructura")
190
+ ss_guidance_strength = gr.Slider(0.0, 10.0, 7.5, step=0.1, label="Guidance")
191
+ ss_sampling_steps = gr.Slider(1, 50, 12, step=1, label="Pasos")
192
+
193
+ with gr.Group():
194
+ gr.Markdown("#### Etapa 2: Detalle")
195
+ slat_guidance_strength = gr.Slider(0.0, 10.0, 3.0, step=0.1, label="Guidance")
196
+ slat_sampling_steps = gr.Slider(1, 50, 12, step=1, label="Pasos")
197
+
198
+ generate_btn = gr.Button("Generar 3D", variant="primary")
199
+
200
+ with gr.Accordion("Exportar GLB", open=False):
201
+ mesh_simplify = gr.Slider(0.9, 0.98, 0.95, step=0.01, label="Simplificar")
202
+ texture_size = gr.Slider(512, 2048, 1024, step=512, label="TamaΓ±o Textura")
203
+ export_glb_btn = gr.Button("Exportar GLB", interactive=False)
204
+
205
+ with gr.Column():
206
+ video_output = gr.Video(label="Vista 3D", autoplay=True, loop=True, height=300)
207
+ model_viewer = LitModel3D(label="Visor 3D", height=400)
208
+ download_glb = gr.DownloadButton("Descargar GLB", interactive=False)
209
+
210
+ output_state = gr.State()
211
+
212
  generate_btn.click(
213
+ get_seed,
214
+ inputs=[gr.Checkbox(value=True, label="Semilla Aleatoria"), gr.Number(0, visible=False)],
215
+ outputs=[gr.Number(0, visible=False)]
216
+ ).then(
217
+ preprocess_image,
218
+ inputs=[
219
+ image_prompt,
220
+ prompt,
221
+ negative_prompt,
222
+ style,
223
+ gr.Slider(1, 20, 8, step=1, label="Pasos SDXL"),
224
+ gr.Slider(0.1, 10.0, 5.0, step=0.1, label="Guidance SDXL"),
225
+ gr.Slider(0.5, 5.0, 0.85, step=0.01, label="ControlNet Strength")
226
+ ],
227
+ outputs=[gr.Image(label="Imagen Procesada")]
228
  ).then(
229
+ image_to_3d,
230
+ inputs=[
231
+ gr.Image(visible=False),
232
+ gr.Number(0),
233
+ ss_guidance_strength,
234
+ ss_sampling_steps,
235
+ slat_guidance_strength,
236
+ slat_sampling_steps
237
+ ],
238
+ outputs=[output_state, video_output]
239
  ).then(
240
+ lambda: gr.update(interactive=True),
241
+ outputs=[export_glb_btn]
242
  )
243
+
244
+ export_glb_btn.click(
245
+ extract_glb,
246
+ inputs=[output_state, mesh_simplify, texture_size],
247
+ outputs=[model_viewer, download_glb]
248
  ).then(
249
+ lambda: gr.update(interactive=True),
250
+ outputs=[download_glb]
251
  )
252
 
253
  if __name__ == "__main__":
254
+ # Inicializar pipelines
255
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large").cuda()
256
+
257
+ # ControlNet SDXL
258
+ controlnet = ControlNetModel.from_pretrained(
259
+ "xinsir/controlnet-scribble-sdxl-1.0",
260
+ torch_dtype=torch.float16
261
+ ).cuda()
262
+
263
+ vae = AutoencoderKL.from_pretrained(
264
+ "madebyollin/sdxl-vae-fp16-fix",
265
+ torch_dtype=torch.float16
266
+ ).cuda()
267
+
268
+ pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
269
+ "sd-community/sdxl-flash",
270
+ controlnet=controlnet,
271
+ vae=vae,
272
+ torch_dtype=torch.float16
273
+ ).cuda()
274
+
275
+ pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
276
+
277
+ demo.launch()