hma / sim /example /genie_langtable_replay.py
LeroyWaa's picture
draft
246c106
Raw
History Blame Contribute Delete
3.74 kB
import numpy as np
import json
import cv2
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator, ReplaySimulator
from sim.policy import ReplayPolicy
# NOTE: ad hoc
def normalize_frames(frames):
new_frames = []
for frame in frames:
H, W = frame.shape[:2]
if H < W:
Hnew, Wnew = 256, int(W * 256 / H)
else:
Hnew, Wnew = int(H * 256 / W), 256
frame = cv2.resize(frame, (Wnew, Hnew))
H, W = frame.shape[:2]
Hstart = (H - 256) // 2
Wstart = (W - 256) // 2
frame = frame[Hstart:Hstart+256, Wstart:Wstart+256]
new_frames.append(frame)
return np.stack(new_frames, axis=0)
if __name__ == '__main__':
prompt_horizon = 11
action_stride = 1
dataset_dir = "data/langtable_raw_train"
metadata = json.load(open(f"{dataset_dir}/metadata.json"))
h, w = metadata['h'], metadata['w']
action_dim = metadata['action_dim']
num_images = metadata['num_images']
actions = np.fromfile(f"{dataset_dir}/actions/actions.bin", dtype=np.float32).reshape(num_images, action_dim)
frames = np.fromfile(f"{dataset_dir}/video.bin", dtype=np.uint8).reshape(num_images, h, w, 3)
# frames = normalize_frames(frames)
segment_ids = np.fromfile(f"{dataset_dir}/segment_ids.bin", dtype=np.int32)
print(f"{actions.shape=}, {frames.shape=}, {segment_ids.shape=}")
# get chunks' start and end
chunks = [] # [start_index, end_index)
start_index = 0
end_index = 0
while end_index < len(segment_ids):
while end_index < len(segment_ids) and segment_ids[end_index] == segment_ids[start_index]:
end_index += 1
if end_index - start_index > prompt_horizon * 2:
chunks.append((start_index, end_index))
start_index = end_index
print(f"there're {len(chunks)} chunks")
for eps_idx, chunk in enumerate(chunks):
start_idx, end_idx = chunk
this_frames = frames[start_idx:end_idx]
this_actions = actions[start_idx:end_idx]
print(f"processing chunk {eps_idx} with {len(this_frames)} frames")
replay_simulator = ReplaySimulator(frames=this_frames, prompt_horizon=prompt_horizon)
replay_policy = ReplayPolicy(actions=this_actions, prompt_horizon=prompt_horizon, action_stride=action_stride)
assert len(replay_policy) == len(replay_simulator)
genie_simulator = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type="stmar",
backbone_ckpt="data/mar_ckpt/langtable",
prompt_horizon=prompt_horizon,
action_stride=action_stride,
domain='language_table',
physics_simulator=replay_simulator,
compute_psnr=False,
compute_delta_psnr=False,
allow_external_prompt=True
)
# use whatever current state is as the initial state
image_prompt = replay_simulator.prompt()
action_prompt = replay_policy.prompt()
genie_simulator.set_initial_state((image_prompt, action_prompt))
playground = InteractiveDigitalWorld(
simulator=genie_simulator,
policy=replay_policy,
offscreen=True,
window_size=(512 * 2, 512) # [genie image | GT image] side-by-side
)
for _ in range(len(replay_policy)):
playground.step()
save_video_path = f'data/langtable_train_videos/{eps_idx}.mp4'
print(f"Saving video to {save_video_path}")
playground.save_video(save_path=save_video_path, as_gif=False)
playground.close()