returnzeros commited on
Commit
6921fb0
·
verified ·
1 Parent(s): 2367d16

Update demo_deploy.py

Browse files
Files changed (1) hide show
  1. demo_deploy.py +2 -2
demo_deploy.py CHANGED
@@ -59,10 +59,10 @@ def get_model(config_file, ckpt_file, device):
59
  print(f"Loading policy config from {config_file}")
60
  policy = ACTDiffusionPolicy(policy_config)
61
  print(f"Loading model from {ckpt_file}")
62
- policy.load_state_dict(convert_weight(torch.load(ckpt_file)["state_dict"]))
63
  policy.to(device)
64
  policy.eval()
65
- stats = torch.load(ckpt_file)["stats"]
66
  print('Resetting observation normalization stats')
67
  policy.reset_obs(stats, norm_type = policy_config["norm_type"])
68
  camera_names = policy_config["camera_names"]
 
59
  print(f"Loading policy config from {config_file}")
60
  policy = ACTDiffusionPolicy(policy_config)
61
  print(f"Loading model from {ckpt_file}")
62
+ policy.load_state_dict(convert_weight(torch.load(ckpt_file, weights_only=False)["state_dict"]))
63
  policy.to(device)
64
  policy.eval()
65
+ stats = torch.load(ckpt_file, weights_only=False)["stats"]
66
  print('Resetting observation normalization stats')
67
  policy.reset_obs(stats, norm_type = policy_config["norm_type"])
68
  camera_names = policy_config["camera_names"]