-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvalue_iterator.py
More file actions
77 lines (64 loc) · 2.43 KB
/
value_iterator.py
File metadata and controls
77 lines (64 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from q_table import QTable
from transitions import Transitions
from utils import state_to_str, action_to_str
from v_table import VTable
from rewarder import Rewarder
from config import Config
from pprint import pprint_transition
class ValueIterator:
def __init__(self, target_position):
self.target_position = target_position
self._tran = Transitions()
self._rewards = Rewarder(target_position)
self._q_tab = QTable()
self._v_tab = VTable()
def update(self, debug=False):
for s1 in self.all_states():
for a in range(len(Config.actions)):
s2 = self._tran.run(s1, a)
rew = self._rewards[s1, s2]
if s2:
q = rew + Config.gamma * self._v_tab[s2]
else:
q = rew
self._q_tab[s1, a] = q
if debug:
pprint_transition(s1, a, s2, rew)
self._v_tab.update_from_q_table(self._q_tab)
# noinspection PyMethodMayBeStatic
def all_states(self):
for i in range(len(Config.letters)):
for j in range(len(Config.numbers)):
if (i, j) == self.target_position:
continue
for o in range(len(Config.orientations)):
yield i, j, o
def path(self, s0):
a, _ = self._q_tab.get_best_action(s0)
s1 = self._tran.run(s0, a)
if not s1:
raise ValueError("Переход в запрещенное состояние: " + state_to_str(s0) + "-" + action_to_str(a) + "-> None")
elif (s1[0], s1[1]) == self.target_position:
return [s0, a, s1]
return [s0, a] + self.path(s1)
if __name__ == '__main__':
from rewarder import Rewarder
from config import parse_position, Config, parse_state
from pprint import pprint_map, pformat_path, pprint_transition
target = parse_position('e4')
vi = ValueIterator(target)
# вычислим Q-таблицу
for _ in range(15):
vi.update()
# вычислим путь к цели из задонного состояния
path = vi.path(parse_state('a3f'))
print(pformat_path(path))
print(pformat_path(path, include_state=False))
dat = {target: '$'}
for s in path:
if isinstance(s, tuple):
dat[s[0], s[1]] = '*'
pprint_map(data=dat)
print(path)
# print(vi._v_tab._v.max(axis=2).T)
print("done")