-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_policy.py
More file actions
68 lines (50 loc) · 2.64 KB
/
train_policy.py
File metadata and controls
68 lines (50 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
'''train_policy.py
Program to train new policies
(note that this is limited to stable_baseline's CnnLstmPolicy)
Default is the training environment, changing p_chest to .1
will train on the test environment
'''
import os
from pathlib import Path
import gym
import fire
import numpy as np
import tensorflow as tf
from stable_baselines import PPO2
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common.policies import CnnLstmPolicy
from maze import Maze
# Reduce tensorflow errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def main(maze_size=12, n_timesteps=30000000, p_hole = .4, p_chest = .5, p_key = .1, max_ticks=130, grid_scale=3):
'''Trains a model with the given parameters
maze_size int Width and height of the maze grid
n_timesteps int Number of timesteps to train for
p_hole float, 0..1 How easy it is for an agent to navigate:
At 0 the grid is a maze generated by Prim's algorithm
At 1 it is an open space
p_chest float, 0..1 Roughly how much of the environment is filled with chests
p_key float, 0..1 Roughly how much of the environment is filled with keys
max_ticks int Number of ticks per episode
grid_scale int Each 1x1 cell of the grid is expanded
by this width and height before passed
to the convolutional net
The train env is given by maze_size=12, p_hole = .4, p_chest = .5, p_key = .1,
max_ticks=130, grid_scale=3
The test env is given by maze_size=12, p_hole = .4, p_chest = .1, p_key = .1,
max_ticks=130, grid_scale=3)
'''
model_name = f"{p_hole}p_hole-{p_chest}p_chest-{p_key}p_key-{max_ticks}max_ticks-{n_timesteps}timesteps"
make_env = lambda: Maze(maze_size, maze_size, p_hole = p_hole, p_chest = p_chest, p_key = p_key,
max_ticks=max_ticks, render_enabled=False, grid_scale=grid_scale)
n_cpu = 4
env = SubprocVecEnv([make_env for i in range(n_cpu)])
model = PPO2(CnnLstmPolicy, env, verbose=1, tensorboard_log='./models', gamma=1.0, ent_coef=.01)
Path(f"./models/{model_name}").mkdir(parents=True, exist_ok=True)
model.learn(total_timesteps=n_timesteps, tb_log_name=f'{model_name}/runs/', reset_num_timesteps=False)
print("Saving model...")
model.save(f"./models/{model_name}/final_model")
env.close()
if __name__ == '__main__':
fire.Fire(main)