import os
import shutil
import json
import numpy as np
import gradio as gr
import spaces
import time
from gradio_litmodel3d import LitModel3D
from pipelines.render import run_rendering
from pipelines.recon import run_recon
from pipelines.estimate import run_estimate
from pipelines.sds import run_sds
from typing import *
from omegaconf import OmegaConf
partnet_dir = f'datasets/PartNet'
real_world_dir = f'examples'
multi_joint_dir = f'datasets/multi_joint'
with open(f"configs/partnet.json") as f:
data_info = json.load(f)
cfg = OmegaConf.load('configs/default.yaml')
labels = {
'Cabinet (real-world)': 'cabinet',
'Cabinet2 (real-world)': 'cabinet2',
'Box (100247)': '100247',
'Dishwasher (12614)': '12614',
'Laptop (10270)': '10270',
'Lighter (100309)': '100309',
'Microwave (7320)': '7320',
'Oven (102001)': '102001',
'Refrigerator (11231)': '11231',
'Safe (102301)': '102301',
'Stapler (103111)': '103111',
'StorageFurniture (47183)': '47183',
'Table (20411)': '20411',
'WashingMachine (100283)': '100283',
}
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
return req.session_hash
def end_session(req: gr.Request):
pass
def get_available_ids():
"""Get list of available object IDs for the dropdown"""
# TODO: Support multi-joints
total_list = list(labels.keys())
return total_list
def load_renderings_with_id(id_input: str, session_hash: str) -> Tuple[List[str], str]:
selected_id = labels[id_input]
print(f'[Selected Object] {selected_id}')
cfg.train_num_state = 6
rendering_paths = []
if 'cabinet' in selected_id:
base_dir = real_world_dir
elif selected_id in os.listdir(partnet_dir):
base_dir = partnet_dir
else:
base_dir = multi_joint_dir
for i in range(cfg.train_num_state):
rendering_paths.append(f'{base_dir}/{selected_id}/{i:02d}_seg.png')
return rendering_paths, selected_id
def handle_mesh_upload(files):
"""Handle uploaded mesh files and create dropdown choices"""
if not files or len(files) == 0:
return gr.Dropdown(choices=[], value=None), None
choices = [f"Mesh State {i:02d} - {os.path.basename(file.name)}" for i, file in enumerate(files)]
# yield gr.update(), gr.update(None)
return gr.Dropdown(choices=choices, value=choices[0] if choices else None), files[0].name
def switch_mesh_view(selected_file, files):
"""Switch the 3D viewer to show the selected mesh file"""
if not files or not selected_file:
return None
# Extract the file index from the selection
try:
file_index = int(selected_file.split(" ")[2])
if 0 <= file_index < len(files):
return files[file_index].name
except:
# Fallback: try to match by basename
try:
selected_base = os.path.basename(selected_file)
for idx, f in enumerate(files):
if os.path.basename(getattr(f, 'name', '')) == selected_base:
return f.name
except Exception:
pass
return None
def process_and_update_gallery(mesh_files, session_hash):
mesh_paths = [f.name for f in mesh_files] if mesh_files else []
cfg.train_num_state = len(mesh_paths)
rendering_paths = run_rendering(mesh_paths, output_dir=f'{TMP_DIR}/{session_hash}')
return rendering_paths
def format_meta_info_text(info_dict):
"""Format joint meta info into 4 rich-text lines (larger font; bold title for Joint Axis)."""
def to_str(v):
if isinstance(v, np.ndarray):
v = v.tolist()
if isinstance(v, (list, dict, tuple)):
try:
return json.dumps(v)
except Exception:
return str(v)
return "" if v is None else str(v)
axis_val = info_dict.get("joint_axis", "")
pos_val = info_dict.get("joint_position", "")
scale_val = info_dict.get("joint_scale", "")
qpos_val = info_dict.get("joint_qpos", "")
axis_str = to_str(axis_val)
pos_str = to_str(pos_val)
scale_str = to_str(-scale_val)
qpos_str = to_str(qpos_val)
# Use inline HTML for size and emphasis; rendered by gr.Markdown
return (
"
"
f"
Joint Axis: {axis_str}
"
f"
Joint Pivot: {pos_str}
"
f"
Max Scale: {scale_str}
"
f"
qpos: {qpos_str}
"
"
"
)
def image_gallery_state_change(p_state):
base_dir = f'outputs/{p_state.get("sel_id")}_{p_state.get("session_hash")}'
if p_state.get("step") == "running":
while True:
with open(f'{base_dir}/state_list.json', 'r') as f:
state_list = json.load(f)
if state_list[0] == "done":
break
time.sleep(0.5)
return gr.update(), gr.update(), \
gr.update(), gr.update(), \
gr.update(), gr.update(), \
gr.update(), gr.update(), gr.update(), \
gr.update(), gr.update(), gr.update()
if p_state.get("step") == "done":
return p_state.get("gallery", gr.update()), p_state.get("sel_id", gr.update()), \
p_state.get("cached_path", gr.update())[0], \
p_state.get("cached_path", gr.update())[1], \
p_state.get("cached_path", gr.update())[2], \
gr.update(value=None), gr.update(value=None), \
gr.update(value=None), gr.update(value=None), \
gr.update(value=None), gr.update(value=None), gr.update(value=None)
return gr.update(), gr.update(), \
gr.update(), gr.update(), \
gr.update(), gr.update(), \
gr.update(), gr.update(), gr.update(), \
gr.update(), gr.update(), gr.update()
def recon_state_change(p_state):
base_dir = f'outputs/{p_state.get("sel_id")}_{p_state.get("session_hash")}'
if p_state.get("step") == "running":
while True:
with open(f'{base_dir}/state_list.json', 'r') as f:
state_list = json.load(f)
if state_list[1] == "done":
break
time.sleep(0.5)
return gr.update(), gr.update()
if p_state.get("step") == "done":
return p_state.get("vox", gr.update()), p_state.get("recon", gr.update())
return gr.update(), gr.update()
def estimate_state_change(p_state):
base_dir = f'outputs/{p_state.get("sel_id")}_{p_state.get("session_hash")}'
if p_state.get("step") == "running":
while True:
with open(f'{base_dir}/state_list.json', 'r') as f:
state_list = json.load(f)
if state_list[2] == "done":
break
time.sleep(0.5)
return gr.update(), gr.update()
if p_state.get("step") == "done":
return p_state.get("match", gr.update()), p_state.get("html", gr.update())
return gr.update(), gr.update()
def sds_state_change(p_state):
base_dir = f'outputs/{p_state.get("sel_id")}_{p_state.get("session_hash")}'
if p_state.get("step") == "running":
while True:
with open(f'{base_dir}/state_list.json', 'r') as f:
state_list = json.load(f)
if state_list[3] == "done":
break
time.sleep(0.5)
return gr.update(), gr.update(), gr.update()
if p_state.get("step") == "done":
return p_state.get("full", gr.update()), p_state.get("fixed", gr.update()), p_state.get("art", gr.update())
return gr.update(), gr.update(), gr.update()
def save_gallery_to_renderings(gallery_value, selected_id, session_hash):
"""Persist gallery images into outputs//renderings with SDS-expected names.
Ensures there are cfg.train_num_state frames by padding with the last image if needed.
Also writes the required 'rendering_pure_joint_00_state_{T-1}.png' using the last frame.
"""
try:
image_paths = []
if gallery_value:
for item in gallery_value:
path = item[0] if isinstance(item, (list, tuple)) and len(item) > 0 else item
if isinstance(path, str) and os.path.exists(path):
image_paths.append(path)
if len(image_paths) == 0:
return
cfg.train_num_state = len(image_paths)
base_dir = f'outputs/{selected_id}_{session_hash}'
if "uploaded" in base_dir:
os.system(f"rm -rf {base_dir}/*")
rendering_dir = f'{base_dir}/renderings'
os.makedirs(rendering_dir, exist_ok=True)
# Normalize count to T frames
for i in range(cfg.train_num_state):
src = image_paths[i] if i < len(image_paths) else image_paths[-1]
dst = f'{rendering_dir}/rendering_joint_00_state_{i:02d}.png'
shutil.copy(src, dst)
# Write the 'pure' image as the last uploaded (or last padded) frame
if os.path.exists(image_paths[-1].replace('_seg.png', '_pure.png')):
pure_src = image_paths[-1].replace('_seg.png', '_pure.png')
else:
pure_src = image_paths[-1]
pure_dst = f'{rendering_dir}/rendering_pure_joint_00_state_{(cfg.train_num_state - 1):02d}.png'
shutil.copy(pure_src, pure_dst)
except Exception as e:
print(f"[Gallery->Renderings] Failed to persist images: {e}")
return
@spaces.GPU(duration=200)
def run_recon_trellis(image_paths, output_dir):
return run_recon(image_paths, output_dir=output_dir, app=True)
def recon_trellis_meshes(input_image_gallery, output_dir):
# Carefully extract file paths from Gradio Gallery elements, which may be dicts, file objects, or strings.
image_paths = [img[0] for img in input_image_gallery]
glb_paths, rendering_dir, recon_voxel_paths = run_recon_trellis(image_paths, output_dir=output_dir)
recon_mesh_paths = run_rendering(glb_paths, rendering_dir, recon=True)
return recon_voxel_paths, recon_mesh_paths
@spaces.GPU(duration=120)
def estimate_initial_joints(input_image_gallery, joint_type, output_dir):
image_paths = [img[0] for img in input_image_gallery]
print(image_paths)
matching_examples, info_dict = run_estimate(image_paths, output_dir=output_dir, cfg=cfg, joint_type=joint_type)
return matching_examples, info_dict
@spaces.GPU(duration=800)
def articulated_generation_sds(selected_id, session_hash):
if 'cabinet' in selected_id:
input_dir = real_world_dir
elif selected_id in os.listdir(partnet_dir):
input_dir = partnet_dir
else:
input_dir = multi_joint_dir
base_dir = f'outputs/{selected_id}_{session_hash}'
rendering_dir = f'{base_dir}/renderings'
os.makedirs(rendering_dir, exist_ok=True)
if os.path.exists(f'{input_dir}/{selected_id}/05_seg.png'):
for i in range(6):
shutil.copy(f'{input_dir}/{selected_id}/{i:02d}_seg.png', f'{rendering_dir}/rendering_joint_00_state_{i:02d}.png')
shutil.copy(f'{input_dir}/{selected_id}/05_pure.png', f'{rendering_dir}/rendering_pure_joint_00_state_05.png')
full_mesh, fixed_part, articulated_part = run_sds(base_dir, 1, cfg)
return full_mesh, fixed_part, articulated_part
def pipeline(id_value, mesh_files, uploaded_image_gallery, joint_type_in, session_hash, run_idx):
"""Server-side pipeline that yields state only; UI updates are applied via a mapper."""
# 1) Prepare inputs
if id_value is not None:
sel_id = f"{labels[id_value]}_{run_idx:02d}"
elif mesh_files is not None:
sel_id = f"mesh_uploading_{run_idx:02d}"
elif uploaded_image_gallery is not None:
sel_id = f"image_uploading_{run_idx:02d}"
base_dir = f'outputs/{sel_id}_{session_hash}'
os.makedirs(base_dir, exist_ok=True)
state_list_output_path = f'{base_dir}/state_list.json'
state_list = ["running", "prepare", "prepare", "prepare"]
with open(state_list_output_path, 'w') as f:
json.dump(state_list, f)
yield (
{"step": state_list[0], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[1], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[2], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[3], "sel_id": sel_id, "session_hash": session_hash}, run_idx, \
f"Preparing input images...
(~3 min)
"
)
if id_value is not None:
rendering_paths, _ = load_renderings_with_id(id_value, session_hash)
elif mesh_files is not None:
rendering_paths = process_and_update_gallery(mesh_files, session_hash)
elif uploaded_image_gallery is not None:
rendering_paths = [img[0] for img in uploaded_image_gallery]
gallery = [[p, None] for p in rendering_paths]
cached_path = [f'datasets/cached_results/{labels[id_value]}/full.glb',
f'datasets/cached_results/{labels[id_value]}/fixed.glb',
f'datasets/cached_results/{labels[id_value]}/art.glb']
save_gallery_to_renderings(gallery, sel_id, session_hash)
state_list = ["done", "running", "prepare", "prepare"]
with open(state_list_output_path, 'w') as f:
json.dump(state_list, f)
yield (
{"step": state_list[0], "gallery": gallery, "sel_id": sel_id, "cached_path": cached_path, "session_hash": session_hash},
{"step": state_list[1], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[2], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[3], "sel_id": sel_id, "session_hash": session_hash}, run_idx, \
f"Stage 1: Initial reconstruction... (~3 min)
"
)
# 3) Reconstruction
vox, recon = recon_trellis_meshes(gallery, f"{base_dir}/recon")
state_list = ["done", "done", "running", "running"]
with open(state_list_output_path, 'w') as f:
json.dump(state_list, f)
yield (
{"step": state_list[0], "gallery": gallery, "sel_id": sel_id, "cached_path": cached_path, "session_hash": session_hash},
{"step": state_list[1], "vox": vox, "recon": recon, "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[2], "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[3], "sel_id": sel_id, "session_hash": session_hash}, run_idx, \
f"Stage 2: Joint initialization... (~1 min)
"
)
# 4) Joint estimation
if joint_type_in is not None:
joint_type = joint_type_in
else:
if 'cabinet' in sel_id:
joint_type = 'revolute' if 'cabinet2' in sel_id else 'prismatic'
else:
joint_type = 'prismatic' if (sel_id in data_info['prismatic']['obj_ids']) else 'revolute'
print(f"[Joint Type] {joint_type}")
match, meta = estimate_initial_joints(gallery, joint_type, base_dir)
# 5) Format meta info
html = format_meta_info_text(meta)
state_list = ["done", "done", "done", "running"]
with open(state_list_output_path, 'w') as f:
json.dump(state_list, f)
yield (
{"step": state_list[0], "gallery": gallery, "sel_id": sel_id, "cached_path": cached_path, "session_hash": session_hash},
{"step": state_list[1], "vox": vox, "recon": recon, "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[2], "match": match, "html": html, "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[3], "sel_id": sel_id, "session_hash": session_hash}, run_idx, \
f"Stage 3: SDS optimization... (~10 min)
"
)
# 6) SDS refinement
full, fixed, art = articulated_generation_sds(sel_id, session_hash)
state_list = ["done", "done", "done", "done"]
with open(state_list_output_path, 'w') as f:
json.dump(state_list, f)
yield (
{"step": state_list[0], "gallery": gallery, "sel_id": sel_id, "cached_path": cached_path,"session_hash": session_hash},
{"step": state_list[1], "vox": vox, "recon": recon, "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[2], "match": match, "html": html, "sel_id": sel_id, "session_hash": session_hash},
{"step": state_list[3], "full": full, "fixed": fixed, "art": art, "sel_id": sel_id, "session_hash": session_hash},
run_idx+1, \
f"Done!
"
)
with gr.Blocks(delete_cache=(600, 600), css="""
.gallery-container {
overflow-x: auto !important;
overflow-y: auto !important;
scrollbar-width: thin;
scrollbar-color: #888 #f1f1f1;
}
/* Ensure inner wrappers can scroll vertically */
.gallery-container > div {
max-height: 100% !important;
overflow-y: auto !important;
}
.gallery-container .grid,
.gallery-container .grid-wrap,
.gallery-container .thumbnail-grid,
.gallery-container .gallery,
.gallery-container .container {
max-height: 100% !important;
overflow-y: auto !important;
}
.gallery-container::-webkit-scrollbar {
height: 8px;
width: 8px;
}
.gallery-container::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 4px;
}
.gallery-container::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.gallery-container::-webkit-scrollbar-thumb:hover {
background: #555;
}
/* Ensure images are fully visible in galleries (no cropping) */
.gallery-container img {
object-fit: contain !important;
object-position: center center;
}
""") as demo:
gr.Markdown("""
## FreeArt3D Demo
### You can use the following three methods to generate articulated 3D models.
1. Choose an example object (from PartNet-Mobility or real-world).
2. Upload multiple mesh files (e.g., .obj, .ply, .stl, .glb, .gltf) at different articulation states (images will be rendered automatically).
3. Upload multiple input images with **a segmented object on the disk** at **different articulation states** (you can use a fixed view but have to contain ** at least 2 states**).
### The total generation time is *~10 min*. Be patient :)
### Currently it's difficult to run the full pipeline with ZeroGPU due to the long running time. We are applying for an exclusive GPU. Stay Tuned! (We add the cached results now for example objects for you to have a quick check)
""")
with gr.Row(equal_height=True):
with gr.Tab("Load by Example Object"):
with gr.Row():
id_input = gr.Dropdown(
choices=get_available_ids(),
value='Cabinet (real-world)',
label="Example Object",
info='Select a real-world object or an object from the PartNet-Mobility dataset. (e.g., Box (100214))',
)
with gr.Row():
id_input_btn = gr.Button("Generate Articulated Object", variant="primary")
with gr.Tab("Upload Mesh Files"):
with gr.Row(equal_height=True):
with gr.Column():
mesh_files = gr.File(label="Upload Mesh Files", file_count="multiple", file_types=[".obj", ".ply", ".stl", ".glb", ".gltf"], height=300)
mesh_selector = gr.Dropdown(label="Select Mesh to View", choices=[], value=None, interactive=True)
with gr.Column():
mesh_viewer = LitModel3D(label="Selected Mesh", exposure=10.0, height=300)
with gr.Row():
joint_type_in_mesh = gr.Dropdown(label="joint_type", choices=["revolute", "prismatic"], value=None)
with gr.Row():
mesh_input_btn = gr.Button("Generate Articulated Object", variant="primary")
with gr.Tab("Upload Image Files"):
with gr.Row(equal_height=True):
uploaded_image_gallery = gr.Gallery(label="Input Images", show_label=True, elem_id="gallery", height=300, columns=10, allow_preview=True, elem_classes=["gallery-container"])
with gr.Row():
joint_type_in_image = gr.Dropdown(label="joint_type", choices=["revolute", "prismatic"], value=None)
with gr.Row():
image_input_btn = gr.Button("Generate Articulated Object", variant="primary")
with gr.Row():
with gr.Accordion("Current Stage", open=True):
current_stage = gr.HTML(value="Waiting for input...
")
with gr.Row(equal_height=True):
with gr.Column():
input_image_gallery = gr.Gallery(label="Input Images", show_label=True, elem_id="gallery", height=300, columns=3, allow_preview=True, interactive=False, elem_classes=["gallery-container"])
with gr.Column():
voxel_gallery = gr.Gallery(label="Reconstructed Voxels", show_label=True, elem_id="gallery", height=360, columns=3, allow_preview=True, interactive=False, elem_classes=["gallery-container"])
with gr.Column():
recon_mesh_gallery = gr.Gallery(label="Reconstructed Meshes", show_label=True, elem_id="gallery", height=360, columns=3, allow_preview=True, interactive=False, elem_classes=["gallery-container"])
with gr.Row():
matching_examples = gr.Gallery(label="Correspondences", show_label=True, elem_id="gallery", height=400, interactive=False, columns=5, allow_preview=True, elem_classes=["gallery-container"])
with gr.Row():
meta_info = gr.State(
value={
"joint_axis": "",
"joint_position": "",
"joint_scale": "",
"joint_qpos": ""
}
)
with gr.Accordion("Joint Information", open=True):
joint_info_html = gr.HTML(value="")
with gr.Row():
with gr.Column():
full_mesh = LitModel3D(label="Full Mesh", exposure=10.0, height=300)
with gr.Column():
fixed_part = LitModel3D(label="Fixed Part", exposure=10.0, height=300)
with gr.Column():
articulated_part = LitModel3D(label="Articulated Part", exposure=10.0, height=300)
with gr.Row():
with gr.Column():
full_mesh_cached = LitModel3D(label="Full Mesh (Cached)", exposure=10.0, height=300)
with gr.Column():
fixed_part_cached = LitModel3D(label="Fixed Part (Cached)", exposure=10.0, height=300)
with gr.Column():
articulated_part_cached = LitModel3D(label="Articulated Part (Cached)", exposure=10.0, height=300)
selected_id_state = gr.State(value="")
session_hash_state = gr.State(value="")
run_idx_state = gr.State(value=0)
image_gallery_state = gr.State(value="prepare")
recon_state = gr.State(value="prepare")
estimate_state = gr.State(value="prepare")
sds_state = gr.State(value="prepare")
none_val = gr.State(value=None)
# Handlers
demo.load(start_session, outputs=session_hash_state)
demo.unload(end_session)
id_input_btn.click(
start_session,
outputs=session_hash_state,
).then(
pipeline,
inputs=[id_input, none_val, none_val, none_val, session_hash_state, run_idx_state],
outputs=[image_gallery_state, recon_state, estimate_state, sds_state, run_idx_state, current_stage],
)
mesh_input_btn.click(
start_session,
outputs=session_hash_state,
).then(
pipeline,
inputs=[none_val, mesh_files, none_val, joint_type_in_mesh, session_hash_state, run_idx_state],
outputs=[image_gallery_state, recon_state, estimate_state, sds_state, run_idx_state, current_stage],
)
image_input_btn.click(
start_session,
outputs=session_hash_state,
).then(
pipeline,
inputs=[none_val, none_val, uploaded_image_gallery, joint_type_in_image, session_hash_state, run_idx_state],
outputs=[image_gallery_state, recon_state, estimate_state, sds_state, run_idx_state, current_stage],
)
# Watch state and update only relevant components per step
image_gallery_state.change(
image_gallery_state_change,
inputs=image_gallery_state,
outputs=[input_image_gallery, selected_id_state, full_mesh_cached, fixed_part_cached, articulated_part_cached,
voxel_gallery, recon_mesh_gallery,
matching_examples, joint_info_html,
full_mesh, fixed_part, articulated_part],
)
recon_state.change(
recon_state_change,
inputs=recon_state,
outputs=[voxel_gallery, recon_mesh_gallery],
)
estimate_state.change(
estimate_state_change,
inputs=estimate_state,
outputs=[matching_examples, joint_info_html],
)
sds_state.change(
sds_state_change,
inputs=sds_state,
outputs=[full_mesh, fixed_part, articulated_part],
)
# Handle mesh file uploads
mesh_files.change(
handle_mesh_upload,
inputs=mesh_files,
outputs=[mesh_selector, mesh_viewer]
)
# Handle dropdown selection for mesh view
mesh_selector.change(
switch_mesh_view,
inputs=[mesh_selector, mesh_files],
outputs=mesh_viewer
)
# Ensure downstream uses the generic 'uploaded' namespace for output paths
uploaded_image_gallery.change(
lambda *_: "uploaded",
inputs=None,
outputs=selected_id_state
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()