-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
181 lines (154 loc) · 7.4 KB
/
main.py
File metadata and controls
181 lines (154 loc) · 7.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import utils
import environments
import algorithms
import argparse
import importlib
from datetime import datetime
import gymnasium as gym
import wandb
import numpy as np
import copy
from gymnasium.wrappers import FrameStackObservation, FlattenObservation
class HistoryWrapper(FrameStackObservation):
def reset(self, **kwargs):
obs, info = super().reset(**kwargs)
return np.array(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
return np.array(obs), reward, terminated, truncated, info
def parse_args():
'''
Parse command line arguments for training/testing a model.
Returns:
train (bool): Whether to train (True) or test (False) the model
algorithm_class (Class): algorithm class located in the algorithms folder.
model_path (str): Path where model will be saved/loaded
training_kwargs (dict): Additional training arguments for the specified algorithm
env_name (str): name of the environment to be loaded by gym. `gym.make(env_name)`
env_kwargs (dict): Dictionary of environment parameters.
debug (bool): Flag for debugging training. Will turn wandb logging off.
'''
parser = argparse.ArgumentParser()
parser.add_argument('--env_id', type=str, default='Quadrotor-Fixed-v0', help='Gym environment name.')
# parser.add_argument('--env_id', type=str, required=True, help='Gym environment name.')
parser.add_argument('--test', action='store_true', help='Flag for loading a model and testing on a specific environment. Otherwise training is set to True.')
parser.add_argument("--algorithm", choices=['PPO', 'SAC'], help='Model to use. PPO, SAC. Each algorithm requires additional arguments. See `algorithms/` for more info.', required=True)
parser.add_argument("--seed", help='random seed to use.', type=int, default=42)
parser.add_argument('--render', action='store_true', help='Flag for rendering the environment.')
parser.add_argument('--debug', action='store_true', help='Flag for debugging the training process. Turns wandb logging off.')
parser.add_argument("--num_test_episodes", help='Number of test episodes to use for model testing.', type=int, default=5)
parser.add_argument('--history_len', help='Window size for observation history. Defaults to 1 (no history)', type=int, default=1)
parser.add_argument('--flatten_observation', action='store_true', help='Flag for flattening the observation space. Useful for MLPs and when history is being used.')
parser.add_argument('--model_path', type=str, default=None, help='Path to model file. If None, uses default path based on log name.')
# Parse known arguments before parsing remainder of training args
args, _ = parser.parse_known_args()
# set seed
utils.set_seed(args.seed)
train = not args.test
# add algorithm-specific arguments
algorithm_name = args.algorithm + 'Trainer'
algorithm_module = importlib.import_module('algorithms')
algorithm_class = getattr(algorithm_module, algorithm_name)
algorithm_class.add_args(parser)
# environment specific arguments
env_id = args.env_id
env_kwargs = {}
if args.render:
env_kwargs['render_mode'] = 'human'
if env_id in environments.CUSTOM_ENV_CLASSES:
env_module_path, env_class_name = environments.CUSTOM_ENV_CLASSES[env_id].split(':')
env_module = importlib.import_module(env_module_path)
env_class = getattr(env_module, env_class_name)
env_class.add_args(parser)
# parse environment and algorithm specific arguments
args = parser.parse_args()
# parse training kwargs from algorithm
algorithm_train_kwargs = algorithm_class.get_training_kwargs(args)
# default log name on weights and biases
default_log_name = args.algorithm + '_s_{}'.format(args.seed)
log_name = algorithm_train_kwargs.get('log_name', default_log_name)
if args.history_len > 1:
log_name = log_name + '_obs_history_{}'.format(args.history_len)
# add unique id to model name
if train:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_name = "{}_{}".format(log_name, timestamp)
# set algorithm log name to updated log_name
algorithm_train_kwargs['log_name'] = log_name
# parse optional environment kwargs
if env_id in environments.CUSTOM_ENV_CLASSES:
env_specific_kwargs = env_class.get_env_kwargs(args)
env_kwargs.update(env_specific_kwargs)
# logging setup
if args.debug or not train:
logger = None
else:
config = copy.deepcopy(algorithm_train_kwargs)
config['env_id'] = env_id
config[env_id] = env_specific_kwargs
logger = init_wandb(config, log_name)
# Model Path for loading model
cur_path = utils.get_cur_path()
if args.model_path is not None:
model_path = args.model_path
elif args.test:
raise ValueError("--model_path must be specified when using --test")
else:
models_folder = os.path.join(cur_path, 'models', env_id)
os.makedirs(models_folder, exist_ok=True)
model_path = os.path.join(models_folder, log_name+'_best.zip')
# parse algorithm init kwargs (algorithm dependent)
algorithm_init_kwargs = algorithm_class.get_init_kwargs(args)
# add to init kwargs (needed for base class)
algorithm_init_kwargs['cur_path'] = cur_path
algorithm_init_kwargs['logger'] = logger
algorithm_params = (algorithm_class, algorithm_init_kwargs, algorithm_train_kwargs)
return train, env_id, env_kwargs, algorithm_params, model_path, args.num_test_episodes, args.history_len, args.flatten_observation
def test_model(model, test_env, num_episodes=1):
rewards = []
for _ in range(num_episodes):
reward = utils.run_single_episode(utils.model_inference, test_env, model)
rewards.append(reward)
mean_reward = np.mean(rewards)
std_reward = np.std(rewards)
print("Reward: {:.2f} +/- {:.2f}".format(mean_reward, std_reward))
def init_wandb(config, log_name):
# setup wandb
run = wandb.init(
project="quad_rl",
name = log_name,
config=config,
sync_tensorboard=True,
)
return run
def main():
train, env_id, env_kwargs, algorithm_params, model_path, num_test_episodes, history_len, flatten_obs = parse_args()
# environment initialization
print("Loading environment...", end=' ')
env = gym.make(env_id, **env_kwargs)
print("Done.")
if history_len > 1:
# observation history for sequential data
env = HistoryWrapper(env, history_len)
if flatten_obs:
env = FlattenObservation(env)
# algorithm initialization
algorithm_class, algorithm_init_kwargs, algorithm_train_kwargs = algorithm_params
algorithm = algorithm_class(**algorithm_init_kwargs)
if train:
# train model
print("Training model...")
# all algorithms are expected to have
algorithm_train_args = (env, model_path)
model = algorithm.train(*algorithm_train_args, **algorithm_train_kwargs)
else:
# Load model from path
print("Loading model...", end=' ')
model = algorithm.load(model_path)
print("Done.")
# run test episodes and check reward
print("Running {} test episode(s)...".format(num_test_episodes))
test_model(model, env, num_episodes=num_test_episodes)
if __name__ == '__main__':
main()