-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathmain.py
More file actions
333 lines (269 loc) · 12.4 KB
/
main.py
File metadata and controls
333 lines (269 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import glob, tqdm, wandb, os, json, random, time, jax
from absl import app, flags
from ml_collections import config_flags
from log_utils import setup_wandb, get_exp_name, get_flag_dict, CsvLogger
from envs.env_utils import make_env_and_datasets
from envs.ogbench_utils import make_ogbench_env_and_datasets
from envs.robomimic_utils import is_robomimic_env
from utils.flax_utils import save_agent
from utils.datasets import Dataset, ReplayBuffer
from evaluation import evaluate
from agents import agents
import numpy as np
if 'CUDA_VISIBLE_DEVICES' in os.environ:
os.environ['EGL_DEVICE_ID'] = os.environ['CUDA_VISIBLE_DEVICES']
os.environ['MUJOCO_EGL_DEVICE_ID'] = os.environ['CUDA_VISIBLE_DEVICES']
FLAGS = flags.FLAGS
flags.DEFINE_string('run_group', 'Debug', 'Run group.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'cube-triple-play-singletask-task2-v0', 'Environment (dataset) name.')
flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
flags.DEFINE_integer('offline_steps', 1000000, 'Number of online steps.')
flags.DEFINE_integer('online_steps', 1000000, 'Number of online steps.')
flags.DEFINE_integer('buffer_size', 2000000, 'Replay buffer size.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
flags.DEFINE_integer('save_interval', -1, 'Save interval.')
flags.DEFINE_integer('start_training', 5000, 'when does training start')
flags.DEFINE_integer('utd_ratio', 1, "update to data ratio")
flags.DEFINE_float('discount', 0.99, 'discount factor')
flags.DEFINE_integer('eval_episodes', 50, 'Number of evaluation episodes.')
flags.DEFINE_integer('video_episodes', 0, 'Number of video episodes for each task.')
flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')
config_flags.DEFINE_config_file('agent', 'agents/acfql.py', lock_config=False)
flags.DEFINE_float('dataset_proportion', 1.0, "Proportion of the dataset to use")
flags.DEFINE_integer('dataset_replace_interval', 1000, 'Dataset replace interval, used for large datasets because of memory constraints')
flags.DEFINE_string('ogbench_dataset_dir', None, 'OGBench dataset directory')
flags.DEFINE_integer('horizon_length', 5, 'action chunking length.')
flags.DEFINE_bool('sparse', False, "make the task sparse reward")
flags.DEFINE_bool('save_all_online_states', False, "save all trajectories to npy")
class LoggingHelper:
def __init__(self, csv_loggers, wandb_logger):
self.csv_loggers = csv_loggers
self.wandb_logger = wandb_logger
self.first_time = time.time()
self.last_time = time.time()
def log(self, data, prefix, step):
assert prefix in self.csv_loggers, prefix
self.csv_loggers[prefix].log(data, step=step)
self.wandb_logger.log({f'{prefix}/{k}': v for k, v in data.items()}, step=step)
def main(_):
exp_name = get_exp_name(FLAGS.seed)
run = setup_wandb(project='qc', group=FLAGS.run_group, name=exp_name)
FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, FLAGS.env_name, exp_name)
os.makedirs(FLAGS.save_dir, exist_ok=True)
flag_dict = get_flag_dict()
with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
json.dump(flag_dict, f)
config = FLAGS.agent
# data loading
if FLAGS.ogbench_dataset_dir is not None:
# custom ogbench dataset
assert FLAGS.dataset_replace_interval != 0
assert FLAGS.dataset_proportion == 1.0
dataset_idx = 0
dataset_paths = [
file for file in sorted(glob.glob(f"{FLAGS.ogbench_dataset_dir}/*.npz")) if '-val.npz' not in file
]
env, eval_env, train_dataset, val_dataset = make_ogbench_env_and_datasets(
FLAGS.env_name,
dataset_path=dataset_paths[dataset_idx],
compact_dataset=False,
)
else:
env, eval_env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name)
# house keeping
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
online_rng, rng = jax.random.split(jax.random.PRNGKey(FLAGS.seed), 2)
log_step = 0
discount = FLAGS.discount
config["horizon_length"] = FLAGS.horizon_length
# handle dataset
def process_train_dataset(ds):
"""
Process the train dataset to
- handle dataset proportion
- handle sparse reward
- convert to action chunked dataset
"""
ds = Dataset.create(**ds)
if FLAGS.dataset_proportion < 1.0:
new_size = int(len(ds['masks']) * FLAGS.dataset_proportion)
ds = Dataset.create(
**{k: v[:new_size] for k, v in ds.items()}
)
if is_robomimic_env(FLAGS.env_name):
penalty_rewards = ds["rewards"] - 1.0
ds_dict = {k: v for k, v in ds.items()}
ds_dict["rewards"] = penalty_rewards
ds = Dataset.create(**ds_dict)
if FLAGS.sparse:
# Create a new dataset with modified rewards instead of trying to modify the frozen one
sparse_rewards = (ds["rewards"] != 0.0) * -1.0
ds_dict = {k: v for k, v in ds.items()}
ds_dict["rewards"] = sparse_rewards
ds = Dataset.create(**ds_dict)
return ds
train_dataset = process_train_dataset(train_dataset)
example_batch = train_dataset.sample(())
agent_class = agents[config['agent_name']]
agent = agent_class.create(
FLAGS.seed,
example_batch['observations'],
example_batch['actions'],
config,
)
# Setup logging.
prefixes = ["eval", "env"]
if FLAGS.offline_steps > 0:
prefixes.append("offline_agent")
if FLAGS.online_steps > 0:
prefixes.append("online_agent")
logger = LoggingHelper(
csv_loggers={prefix: CsvLogger(os.path.join(FLAGS.save_dir, f"{prefix}.csv"))
for prefix in prefixes},
wandb_logger=wandb,
)
offline_init_time = time.time()
# Offline RL
for i in tqdm.tqdm(range(1, FLAGS.offline_steps + 1)):
log_step += 1
if FLAGS.ogbench_dataset_dir is not None and FLAGS.dataset_replace_interval != 0 and i % FLAGS.dataset_replace_interval == 0:
dataset_idx = (dataset_idx + 1) % len(dataset_paths)
print(f"Using new dataset: {dataset_paths[dataset_idx]}", flush=True)
train_dataset, val_dataset = make_ogbench_env_and_datasets(
FLAGS.env_name,
dataset_path=dataset_paths[dataset_idx],
compact_dataset=False,
dataset_only=True,
cur_env=env,
)
train_dataset = process_train_dataset(train_dataset)
batch = train_dataset.sample_sequence(config['batch_size'], sequence_length=FLAGS.horizon_length, discount=discount)
agent, offline_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
logger.log(offline_info, "offline_agent", step=log_step)
# saving
if FLAGS.save_interval > 0 and i % FLAGS.save_interval == 0:
save_agent(agent, FLAGS.save_dir, log_step)
# eval
if i == FLAGS.offline_steps - 1 or \
(FLAGS.eval_interval != 0 and i % FLAGS.eval_interval == 0):
# during eval, the action chunk is executed fully
eval_info, _, _ = evaluate(
agent=agent,
env=eval_env,
action_dim=example_batch["actions"].shape[-1],
num_eval_episodes=FLAGS.eval_episodes,
num_video_episodes=FLAGS.video_episodes,
video_frame_skip=FLAGS.video_frame_skip,
)
logger.log(eval_info, "eval", step=log_step)
# transition from offline to online
replay_buffer = ReplayBuffer.create_from_initial_dataset(
dict(train_dataset), size=max(FLAGS.buffer_size, train_dataset.size + 1)
)
ob, _ = env.reset()
action_queue = []
action_dim = example_batch["actions"].shape[-1]
# Online RL
update_info = {}
from collections import defaultdict
data = defaultdict(list)
online_init_time = time.time()
for i in tqdm.tqdm(range(1, FLAGS.online_steps + 1)):
log_step += 1
online_rng, key = jax.random.split(online_rng)
# during online rl, the action chunk is executed fully
if len(action_queue) == 0:
action = agent.sample_actions(observations=ob, rng=key)
action_chunk = np.array(action).reshape(-1, action_dim)
for action in action_chunk:
action_queue.append(action)
action = action_queue.pop(0)
next_ob, int_reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if FLAGS.save_all_online_states:
state = env.get_state()
data["steps"].append(i)
data["obs"].append(np.copy(next_ob))
data["qpos"].append(np.copy(state["qpos"]))
data["qvel"].append(np.copy(state["qvel"]))
if "button_states" in state:
data["button_states"].append(np.copy(state["button_states"]))
# logging useful metrics from info dict
env_info = {}
for key, value in info.items():
if key.startswith("distance"):
env_info[key] = value
# always log this at every step
logger.log(env_info, "env", step=log_step)
if 'antmaze' in FLAGS.env_name and (
'diverse' in FLAGS.env_name or 'play' in FLAGS.env_name or 'umaze' in FLAGS.env_name
):
# Adjust reward for D4RL antmaze.
int_reward = int_reward - 1.0
elif is_robomimic_env(FLAGS.env_name):
# Adjust online (0, 1) reward for robomimic
int_reward = int_reward - 1.0
if FLAGS.sparse:
assert int_reward <= 0.0
int_reward = (int_reward != 0.0) * -1.0
transition = dict(
observations=ob,
actions=action,
rewards=int_reward,
terminals=float(done),
masks=1.0 - terminated,
next_observations=next_ob,
)
replay_buffer.add_transition(transition)
# done
if done:
ob, _ = env.reset()
action_queue = [] # reset the action queue
else:
ob = next_ob
if i >= FLAGS.start_training:
batch = replay_buffer.sample_sequence(config['batch_size'] * FLAGS.utd_ratio,
sequence_length=FLAGS.horizon_length, discount=discount)
batch = jax.tree.map(lambda x: x.reshape((
FLAGS.utd_ratio, config["batch_size"]) + x.shape[1:]), batch)
agent, update_info["online_agent"] = agent.batch_update(batch)
if i % FLAGS.log_interval == 0:
for key, info in update_info.items():
logger.log(info, key, step=log_step)
update_info = {}
if i == FLAGS.online_steps - 1 or \
(FLAGS.eval_interval != 0 and i % FLAGS.eval_interval == 0):
eval_info, _, _ = evaluate(
agent=agent,
env=eval_env,
action_dim=action_dim,
num_eval_episodes=FLAGS.eval_episodes,
num_video_episodes=FLAGS.video_episodes,
video_frame_skip=FLAGS.video_frame_skip,
)
logger.log(eval_info, "eval", step=log_step)
# saving
if FLAGS.save_interval > 0 and i % FLAGS.save_interval == 0:
save_agent(agent, FLAGS.save_dir, log_step)
end_time = time.time()
for key, csv_logger in logger.csv_loggers.items():
csv_logger.close()
if FLAGS.save_all_online_states:
c_data = {"steps": np.array(data["steps"]),
"qpos": np.stack(data["qpos"], axis=0),
"qvel": np.stack(data["qvel"], axis=0),
"obs": np.stack(data["obs"], axis=0),
"offline_time": online_init_time - offline_init_time,
"online_time": end_time - online_init_time,
}
if len(data["button_states"]) != 0:
c_data["button_states"] = np.stack(data["button_states"], axis=0)
np.savez(os.path.join(FLAGS.save_dir, "data.npz"), **c_data)
with open(os.path.join(FLAGS.save_dir, 'token.tk'), 'w') as f:
f.write(run.url)
if __name__ == '__main__':
app.run(main)