Update demo_deploy.py
Browse files- demo_deploy.py +2 -2
demo_deploy.py
CHANGED
|
@@ -59,10 +59,10 @@ def get_model(config_file, ckpt_file, device):
|
|
| 59 |
print(f"Loading policy config from {config_file}")
|
| 60 |
policy = ACTDiffusionPolicy(policy_config)
|
| 61 |
print(f"Loading model from {ckpt_file}")
|
| 62 |
-
policy.load_state_dict(convert_weight(torch.load(ckpt_file)["state_dict"]))
|
| 63 |
policy.to(device)
|
| 64 |
policy.eval()
|
| 65 |
-
stats = torch.load(ckpt_file)["stats"]
|
| 66 |
print('Resetting observation normalization stats')
|
| 67 |
policy.reset_obs(stats, norm_type = policy_config["norm_type"])
|
| 68 |
camera_names = policy_config["camera_names"]
|
|
|
|
| 59 |
print(f"Loading policy config from {config_file}")
|
| 60 |
policy = ACTDiffusionPolicy(policy_config)
|
| 61 |
print(f"Loading model from {ckpt_file}")
|
| 62 |
+
policy.load_state_dict(convert_weight(torch.load(ckpt_file, weights_only=False)["state_dict"]))
|
| 63 |
policy.to(device)
|
| 64 |
policy.eval()
|
| 65 |
+
stats = torch.load(ckpt_file, weights_only=False)["stats"]
|
| 66 |
print('Resetting observation normalization stats')
|
| 67 |
policy.reset_obs(stats, norm_type = policy_config["norm_type"])
|
| 68 |
camera_names = policy_config["camera_names"]
|