| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| import os |
| import cv2 |
| import matplotlib.pyplot as plt |
| import math |
| from torch.nn.modules.batchnorm import _BatchNorm |
| from collections import OrderedDict |
| from torch.optim.lr_scheduler import LambdaLR |
| import torch.nn as nn |
| from torch.nn import functional as F |
| import h5py |
| import fnmatch |
| from torchvision import transforms |
| import pickle |
| from tqdm import tqdm |
| _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) |
|
|
| def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): |
| |
| for key in train_history[0]: |
| plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png") |
| plt.figure() |
| train_values = [summary[key].item() for summary in train_history] |
| val_values = [summary[key].item() for summary in validation_history] |
| plt.plot( |
| np.linspace(0, num_epochs - 1, len(train_history)), |
| train_values, |
| label="train", |
| ) |
| plt.plot( |
| np.linspace(0, num_epochs - 1, len(validation_history)), |
| val_values, |
| label="validation", |
| ) |
| plt.tight_layout() |
| plt.legend() |
| plt.title(key) |
| plt.savefig(plot_path) |
| print(f"Saved plots to {ckpt_dir}") |
|
|
|
|
| def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: |
| """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. |
| |
| Args: |
| input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. |
| Returns: |
| A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. |
| """ |
| if range_min == -1: |
| input_tensor = (input_tensor.float() + 1.0) / 2.0 |
| ndim = input_tensor.ndim |
| output_image = input_tensor.clamp(0, 1).cpu().numpy() |
| output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) |
| return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) |
|
|
|
|
| def kl_divergence(mu, logvar): |
| batch_size = mu.size(0) |
| assert batch_size != 0 |
| if mu.data.ndimension() == 4: |
| mu = mu.view(mu.size(0), mu.size(1)) |
| if logvar.data.ndimension() == 4: |
| logvar = logvar.view(logvar.size(0), logvar.size(1)) |
|
|
| klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) |
| total_kld = klds.sum(1).mean(0, True) |
| dimension_wise_kld = klds.mean(0) |
| mean_kld = klds.mean(1).mean(0, True) |
|
|
| return total_kld, dimension_wise_kld, mean_kld |
|
|
|
|
| class RandomShiftsAug(nn.Module): |
| def __init__(self, pad_h, pad_w): |
| super().__init__() |
| self.pad_h = pad_h |
| self.pad_w = pad_w |
| print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}") |
|
|
| def forward(self, x): |
| orignal_shape = x.shape |
| n, h, w = x.shape[0], x.shape[-2], x.shape[-1] |
| x = x.view(n, -1, h, w) |
| padding = ( |
| self.pad_w, |
| self.pad_w, |
| self.pad_h, |
| self.pad_h, |
| ) |
| x = F.pad(x, padding, mode="replicate") |
|
|
| h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w |
| eps_h = 1.0 / h_pad |
| eps_w = 1.0 / w_pad |
|
|
| arange_h = torch.linspace( |
| -1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype |
| )[:h] |
| arange_w = torch.linspace( |
| -1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype |
| )[:w] |
|
|
| arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) |
| arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) |
|
|
| |
| base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) |
| base_grid = base_grid.unsqueeze(0).repeat( |
| n, 1, 1, 1 |
| ) |
|
|
| shift_h = torch.randint( |
| 0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
| ).float() |
| shift_w = torch.randint( |
| 0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
| ).float() |
| shift_h *= 2.0 / h_pad |
| shift_w *= 2.0 / w_pad |
|
|
| grid = base_grid + torch.cat([shift_w, shift_h], dim=3) |
| x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) |
| return x.view(orignal_shape) |
|
|
|
|
| def get_norm_stats(state, action): |
| all_qpos_data = torch.from_numpy(np.array(state)) |
| all_action_data = torch.from_numpy(np.array(action)) |
| |
| action_mean = all_action_data.mean(dim=[0], keepdim=True) |
| action_std = all_action_data.std(dim=[0], keepdim=True) |
| action_std = torch.clip(action_std, 1e-2, np.inf) |
| action_max = torch.amax(all_action_data, dim=[0], keepdim=True) |
| action_min = torch.amin(all_action_data, dim=[0], keepdim=True) |
|
|
| |
| qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True) |
| qpos_std = all_qpos_data.std(dim=[0], keepdim=True) |
| qpos_std = torch.clip(qpos_std, 1e-2, np.inf) |
|
|
| stats = { |
| "action_mean": action_mean.numpy().squeeze(), |
| "action_std": action_std.numpy().squeeze(), |
| "action_max": action_max.numpy().squeeze(), |
| "action_min": action_min.numpy().squeeze(), |
| "qpos_mean": qpos_mean.numpy().squeeze(), |
| "qpos_std": qpos_std.numpy().squeeze(), |
| } |
|
|
| return stats |
|
|
| class EpisodicDataset_Unified_Multiview(Dataset): |
| def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False): |
| super(EpisodicDataset_Unified_Multiview).__init__() |
| self.data_path_list = data_path_list |
| self.camera_names = camera_names |
| self.chunk_size = chunk_size |
| self.norm_stats = stats |
| self.img_aug = img_aug |
| self.ColorJitter = transforms.ColorJitter( |
| brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01) |
| def __len__(self): |
| return len(self.data_path_list) * 16 |
| def __getitem__(self, path_index): |
| |
| |
| path_index = path_index % len(self.data_path_list) |
| example_path = self.data_path_list[path_index] |
| with h5py.File(example_path, 'r') as f: |
| action = f['observations']['qpos'][()] |
| qpos = f['action'][()] |
|
|
| parent_path = os.path.dirname(example_path) |
| Instruction_path = os.path.join(parent_path, 'instructions') |
| |
| instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')] |
| instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files)) |
| instruction = torch.load(instruction_file, weights_only=False) |
| |
| episode_len = action.shape[0] |
| index = np.random.randint(0, episode_len) |
| obs_qpos = qpos[index:index + 1] |
| |
| |
| with h5py.File(example_path, 'r') as f: |
| camera_list = [] |
| for camera_name in self.camera_names: |
| cam_jpeg_code = f['observations']['images'][camera_name][index] |
| cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) |
| camera_list.append(cam_image) |
| obs_img = np.stack(camera_list, axis=0) |
| original_action_shape = (self.chunk_size, *action.shape[1:]) |
| gt_action = np.zeros(original_action_shape) |
| action_len = min(self.chunk_size, episode_len - index) |
| gt_action[:action_len] = action[ |
| index : index + action_len |
| ] |
| is_pad = np.zeros(self.chunk_size) |
| is_pad[action_len:] = 1 |
| |
| |
| image_data = torch.from_numpy(obs_img).unsqueeze(0).float() |
| image_data = image_data.permute(0, 1, 4, 2, 3) |
| qpos_data = torch.from_numpy(obs_qpos).float() |
| action_data = torch.from_numpy(gt_action).float() |
| is_pad = torch.from_numpy(is_pad).bool() |
| instruction_data = instruction.mean(0).float() |
| |
| image_data = image_data / 255.0 |
| qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ |
| "qpos_std" |
| ] |
| if self.img_aug and random.random() < 0.25: |
| for t in range(image_data.shape[0]): |
| for i in range(image_data.shape[1]): |
| image_data[t, i] =self.ColorJitter(image_data[t, i]) |
| return image_data, qpos_data.float(), action_data, is_pad, instruction_data |
|
|
| def load_data_unified( |
| data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
| camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
| batch_size_train=32, |
| chunk_size=100, |
| img_aug=False, |
| fintune=False, |
| ): |
| |
| HDF5_file_path = [] |
| for root, _, files in os.walk(data_dir, followlinks=True): |
| for filename in files: |
| if filename.endswith('.hdf5'): |
| HDF5_file_path.append(os.path.join(root, filename)) |
| print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}") |
| |
| state_list = [] |
| action_list = [] |
| |
| |
| for p in tqdm(HDF5_file_path, desc="Data statics collection"): |
| with h5py.File(p, 'r') as f: |
| action = f['observations']['qpos'][()] |
| qpos = f['action'][()] |
| state_list.append(qpos) |
| action_list.append(action) |
| states = np.concatenate(state_list, axis=0) |
| actions = np.concatenate(action_list, axis=0) |
| |
| if fintune: |
| |
| pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl' |
| with open(pretrain_stats_path, 'rb') as f: |
| stats = pickle.load(f) |
| print(f"Loaded stats from {pretrain_stats_path}") |
| else: |
| stats = get_norm_stats(states, actions) |
| |
| for key, value in stats.items(): |
| print(f"{key}: {value}") |
| |
| train_dataset = EpisodicDataset_Unified_Multiview( |
| data_path_list=HDF5_file_path, |
| camera_names=camera_names, |
| chunk_size=chunk_size, |
| stats=stats, |
| img_aug=img_aug, |
| ) |
| |
| traind_data_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size_train, |
| shuffle=True, |
| num_workers=8, |
| pin_memory=True, |
| ) |
| |
| return traind_data_loader,None,None, stats |
|
|
| def compute_dict_mean(epoch_dicts): |
| result = {k: None for k in epoch_dicts[0]} |
| num_items = len(epoch_dicts) |
| for k in result: |
| value_sum = 0 |
| for epoch_dict in epoch_dicts: |
| value_sum += epoch_dict[k] |
| result[k] = value_sum / num_items |
| return result |
|
|
|
|
| def detach_dict(d): |
| new_d = dict() |
| for k, v in d.items(): |
| new_d[k] = v.detach() |
| return new_d |
|
|
|
|
| |
| |
| |
| import random |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_cosine_schedule_with_warmup( |
| optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 |
| ): |
| """ |
| Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| initial lr set in the optimizer. |
| |
| Args: |
| optimizer ([~torch.optim.Optimizer]): |
| The optimizer for which to schedule the learning rate. |
| num_warmup_steps (int): |
| The number of steps for the warmup phase. |
| num_training_steps (int): |
| The total number of training steps. |
| num_cycles (float, *optional*, defaults to 0.5): |
| The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
| following a half-cosine). |
| last_epoch (int, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float( |
| max(1, num_training_steps - num_warmup_steps) |
| ) |
| return max( |
| 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| ) |
|
|
| return LambdaLR(optimizer, lr_lambda, last_epoch) |
|
|
|
|
| def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR: |
| """ |
| Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| |
| Args: |
| optimizer ([`~torch.optim.Optimizer`]): |
| The optimizer for which to schedule the learning rate. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| """ |
| return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) |
|
|
|
|
| def normalize_data(action_data, stats, norm_type, data_type="action"): |
|
|
| if norm_type == "minmax": |
| action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device) |
| action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device) |
| action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1 |
| elif norm_type == "gaussian": |
| action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device) |
| action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device) |
| action_data = (action_data - action_mean) / action_std |
| return action_data |
|
|
|
|
| def convert_weight(obj): |
| newmodel = OrderedDict() |
| for k, v in obj.items(): |
| if k.startswith("module."): |
| newmodel[k[7:]] = v |
| else: |
| newmodel[k] = v |
| return newmodel |
|
|
|
|
| if __name__ == "__main__": |
| train_dataloader,_,_,stats = load_data_unified( |
| data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
| camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
| batch_size_train=32, |
| chunk_size=100, |
| img_aug=True, |
| ) |
| |
| for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate( |
| tqdm(train_dataloader, desc="Data loading") |
| ): |
| if i == 0: |
| print(f"Batch {i}:") |
| print(f"Image data shape: {image_data.shape} {image_data.max()}") |
| print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" ) |
| print(f"Action data shape: {action_data.shape} {action_data.max()}") |
| print(f"Is pad shape: {is_pad.shape}") |
| print(f"Instruction data shape: {instruction_data.shape}") |
| |
| continue |
| |