Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions phasic_policy_gradient/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import gym3
from procgen import ProcgenGym3Env

def get_procgen_venv(*, env_id, num_envs, rendering=False, **env_kwargs):
def get_procgen_venv(*, env_id, num_envs, distribution_mode, start_level, num_levels, rendering=False, **env_kwargs):
if rendering:
env_kwargs["render_human"] = True

env = ProcgenGym3Env(num=num_envs, env_name=env_id, **env_kwargs)
env = ProcgenGym3Env(num=num_envs, env_name=env_id, \
distribution_mode=distribution_mode, start_level=start_level, \
num_levels=num_levels, **env_kwargs)

env = gym3.ExtractDictObWrapper(env, "rgb")

if rendering:
env = gym3.ViewerWrapper(env, info_key="rgb")
return env

def get_venv(num_envs, env_name, **env_kwargs):
venv = get_procgen_venv(num_envs=num_envs, env_id=env_name, **env_kwargs)
def get_venv(num_envs, env_name, distribution_mode, start_level, num_levels, **env_kwargs):
venv = get_procgen_venv(num_envs=num_envs, env_id=env_name, \
distribution_mode=distribution_mode, start_level=start_level, \
num_levels=num_levels, **env_kwargs)

return venv
24 changes: 23 additions & 1 deletion phasic_policy_gradient/log_save_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self.log_callbacks = log_callbacks
self.log_new_eps = log_new_eps
self.roller_stats = {}
self.eval_roller_stats = {}

def __call__(self):
self.total_interact_count += self.ic_per_step
Expand Down Expand Up @@ -85,6 +86,26 @@ def gather_roller_stats(self, roller):
}
)

def gather_eval_roller_stats(self, roller):
self.eval_roller_stats = {
"EpRewMeanTest": self._nanmean([] if roller is None else roller.recent_eprets),
"EpLenMeanTest": self._nanmean([] if roller is None else roller.recent_eplens),
}
if roller is not None and self.log_new_eps:
assert roller.has_non_rolling_eps, "roller needs keep_non_rolling"
ret_n, ret_mean, ret_std = self._nanmoments(roller.non_rolling_eprets)
_len_n, len_mean, len_std = self._nanmoments(roller.non_rolling_eplens)
roller.clear_non_rolling_episode_buf()
self.eval_roller_stats.update(
{
"NewEpNumTest": ret_n,
"NewEpRewMeanTest": ret_mean,
"NewEpRewStdTest": ret_std,
"NewEpLenMeanTest": len_mean,
"NewEpLenStdTest": len_std,
}
)

def log(self):
if self.log_callbacks is not None:
for callback in self.log_callbacks:
Expand All @@ -93,7 +114,8 @@ def log(self):
for k, v in self.roller_stats.items():
logger.logkv(k, v)

logger.logkv("Misc/InteractCount", self.total_interact_count)
for k, v in self.eval_roller_stats.items():
logger.logkv(k, v)
cur_time = time.time()
Δtime = cur_time - self.last_time
Δic = self.total_interact_count - self.last_ic
Expand Down
1 change: 1 addition & 0 deletions phasic_policy_gradient/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def configure(
dir: "(str|None) Local directory to write to" = None,
format_strs: "(str|None) list of formats" = None,
comm: "(MPI communicator | None) average numerical stats over comm" = None,
suffix: "(str) suffix of the file to write to" = None,
):
if dir is None:
if os.getenv("OPENAI_LOGDIR"):
Expand Down
2 changes: 2 additions & 0 deletions phasic_policy_gradient/ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def learn(
*,
model,
venv,
eval_venv,
ppo_hps,
aux_lr,
aux_mbsize,
Expand Down Expand Up @@ -245,6 +246,7 @@ def learn(
# Policy phase
ppo_state = ppo.learn(
venv=venv,
eval_venv=eval_venv,
model=model,
learn_state=ppo_state,
callbacks=[
Expand Down
16 changes: 15 additions & 1 deletion phasic_policy_gradient/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def compute_losses(

def learn(
*,
venv: "(VecEnv) vectorized environment",
venv: "(VecEnv) vectorized train environment",
eval_venv: "(VecEnv) vectorized test environment",
model: "(ppo.PpoModel)",
interacts_total: "(float) total timesteps of interaction" = float("inf"),
nstep: "(int) number of serial timesteps" = 256,
Expand Down Expand Up @@ -199,6 +200,14 @@ def train_pi_and_vf(**arrays):
keep_non_rolling=log_save_opts.get("log_new_eps", False),
)

eval_roller = learn_state.get("eval_roller") or Roller(
act_fn=model.act,
venv=eval_venv,
initial_state=model.initial_state(venv.num),
keep_buf=100,
keep_non_rolling=log_save_opts.get("log_new_eps", False),
)

lsh = learn_state.get("lsh") or LogSaveHelper(
ic_per_step=ic_per_step, model=model, comm=comm, **log_save_opts
)
Expand All @@ -212,6 +221,10 @@ def train_pi_and_vf(**arrays):
while curr_interact_count < interacts_total and not callback_exit:
seg = roller.multi_step(nstep)
lsh.gather_roller_stats(roller)

eval_seg = eval_roller.multi_step(nstep)
lsh.gather_eval_roller_stats(eval_roller)

if rnorm:
seg["reward"] = reward_normalizer(seg["reward"], seg["first"])
compute_advantage(model, seg, γ, λ, comm=comm)
Expand Down Expand Up @@ -257,6 +270,7 @@ def train_pi_and_vf(**arrays):
return dict(
opts=opts,
roller=roller,
eval_roller=eval_roller,
lsh=lsh,
reward_normalizer=reward_normalizer,
curr_interact_count=curr_interact_count,
Expand Down
12 changes: 11 additions & 1 deletion phasic_policy_gradient/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

def train_fn(env_name="coinrun",
distribution_mode="hard",
start_level=0,
num_levels=500,
arch="dual", # 'shared', 'detach', or 'dual'
# 'shared' = shared policy and value networks
# 'dual' = separate policy and value networks
Expand Down Expand Up @@ -38,7 +40,10 @@ def train_fn(env_name="coinrun",
format_strs = ['csv', 'stdout'] if comm.Get_rank() == 0 else []
logger.configure(comm=comm, dir=log_dir, format_strs=format_strs)

venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode)
venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode, \
start_level=start_level, num_levels=num_levels)
eval_venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode, \
start_level=0, num_levels=0)

enc_fn = lambda obtype: ImpalaEncoder(
obtype.shape,
Expand All @@ -55,6 +60,7 @@ def train_fn(env_name="coinrun",

ppg.learn(
venv=venv,
eval_venv=eval_venv,
model=model,
interacts_total=interacts_total,
ppo_hps=dict(
Expand All @@ -79,6 +85,8 @@ def train_fn(env_name="coinrun",
def main():
parser = argparse.ArgumentParser(description='Process PPG training arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--start_level', type=int, default=0)
parser.add_argument('--num_levels', type=int, default=200)
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--n_epoch_pi', type=int, default=1)
parser.add_argument('--n_epoch_vf', type=int, default=1)
Expand All @@ -94,6 +102,8 @@ def main():

train_fn(
env_name=args.env_name,
start_level=args.start_level,
num_levels=args.num_levels,
num_envs=args.num_envs,
n_epoch_pi=args.n_epoch_pi,
n_epoch_vf=args.n_epoch_vf,
Expand Down