-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathreplay_buffer.py
More file actions
54 lines (51 loc) · 2.94 KB
/
replay_buffer.py
File metadata and controls
54 lines (51 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import config
import numpy as np
class ReplayBuffer:
def __init__(self):
o_dim = config.obs_shape
s_dim = config.state_shape
max_epi_len = config.episode_limit
n_agents = config.n_agents
n_actions = config.n_actions
self.o_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, o_dim), dtype=np.float32)
self.u_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, 1), dtype=np.float32)
self.s_buf = np.zeros((config.buffer_size, max_epi_len, s_dim), dtype=np.float32)
self.r_buf = np.zeros((config.buffer_size, max_epi_len, 1), dtype=np.float32)
self.d_buf = np.zeros((config.buffer_size, max_epi_len, 1), dtype=np.float32)
self.pad_buf = np.zeros((config.buffer_size, max_epi_len, 1), dtype=np.float32)
self.o2_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, o_dim), dtype=np.float32)
self.s2_buf = np.zeros((config.buffer_size, max_epi_len, s_dim), dtype=np.float32)
self.avail_u_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, n_actions), dtype=np.float32)
self.avail_u2_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, n_actions), dtype=np.float32)
self.u_onehot_buf = np.zeros((config.buffer_size, max_epi_len, n_agents, n_actions), dtype=np.float32)
self.ptr, self.size, self.max_size = 0, 0, config.buffer_size
# def store_epi(self, o, u, s, r, d, pad, o2, s2, avail_u, avail_u2, u_onehot):
def store_epi(self, episode):
self.o_buf[self.ptr, :, :, :] = episode['o']
self.u_buf[self.ptr, :, :, :] = episode['u']
self.s_buf[self.ptr, :, :] = episode['s']
self.r_buf[self.ptr, :, :] = episode['r']
self.d_buf[self.ptr, :, :] = episode['d']
self.pad_buf[self.ptr, :, :] = episode['pad']
self.o2_buf[self.ptr, :, :, :] = episode['o2']
self.s2_buf[self.ptr, :, :] = episode['s2']
self.avail_u_buf[self.ptr, :, :, :] = episode['avail_u']
self.avail_u2_buf[self.ptr, :, :, :] = episode['avail_u2']
self.u_onehot_buf[self.ptr, :, :, :] = episode['u_onehot']
self.ptr = (self.ptr+1) % self.max_size
self.size = min(self.size+1, self.max_size)\
def sample_batch(self, batch_size=32):
idxs = np.random.randint(0, self.size, size=batch_size)
batch = dict(o=self.o_buf[idxs],
u=self.u_buf[idxs],
s=self.s_buf[idxs],
r=self.r_buf[idxs],
d=self.d_buf[idxs],
pad=self.pad_buf[idxs],
o2=self.o2_buf[idxs],
s2=self.s2_buf[idxs],
avail_u=self.avail_u_buf[idxs],
avail_u2=self.avail_u2_buf[idxs],
u_onehot=self.u_onehot_buf[idxs])
return {k: torch.as_tensor(v, device=torch.device(config.device)) for k,v in batch.items()}