-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathppo_with_norm.py
More file actions
482 lines (400 loc) · 19 KB
/
ppo_with_norm.py
File metadata and controls
482 lines (400 loc) · 19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
import cProfile
import io
import os
import pstats
import time
from typing import Iterator
import numba
import numpy as np
import torch
import torch as th
import tqdm
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
from rocket_learn.experience_buffer import ExperienceBuffer
from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator
class PPO:
"""
Proximal Policy Optimization algorithm (PPO)
:param rollout_generator: Function that will generate the rollouts
:param agent: An ActorCriticAgent
:param n_steps: The number of steps to run per update
:param gamma: Discount factor
:param batch_size: batch size to break experience data into for training
:param epochs: Number of epoch when optimizing the loss
:param minibatch_size: size to break batch sets into (helps combat VRAM issues)
:param clip_range: PPO Clipping parameter for the value function
:param ent_coef: Entropy coefficient for the loss calculation
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: optional clip_grad_norm value
:param logger: wandb logger to store run results
:param device: torch device
"""
def __init__(
self,
rollout_generator: BaseRolloutGenerator,
agent: ActorCriticAgent,
n_steps=4096,
gamma=0.99,
batch_size=512,
epochs=10,
# reuse=2,
minibatch_size=None,
clip_range=0.2,
ent_coef=0.01,
gae_lambda=0.95,
vf_coef=1,
max_grad_norm=0.5,
logger=None,
device="cuda",
reward_clip=10,
):
self.rollout_generator = rollout_generator
# TODO let users choose their own agent
# TODO move agent to rollout generator
self.agent = agent.to(device)
self.device = device
self.starting_iteration = 0
# hyperparameters
self.epochs = epochs
self.gamma = gamma
# assert n_steps % batch_size == 0
# self.reuse = reuse
self.n_steps = n_steps
self.gae_lambda = gae_lambda
self.batch_size = batch_size
self.minibatch_size = minibatch_size or batch_size
assert self.batch_size % self.minibatch_size == 0
self.clip_range = clip_range
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.running_rew_mean = 0
self.running_rew_std = 1
self.running_rew_var = 1
self.running_rew_count = 1e-4
self.ema_reward_discount = 0.9
self.reward_clip = reward_clip
self.total_steps = 0
self.logger = logger
self.logger.watch((self.agent.actor, self.agent.critic))
self.timer = time.time_ns() // 1_000_000
self.jit_tracer = None
def update_reward_norm(self, rewards: np.ndarray) -> np.ndarray:
assert "Not Implemented update_reward_norm" # TODO needs to be fixed
batch_mean = np.mean(rewards)
batch_var = np.var(rewards)
batch_count = rewards.shape[0]
delta = batch_mean - self.running_rew_mean
tot_count = self.running_rew_count + batch_count
new_mean = self.running_rew_mean + delta * batch_count / tot_count
m_a = self.running_rew_var * self.running_rew_count
m_b = batch_var * batch_count
m_2 = m_a + m_b + np.square(delta) * self.running_rew_count * batch_count / (
self.running_rew_count + batch_count)
new_var = m_2 / (self.running_rew_count + batch_count)
new_count = batch_count + self.running_rew_count
self.running_rew_mean = new_mean
self.running_rew_var = new_var
self.running_rew_count = new_count
return (rewards - self.running_rew_mean) / np.sqrt(self.running_rew_var + 1e-8) # TODO normalize before update?
def _update_ema_reward_norm(self, ep_rewards: np.ndarray) -> None:
self.running_rew_mean = self.running_rew_mean * self.ema_reward_discount + ep_rewards.mean() * \
(1 - self.ema_reward_discount)
diff = ep_rewards.mean() - self.running_rew_mean
incr = (1 - self.ema_reward_discount) * diff
self.running_rew_var = self.ema_reward_discount * (self.running_rew_var + diff * incr)
self.running_rew_std = np.sqrt(self.running_rew_var)
return
def _normalize_reward(self, rewards: np.ndarray) -> np.ndarray:
return np.clip(rewards / (self.running_rew_std + 1e-32), -self.reward_clip, self.reward_clip)
def run(self, iterations_per_save=10, save_dir=None, save_jit=False):
"""
Generate rollout data and train
:param iterations_per_save: number of iterations between checkpoint saves
:param save_dir: where to save
"""
if save_dir:
current_run_dir = os.path.join(save_dir, self.logger.project + "_" + str(time.time()))
os.makedirs(current_run_dir)
elif iterations_per_save and not save_dir:
print("Warning: no save directory specified.")
print("Checkpoints will not be saved.")
iteration = self.starting_iteration
rollout_gen = self.rollout_generator.generate_rollouts()
self.rollout_generator.update_parameters(self.agent.actor)
while True:
# pr = cProfile.Profile()
# pr.enable()
t0 = time.time()
def _iter():
size = 0
print(f"Collecting rollouts ({iteration})...")
# progress = tqdm.tqdm(desc=f"Collecting rollouts ({iteration})", total=self.n_steps, position=0, leave=True)
while size < self.n_steps:
try:
rollout = next(rollout_gen)
if rollout.size() > 0:
size += rollout.size()
# progress.update(rollout.size())
yield rollout
except StopIteration:
return
self.calculate(_iter(), iteration)
iteration += 1
if save_dir and iteration % iterations_per_save == 0:
self.save(current_run_dir, iteration, save_jit) # noqa
self.rollout_generator.update_parameters(self.agent.actor)
self.total_steps += self.n_steps # size
t1 = time.time()
self.logger.log({"fps": self.n_steps / (t1 - t0), "total_timesteps": self.total_steps})
# pr.disable()
# s = io.StringIO()
# sortby = pstats.SortKey.CUMULATIVE
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
# ps.dump_stats(f"profile_{self.total_steps}")
def set_logger(self, logger):
self.logger = logger
def evaluate_actions(self, observations, actions):
"""
Calculate Log Probability and Entropy of actions
"""
dist = self.agent.actor.get_action_distribution(observations)
# indices = self.agent.get_action_indices(dists)
log_prob = self.agent.actor.log_prob(dist, actions)
entropy = self.agent.actor.entropy(dist, actions)
entropy = -torch.mean(entropy)
return log_prob, entropy
@staticmethod
@numba.njit
def _calculate_advantages_numba(rewards, values, gamma, gae_lambda):
advantages = np.zeros_like(rewards)
# v_targets = np.zeros_like(rewards)
dones = np.zeros_like(rewards)
dones[-1] = 1.
episode_starts = np.zeros_like(rewards)
episode_starts[0] = 1.
last_values = values[-1]
last_gae_lam = 0
size = len(advantages)
for step in range(size - 1, -1, -1):
if step == size - 1:
next_non_terminal = 1.0 - dones[-1].item()
next_values = last_values
else:
next_non_terminal = 1.0 - episode_starts[step + 1].item()
next_values = values[step + 1]
v_target = rewards[step] + gamma * next_values * next_non_terminal
delta = v_target - values[step]
last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
advantages[step] = last_gae_lam
# v_targets[step] = v_target
return advantages # , v_targets
def calculate(self, buffers: Iterator[ExperienceBuffer], iteration):
"""
Calculate loss and update network
"""
obs_tensors = []
act_tensors = []
# value_tensors = []
log_prob_tensors = []
# advantage_tensors = []
returns_tensors = []
rewards_tensors = []
ep_rewards = []
ep_std_rewards = []
rewards_before = []
rewards_after = []
ep_raw_rewards = []
ep_steps = []
n = 0
for buffer in buffers: # Do discounts for each ExperienceBuffer individually
if isinstance(buffer.observations[0], (tuple, list)):
transposed = tuple(zip(*buffer.observations))
obs_tensor = tuple(torch.from_numpy(np.vstack(t)).float() for t in transposed)
else:
obs_tensor = th.from_numpy(np.vstack(buffer.observations)).float()
with th.no_grad():
if isinstance(obs_tensor, tuple):
x = tuple(o.to(self.device) for o in obs_tensor)
else:
x = obs_tensor.to(self.device)
values = self.agent.critic(x).detach().cpu().numpy().flatten() # No batching?
actions = np.stack(buffer.actions)
log_probs = np.stack(buffer.log_probs)
rewards = np.stack(buffer.rewards)
dones = np.stack(buffer.dones)
size = rewards.shape[0]
episode_starts = np.roll(dones, 1)
episode_starts[0] = 1.
# T ODO testing
rewards_before.append(rewards.sum())
# normalize before update
ep_raw_rewards.append(rewards.sum())
rewards = self._normalize_reward(rewards)
# T ODO testing
rewards_after.append(rewards.sum())
advantages = self._calculate_advantages_numba(rewards, values, self.gamma, self.gae_lambda)
returns = advantages + values
obs_tensors.append(obs_tensor)
act_tensors.append(th.from_numpy(actions))
log_prob_tensors.append(th.from_numpy(log_probs))
returns_tensors.append(th.from_numpy(returns))
rewards_tensors.append(th.from_numpy(rewards))
ep_rewards.append(rewards.sum())
ep_steps.append(size)
n += 1
ep_rewards = np.array(ep_rewards)
self._update_ema_reward_norm(ep_rewards)
ep_steps = np.array(ep_steps)
ep_raw_rewards = np.array(ep_raw_rewards)
# TODO testing
rewards_before = np.array(rewards_before)
rewards_after = np.array(rewards_after)
print(f"before_std={rewards_before.std()} - after_std={rewards_after.std()}")
print(f"run_mean={self.running_rew_mean} - run_std={self.running_rew_std}")
# mean rewards per step
step_rewards = ep_rewards.mean() / ep_steps.mean()
self.logger.log({
"ep_raw_reward_mean": ep_raw_rewards.mean(),
"ep_raw_reward_std": ep_raw_rewards.std(),
"per_step_reward_mean": step_rewards,
"ep_reward_mean": ep_rewards.mean(),
"ep_reward_std": ep_rewards.std(),
"ep_len_mean": ep_steps.mean(),
}, step=iteration, commit=False)
if isinstance(obs_tensors[0], tuple):
transposed = zip(*obs_tensors)
obs_tensor = tuple(th.cat(t).float() for t in transposed)
else:
obs_tensor = th.cat(obs_tensors).float()
act_tensor = th.cat(act_tensors)
log_prob_tensor = th.cat(log_prob_tensors).float()
# advantages_tensor = th.cat(advantage_tensors)
returns_tensor = th.cat(returns_tensors).float()
tot_loss = 0
tot_policy_loss = 0
tot_entropy_loss = 0
tot_value_loss = 0
total_kl_div = 0
tot_clipped = 0
n = 0
print("Training network...")
precompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()])
t0 = time.perf_counter_ns()
self.agent.optimizer.zero_grad()
for e in range(self.epochs):
# this is mostly pulled from sb3
indices = torch.randperm(returns_tensor.shape[0])[:self.batch_size]
if isinstance(obs_tensor, tuple):
obs_batch = tuple(o[indices] for o in obs_tensor)
else:
obs_batch = obs_tensor[indices]
act_batch = act_tensor[indices]
log_prob_batch = log_prob_tensor[indices]
# advantages_batch = advantages_tensor[indices]
returns_batch = returns_tensor[indices]
for i in range(0, self.batch_size, self.minibatch_size):
# Note: Will cut off final few samples
if isinstance(obs_tensor, tuple):
obs = tuple(o[i: i + self.minibatch_size].to(self.device) for o in obs_batch)
else:
obs = obs_batch[i: i + self.minibatch_size].to(self.device)
act = act_batch[i: i + self.minibatch_size].to(self.device)
# adv = advantages_batch[i:i + self.minibatch_size].to(self.device)
ret = returns_batch[i: i + self.minibatch_size].to(self.device)
old_log_prob = log_prob_batch[i: i + self.minibatch_size].to(self.device)
# TODO optimization: use forward_actor_critic instead of separate in case shared, also use GPU
log_prob, entropy = self.evaluate_actions(obs, act) # Assuming obs and actions as input
ratio = torch.exp(log_prob - old_log_prob)
values_pred = self.agent.critic(obs)
values_pred = th.squeeze(values_pred)
adv = ret - values_pred
# adv = (adv - th.mean(adv)) / (th.std(adv) + 1e-8)
adv = adv / (th.std(adv) + 1e-8) # updated to remove the mean subtraction
# clipped surrogate loss
policy_loss_1 = adv * ratio
policy_loss_2 = adv * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()
# **If we want value clipping, add it here**
value_loss = F.mse_loss(ret, values_pred)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = entropy
loss = ((policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss)
/ (self.batch_size / self.minibatch_size))
loss.backward()
# Unbiased low variance KL div estimator from http://joschu.net/blog/kl-approx.html
total_kl_div += th.mean((ratio - 1) - (log_prob - old_log_prob)).item()
tot_loss += loss.item()
tot_policy_loss += policy_loss.item()
tot_entropy_loss += entropy_loss.item()
tot_value_loss += value_loss.item()
tot_clipped += th.mean((th.abs(ratio - 1) > self.clip_range).float()).item()
n += 1
# pb.update(self.minibatch_size)
# Clip grad norm
if self.max_grad_norm is not None:
clip_grad_norm_(self.agent.actor.parameters(), self.max_grad_norm)
self.agent.optimizer.step()
self.agent.optimizer.zero_grad()
t1 = time.perf_counter_ns()
postcompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()])
self.logger.log({
"loss": tot_loss / n,
"policy_loss": tot_policy_loss / n,
"entropy_loss": tot_entropy_loss / n,
"value_loss": tot_value_loss / n,
"mean_kl": total_kl_div / n,
"clip_fraction": tot_clipped / n,
"epoch_time": (t1 - t0) / (1e6 * self.epochs),
"update_magnitude": th.dist(precompute, postcompute, p=2),
}, step=iteration, commit=False) # Is committed after when calculating fps
def load(self, load_location, continue_iterations=True):
"""
load the model weights, optimizer values, and metadata
:param load_location: checkpoint folder to read
:param continue_iterations: keep the same training steps
"""
checkpoint = torch.load(load_location)
self.agent.actor.load_state_dict(checkpoint['actor_state_dict'])
self.agent.critic.load_state_dict(checkpoint['critic_state_dict'])
# self.agent.shared.load_state_dict(checkpoint['shared_state_dict'])
self.agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.running_rew_mean = checkpoint['reward_norm_mean']
# self.running_rew_std = checkpoint['reward_norm_std']
self.running_rew_var = checkpoint['rward_norm_var'] # TODO fix typo
self.running_rew_count = checkpoint['reward_norm_count']
self.ema_reward_discount = checkpoint['ema_reward_didscount'] # TODO fix typo
if continue_iterations:
self.starting_iteration = checkpoint['epoch']
self.total_steps = checkpoint["total_steps"]
print("Continuing training at iteration " + str(self.starting_iteration))
def save(self, save_location, current_step, save_actor_jit=False):
"""
Save the model weights, optimizer values, and metadata
:param save_location: where to save
:param current_step: the current iteration when saved. Use to later continue training
"""
version_str = str(self.logger.project) + "_" + str(current_step)
version_dir = save_location + "\\" + version_str
os.makedirs(version_dir)
torch.save({
'epoch': current_step,
"total_steps": self.total_steps,
'actor_state_dict': self.agent.actor.state_dict(),
'critic_state_dict': self.agent.critic.state_dict(),
# 'shared_state_dict': self.agent.shared.state_dict(),
'optimizer_state_dict': self.agent.optimizer.state_dict(),
'reward_norm_mean': self.running_rew_mean,
# 'reward_norm_std': self.running_rew_std,
'reward_norm_var': self.running_rew_var,
'reward_norm_count': self.running_rew_count,
'ema_reward_discount': self.ema_reward_discount,
}, version_dir + "\\checkpoint.pt")
if save_actor_jit:
torch.save(th.jit.trace(self.agent.actor, self.jit_tracer), version_dir + "\\policy.jit")