-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
169 lines (142 loc) · 5.02 KB
/
train.py
File metadata and controls
169 lines (142 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from environments.SumoEnvMulti2 import SumoEnvMulti
import argparse
from ray import air, tune
from ray.tune.registry import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import Policy
from utils.utils import *
from networks.custom_policy import CustomCNN
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--test', action='store_true')
ap.add_argument('--gui', action='store_true')
ap.add_argument('-s', '--scenario', type=str, default='corridor')
ap.add_argument('--log', action='store_true')
args = ap.parse_args()
if args.test:
test_rllib(args)
else:
train_rllib()
def sumo_cmd(gui=False, scenario='corridor', log=False):
if gui:
sumoBinary = checkBinary('sumo-gui')
else:
sumoBinary = checkBinary('sumo')
# choose scenario
if scenario == 'intersection':
data_path = 'scenarios/h_intersection'
elif scenario == 'corridor':
data_path = 'scenarios/h_corridor'
else:
data_path = 'scenarios/r_intersection'
if log:
cmd = [
sumoBinary, "-c", f'{data_path}/h.sumocfg',
'--no-warnings', '--random', '--no-step-log',
"--duration-log.statistics",
]
else:
cmd = [
sumoBinary, "-c", f'{data_path}/h.sumocfg',
'--no-warnings', '--random', '--no-step-log',
]
return cmd
def env_creator(args):
cmd = sumo_cmd()
env = SumoEnvMulti(cmd)
return ParallelPettingZooEnv(env)
def train_rllib():
# episodes_total, timesteps_total,training_iteration
stop = {'episodes_total': 1000}
log_dir = '/home/ytj/PycharmProjects/MARL_TSC/logs/PPO_PS/850'
env = env_creator({})
register_env('sumo_env', env_creator)
# Register policy
ModelCatalog.register_custom_model('CustomCNN', CustomCNN)
config = (
PPOConfig()
.environment('sumo_env')
.resources(num_gpus=1)
.rollouts(
num_rollout_workers=8,
)
# # Independent learning
# .multi_agent(
# policies=env.get_agent_ids(),
# policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
# count_steps_by='agent_steps',
# )
# Parameter sharing
.multi_agent(
policies={'shared_policy'},
policy_mapping_fn=(lambda agent_id, *args, **kwargs: 'shared_policy'),
# count_steps_by='agent_steps',
)
.training(
gamma=0.65,
# gamma=tune.grid_search([0.6, 0.7]),
model={'custom_model': 'CustomCNN'},
lr_schedule=[[0, 0.001], [1e6, 0.0001]],
# lr=0.0001,
# lr=tune.grid_search([0.0001, 0.0003, 0.0005]),
use_gae=True,
lambda_=0.95,
sgd_minibatch_size=256,
# sgd_minibatch_size=tune.grid_search([256, 512]),
num_sgd_iter=5,
# num_sgd_iter=tune.grid_search([5, 10, 20]),
vf_loss_coeff=0.5,
entropy_coeff=0.01,
clip_param=0.2,
# clip_param=tune.grid_search([0.1, 0.2, 0.3]),
grad_clip=0.5,
)
.framework(framework='torch')
)
tuner = tune.Tuner(
'PPO',
tune_config=tune.TuneConfig(
metric='episode_reward_mean',
mode='max',
),
run_config=air.RunConfig(
stop=stop,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=500,
checkpoint_at_end=True
),
local_dir=log_dir
),
param_space=config.to_dict(),
)
result = tuner.fit()
best_trial = result.get_best_result()
# print(f"Best trial config: {best_trial.config}")
# print(f"Best trial episode_reward_mean: {best_trial.last_result['episode_reward_mean']}")
print(f"Best trial path: {best_trial.path}")
def test_rllib(args):
gui = True if args.gui else False
log = True if args.log else False
cmd = sumo_cmd(gui=gui, log=log)
env = SumoEnvMulti(cmd)
ModelCatalog.register_custom_model('CustomCNN', CustomCNN)
checkpoint_path = ('logs/PPO_PS/850/PPO_2024-01-03_14-43-45/PPO_sumo_env_69149_00000_0_2024-01-03_14-43-47'
'/checkpoint_000000')
# Independent policies
# policies = {a: Policy.from_checkpoint(checkpoint_path)[a] for a in env.possible_agents}
# Shared policy
policy = Policy.from_checkpoint(checkpoint_path)['shared_policy']
# print(policies)
obs, _ = env.reset()
terminations = {}
while True not in terminations.values():
actions = {}
for agent_id, agent_obs in obs.items():
# policy = policies[agent_id] # Independent policies
actions[agent_id] = policy.compute_single_action(agent_obs)[0]
obs, rewards, terminations, truncations, infos = env.step(actions)
env.close()
if __name__ == '__main__':
main()