-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun.py
More file actions
61 lines (48 loc) · 1.53 KB
/
run.py
File metadata and controls
61 lines (48 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from env import Env
from agent import Agent
from master import Master
from arguments import argparser
import numpy as np
import sys
import utils
from visualization import visualize
def run():
args = argparser()
path = utils.create_log_dir(sys.argv)
utils.start(args.http_port)
env = Env(args)
agents = [Agent(args) for _ in range(args.n_agent)]
master = Master(args)
for agent in agents:
master.add_agent(agent)
master.add_env(env)
success_list = []
time_list = []
for idx in range(args.n_episode):
print('=' * 80)
print("Episode {}".format(idx + 1))
# 서버의 stack, timer 초기화
print("서버를 초기화하는중...")
master.reset(path)
# 에피소드 시작
master.start()
# 에이전트 학습
master.train()
print('=' * 80)
success_list.append(master.infos["is_success"])
time_list.append(master.infos["end_time"] - master.infos["start_time"])
if (idx + 1) % args.print_interval == 0:
print("=" * 80)
print("EPISODE {}: Avg. Success Rate / Time: {:.2} / {:.2}"
.format(idx + 1, np.mean(success_list), np.mean(time_list)))
success_list.clear()
time_list.clear()
print("=" * 80)
if (idx + 1) % args.checkpoint_interval == 0:
utils.save_checkpoints(path, agents, idx+1)
if args.visual:
visualize(path, args)
print("끝")
utils.close()
if __name__ == '__main__':
run()