cavargas10 commited on
Commit
945c4e6
·
verified ·
1 Parent(s): 6b439c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -114
app.py CHANGED
@@ -3,9 +3,6 @@ import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
- os.environ['SPCONV_ALGO'] = 'native'
7
- from typing import *
8
- import torch
9
  import numpy as np
10
  import imageio
11
  from easydict import EasyDict as edict
@@ -13,59 +10,88 @@ from PIL import Image, ImageOps
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
 
16
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
17
- from diffusers import EulerAncestralDiscreteScheduler
18
- from pathlib import Path
 
 
 
 
19
 
 
20
  style_list = [
21
- {"name": "(No style)", "prompt": "{prompt}", "negative_prompt": ""},
22
- {"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"},
23
- {"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting"},
 
 
 
 
 
 
 
 
24
  ]
25
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
26
  STYLE_NAMES = list(styles.keys())
27
  DEFAULT_STYLE_NAME = "(No style)"
28
- MAX_SEED = np.iinfo(np.int32).max
29
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
- os.makedirs(TMP_DIR, exist_ok=True)
31
 
32
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
33
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
34
- return p.replace("{prompt}", positive), n + negative
35
 
36
  def start_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  os.makedirs(user_dir, exist_ok=True)
39
-
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
  shutil.rmtree(user_dir)
43
 
44
- @spaces.GPU
45
- def preprocess_image(image: Image.Image,
46
- prompt: str,
47
- negative_prompt: str,
48
- style_name: str,
49
- num_steps: int,
50
- guidance_scale: float,
51
- controlnet_conditioning_scale: float) -> Image.Image:
52
- width, height = image.size
53
- ratio = np.sqrt(1024 * 1024 / (width * height))
54
- new_size = (int(width * ratio), int(height * ratio))
55
- image = image.resize(new_size)
56
-
57
- prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- output = pipe_control(
 
60
  prompt=prompt,
61
  negative_prompt=negative_prompt,
62
- image=image,
63
  num_inference_steps=num_steps,
64
  controlnet_conditioning_scale=controlnet_conditioning_scale,
65
  guidance_scale=guidance_scale,
 
 
66
  ).images[0]
67
 
68
- return pipeline.preprocess_image(output)
69
 
70
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
71
  return {
@@ -83,7 +109,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
83
  },
84
  }
85
 
