-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainDP.py
More file actions
205 lines (183 loc) · 8.06 KB
/
trainDP.py
File metadata and controls
205 lines (183 loc) · 8.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# local file import
from pushTimageEnv import PushTImageEnv
from pushTdataset import PushTImageDataset, gdown
from network import get_resnet, replace_bn_with_gn, ConditionalUnet1D
# diffusion policy import
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
# basic library
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
import os
## Create Env
# Standard Gym Env (0.21.0 API)
# 0. create env object
env = PushTImageEnv()
# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)
# 2. must reset before use
obs, info = env.reset()
# 3. 2D positional action space [0,512]
action = env.action_space.sample()
# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)
## Dataset
# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
gdown.download(id=id, output=dataset_path, quiet=False)
# parameters
#|o|o| observations: 2 (包括image和agent_pos)
#| |a|a|a|a|a|a|a|a| actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16
pred_horizon = 16 #此为论文Fig.3中 Diffusion Policy的预测步数 T_{p}
obs_horizon = 2 #此为论文Fig.3中 输入Diffusion Policy的 latest T_{o}
action_horizon = 8 #此为论文中的执行步数 T_{a}
gradient_accumulate_every = 1 #累计梯度机制,可节省训练过程中显存(可参考diffusion policy原始仓库该变量及相关博客,例 https://zhuanlan.zhihu.com/p/454876670)
global_train_step = 0 #与gradient_accumulate_every一起使用,记录训练步数
# create dataset from file
dataset = PushTImageDataset( #共24208条数据。一共25650条 减去 204组演示数据 乘以 一组中(action_horizon-1=7)不填充的数据量 等于 24208条可用
dataset_path=dataset_path,
pred_horizon=pred_horizon,
obs_horizon=obs_horizon,
action_horizon=action_horizon
)
# create dataloader
dataloader = torch.utils.data.DataLoader( #批处理后,共24208/64=378组数据
dataset,
batch_size=64,
num_workers=4,
shuffle=True,
# accelerate cpu-gpu transfer
pin_memory=True,
# don't kill worker process afte each epoch
persistent_workers=True
)
## Network
# construc ResNet18 encoder
vision_encoder = get_resnet('resnet18')
# IMPORTANT! replace all BatchNorm with GroupNorm to work with EMA, performance will tank if you forget to do this!
vision_encoder = replace_bn_with_gn(vision_encoder)
# ResNet18 has output dim of 512
vision_feature_dim = 512
# agent_pos is 2 dimensional
lowdim_obs_dim = 2
# observation feature has 514 dims in total per step
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2
# create network object
noise_pred_net = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=obs_dim*obs_horizon
)
# the final arch has 2 parts
nets = nn.ModuleDict({
'vision_encoder': vision_encoder,
'noise_pred_net': noise_pred_net
})
# DDPM Scheduler with 100 diffusion interations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
num_train_timesteps=num_diffusion_iters,
# the choise of beta schedule has big impact on performance
# we found squared cosine works the best
beta_schedule='squaredcos_cap_v2',
# clip output to [-1,1] to improve stability
clip_sample=True,
# our network predicts noise (instead of denoised action)
prediction_type='epsilon'
)
# device transfer
device = torch.device('cuda')
_ = nets.to(device)
## Training
num_epochs = 100
# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
parameters=nets.parameters(),
power=0.75)
# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
params=nets.parameters(),
lr=1e-4, weight_decay=1e-6)
# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
name='cosine',
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=len(dataloader) * num_epochs // gradient_accumulate_every,
last_epoch=global_train_step-1
)
with tqdm(range(num_epochs), desc='Epoch') as tglobal:
# epoch loop
for epoch_idx in tglobal:
epoch_loss = list()
# batch loop
with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
for batch_idx, nbatch in enumerate(tepoch):
# data normalized in dataset, device transfer
# 注: [:,:obs_horizon] 实际想做 pushTdataset.py 中 nsample['image'] = nsample['image'][:self.obs_horizon,:]做的事情,所以此处作用重复
nimage = nbatch['image'][:,:obs_horizon].to(device) # [64, 2, 3, 96, 96]
nagent_pos = nbatch['agent_pos'][:,:obs_horizon].to(device) # [64, 2, 2]
naction = nbatch['action'].to(device) # [64, 16, 2]
B = nagent_pos.shape[0] # 64
# encoder vision features, input var 'nimage.flatten(end_dim=1).shape' is [128,3,96,96]
image_features = nets['vision_encoder'](nimage.flatten(end_dim=1)) # [128,512]
# reshape input var 'nimage.shape[:2]' is [64,2]
image_features = image_features.reshape(*nimage.shape[:2],-1) # [64,2,512]
# (B,obs_horizon,D)
# concatenate vision feature and low-dim obs
obs_features = torch.cat([image_features, nagent_pos], dim=-1) # [64,2,514]
obs_cond = obs_features.flatten(start_dim=1) # [64,2*514]
# (B, obs_horizon * obs_dim)
# sample noise to add to actions
noise = torch.randn(naction.shape, device=device) # [64, 16, 2]
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps,
(B,), device=device
).long() # [64]
# add noise to the clean images (not actions??) according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = noise_scheduler.add_noise(
naction, noise, timesteps) # [64, 16, 2]
# predict the noise residual (with condion obs_cond)
noise_pred = noise_pred_net(
noisy_actions, timesteps, global_cond=obs_cond) # [64, 16, 2]
# L2 loss
raw_loss = nn.functional.mse_loss(noise_pred, noise)
loss = raw_loss / gradient_accumulate_every
# optimize
loss.backward()
print("loss.backward()")
if global_train_step % gradient_accumulate_every == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
print("optimizer.step()")
# update Exponential Moving Average of the model weights
ema.step(nets.parameters())
is_last_batch = (batch_idx == (len(dataloader)-1))
if not is_last_batch:
global_train_step += 1
# logging
loss_cpu = loss.item()
epoch_loss.append(loss_cpu)
tepoch.set_postfix(loss=loss_cpu)
tglobal.set_postfix(loss=np.mean(epoch_loss))
# Weights of the EMA model
# is used for inference
ema_nets = nets
ema.copy_to(ema_nets.parameters())
print("Train End.")
# 保存参数模型到本地检查点文件
torch.save(ema_nets.state_dict(), "simpledp.ckpt")
print("Model parameters saved to 'simpledp.ckpt'.")