From a5991f19c13f1257247c17ad6be8e7187c05709b Mon Sep 17 00:00:00 2001 From: jsrimr Date: Fri, 13 May 2022 04:00:35 +0000 Subject: [PATCH 01/10] diayn_with_controller first commit --- agent/diayn_with_controller.py | 87 ++++++++++++++++++++++++++++++++ agent/diayn_with_controller.yaml | 25 +++++++++ finetune.py | 4 +- 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 agent/diayn_with_controller.py create mode 100644 agent/diayn_with_controller.yaml diff --git a/agent/diayn_with_controller.py b/agent/diayn_with_controller.py new file mode 100644 index 0000000..c5a7e21 --- /dev/null +++ b/agent/diayn_with_controller.py @@ -0,0 +1,87 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from dm_env import specs + +import utils +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + +class DIAYNAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + + # create actor and critic + super().__init__(**kwargs) + + # create diayn + self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + kwargs['hidden_dim']).to(kwargs['device']) + + # loss criterion + self.diayn_criterion = nn.CrossEntropyLoss() + # optimizers + self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr) + + self.diayn.train() + + def act(self, obs, meta, step, eval_mode): + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) + h = self.encoder(obs) + # inputs = [h] + # for value in meta.values(): # skill + # value = torch.as_tensor(value, device=self.device).unsqueeze(0) + # inputs.append(value) + meta = self.diayn(obs) # (B, *obs) + inpt = torch.cat([h, meta], dim=-1) + #assert obs.shape[-1] == self.obs_shape[-1] + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(inpt, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + return action.cpu().numpy()[0] + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) + + def init_meta(self): + return OrderedDict() + + def update_meta(self, meta, global_step, time_step, finetune=False): + return meta diff --git a/agent/diayn_with_controller.yaml b/agent/diayn_with_controller.yaml new file mode 100644 index 0000000..9b3f891 --- /dev/null +++ b/agent/diayn_with_controller.yaml @@ -0,0 +1,25 @@ +# @package agent +_target_: agent.diayn.DIAYNAgent +name: diayn +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_critic: true +update_encoder: ${update_encoder} diff --git a/finetune.py b/finetune.py index 373f816..5f9b65b 100644 --- a/finetune.py +++ b/finetune.py @@ -187,9 +187,9 @@ def train(self): self.global_frame) self.eval() - meta = self.agent.update_meta(meta, self.global_step, time_step) + # meta = self.agent.update_meta(meta, self.global_step, time_step) - if hasattr(self.agent, "regress_meta"): + if hasattr(self.agent, "regress_meta"): # aps 라는 알고리즘이 사용 repeat = self.cfg.action_repeat every = self.agent.update_task_every_step // repeat init_step = self.agent.num_init_steps From 172b9486fa33dc828de0cfd5e71ea15b64a9fbaa Mon Sep 17 00:00:00 2001 From: jsrimr Date: Fri, 13 May 2022 09:18:31 +0000 Subject: [PATCH 02/10] =?UTF-8?q?controller=20=EB=A1=9C=EC=A7=81=20?= =?UTF-8?q?=EA=B5=AC=ED=98=84=20-=20controller=20=EB=8A=94=201=EA=B0=9C?= =?UTF-8?q?=EC=9D=98=20skill=20=EC=9D=84=20=EC=83=98=ED=94=8C=ED=95=A8=20-?= =?UTF-8?q?=20weight=20=EB=8A=94=20diayn=20=EC=97=90=EC=84=9C=20=EA=B0=80?= =?UTF-8?q?=EC=A0=B8=EC=98=B4=20-=20=EA=B7=B8=EB=A6=AC=EA=B3=A0=20state=20?= =?UTF-8?q?=EC=97=90=20=EB=A7=9E=EB=8A=94=20skill=20=EC=9D=84=20=EB=A6=AC?= =?UTF-8?q?=ED=84=B4=20-=20controller=20=EC=97=85=EB=8D=B0=EC=9D=B4?= =?UTF-8?q?=ED=8A=B8=20=EB=B0=A9=EC=8B=9D=EC=9D=80=20REINFORCE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/diayn_with_controller.py | 125 ++++++++++++++++++++++++++----- agent/diayn_with_controller.yaml | 4 +- 2 files changed, 109 insertions(+), 20 deletions(-) diff --git a/agent/diayn_with_controller.py b/agent/diayn_with_controller.py index c5a7e21..170164f 100644 --- a/agent/diayn_with_controller.py +++ b/agent/diayn_with_controller.py @@ -28,7 +28,7 @@ def forward(self, obs): return skill_pred -class DIAYNAgent(DDPGAgent): +class DIAYNwithController(DDPGAgent): def __init__(self, update_skill_every_step, skill_dim, diayn_scale, update_encoder, **kwargs): self.skill_dim = skill_dim @@ -45,24 +45,24 @@ def __init__(self, update_skill_every_step, skill_dim, diayn_scale, self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, kwargs['hidden_dim']).to(kwargs['device']) - # loss criterion - self.diayn_criterion = nn.CrossEntropyLoss() # optimizers self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr) self.diayn.train() def act(self, obs, meta, step, eval_mode): - obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) - h = self.encoder(obs) - # inputs = [h] - # for value in meta.values(): # skill - # value = torch.as_tensor(value, device=self.device).unsqueeze(0) - # inputs.append(value) - meta = self.diayn(obs) # (B, *obs) - inpt = torch.cat([h, meta], dim=-1) - #assert obs.shape[-1] == self.obs_shape[-1] - stddev = utils.schedule(self.stddev_schedule, step) + + with torch.no_grad(): + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) + h = self.encoder(obs) + x = self.diayn(obs) # (B, skill_dim) + promising_skills = torch.argmax(x, dim=1) # (B,) + meta = torch.zeros_like(x).scatter_(1, promising_skills.unsqueeze(1), 1.) + + inpt = torch.cat([h, meta], dim=-1) + #assert obs.shape[-1] == self.obs_shape[-1] + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(inpt, stddev) if eval_mode: action = dist.mean @@ -70,7 +70,8 @@ def act(self, obs, meta, step, eval_mode): action = dist.sample(clip=None) if step < self.num_expl_steps: action.uniform_(-1.0, 1.0) - return action.cpu().numpy()[0] + + return action.cpu().numpy()[0], promising_skills def init_from(self, other): # copy parameters over @@ -80,8 +81,96 @@ def init_from(self, other): if self.init_critic: utils.hard_update_params(other.critic.trunk, self.critic.trunk) - def init_meta(self): - return OrderedDict() + def get_meta_specs(self): + return (specs.Array((self.skill_dim,), np.float32, 'skill'),) + + + def update_actor_and_controller(self, obs, step): + metrics = dict() + # update controller + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(obs, stddev) + action = dist.sample(clip=self.stddev_clip) + log_prob = dist.log_prob(action).sum(-1, keepdim=True) + Q1, Q2 = self.critic(obs, action) + Q = torch.min(Q1, Q2) + + actor_loss = -Q.mean() + + # optimize actor + self.actor_opt.zero_grad(set_to_none=True) + actor_loss.backward() + # actor_loss.backward(retain_graph=True) + self.actor_opt.step() + + if self.use_tb or self.use_wandb: + metrics['actor_loss'] = actor_loss.item() + metrics['actor_logprob'] = log_prob.mean().item() + metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() + + # update controller + skill_pred = self.diayn(obs[:, :-self.skill_dim]) + actual_skill = obs[:, -self.skill_dim:].argmax(dim=1) + loss = - Q.mean().detach() * skill_pred.gather(1, actual_skill) + + self.diayn_opt.zero_grad() + if self.encoder_opt is not None: + self.encoder_opt.zero_grad(set_to_none=True) + loss.backward() + self.diayn_opt.step() + if self.encoder_opt is not None: + self.encoder_opt.step() + + if self.use_tb or self.use_wandb: + metrics['diayn_loss'] = loss.item() + + return metrics + + + def update(self, replay_iter, step): + """ + only ext_reward logic + update controller + + """ + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs, skill = utils.to_torch( + batch, self.device) + + # augment and encode + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + # extend observations with skill + obs = torch.cat([obs, skill], dim=1) + next_obs = torch.cat([next_obs, skill], dim=1) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor and controller + metrics.update(self.update_actor_and_controller(obs.detach(), step)) + + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) - def update_meta(self, meta, global_step, time_step, finetune=False): - return meta + return metrics \ No newline at end of file diff --git a/agent/diayn_with_controller.yaml b/agent/diayn_with_controller.yaml index 9b3f891..4a3773b 100644 --- a/agent/diayn_with_controller.yaml +++ b/agent/diayn_with_controller.yaml @@ -1,6 +1,6 @@ # @package agent -_target_: agent.diayn.DIAYNAgent -name: diayn +_target_: agent.diayn_with_controller.DIAYNwithController +name: diayn_with_controller reward_free: ${reward_free} obs_type: ??? # to be specified later obs_shape: ??? # to be specified later From 46236f4e8b51434d6dda9e254d2ebf7110211007 Mon Sep 17 00:00:00 2001 From: jsrimr Date: Sat, 14 May 2022 06:07:27 +0000 Subject: [PATCH 03/10] =?UTF-8?q?controller=20=EC=97=90=EC=84=9C=20softmax?= =?UTF-8?q?=20=EB=A1=9C=20state=20=EB=A7=88=EB=8B=A4=20skill=20=EB=8C=80?= =?UTF-8?q?=EC=9D=91=EC=8B=9C=EC=BC=9C=EC=A3=BC=EA=B8=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/diayn_with_controller.py | 43 +++++++++++++++++++++++----------- finetune.py | 3 +++ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/agent/diayn_with_controller.py b/agent/diayn_with_controller.py index 170164f..407812c 100644 --- a/agent/diayn_with_controller.py +++ b/agent/diayn_with_controller.py @@ -6,9 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +import utils from dm_env import specs +from torch.distributions import Categorical -import utils from agent.ddpg import DDPGAgent @@ -37,6 +38,7 @@ def __init__(self, update_skill_every_step, skill_dim, diayn_scale, self.update_encoder = update_encoder # increase obs shape to include skill dim kwargs["meta_dim"] = self.skill_dim + self.current_meta = None # create actor and critic super().__init__(**kwargs) @@ -44,6 +46,7 @@ def __init__(self, update_skill_every_step, skill_dim, diayn_scale, # create diayn self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, kwargs['hidden_dim']).to(kwargs['device']) + self.softmax = nn.Softmax(dim=1) # optimizers self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr) @@ -51,17 +54,22 @@ def __init__(self, update_skill_every_step, skill_dim, diayn_scale, self.diayn.train() def act(self, obs, meta, step, eval_mode): - + """ + meta from passed parameter is useless + """ with torch.no_grad(): obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) h = self.encoder(obs) x = self.diayn(obs) # (B, skill_dim) - promising_skills = torch.argmax(x, dim=1) # (B,) - meta = torch.zeros_like(x).scatter_(1, promising_skills.unsqueeze(1), 1.) - - inpt = torch.cat([h, meta], dim=-1) - #assert obs.shape[-1] == self.obs_shape[-1] - stddev = utils.schedule(self.stddev_schedule, step) + skill_dist = Categorical(self.softmax(x)) + skill = skill_dist.sample() # Or, skill = torch.argmax(x, dim=1) # (B,1) + meta = torch.zeros(self.skill_dim, device=self.device) + meta[skill] = 1.0 + self.current_meta = meta + + inpt = torch.cat([h, meta.unsqueeze(0)], dim=-1) + #assert obs.shape[-1] == self.obs_shape[-1] + stddev = utils.schedule(self.stddev_schedule, step) dist = self.actor(inpt, stddev) if eval_mode: @@ -71,7 +79,7 @@ def act(self, obs, meta, step, eval_mode): if step < self.num_expl_steps: action.uniform_(-1.0, 1.0) - return action.cpu().numpy()[0], promising_skills + return action.cpu().numpy()[0] def init_from(self, other): # copy parameters over @@ -84,6 +92,10 @@ def init_from(self, other): def get_meta_specs(self): return (specs.Array((self.skill_dim,), np.float32, 'skill'),) + def get_current_meta(self): + meta = OrderedDict() + meta['skill'] = self.current_meta.cpu().numpy() + return meta def update_actor_and_controller(self, obs, step): metrics = dict() @@ -109,9 +121,14 @@ def update_actor_and_controller(self, obs, step): metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() # update controller - skill_pred = self.diayn(obs[:, :-self.skill_dim]) - actual_skill = obs[:, -self.skill_dim:].argmax(dim=1) - loss = - Q.mean().detach() * skill_pred.gather(1, actual_skill) + + skill_prob = self.softmax(self.diayn(obs[:, :-self.skill_dim])) + skill_dist = Categorical(skill_prob) + + skill_used = obs[:, -self.skill_dim:].argmax(dim=1) + skill_dist.log_prob(torch.tensor(1, device="cuda:0")).shape + # loss = - Q.mean().detach() * skill_pred.gather(1, skill_used.unsqueeze(0)).log_prob() + loss = (-Q.detach() * skill_dist.log_prob(skill_used).unsqueeze(1)).mean() self.diayn_opt.zero_grad() if self.encoder_opt is not None: @@ -126,7 +143,6 @@ def update_actor_and_controller(self, obs, step): return metrics - def update(self, replay_iter, step): """ only ext_reward logic @@ -168,7 +184,6 @@ def update(self, replay_iter, step): # update actor and controller metrics.update(self.update_actor_and_controller(obs.detach(), step)) - # update critic target utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) diff --git a/finetune.py b/finetune.py index 5f9b65b..35a3950 100644 --- a/finetune.py +++ b/finetune.py @@ -205,6 +205,9 @@ def train(self): self.global_step, eval_mode=False) + if hasattr(self.agent, "get_current_meta"): # skill_controller 라는 알고리즘이 사용 + meta = self.agent.get_current_meta() + # try to update the agent if not seed_until_step(self.global_step): metrics = self.agent.update(self.replay_iter, self.global_step) From f48dafb94e615603ed0b405a1b8bdc593bed6d82 Mon Sep 17 00:00:00 2001 From: jsrimr Date: Sat, 14 May 2022 07:34:15 +0000 Subject: [PATCH 04/10] apply advantage learning --- agent/diayn_with_controller.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agent/diayn_with_controller.py b/agent/diayn_with_controller.py index 407812c..f184663 100644 --- a/agent/diayn_with_controller.py +++ b/agent/diayn_with_controller.py @@ -128,7 +128,8 @@ def update_actor_and_controller(self, obs, step): skill_used = obs[:, -self.skill_dim:].argmax(dim=1) skill_dist.log_prob(torch.tensor(1, device="cuda:0")).shape # loss = - Q.mean().detach() * skill_pred.gather(1, skill_used.unsqueeze(0)).log_prob() - loss = (-Q.detach() * skill_dist.log_prob(skill_used).unsqueeze(1)).mean() + advantage = (Q - Q.mean(dim=1)).detach() + loss = -(advantage * skill_dist.log_prob(skill_used).unsqueeze(1)).mean() self.diayn_opt.zero_grad() if self.encoder_opt is not None: From 52f6501acbbca2b269d0e389b5d997df113ab7d8 Mon Sep 17 00:00:00 2001 From: jsrimr Date: Sat, 14 May 2022 11:02:59 +0000 Subject: [PATCH 05/10] =?UTF-8?q?Actor=20=EB=A5=BC=20=EC=83=88=EB=A1=9C=20?= =?UTF-8?q?=EB=A7=8C=EB=93=9C=EB=8A=94=EA=B2=8C=20=EB=A7=9E=EB=82=98=3F=20?= =?UTF-8?q?update=20=ED=95=A8=EC=88=98=EB=A5=BC=20=EC=98=A4=EB=B2=84?= =?UTF-8?q?=EB=9D=BC=EC=9D=B4=EB=93=9C=20=ED=95=98=EB=8A=94=EA=B2=8C=20?= =?UTF-8?q?=EB=A7=9E=EB=82=98=3F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/diayn_as_importance_predictor.py | 134 +++++++++++++++++++++++ agent/diayn_as_importance_predictor.yaml | 25 +++++ 2 files changed, 159 insertions(+) create mode 100644 agent/diayn_as_importance_predictor.py create mode 100644 agent/diayn_as_importance_predictor.yaml diff --git a/agent/diayn_as_importance_predictor.py b/agent/diayn_as_importance_predictor.py new file mode 100644 index 0000000..ee384a5 --- /dev/null +++ b/agent/diayn_as_importance_predictor.py @@ -0,0 +1,134 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import utils +from dm_env import specs +from agent.ddpg import Actor, DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + +class SkillMixingActor(Actor): + def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, pretrained_actor, diayn): + super().__init__(obs_type, obs_dim, action_dim, feature_dim, hidden_dim) + + self.pretrained_actor = pretrained_actor + self.diayn = diayn + + def forward(self, obs, std): + + return super().forward(obs, std) + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + + # create actor and critic + super().__init__(**kwargs) + self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + kwargs['hidden_dim']).to(kwargs['device']) + self.diayn.train() + + self.pretrained_actor = self.actor + self.actor = SkillMixingActor(self.obs_type, self.obs_dim, self.action_dim, + self.feature_dim, self.hidden_dim, self.pretrained_actor, self.diayn).to(kwargs['device']) + + self.actor_opt = torch.optim.Adam(self.actor.parameters()) + list(self.diayn.parameters()), lr=self.lr) + + + def update_actor(self, obs, step): + metrics = dict() + + obs = torch.as_tensor(obs, device=self.device) # (B, obs_dim) + skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + + obs = obs.unsqueeze(1).repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + obs_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + inpt = obs_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) + h = self.encoder(inpt) + + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + action = dist.sample(clip=self.stddev_clip) + log_prob = dist.log_prob(action).sum(-1, keepdim=True) + Q1, Q2 = self.critic(obs, action) + Q = torch.min(Q1, Q2) + + actor_loss = -Q.mean() + + # optimize actor + self.actor_opt.zero_grad(set_to_none=True) + actor_loss.backward() + # actor_loss.backward(retain_graph=True) + self.actor_opt.step() + + if self.use_tb or self.use_wandb: + metrics['actor_loss'] = actor_loss.item() + metrics['actor_logprob'] = log_prob.mean().item() + metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() + + return metrics + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + with torch.no_grad(): + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + skill_weight = F.softmax(self.diayn(obs), dim=-1) # (1, skill_dim) + + obs = obs.repeat(self.skill_dim, 1) # (skill_dim, obs_dim) + skill_list = torch.eye(self.skill_dim, device=self.device) # (skill_dim, skill_dim) + obs_with_skill = torch.cat([obs, skill_list], dim=-1) # (skill_dim, skill_dim + obs_dim) + + inpt = obs_with_skill * skill_weight.view(self.skill_dim, 1) # (skill_dim, skill_dim + obs_dim) + + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.pretrained_actor) + utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_as_importance_predictor.yaml b/agent/diayn_as_importance_predictor.yaml new file mode 100644 index 0000000..9207a0d --- /dev/null +++ b/agent/diayn_as_importance_predictor.yaml @@ -0,0 +1,25 @@ +# @package agent +_target_: agent.diayn_as_importance_predictor.DIAYNasWeightPredictorAgent +name: diayn_as_importance_predictor +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_critic: true +update_encoder: ${update_encoder} From ffa5bf839273bfa59dd287fefec2bfd2c87147ad Mon Sep 17 00:00:00 2001 From: jsrimr Date: Sat, 14 May 2022 12:43:02 +0000 Subject: [PATCH 06/10] implemented simple weight on skill strategy --- agent/diayn_as_importance_predictor.py | 139 +++++++++++++------------ 1 file changed, 74 insertions(+), 65 deletions(-) diff --git a/agent/diayn_as_importance_predictor.py b/agent/diayn_as_importance_predictor.py index ee384a5..534c177 100644 --- a/agent/diayn_as_importance_predictor.py +++ b/agent/diayn_as_importance_predictor.py @@ -6,9 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from zmq import device import utils from dm_env import specs -from agent.ddpg import Actor, DDPGAgent +from agent.ddpg import DDPGAgent class DIAYN(nn.Module): @@ -26,18 +27,8 @@ def forward(self, obs): skill_pred = self.skill_pred_net(obs) return skill_pred -class SkillMixingActor(Actor): - def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, pretrained_actor, diayn): - super().__init__(obs_type, obs_dim, action_dim, feature_dim, hidden_dim) - self.pretrained_actor = pretrained_actor - self.diayn = diayn - def forward(self, obs, std): - - return super().forward(obs, std) - - class DIAYNasWeightPredictorAgent(DDPGAgent): def __init__(self, update_skill_every_step, skill_dim, diayn_scale, update_encoder, **kwargs): @@ -50,77 +41,95 @@ def __init__(self, update_skill_every_step, skill_dim, diayn_scale, # create actor and critic super().__init__(**kwargs) - self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, - kwargs['hidden_dim']).to(kwargs['device']) - self.diayn.train() - - self.pretrained_actor = self.actor - self.actor = SkillMixingActor(self.obs_type, self.obs_dim, self.action_dim, - self.feature_dim, self.hidden_dim, self.pretrained_actor, self.diayn).to(kwargs['device']) + # self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + # kwargs['hidden_dim']).to(kwargs['device']) + # self.diayn.train() - self.actor_opt = torch.optim.Adam(self.actor.parameters()) + list(self.diayn.parameters()), lr=self.lr) + self.weight_param = nn.Parameter(torch.rand(self.skill_dim, device=kwargs['device'])) + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + [self.weight_param], lr=self.lr) - def update_actor(self, obs, step): + def update(self, replay_iter, step): metrics = dict() - obs = torch.as_tensor(obs, device=self.device) # (B, obs_dim) - skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + if step % self.update_every_steps != 0: + return metrics - obs = obs.unsqueeze(1).repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) - skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) - skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) - obs_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + batch = next(replay_iter) - inpt = obs_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) - h = self.encoder(inpt) + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) - stddev = utils.schedule(self.stddev_schedule, step) - dist = self.actor(h, stddev) - action = dist.sample(clip=self.stddev_clip) - log_prob = dist.log_prob(action).sum(-1, keepdim=True) - Q1, Q2 = self.critic(obs, action) - Q = torch.min(Q1, Q2) - - actor_loss = -Q.mean() + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) - # optimize actor - self.actor_opt.zero_grad(set_to_none=True) - actor_loss.backward() - # actor_loss.backward(retain_graph=True) - self.actor_opt.step() + reward = extr_reward if self.use_tb or self.use_wandb: - metrics['actor_loss'] = actor_loss.item() - metrics['actor_logprob'] = log_prob.mean().item() - metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + # extend observations with skill + # obs = torch.cat([obs, skill], dim=1) + # next_obs = torch.cat([next_obs, skill], dim=1) + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + metrics.update(self.update_actor(obs.detach(), step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) return metrics + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + # skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + skill_weight = F.softmax(self.weight_param, dim=0).unsqueeze(0).repeat(obs.shape[0],1).unsqueeze(-1) # (B, skill_dim, 1) + + processed = state_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) + + return processed.sum(dim=1) # (B, skill_dim + obs_dim) + def act(self, obs, meta, step, eval_mode): """ meta from passed parameter is useless """ #assert obs.shape[-1] == self.obs_shape[-1] - with torch.no_grad(): - obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) - skill_weight = F.softmax(self.diayn(obs), dim=-1) # (1, skill_dim) - - obs = obs.repeat(self.skill_dim, 1) # (skill_dim, obs_dim) - skill_list = torch.eye(self.skill_dim, device=self.device) # (skill_dim, skill_dim) - obs_with_skill = torch.cat([obs, skill_list], dim=-1) # (skill_dim, skill_dim + obs_dim) - - inpt = obs_with_skill * skill_weight.view(self.skill_dim, 1) # (skill_dim, skill_dim + obs_dim) - - h = self.encoder(inpt) - stddev = utils.schedule(self.stddev_schedule, step) - dist = self.actor(h, stddev) - if eval_mode: - action = dist.mean - else: - action = dist.sample(clip=None) - if step < self.num_expl_steps: - action.uniform_(-1.0, 1.0) + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) return action.cpu().numpy()[0] @@ -128,7 +137,7 @@ def act(self, obs, meta, step, eval_mode): def init_from(self, other): # copy parameters over utils.hard_update_params(other.encoder, self.encoder) - utils.hard_update_params(other.actor, self.pretrained_actor) - utils.hard_update_params(other.diayn, self.diayn) + utils.hard_update_params(other.actor, self.actor) + # utils.hard_update_params(other.diayn, self.diayn) if self.init_critic: utils.hard_update_params(other.critic.trunk, self.critic.trunk) From 87ff7710632cc1252316c7978255752bd8b35ad6 Mon Sep 17 00:00:00 2001 From: jsrimr Date: Thu, 19 May 2022 14:37:05 +0000 Subject: [PATCH 07/10] changed name: diayn_with_controller->diayn_simple_weight --- ...ce_predictor.py => diayn_simple_weight.py} | 0 ...redictor.yaml => diayn_simple_weight.yaml} | 2 +- agent/diayn_with_controller.py | 192 ------------------ agent/diayn_with_controller.yaml | 25 --- finetune.py | 26 ++- 5 files changed, 21 insertions(+), 224 deletions(-) rename agent/{diayn_as_importance_predictor.py => diayn_simple_weight.py} (100%) rename agent/{diayn_as_importance_predictor.yaml => diayn_simple_weight.yaml} (88%) delete mode 100644 agent/diayn_with_controller.py delete mode 100644 agent/diayn_with_controller.yaml diff --git a/agent/diayn_as_importance_predictor.py b/agent/diayn_simple_weight.py similarity index 100% rename from agent/diayn_as_importance_predictor.py rename to agent/diayn_simple_weight.py diff --git a/agent/diayn_as_importance_predictor.yaml b/agent/diayn_simple_weight.yaml similarity index 88% rename from agent/diayn_as_importance_predictor.yaml rename to agent/diayn_simple_weight.yaml index 9207a0d..1794303 100644 --- a/agent/diayn_as_importance_predictor.yaml +++ b/agent/diayn_simple_weight.yaml @@ -1,5 +1,5 @@ # @package agent -_target_: agent.diayn_as_importance_predictor.DIAYNasWeightPredictorAgent +_target_: agent.diayn_simple_weight.DIAYNasWeightPredictorAgent name: diayn_as_importance_predictor reward_free: ${reward_free} obs_type: ??? # to be specified later diff --git a/agent/diayn_with_controller.py b/agent/diayn_with_controller.py deleted file mode 100644 index f184663..0000000 --- a/agent/diayn_with_controller.py +++ /dev/null @@ -1,192 +0,0 @@ -import math -from collections import OrderedDict - -import hydra -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import utils -from dm_env import specs -from torch.distributions import Categorical - -from agent.ddpg import DDPGAgent - - -class DIAYN(nn.Module): - def __init__(self, obs_dim, skill_dim, hidden_dim): - super().__init__() - self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, skill_dim)) - - self.apply(utils.weight_init) - - def forward(self, obs): - skill_pred = self.skill_pred_net(obs) - return skill_pred - - -class DIAYNwithController(DDPGAgent): - def __init__(self, update_skill_every_step, skill_dim, diayn_scale, - update_encoder, **kwargs): - self.skill_dim = skill_dim - self.update_skill_every_step = update_skill_every_step - self.diayn_scale = diayn_scale - self.update_encoder = update_encoder - # increase obs shape to include skill dim - kwargs["meta_dim"] = self.skill_dim - self.current_meta = None - - # create actor and critic - super().__init__(**kwargs) - - # create diayn - self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, - kwargs['hidden_dim']).to(kwargs['device']) - self.softmax = nn.Softmax(dim=1) - - # optimizers - self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr) - - self.diayn.train() - - def act(self, obs, meta, step, eval_mode): - """ - meta from passed parameter is useless - """ - with torch.no_grad(): - obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) - h = self.encoder(obs) - x = self.diayn(obs) # (B, skill_dim) - skill_dist = Categorical(self.softmax(x)) - skill = skill_dist.sample() # Or, skill = torch.argmax(x, dim=1) # (B,1) - meta = torch.zeros(self.skill_dim, device=self.device) - meta[skill] = 1.0 - self.current_meta = meta - - inpt = torch.cat([h, meta.unsqueeze(0)], dim=-1) - #assert obs.shape[-1] == self.obs_shape[-1] - stddev = utils.schedule(self.stddev_schedule, step) - - dist = self.actor(inpt, stddev) - if eval_mode: - action = dist.mean - else: - action = dist.sample(clip=None) - if step < self.num_expl_steps: - action.uniform_(-1.0, 1.0) - - return action.cpu().numpy()[0] - - def init_from(self, other): - # copy parameters over - utils.hard_update_params(other.encoder, self.encoder) - utils.hard_update_params(other.actor, self.actor) - utils.hard_update_params(other.diayn, self.diayn) - if self.init_critic: - utils.hard_update_params(other.critic.trunk, self.critic.trunk) - - def get_meta_specs(self): - return (specs.Array((self.skill_dim,), np.float32, 'skill'),) - - def get_current_meta(self): - meta = OrderedDict() - meta['skill'] = self.current_meta.cpu().numpy() - return meta - - def update_actor_and_controller(self, obs, step): - metrics = dict() - # update controller - stddev = utils.schedule(self.stddev_schedule, step) - dist = self.actor(obs, stddev) - action = dist.sample(clip=self.stddev_clip) - log_prob = dist.log_prob(action).sum(-1, keepdim=True) - Q1, Q2 = self.critic(obs, action) - Q = torch.min(Q1, Q2) - - actor_loss = -Q.mean() - - # optimize actor - self.actor_opt.zero_grad(set_to_none=True) - actor_loss.backward() - # actor_loss.backward(retain_graph=True) - self.actor_opt.step() - - if self.use_tb or self.use_wandb: - metrics['actor_loss'] = actor_loss.item() - metrics['actor_logprob'] = log_prob.mean().item() - metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() - - # update controller - - skill_prob = self.softmax(self.diayn(obs[:, :-self.skill_dim])) - skill_dist = Categorical(skill_prob) - - skill_used = obs[:, -self.skill_dim:].argmax(dim=1) - skill_dist.log_prob(torch.tensor(1, device="cuda:0")).shape - # loss = - Q.mean().detach() * skill_pred.gather(1, skill_used.unsqueeze(0)).log_prob() - advantage = (Q - Q.mean(dim=1)).detach() - loss = -(advantage * skill_dist.log_prob(skill_used).unsqueeze(1)).mean() - - self.diayn_opt.zero_grad() - if self.encoder_opt is not None: - self.encoder_opt.zero_grad(set_to_none=True) - loss.backward() - self.diayn_opt.step() - if self.encoder_opt is not None: - self.encoder_opt.step() - - if self.use_tb or self.use_wandb: - metrics['diayn_loss'] = loss.item() - - return metrics - - def update(self, replay_iter, step): - """ - only ext_reward logic - update controller - - """ - metrics = dict() - - if step % self.update_every_steps != 0: - return metrics - - batch = next(replay_iter) - - obs, action, extr_reward, discount, next_obs, skill = utils.to_torch( - batch, self.device) - - # augment and encode - obs = self.aug_and_encode(obs) - next_obs = self.aug_and_encode(next_obs) - reward = extr_reward - - if self.use_tb or self.use_wandb: - metrics['extr_reward'] = extr_reward.mean().item() - metrics['batch_reward'] = reward.mean().item() - - if not self.update_encoder: - obs = obs.detach() - next_obs = next_obs.detach() - - # extend observations with skill - obs = torch.cat([obs, skill], dim=1) - next_obs = torch.cat([next_obs, skill], dim=1) - - # update critic - metrics.update( - self.update_critic(obs.detach(), action, reward, discount, - next_obs.detach(), step)) - - # update actor and controller - metrics.update(self.update_actor_and_controller(obs.detach(), step)) - - # update critic target - utils.soft_update_params(self.critic, self.critic_target, - self.critic_target_tau) - - return metrics \ No newline at end of file diff --git a/agent/diayn_with_controller.yaml b/agent/diayn_with_controller.yaml deleted file mode 100644 index 4a3773b..0000000 --- a/agent/diayn_with_controller.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# @package agent -_target_: agent.diayn_with_controller.DIAYNwithController -name: diayn_with_controller -reward_free: ${reward_free} -obs_type: ??? # to be specified later -obs_shape: ??? # to be specified later -action_shape: ??? # to be specified later -device: ${device} -lr: 1e-4 -critic_target_tau: 0.01 -update_every_steps: 2 -use_tb: ${use_tb} -use_wandb: ${use_wandb} -num_expl_steps: ??? # to be specified later -hidden_dim: 1024 -feature_dim: 50 -stddev_schedule: 0.2 -stddev_clip: 0.3 -skill_dim: 16 -diayn_scale: 1.0 -update_skill_every_step: 50 -nstep: 3 -batch_size: 1024 -init_critic: true -update_encoder: ${update_encoder} diff --git a/finetune.py b/finetune.py index 35a3950..3e0157a 100644 --- a/finetune.py +++ b/finetune.py @@ -138,6 +138,15 @@ def eval(self): log('episode', self.global_episode) log('step', self.global_step) + def save_snapshot(self): + snapshot_dir = self.work_dir / "snapshot" + snapshot_dir.mkdir(exist_ok=True, parents=True) + snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt' + keys_to_save = ['agent', '_global_step', '_global_episode'] + payload = {k: self.__dict__[k] for k in keys_to_save} + with snapshot.open('wb') as f: + torch.save(payload, f) + def train(self): # predicates train_until_step = utils.Until(self.cfg.num_train_frames, @@ -157,6 +166,7 @@ def train(self): if time_step.last(): self._global_episode += 1 self.train_video_recorder.save(f'{self.global_frame}.mp4') + # wait until all the metrics schema is populated if metrics is not None: # log stats @@ -177,6 +187,7 @@ def train(self): meta = self.agent.init_meta() self.replay_storage.add(time_step, meta) self.train_video_recorder.init(time_step.observation) + episode_step = 0 episode_reward = 0 @@ -186,8 +197,10 @@ def train(self): self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame) self.eval() + # save_snapshot + self.save_snapshot() - # meta = self.agent.update_meta(meta, self.global_step, time_step) + meta = self.agent.update_meta(meta, self.global_step, time_step) if hasattr(self.agent, "regress_meta"): # aps 라는 알고리즘이 사용 repeat = self.cfg.action_repeat @@ -241,11 +254,12 @@ def try_load(seed): if payload is not None: return payload # otherwise try random seed - while True: - seed = np.random.randint(1, 11) - payload = try_load(seed) - if payload is not None: - return payload + # while True: + # seed = np.random.randint(1, 11) + # payload = try_load(seed) + # if payload is not None: + # return payload + print(f"failed to load from {snapshot_dir}") return None From 5027aa5f70917f34a0280a3acd2dc0e552f9e2fc Mon Sep 17 00:00:00 2001 From: jsrimr Date: Fri, 20 May 2022 17:03:58 +0000 Subject: [PATCH 08/10] use same weight --- agent/diayn_same_weight.py | 144 +++++++++++++++++++++++++++++++++++ agent/diayn_same_weight.yaml | 25 ++++++ 2 files changed, 169 insertions(+) create mode 100644 agent/diayn_same_weight.py create mode 100644 agent/diayn_same_weight.yaml diff --git a/agent/diayn_same_weight.py b/agent/diayn_same_weight.py new file mode 100644 index 0000000..5de3b51 --- /dev/null +++ b/agent/diayn_same_weight.py @@ -0,0 +1,144 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from zmq import device +import utils +from dm_env import specs +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + + # create actor and critic + super().__init__(**kwargs) + # self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + # kwargs['hidden_dim']).to(kwargs['device']) + # self.diayn.train() + + self.weight_param = nn.Parameter(torch.rand(self.skill_dim, device=kwargs['device'])) + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + [self.weight_param], lr=self.lr) + + + def update(self, replay_iter, step): + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) + + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + # extend observations with skill + # obs = torch.cat([obs, skill], dim=1) + # next_obs = torch.cat([next_obs, skill], dim=1) + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + metrics.update(self.update_actor(obs.detach(), step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) + + return metrics + + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + return state_with_skill.mean(dim=1) # (B, skill_dim + obs_dim) + # skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + # skill_weight = F.softmax(self.weight_param, dim=0).unsqueeze(0).repeat(obs.shape[0],1).unsqueeze(-1) # (B, skill_dim, 1) + + # processed = state_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) + + # return processed.sum(dim=1) # (B, skill_dim + obs_dim) + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + # utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_same_weight.yaml b/agent/diayn_same_weight.yaml new file mode 100644 index 0000000..ff52e22 --- /dev/null +++ b/agent/diayn_same_weight.yaml @@ -0,0 +1,25 @@ +# @package agent +_target_: agent.diayn_same_weight.DIAYNasWeightPredictorAgent +name: same_weight +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_critic: true +update_encoder: ${update_encoder} From 5a9d8551598f8a85e6d0aabf90f09b891f1dbc4c Mon Sep 17 00:00:00 2001 From: jsrimr Date: Mon, 23 May 2022 07:27:18 +0000 Subject: [PATCH 09/10] =?UTF-8?q?obs.detach()=20=EB=95=8C=EB=AC=B8?= =?UTF-8?q?=EC=97=90=20=EC=97=85=EB=8D=B0=EC=9D=B4=ED=8A=B8=20=EC=95=88?= =?UTF-8?q?=EB=90=98=EA=B3=A0=20=EC=9E=88=EC=97=88=EC=9D=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/diayn_simple_weight.py | 3 ++- agent/diayn_simple_weight.yaml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/agent/diayn_simple_weight.py b/agent/diayn_simple_weight.py index 534c177..3d3c906 100644 --- a/agent/diayn_simple_weight.py +++ b/agent/diayn_simple_weight.py @@ -87,7 +87,8 @@ def update(self, replay_iter, step): next_obs.detach(), step)) # update actor - metrics.update(self.update_actor(obs.detach(), step)) + # metrics.update(self.update_actor(obs.detach(), step)) + metrics.update(self.update_actor(obs, step)) # update critic target utils.soft_update_params(self.critic, self.critic_target, diff --git a/agent/diayn_simple_weight.yaml b/agent/diayn_simple_weight.yaml index 1794303..a552f8e 100644 --- a/agent/diayn_simple_weight.yaml +++ b/agent/diayn_simple_weight.yaml @@ -1,6 +1,6 @@ # @package agent _target_: agent.diayn_simple_weight.DIAYNasWeightPredictorAgent -name: diayn_as_importance_predictor +name: diayn_simple_weight reward_free: ${reward_free} obs_type: ??? # to be specified later obs_shape: ??? # to be specified later From 4c06cf3f72bc05223d9db57250d49d998daf7da1 Mon Sep 17 00:00:00 2001 From: jsrimr Date: Mon, 23 May 2022 10:05:03 +0000 Subject: [PATCH 10/10] =?UTF-8?q?diayn=20as=20importance=20predictor=20:?= =?UTF-8?q?=20scratch=20=EC=99=80=20=EB=B9=84=EA=B5=90=ED=96=88=EC=9D=84?= =?UTF-8?q?=20=EB=95=8C=20=EC=9E=98=ED=95=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/diayn_as_importance_predictor.py | 141 +++++++++++++++++++++++ agent/diayn_as_importance_predictor.yaml | 26 +++++ 2 files changed, 167 insertions(+) create mode 100644 agent/diayn_as_importance_predictor.py create mode 100644 agent/diayn_as_importance_predictor.yaml diff --git a/agent/diayn_as_importance_predictor.py b/agent/diayn_as_importance_predictor.py new file mode 100644 index 0000000..a7f0131 --- /dev/null +++ b/agent/diayn_as_importance_predictor.py @@ -0,0 +1,141 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import utils + +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + self.init_diayn = kwargs.pop('init_diayn') + + # create actor and critic + super().__init__(**kwargs) + self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + kwargs['hidden_dim']).to(kwargs['device']) + self.diayn.train() + + + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + list(self.diayn.parameters()), lr=self.lr) + + + def update(self, replay_iter, step): + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) + + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + metrics.update(self.update_actor(obs, step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) + + return metrics + + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + skill_weight = F.softmax(self.diayn(obs), dim=-1) # (B, 1, skill_dim) + + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + + processed = state_with_skill * skill_weight.permute(0,2,1) # (B, skill_dim, skill_dim + obs_dim) + + return processed.sum(dim=1) # (B, skill_dim + obs_dim) + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + if self.init_diayn: + print("load DIAYN WEIGHT!!") + utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_as_importance_predictor.yaml b/agent/diayn_as_importance_predictor.yaml new file mode 100644 index 0000000..7a01eac --- /dev/null +++ b/agent/diayn_as_importance_predictor.yaml @@ -0,0 +1,26 @@ +# @package agent +_target_: agent.diayn_as_importance_predictor.DIAYNasWeightPredictorAgent +name: diayn_as_importance_predictor +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_diayn: ${init_diayn} +init_critic: true +update_encoder: ${update_encoder}