86
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
87
  gs = Gaussian(
88
  aabb=state['gaussian']['aabb'],
89
  sh_degree=state['gaussian']['sh_degree'],
@@ -105,6 +131,9 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
105
 
106
  return gs, mesh
107
 
 
 
 
108
  @spaces.GPU
109
  def image_to_3d(
110
  image: Image.Image,
@@ -113,9 +142,10 @@ def image_to_3d(
113
  ss_sampling_steps: int,
114
  slat_guidance_strength: float,
115
  slat_sampling_steps: int,
116
- req: gr.Request,
117
  ) -> Tuple[dict, str]:
118
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
119
  outputs = pipeline.run(
120
  image,
121
  seed=seed,
@@ -129,14 +159,18 @@ def image_to_3d(
129
  "cfg_strength": slat_guidance_strength,
130
  },
131
  )
132
-
 
133
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
134
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
135
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
136
  video_path = os.path.join(user_dir, 'sample.mp4')
137
- imageio.mimsave(video_path, video, fps=15)
 
 
138
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
139
  torch.cuda.empty_cache()
 
140
  return state, video_path
141
 
142
  @spaces.GPU(duration=90)
@@ -144,103 +178,134 @@ def extract_glb(
144
  state: dict,
145
  mesh_simplify: float,
146
  texture_size: int,
147
- req: gr.Request,
148
- ) -> Tuple[str, str]:
149
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
150
  gs, mesh = unpack_state(state)
151
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
152
- glb_path = os.path.join(user_dir, 'sample.glb')
 
 
 
 
 
 
153
  glb.export(glb_path)
154
  torch.cuda.empty_cache()
155
- return glb_path, glb_path
156
 
157
  with gr.Blocks() as demo:
158
- gr.Markdown("# Sketch to 3D with TRELLIS")
 
 
 
 
159
  with gr.Row():
160
- with gr.Column():
161
- image_prompt = gr.Image(label="Sketch Input", type="pil", image_mode="RGBA", height=512)
162
- prompt = gr.Textbox(label="Prompt", placeholder="Describe tu modelo 3D")
163
- style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
 
 
 
 
 
164
 
165
- with gr.Accordion("Generation Settings", open=False):
166
- num_steps = gr.Slider(1, 20, label="Steps", value=8, step=1)
167
- guidance_scale = gr.Slider(0.1, 10.0, label="Guidance Scale", value=5.0, step=0.1)
168
- controlnet_scale = gr.Slider(0.5, 5.0, label="ControlNet Scale", value=0.85, step=0.01)
169
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
170
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
171
 
172
- with gr.Group():
173
- gr.Markdown("#### Stage 1: Structure")
174
- ss_guidance = gr.Slider(0.0, 10.0, label="Guidance", value=7.5, step=0.1)
175
- ss_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
176
 
177
- with gr.Group():
178
- gr.Markdown("#### Stage 2: Detail")
179
- slat_guidance = gr.Slider(0.0, 10.0, label="Guidance", value=3.0, step=0.1)
180
- slat_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
181
-
182
- generate_btn = gr.Button("Generate 3D Model", variant="primary")
183
-
184
- with gr.Accordion("Export Settings", open=False):
185
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify Mesh", value=0.95, step=0.01)
186
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
187
-
188
- extract_btn = gr.Button("Export GLB", interactive=False)
189
-
190
- with gr.Column():
191
- video_output = gr.Video(label="3D Preview", autoplay=True, loop=True, height=300)
192
- model_viewer = LitModel3D(label="3D Model Viewer", height=400)
193
- download_btn = gr.DownloadButton("Download GLB", interactive=False)
194
 
195
- output_state = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- demo.load(start_session)
198
- demo.unload(end_session)
199
-
 
 
200
  generate_btn.click(
201
- lambda rand, s: np.random.randint(0, MAX_SEED) if rand else s,
 
 
 
202
  inputs=[randomize_seed, seed],
203
- outputs=[seed],
204
  ).then(
205
- preprocess_image,
206
- inputs=[image_prompt, prompt, gr.Textbox(), style, num_steps, guidance_scale, controlnet_scale],
207
- outputs=[image_prompt],
 
 
 
 
 
 
 
208
  ).then(
209
- image_to_3d,
210
- inputs=[image_prompt, seed, ss_guidance, ss_steps, slat_guidance, slat_steps],
211
- outputs=[output_state, video_output],
 
 
 
 
 
 
 
212
  ).then(
213
- lambda: gr.Button(interactive=True),
214
- outputs=[extract_btn],
 
 
 
 
 
215
  )
216
 
217
- extract_btn.click(
218
- extract_glb,
219
- inputs=[output_buf, mesh_simplify, texture_size],
220
- outputs=[model_output, download_glb],
221
- ).then(
222
- lambda: gr.Button(interactive=True),
223
- outputs=[download_glb],
224
  )
225
 
226
  if __name__ == "__main__":
227
- # TRELLIS pipeline
228
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
229
- pipeline.cuda()
230
-
231
- # ControlNet y SDXL
232
- controlnet = ControlNetModel.from_pretrained(
233
- "xinsir/controlnet-scribble-sdxl-1.0",
234
- torch_dtype=torch.float16
235
- )
236
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
237
- pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
238
- "sd-community/sdxl-flash",
239
- controlnet=controlnet,
240
- vae=vae,
241
- torch_dtype=torch.float16,
242
- )
243
- pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
244
- pipe_control.to("cuda")
245
-
246
- demo.launch()
 
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
 
 
 
6
  import numpy as np
7
  import imageio
8
  from easydict import EasyDict as edict
 
10
  from trellis.pipelines import TrellisImageTo3DPipeline
11
  from trellis.representations import Gaussian, MeshExtractResult
12
  from trellis.utils import render_utils, postprocessing_utils
13
+ import torch
14
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
15
+ from controlnet_aux import PidiNetDetector, HEDdetector
16
+
17
+ os.environ['SPCONV_ALGO'] = 'native'
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
20
+ os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
+ # Configuración de estilos
23
  style_list = [
24
+ {
25
+ "name": "(No style)",
26
+ "prompt": "{prompt}",
27
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
28
+ },
29
+ {
30
+ "name": "3D Model",
31
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
32
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
33
+ },
34
+ # ... (otros estilos)
35
  ]
36
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
37
  STYLE_NAMES = list(styles.keys())
38
  DEFAULT_STYLE_NAME = "(No style)"
 
 
 
39
 
40
+ def apply_style(style_name: str, prompt: str, negative: str = "") -> tuple:
41
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
42
+ return p.replace("{prompt}", prompt), n + negative
43
 
44
  def start_session(req: gr.Request):
45
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
46
  os.makedirs(user_dir, exist_ok=True)
47
+
48
  def end_session(req: gr.Request):
49
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
50
  shutil.rmtree(user_dir)
51
 
52
+ # Inicialización de ControlNet
53
+ controlnet = ControlNetModel.from_pretrained(
54
+ "xinsir/controlnet-scribble-sdxl-1.0",
55
+ torch_dtype=torch.float16
56
+ )
57
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
58
+ pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
59
+ "sd-community/sdxl-flash",
60
+ controlnet=controlnet,
61
+ vae=vae,
62
+ torch_dtype=torch.float16,
63
+ )
64
+ pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
65
+ pipe_control.to("cuda")
66
+
67
+ # Inicialización de TRELLIS
68
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
69
+ pipeline.cuda()
70
+
71
+ def preprocess_image(
72
+ image: Image.Image,
73
+ prompt: str,
74
+ style_name: str,
75
+ num_steps: int = 20,
76
+ guidance_scale: float = 5,
77
+ controlnet_conditioning_scale: float = 0.85
78
+ ) -> Image.Image:
79
+ # Aplicar estilo
80
+ prompt, negative_prompt = apply_style(style_name, prompt)
81
 
