forked from ac-rad/arnold
-
Notifications
You must be signed in to change notification settings - Fork 0
/
render_test.py
133 lines (102 loc) · 4.64 KB
/
render_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import hydra
from omegaconf import OmegaConf
import os
import torch
import torch.nn as nn
import sys
import numpy as np
from pathlib import Path
from environment.runner_utils import get_simulation
import matplotlib.pyplot as plt
import random
SAVE_DIR = '/home/chemrobot/Documents/RichardHanxu2023/SRTACT_Eval/arnold_re_rendered'
DATA_DIR = '/home/chemrobot/Documents/RichardHanxu2023/SRTACT_Eval/arnold_dataset/data'
SPLIT = 'train'
def load_data(data_path):
demos = list(Path(data_path).iterdir())
demo_path = sorted([str(item) for item in demos if not item.is_dir()])
random.shuffle(demo_path)
data = []
fnames = []
for npz_path in demo_path:
data.append(np.load(npz_path, allow_pickle=True))
fnames.append(npz_path)
return data, fnames
def save_camera_renders(obs, gt_frame, obs_counter):
for i in range(len(obs['images'])):
plt.imsave(f'./image_out/obs_{obs_counter}_cube_{i}.png', obs['images'][i]['rgb'])
for i in range(len(gt_frame['images'])):
plt.imsave(f'./image_out/obs_{obs_counter}_base_{i}.png', gt_frame['images'][i]['rgb'])
def add_cube_to_observation(obs, gt_frame):
return np.concatenate((gt_frame['images'], obs['images'][5:]))
def save_observation_np(anno, gt_frames, path):
np.savez(path, gt=gt_frames, info=anno['info'])
def main(cfg):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device is {device}')
# task_list = ['close_cabinet', 'close_drawer', 'open_cabinet', 'pickup_object', 'reorient_object', 'pour_water']
task_list = ['close_drawer']
simulation_app, _, _ = get_simulation(headless = True)
from arnold_dataset.tasks import load_task
obs_counter = 0
while simulation_app.is_running():
for task in task_list:
for SPLIT in ['train', 'val', 'test']:
if SPLIT == 'train' and task == 'transfer_water':
continue
data, fnames = load_data(os.path.join(DATA_DIR, task, SPLIT))
os.makedirs(os.path.join(SAVE_DIR, task, SPLIT), exist_ok=True)
print(f"Rendering {len(data)} episodes")
number_of_rendered = len([name for name in os.listdir(os.path.join(SAVE_DIR, task, SPLIT))])
if SPLIT == 'train':
data = data[:50]
elif SPLIT == 'val':
data = data[:10]
elif SPLIT == 'test':
data = data[:10]
else:
raise Exception("Invalid Split!")
if number_of_rendered > 0:
data = data[number_of_rendered:]
fnames = fnames[number_of_rendered:]
while len(data) > 0:
anno = data.pop(0)
fname = fnames.pop(0)
gt_frames = anno['gt'].copy()
robot_base = gt_frames[0]['robot_base']
gt_actions = [
gt_frames[1]['position_rotation_world'], gt_frames[2]['position_rotation_world'],
gt_frames[3]['position_rotation_world'] if 'water' not in task \
else (gt_frames[3]['position_rotation_world'][0], gt_frames[4]['position_rotation_world'][1])
]
env, object_parameters, robot_parameters, scene_parameters = load_task('/home/chemrobot/Documents/RichardHanxu2023/SRTACT_Eval/arnold_dataset/assets', npz=anno, cfg=cfg)
obs = env.reset(robot_parameters, scene_parameters, object_parameters,
robot_base=robot_base, gt_actions=gt_actions)
# gt_frames[0]['images'] = add_cube_to_observation(obs, gt_frames[0])
gt_frames[0]['images'] = obs['images']
# save_camera_renders(obs, gt_frames[0], obs_counter)
obs_counter += 1
for i in range(len(gt_actions)):
obs, suc = env.step(act_pos=None, act_rot=None, render=True, use_gt=True)
# gt_frames[i + 1]['images'] = add_cube_to_observation(obs, gt_frames[i+1])
gt_frames[i+1]['images'] = obs['images']
# save_camera_renders(obs, gt_frames[i+1], obs_counter)
obs_counter += 1
save_observation_np(anno, gt_frames, Path(f'{SAVE_DIR}/{task}/{SPLIT}/{fname.split("/")[-1]}'))
env.stop()
simulation_app.close()
class DotDict:
def __init__(self, d):
self.d = d
def __getattr__(self, attr):
if attr in self.d:
return self.d[attr]
raise AttributeError(f"'DotDict' object has no attribute '{attr}'")
def hydra_extract():
import omegaconf
cfg = omegaconf.OmegaConf.load('/home/chemrobot/Documents/RichardHanxu2023/SRTACT_Eval/arnold_dataset/configs/default.yaml')
return dict(cfg)
if __name__ == '__main__':
s = hydra_extract()
s = DotDict(s)
main(s)