diff --git a/agent/diayn_as_importance_predictor.py b/agent/diayn_as_importance_predictor.py new file mode 100644 index 0000000..a7f0131 --- /dev/null +++ b/agent/diayn_as_importance_predictor.py @@ -0,0 +1,141 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import utils + +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + self.init_diayn = kwargs.pop('init_diayn') + + # create actor and critic + super().__init__(**kwargs) + self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + kwargs['hidden_dim']).to(kwargs['device']) + self.diayn.train() + + + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + list(self.diayn.parameters()), lr=self.lr) + + + def update(self, replay_iter, step): + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) + + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + metrics.update(self.update_actor(obs, step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) + + return metrics + + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + skill_weight = F.softmax(self.diayn(obs), dim=-1) # (B, 1, skill_dim) + + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + + processed = state_with_skill * skill_weight.permute(0,2,1) # (B, skill_dim, skill_dim + obs_dim) + + return processed.sum(dim=1) # (B, skill_dim + obs_dim) + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + if self.init_diayn: + print("load DIAYN WEIGHT!!") + utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_as_importance_predictor.yaml b/agent/diayn_as_importance_predictor.yaml new file mode 100644 index 0000000..7a01eac --- /dev/null +++ b/agent/diayn_as_importance_predictor.yaml @@ -0,0 +1,26 @@ +# @package agent +_target_: agent.diayn_as_importance_predictor.DIAYNasWeightPredictorAgent +name: diayn_as_importance_predictor +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_diayn: ${init_diayn} +init_critic: true +update_encoder: ${update_encoder} diff --git a/agent/diayn_same_weight.py b/agent/diayn_same_weight.py new file mode 100644 index 0000000..5de3b51 --- /dev/null +++ b/agent/diayn_same_weight.py @@ -0,0 +1,144 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from zmq import device +import utils +from dm_env import specs +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + + # create actor and critic + super().__init__(**kwargs) + # self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + # kwargs['hidden_dim']).to(kwargs['device']) + # self.diayn.train() + + self.weight_param = nn.Parameter(torch.rand(self.skill_dim, device=kwargs['device'])) + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + [self.weight_param], lr=self.lr) + + + def update(self, replay_iter, step): + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) + + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + # extend observations with skill + # obs = torch.cat([obs, skill], dim=1) + # next_obs = torch.cat([next_obs, skill], dim=1) + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + metrics.update(self.update_actor(obs.detach(), step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) + + return metrics + + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + return state_with_skill.mean(dim=1) # (B, skill_dim + obs_dim) + # skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + # skill_weight = F.softmax(self.weight_param, dim=0).unsqueeze(0).repeat(obs.shape[0],1).unsqueeze(-1) # (B, skill_dim, 1) + + # processed = state_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) + + # return processed.sum(dim=1) # (B, skill_dim + obs_dim) + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + # utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_same_weight.yaml b/agent/diayn_same_weight.yaml new file mode 100644 index 0000000..ff52e22 --- /dev/null +++ b/agent/diayn_same_weight.yaml @@ -0,0 +1,25 @@ +# @package agent +_target_: agent.diayn_same_weight.DIAYNasWeightPredictorAgent +name: same_weight +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_critic: true +update_encoder: ${update_encoder} diff --git a/agent/diayn_simple_weight.py b/agent/diayn_simple_weight.py new file mode 100644 index 0000000..3d3c906 --- /dev/null +++ b/agent/diayn_simple_weight.py @@ -0,0 +1,144 @@ +import math +from collections import OrderedDict + +import hydra +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from zmq import device +import utils +from dm_env import specs +from agent.ddpg import DDPGAgent + + +class DIAYN(nn.Module): + def __init__(self, obs_dim, skill_dim, hidden_dim): + super().__init__() + self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, skill_dim)) + + self.apply(utils.weight_init) + + def forward(self, obs): + skill_pred = self.skill_pred_net(obs) + return skill_pred + + + +class DIAYNasWeightPredictorAgent(DDPGAgent): + def __init__(self, update_skill_every_step, skill_dim, diayn_scale, + update_encoder, **kwargs): + self.skill_dim = skill_dim + self.update_skill_every_step = update_skill_every_step + self.diayn_scale = diayn_scale + self.update_encoder = update_encoder + # increase obs shape to include skill dim + kwargs["meta_dim"] = self.skill_dim + + # create actor and critic + super().__init__(**kwargs) + # self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, + # kwargs['hidden_dim']).to(kwargs['device']) + # self.diayn.train() + + self.weight_param = nn.Parameter(torch.rand(self.skill_dim, device=kwargs['device'])) + self.actor_opt = torch.optim.Adam(list(self.actor.parameters()) + [self.weight_param], lr=self.lr) + + + def update(self, replay_iter, step): + metrics = dict() + + if step % self.update_every_steps != 0: + return metrics + + batch = next(replay_iter) + + obs, action, extr_reward, discount, next_obs = utils.to_torch( + batch, self.device) + + # augment and encode : state 일 땐 aug_and_encode 의미없음 + obs = self.aug_and_encode(obs) + next_obs = self.aug_and_encode(next_obs) + + reward = extr_reward + + if self.use_tb or self.use_wandb: + metrics['extr_reward'] = extr_reward.mean().item() + metrics['batch_reward'] = reward.mean().item() + + if not self.update_encoder: + obs = obs.detach() + next_obs = next_obs.detach() + + # extend observations with skill + # obs = torch.cat([obs, skill], dim=1) + # next_obs = torch.cat([next_obs, skill], dim=1) + + obs = self.mix_skill_obs(obs) + next_obs = self.mix_skill_obs(next_obs) + + # update critic + metrics.update( + self.update_critic(obs.detach(), action, reward, discount, + next_obs.detach(), step)) + + # update actor + # metrics.update(self.update_actor(obs.detach(), step)) + metrics.update(self.update_actor(obs, step)) + + # update critic target + utils.soft_update_params(self.critic, self.critic_target, + self.critic_target_tau) + + return metrics + + def mix_skill_obs(self, obs): + """ + (B, obs_dim) => (B, skill_dim + obs_dim) + """ + obs = torch.as_tensor(obs, device=self.device).unsqueeze(1) # (B, 1, obs_dim) + obs = obs.repeat(1, self.skill_dim, 1) # (B, skill_dim, obs_dim) + + skill_list = torch.eye(self.skill_dim, device=self.device).unsqueeze(0) # (1, skill_dim, skill_dim) + skill_list = skill_list.repeat(obs.shape[0], 1, 1) # (B, skill_dim, skill_dim) + state_with_skill = torch.cat([obs, skill_list], dim=-1) # (B, skill_dim, skill_dim + obs_dim) + + # skill_weight = F.softmax(self.diayn(obs), dim=-1).unsqueeze(-1) # (B, skill_dim, 1) + skill_weight = F.softmax(self.weight_param, dim=0).unsqueeze(0).repeat(obs.shape[0],1).unsqueeze(-1) # (B, skill_dim, 1) + + processed = state_with_skill * skill_weight # (B, skill_dim, skill_dim + obs_dim) + + return processed.sum(dim=1) # (B, skill_dim + obs_dim) + + def act(self, obs, meta, step, eval_mode): + """ + meta from passed parameter is useless + """ + #assert obs.shape[-1] == self.obs_shape[-1] + obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) # (1, obs_dim) + + inpt = self.mix_skill_obs(obs) + h = self.encoder(inpt) + stddev = utils.schedule(self.stddev_schedule, step) + dist = self.actor(h, stddev) + if eval_mode: + action = dist.mean + else: + action = dist.sample(clip=None) + if step < self.num_expl_steps: + action.uniform_(-1.0, 1.0) + + return action.cpu().numpy()[0] + + + def init_from(self, other): + # copy parameters over + utils.hard_update_params(other.encoder, self.encoder) + utils.hard_update_params(other.actor, self.actor) + # utils.hard_update_params(other.diayn, self.diayn) + if self.init_critic: + utils.hard_update_params(other.critic.trunk, self.critic.trunk) diff --git a/agent/diayn_simple_weight.yaml b/agent/diayn_simple_weight.yaml new file mode 100644 index 0000000..a552f8e --- /dev/null +++ b/agent/diayn_simple_weight.yaml @@ -0,0 +1,25 @@ +# @package agent +_target_: agent.diayn_simple_weight.DIAYNasWeightPredictorAgent +name: diayn_simple_weight +reward_free: ${reward_free} +obs_type: ??? # to be specified later +obs_shape: ??? # to be specified later +action_shape: ??? # to be specified later +device: ${device} +lr: 1e-4 +critic_target_tau: 0.01 +update_every_steps: 2 +use_tb: ${use_tb} +use_wandb: ${use_wandb} +num_expl_steps: ??? # to be specified later +hidden_dim: 1024 +feature_dim: 50 +stddev_schedule: 0.2 +stddev_clip: 0.3 +skill_dim: 16 +diayn_scale: 1.0 +update_skill_every_step: 50 +nstep: 3 +batch_size: 1024 +init_critic: true +update_encoder: ${update_encoder} diff --git a/finetune.py b/finetune.py index 373f816..3e0157a 100644 --- a/finetune.py +++ b/finetune.py @@ -138,6 +138,15 @@ def eval(self): log('episode', self.global_episode) log('step', self.global_step) + def save_snapshot(self): + snapshot_dir = self.work_dir / "snapshot" + snapshot_dir.mkdir(exist_ok=True, parents=True) + snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt' + keys_to_save = ['agent', '_global_step', '_global_episode'] + payload = {k: self.__dict__[k] for k in keys_to_save} + with snapshot.open('wb') as f: + torch.save(payload, f) + def train(self): # predicates train_until_step = utils.Until(self.cfg.num_train_frames, @@ -157,6 +166,7 @@ def train(self): if time_step.last(): self._global_episode += 1 self.train_video_recorder.save(f'{self.global_frame}.mp4') + # wait until all the metrics schema is populated if metrics is not None: # log stats @@ -177,6 +187,7 @@ def train(self): meta = self.agent.init_meta() self.replay_storage.add(time_step, meta) self.train_video_recorder.init(time_step.observation) + episode_step = 0 episode_reward = 0 @@ -186,10 +197,12 @@ def train(self): self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame) self.eval() + # save_snapshot + self.save_snapshot() meta = self.agent.update_meta(meta, self.global_step, time_step) - if hasattr(self.agent, "regress_meta"): + if hasattr(self.agent, "regress_meta"): # aps 라는 알고리즘이 사용 repeat = self.cfg.action_repeat every = self.agent.update_task_every_step // repeat init_step = self.agent.num_init_steps @@ -205,6 +218,9 @@ def train(self): self.global_step, eval_mode=False) + if hasattr(self.agent, "get_current_meta"): # skill_controller 라는 알고리즘이 사용 + meta = self.agent.get_current_meta() + # try to update the agent if not seed_until_step(self.global_step): metrics = self.agent.update(self.replay_iter, self.global_step) @@ -238,11 +254,12 @@ def try_load(seed): if payload is not None: return payload # otherwise try random seed - while True: - seed = np.random.randint(1, 11) - payload = try_load(seed) - if payload is not None: - return payload + # while True: + # seed = np.random.randint(1, 11) + # payload = try_load(seed) + # if payload is not None: + # return payload + print(f"failed to load from {snapshot_dir}") return None