82
+ # Procesar con ControlNet
83
+ processed_image = pipe_control(
84
  prompt=prompt,
85
  negative_prompt=negative_prompt,
86
+ image=image.convert("RGB"),
87
  num_inference_steps=num_steps,
88
  controlnet_conditioning_scale=controlnet_conditioning_scale,
89
  guidance_scale=guidance_scale,
90
+ width=512,
91
+ height=512
92
  ).images[0]
93
 
94
+ return processed_image
95
 
96
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
97
  return {
 
109
  },
110
  }
111
 
112
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
113
  gs = Gaussian(
114
  aabb=state['gaussian']['aabb'],
115
  sh_degree=state['gaussian']['sh_degree'],
 
131
 
132
  return gs, mesh
133
 
134
+ def get_seed(randomize_seed: bool, seed: int) -> int:
135
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
136
+
137
  @spaces.GPU
138
  def image_to_3d(
139
  image: Image.Image,
 
142
  ss_sampling_steps: int,
143
  slat_guidance_strength: float,
144
  slat_sampling_steps: int,
145
+ req: gr.Request
146
  ) -> Tuple[dict, str]:
147
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
148
+
149
  outputs = pipeline.run(
150
  image,
151
  seed=seed,
 
159
  "cfg_strength": slat_guidance_strength,
160
  },
161
  )
162
+
163
+ # Renderizar video
164
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
165
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
166
+ combined_video = [np.concatenate([frame, geo], axis=1) for frame, geo in zip(video, video_geo)]
167
  video_path = os.path.join(user_dir, 'sample.mp4')
168
+ imageio.mimsave(video_path, combined_video, fps=15)
169
+
170
+ # Empaquetar estado
171
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
172
  torch.cuda.empty_cache()
173
+
174
  return state, video_path
175
 
176
  @spaces.GPU(duration=90)
 
178
  state: dict,
179
  mesh_simplify: float,
180
  texture_size: int,
181
+ req: gr.Request
182
+ ) -> str:
183
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
184
  gs, mesh = unpack_state(state)
185
+ glb = postprocessing_utils.to_glb(
186
+ gs,
187
+ mesh,
188
+ simplify=mesh_simplify,
189
+ texture_size=texture_size,
190
+ verbose=False
191
+ )
192
+ glb_path = os.path.join(user_dir, 'model.glb')
193
  glb.export(glb_path)
194
  torch.cuda.empty_cache()
195
+ return glb_path
196
 
