quantshah commited on
Commit
bbbd567
·
verified ·
1 Parent(s): 659c708

Add full 3D-mesh demo (demo_3d.py) + run instructions

Browse files
Files changed (1) hide show
  1. demo_3d.py +101 -0
demo_3d.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """End-to-end 3D human-mesh demo using the Embedl INT8 backbone.
3
+
4
+ Our quantized DINOv3 backbone (this repo) provides the image features; the
5
+ upstream SAM-3D-Body decoder + MHR mesh head turn them into a 3D body mesh.
6
+ This script runs the full pipeline and renders the result with matplotlib
7
+ (no OpenGL needed).
8
+
9
+ Prerequisites
10
+ -------------
11
+ # 1. upstream pipeline (you must have accepted the gated upstream license)
12
+ git clone https://github.com/facebookresearch/sam-3d-body
13
+ pip install -e sam-3d-body # + its deps (see its INSTALL.md)
14
+ pip install torch matplotlib pillow numpy imageio huggingface_hub
15
+ # 2. gated checkpoint (facebook/sam-3d-body-dinov3): model.ckpt, model_config.yaml,
16
+ # assets/mhr_model.pt -> download with `hf download` after accepting the license
17
+ # 3. this repo's backbone: embedl_sam3dbody_int8.pt2
18
+
19
+ Run
20
+ ---
21
+ python demo_3d.py --image person.jpg --ckpt-dir ./sam3d_ckpt \
22
+ --pt2 embedl_sam3dbody_int8.pt2 --bbox 180 210 700 950 --out mesh_demo.png
23
+ """
24
+ import argparse, types, numpy as np, cv2, torch
25
+ import matplotlib; matplotlib.use("Agg")
26
+ import matplotlib.pyplot as plt
27
+ from matplotlib.collections import PolyCollection
28
+ import imageio.v2 as imageio
29
+ from sam_3d_body import load_sam_3d_body, SAM3DBodyEstimator # upstream repo
30
+
31
+ LIGHT = np.array([0.3, 0.5, 1.0]); LIGHT /= np.linalg.norm(LIGHT)
32
+ SKIN = np.array([0.80, 0.78, 0.72])
33
+
34
+
35
+ def recover_mesh(image, ckpt_dir, pt2, bbox):
36
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
37
+ model, cfg = load_sam_3d_body(f"{ckpt_dir}/model.ckpt", device=dev,
38
+ mhr_path=f"{ckpt_dir}/assets/mhr_model.pt")
39
+ # swap in the Embedl INT8 backbone (same I/O as the DINOv3 encoder; pipeline is bf16)
40
+ qb = torch.export.load(pt2).module().to(dev)
41
+ def backbone(self, x, *a, **k):
42
+ return torch.cat([qb(x[i:i + 1].float()) for i in range(x.shape[0])], 0).to(x.dtype)
43
+ model.backbone.forward = types.MethodType(backbone, model.backbone)
44
+
45
+ est = SAM3DBodyEstimator(model, cfg) # no detector: pass a bbox
46
+ h, w = cv2.imread(image).shape[:2]
47
+ box = np.array([bbox if bbox else [0, 0, w, h]], dtype=np.float32)
48
+ out = est.process_one_image(image, bboxes=box, use_mask=False)[0]
49
+ return out["pred_vertices"], est.faces, out["pred_cam_t"], float(out["focal_length"])
50
+
51
+
52
+ def _shade(v, f):
53
+ n = np.cross(v[f][:, 1] - v[f][:, 0], v[f][:, 2] - v[f][:, 0])
54
+ n /= (np.linalg.norm(n, axis=1, keepdims=True) + 1e-9)
55
+ lam = np.clip(np.abs(n @ LIGHT), 0, 1)[:, None]
56
+ return np.clip(0.25 + 0.75 * lam * SKIN, 0, 1)
57
+
58
+
59
+ def _view(ax, V, F, deg, title):
60
+ Vc = V - V.mean(0); th = np.radians(deg)
61
+ R = np.array([[np.cos(th), 0, np.sin(th)], [0, 1, 0], [-np.sin(th), 0, np.cos(th)]])
62
+ Vr = Vc @ R.T; p = Vr[:, :2] * [1, -1]; o = np.argsort(Vr[F].mean(1)[:, 2])
63
+ ax.add_collection(PolyCollection(p[F][o], facecolors=_shade(Vr, F)[o], edgecolors="none"))
64
+ ax.set_xlim(p[:, 0].min(), p[:, 0].max()); ax.set_ylim(p[:, 1].min(), p[:, 1].max())
65
+ ax.set_aspect("equal"); ax.axis("off"); ax.set_title(title, fontsize=11)
66
+
67
+
68
+ def render(image, V, F, cam_t, focal, bbox, out):
69
+ img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB)
70
+ x1, y1, x2, y2 = bbox if bbox else [0, 0, img.shape[1], img.shape[0]]
71
+ crop = cv2.resize(img[y1:y2, x1:x2], (512, 512))
72
+ fig, ax = plt.subplots(1, 4, figsize=(15, 6)); fig.patch.set_facecolor("white")
73
+ ax[0].imshow(img); ax[0].axis("off"); ax[0].set_title("Input", fontsize=11)
74
+ Vc = V + cam_t; z = np.clip(Vc[:, 2], 1e-3, None)
75
+ p = np.stack([focal * Vc[:, 0] / z + 256, focal * Vc[:, 1] / z + 256], 1)
76
+ o = np.argsort(-Vc[F].mean(1)[:, 2])
77
+ ax[1].imshow(crop)
78
+ ax[1].add_collection(PolyCollection(p[F][o], facecolors=_shade(Vc, F)[o], edgecolors="none", alpha=0.8))
79
+ ax[1].set_xlim(0, 512); ax[1].set_ylim(512, 0); ax[1].axis("off"); ax[1].set_title("Mesh overlay", fontsize=11)
80
+ _view(ax[2], V, F, 20, "¾ view"); _view(ax[3], V, F, 90, "side view")
81
+ plt.tight_layout(); plt.savefig(out, dpi=160, bbox_inches="tight"); plt.close()
82
+ frames = []
83
+ for a in range(0, 360, 15):
84
+ fig, axx = plt.subplots(figsize=(4, 6)); fig.patch.set_facecolor("white"); _view(axx, V, F, a, "")
85
+ fig.canvas.draw()
86
+ frames.append(np.asarray(fig.canvas.buffer_rgba())[..., :3].copy()); plt.close()
87
+ imageio.mimsave(out.rsplit(".", 1)[0] + "_spin.gif", frames, duration=0.1, loop=0)
88
+ print(f"wrote {out} and {out.rsplit('.', 1)[0]}_spin.gif")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ ap = argparse.ArgumentParser()
93
+ ap.add_argument("--image", required=True)
94
+ ap.add_argument("--ckpt-dir", required=True, help="dir with model.ckpt, model_config.yaml, assets/mhr_model.pt")
95
+ ap.add_argument("--pt2", default="embedl_sam3dbody_int8.pt2")
96
+ ap.add_argument("--bbox", type=int, nargs=4, default=None, metavar=("x1", "y1", "x2", "y2"))
97
+ ap.add_argument("--out", default="mesh_demo.png")
98
+ a = ap.parse_args()
99
+ V, F, cam_t, focal = recover_mesh(a.image, a.ckpt_dir, a.pt2, a.bbox)
100
+ print(f"recovered mesh: {V.shape[0]} vertices")
101
+ render(a.image, V, F, cam_t, focal, a.bbox, a.out)