197
  with gr.Blocks() as demo:
198
+ gr.Markdown("""
199
+ # Conversor de Bocetos a 3D
200
+ ### Carga un boceto, ajusta parámetros y genera un modelo 3D
201
+ """)
202
+
203
  with gr.Row():
204
+ with gr.Column(scale=1):
205
+ # Entrada de boceto
206
+ image_prompt = gr.Image(
207
+ label="Boceto",
208
+ type="pil",
209
+ tool="sketch",
210
+ image_mode="RGBA",
211
+ height=512
212
+ )
213
 
214
+ # Parámetros
215
+ with gr.Accordion("Configuración", open=True):
216
+ prompt = gr.Textbox(label="Prompt", value="3D model")
217
+ style = gr.Dropdown(
218
+ label="Estilo",
219
+ choices=STYLE_NAMES,
220
+ value=DEFAULT_STYLE_NAME
221
+ )
222
 
223
+ with gr.Tab("ControlNet"):
224
+ num_steps = gr.Slider(5, 30, value=20, label="Pasos")
225
+ guidance_scale = gr.Slider(0.1, 10, value=5, label="Guidance Scale")
226
+ controlnet_scale = gr.Slider(0.5, 1.5, value=0.85, label="ControlNet Scale")
227
 
228
+ with gr.Tab("Generación 3D"):
229
+ seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
230
+ randomize_seed = gr.Checkbox(True, label="Randomizar Seed")
231
+
232
+ with gr.Group():
233
+ gr.Markdown("Estructura (Stage 1)")
234
+ ss_guidance = gr.Slider(0, 10, value=7.5, label="Guidance Strength")
235
+ ss_steps = gr.Slider(5, 20, value=12, label="Pasos")
236
+
237
+ with gr.Group():
238
+ gr.Markdown("Detalles (Stage 2)")
239
+ slat_guidance = gr.Slider(0, 10, value=3.0, label="Guidance Strength")
240
+ slat_steps = gr.Slider(5, 20, value=12, label="Pasos")
241
+
242
+ generate_btn = gr.Button("Generar 3D", variant="primary")
 
 
243
 
244
+ with gr.Column(scale=2):
245
+ video_output = gr.Video(
246
+ label="Vista 3D",
247
+ height=400,
248
+ interactive=False
249
+ )
250
+ model_output = LitModel3D(
251
+ label="Modelo 3D",
252
+ height=300,
253
+ exposure=10.0
254
+ )
255
+ download_btn = gr.Download(
256
+ label="Descargar GLB",
257
+ interactive=False
258
+ )
259
 
260
+ # Estado interno
261
+ output_buf = gr.State()
262
+ is_processing = gr.State(False)
263
+
264
+ # Eventos
265
  generate_btn.click(
266
+ fn=lambda: gr.update(interactive=False),
267
+ outputs=[generate_btn]
268
+ ).then(
269
+ fn=get_seed,
270
  inputs=[randomize_seed, seed],
271
+ outputs=[seed]
272
  ).then(
273
+ fn=preprocess_image,
274
+ inputs=[
275
+ image_prompt,
276
+ prompt,
277
+ style,
278
+ num_steps,
279
+ guidance_scale,
280
+ controlnet_scale
281
+ ],
282
+ outputs=image_prompt
283
  ).then(
284
+ fn=image_to_3d,
285
+ inputs=[
286
+ image_prompt,
287
+ seed,
288
+ ss_guidance,
289
+ ss_steps,
290
+ slat_guidance,
291
+ slat_steps
292
+ ],
293
+ outputs=[output_buf, video_output]
294
  ).then(
295
+ fn=lambda state: extract_glb(state, 0.95, 1024),
296
+ inputs=[output_buf],
297
+ outputs=download_btn,
298
+ show_progress=True
299
+ ).then(
300
+ fn=lambda: gr.update(interactive=True),
301
+ outputs=[generate_btn]
302
  )
303
 
304
+ # Eventos de limpieza
305
+ video_output.clear(
306
+ fn=lambda: gr.update(interactive=False),
307
+ outputs=[download_btn]
 
 
 
308
  )
309
 
310
  if __name__ == "__main__":
311
+ demo.queue(max_size=20).launch(share=True)